From 3abb56738eb2eb7c035374a6c712d80c320e8e5e Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sun, 22 Feb 2026 21:39:38 +0900 Subject: [PATCH 1/2] Expand ruff lint rules and rename make chk to make lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 11 new ruff lint rule categories (UP, PIE, PERF, T20, FLY, ISC, RSE, RUF, PGH, G, PT) and fix all violations across the codebase. Key changes: - Rename `make chk` to `make lint` in Makefile, CLAUDE.md, and docs - Modernize type annotations (Union → X|Y, Optional → X|None) - Convert f-string logging to lazy % formatting (G004/G003) - Replace blanket `# type: ignore` with specific error codes (PGH003) - Add ClassVar annotations for mutable class attributes (RUF012) - Add match parameters to pytest.raises(ValueError) calls (PT011) - Extract setup from pytest.raises blocks (PT012) - Use tuple for pytest.mark.parametrize names (PT006) - Replace yield-for-loop with yield from (UP028) - Use collection unpacking instead of concatenation (RUF005) - Add per-file-ignores for SQLAlchemy naming conventions (N802, N801) - Suppress PEP 249 required names that conflict with lint rules (B027, N818) - Exclude benchmarks directory from ruff checking Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 4 +- Makefile | 6 +- docs/testing.md | 4 +- pyathena/__init__.py | 23 +- pyathena/aio/__init__.py | 1 - pyathena/aio/arrow/__init__.py | 1 - pyathena/aio/arrow/cursor.py | 55 ++-- pyathena/aio/common.py | 109 ++++--- pyathena/aio/connection.py | 5 +- pyathena/aio/cursor.py | 41 ++- pyathena/aio/pandas/__init__.py | 1 - pyathena/aio/pandas/cursor.py | 62 ++-- pyathena/aio/polars/__init__.py | 1 - pyathena/aio/polars/cursor.py | 57 ++-- pyathena/aio/result_set.py | 32 +- pyathena/aio/s3fs/__init__.py | 1 - pyathena/aio/s3fs/cursor.py | 49 ++- pyathena/aio/spark/__init__.py | 1 - pyathena/aio/spark/cursor.py | 53 ++-- pyathena/aio/sqlalchemy/__init__.py | 1 - pyathena/aio/sqlalchemy/arrow.py | 1 - pyathena/aio/sqlalchemy/base.py | 34 +-- pyathena/aio/sqlalchemy/pandas.py | 3 +- pyathena/aio/sqlalchemy/polars.py | 1 - pyathena/aio/sqlalchemy/rest.py | 1 - pyathena/aio/sqlalchemy/s3fs.py | 1 - pyathena/aio/util.py | 6 +- pyathena/arrow/__init__.py | 1 - pyathena/arrow/async_cursor.py | 45 ++- pyathena/arrow/converter.py | 14 +- pyathena/arrow/cursor.py | 48 +-- pyathena/arrow/result_set.py | 58 ++-- pyathena/arrow/util.py | 7 +- pyathena/async_cursor.py | 47 ++- pyathena/common.py | 193 ++++++------ pyathena/connection.py | 171 +++++------ pyathena/converter.py | 70 ++--- pyathena/cursor.py | 34 +-- pyathena/error.py | 29 +- pyathena/filesystem/__init__.py | 1 - pyathena/filesystem/s3.py | 167 ++++++----- pyathena/filesystem/s3_async.py | 61 ++-- pyathena/filesystem/s3_executor.py | 8 +- pyathena/filesystem/s3_object.py | 192 ++++++------ pyathena/formatter.py | 42 ++- pyathena/model.py | 314 ++++++++++---------- pyathena/pandas/__init__.py | 1 - pyathena/pandas/async_cursor.py | 48 +-- pyathena/pandas/converter.py | 14 +- pyathena/pandas/cursor.py | 57 ++-- pyathena/pandas/result_set.py | 117 ++++---- pyathena/pandas/util.py | 95 +++--- pyathena/polars/__init__.py | 1 - pyathena/polars/async_cursor.py | 45 ++- pyathena/polars/converter.py | 12 +- pyathena/polars/cursor.py | 56 ++-- pyathena/polars/result_set.py | 100 +++---- pyathena/polars/util.py | 9 +- pyathena/result_set.py | 270 ++++++++--------- pyathena/s3fs/__init__.py | 1 - pyathena/s3fs/async_cursor.py | 39 ++- pyathena/s3fs/converter.py | 5 +- pyathena/s3fs/cursor.py | 42 +-- pyathena/s3fs/reader.py | 31 +- pyathena/s3fs/result_set.py | 27 +- pyathena/spark/__init__.py | 1 - pyathena/spark/async_cursor.py | 31 +- pyathena/spark/common.py | 63 ++-- pyathena/spark/cursor.py | 31 +- pyathena/sqlalchemy/__init__.py | 1 - pyathena/sqlalchemy/arrow.py | 1 - pyathena/sqlalchemy/base.py | 110 +++---- pyathena/sqlalchemy/compiler.py | 166 +++++------ pyathena/sqlalchemy/constants.py | 9 +- pyathena/sqlalchemy/pandas.py | 3 +- pyathena/sqlalchemy/polars.py | 1 - pyathena/sqlalchemy/preparer.py | 9 +- pyathena/sqlalchemy/requirements.py | 1 - pyathena/sqlalchemy/rest.py | 1 - pyathena/sqlalchemy/s3fs.py | 1 - pyathena/sqlalchemy/types.py | 27 +- pyathena/sqlalchemy/util.py | 5 +- pyathena/util.py | 21 +- pyproject.toml | 27 +- tests/__init__.py | 1 - tests/pyathena/__init__.py | 1 - tests/pyathena/aio/__init__.py | 1 - tests/pyathena/aio/arrow/__init__.py | 1 - tests/pyathena/aio/arrow/test_cursor.py | 1 - tests/pyathena/aio/conftest.py | 15 +- tests/pyathena/aio/pandas/__init__.py | 1 - tests/pyathena/aio/pandas/test_cursor.py | 1 - tests/pyathena/aio/polars/__init__.py | 1 - tests/pyathena/aio/polars/test_cursor.py | 1 - tests/pyathena/aio/s3fs/__init__.py | 1 - tests/pyathena/aio/s3fs/test_cursor.py | 5 +- tests/pyathena/aio/spark/__init__.py | 1 - tests/pyathena/aio/spark/test_cursor.py | 1 - tests/pyathena/aio/sqlalchemy/__init__.py | 1 - tests/pyathena/aio/sqlalchemy/test_base.py | 1 - tests/pyathena/aio/test_cursor.py | 5 +- tests/pyathena/arrow/__init__.py | 1 - tests/pyathena/arrow/test_async_cursor.py | 1 - tests/pyathena/arrow/test_cursor.py | 1 - tests/pyathena/arrow/test_util.py | 1 - tests/pyathena/conftest.py | 11 +- tests/pyathena/filesystem/__init__.py | 1 - tests/pyathena/filesystem/test_s3.py | 45 ++- tests/pyathena/filesystem/test_s3_async.py | 45 ++- tests/pyathena/filesystem/test_s3_object.py | 1 - tests/pyathena/pandas/__init__.py | 1 - tests/pyathena/pandas/test_async_cursor.py | 43 ++- tests/pyathena/pandas/test_cursor.py | 47 ++- tests/pyathena/pandas/test_util.py | 29 +- tests/pyathena/polars/__init__.py | 1 - tests/pyathena/polars/test_async_cursor.py | 1 - tests/pyathena/polars/test_cursor.py | 1 - tests/pyathena/s3fs/__init__.py | 1 - tests/pyathena/s3fs/test_async_cursor.py | 1 - tests/pyathena/s3fs/test_cursor.py | 3 +- tests/pyathena/s3fs/test_reader.py | 1 - tests/pyathena/spark/__init__.py | 1 - tests/pyathena/spark/test_async_cursor.py | 1 - tests/pyathena/spark/test_spark_cursor.py | 1 - tests/pyathena/sqlalchemy/__init__.py | 1 - tests/pyathena/sqlalchemy/test_base.py | 6 +- tests/pyathena/sqlalchemy/test_compiler.py | 10 +- tests/pyathena/sqlalchemy/test_types.py | 3 +- tests/pyathena/test_async_cursor.py | 1 - tests/pyathena/test_converter.py | 20 +- tests/pyathena/test_cursor.py | 100 ++++--- tests/pyathena/test_formatter.py | 1 - tests/pyathena/test_model.py | 1 - tests/pyathena/test_util.py | 1 - tests/pyathena/util.py | 1 - tests/sqlalchemy/__init__.py | 1 - tests/sqlalchemy/conftest.py | 3 +- tests/sqlalchemy/test_suite.py | 41 ++- 138 files changed, 1915 insertions(+), 2088 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index dee95022..54bc4452 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -17,12 +17,12 @@ PyAthena is a Python DB API 2.0 (PEP 249) compliant client for Amazon Athena. Se ### Code Quality — Always Run Before Committing ```bash make fmt # Auto-fix formatting and imports -make chk # Lint + format check + mypy +make lint # Lint + format check + mypy ``` ### Testing ```bash -# ALWAYS run `make chk` first — tests will fail if lint doesn't pass +# ALWAYS run `make lint` first — tests will fail if lint doesn't pass make test # Unit tests (runs chk first) make test-sqla # SQLAlchemy dialect tests ``` diff --git a/Makefile b/Makefile index 4b2a6261..465e8431 100644 --- a/Makefile +++ b/Makefile @@ -7,14 +7,14 @@ fmt: uvx ruff@$(RUFF_VERSION) check --select I --fix . uvx ruff@$(RUFF_VERSION) format . -.PHONY: chk -chk: +.PHONY: lint +lint: uvx ruff@$(RUFF_VERSION) check . uvx ruff@$(RUFF_VERSION) format --check . uv run mypy . .PHONY: test -test: chk +test: lint uv run pytest -n 8 --cov pyathena --cov-report html --cov-report term tests/pyathena/ .PHONY: test-sqla diff --git a/docs/testing.md b/docs/testing.md index 2ebb4696..43fe5794 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -54,10 +54,10 @@ The code formatting uses [ruff](https://github.com/astral-sh/ruff). $ make fmt ``` -### Check format +### Lint and check format ```bash -$ make chk +$ make lint ``` ## GitHub Actions diff --git a/pyathena/__init__.py b/pyathena/__init__.py index 7e44fbf0..c3ad35c1 100644 --- a/pyathena/__init__.py +++ b/pyathena/__init__.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, FrozenSet, Type, overload +from typing import TYPE_CHECKING, Any, overload -from pyathena.error import * # noqa +from pyathena.error import * # noqa: F403 if TYPE_CHECKING: from pyathena.aio.connection import AioConnection @@ -28,7 +27,7 @@ paramstyle: str = "pyformat" -class DBAPITypeObject(FrozenSet[str]): +class DBAPITypeObject(frozenset[str]): """Type Objects and Constructors https://www.python.org/dev/peps/pep-0249/#type-objects-and-constructors @@ -60,22 +59,22 @@ def __hash__(self): DATETIME: DBAPITypeObject = DBAPITypeObject(("timestamp", "timestamp with time zone")) JSON: DBAPITypeObject = DBAPITypeObject(("json",)) -Date: Type[datetime.date] = datetime.date -Time: Type[datetime.time] = datetime.time -Timestamp: Type[datetime.datetime] = datetime.datetime +Date: type[datetime.date] = datetime.date +Time: type[datetime.time] = datetime.time +Timestamp: type[datetime.datetime] = datetime.datetime @overload -def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]": ... +def connect(*args, cursor_class: None = ..., **kwargs) -> Connection[Cursor]: ... @overload def connect( - *args, cursor_class: Type[ConnectionCursor], **kwargs -) -> "Connection[ConnectionCursor]": ... + *args, cursor_class: type[ConnectionCursor], **kwargs +) -> Connection[ConnectionCursor]: ... -def connect(*args, **kwargs) -> "Connection[Any]": +def connect(*args, **kwargs) -> Connection[Any]: """Create a new database connection to Amazon Athena. This function provides the main entry point for establishing connections @@ -131,7 +130,7 @@ def connect(*args, **kwargs) -> "Connection[Any]": return Connection(*args, **kwargs) -async def aio_connect(*args, **kwargs) -> "AioConnection": +async def aio_connect(*args, **kwargs) -> AioConnection: """Create a new async database connection to Amazon Athena. This is the async counterpart of :func:`connect`. It returns an diff --git a/pyathena/aio/__init__.py b/pyathena/aio/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/__init__.py +++ b/pyathena/aio/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/arrow/__init__.py b/pyathena/aio/arrow/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/arrow/__init__.py +++ b/pyathena/aio/arrow/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/arrow/cursor.py b/pyathena/aio/arrow/cursor.py index ae43531d..9baebc77 100644 --- a/pyathena/aio/arrow/cursor.py +++ b/pyathena/aio/arrow/cursor.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast from pyathena.aio.common import WithAsyncFetch from pyathena.arrow.converter import ( @@ -19,7 +18,7 @@ import polars as pl from pyarrow import Table -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioArrowCursor(WithAsyncFetch): @@ -37,19 +36,19 @@ class AioArrowCursor(WithAsyncFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - connect_timeout: Optional[float] = None, - request_timeout: Optional[float] = None, + connect_timeout: float | None = None, + request_timeout: float | None = None, **kwargs, ) -> None: super().__init__( @@ -68,12 +67,12 @@ def __init__( self._unload = unload self._connect_timeout = connect_timeout self._request_timeout = request_timeout - self._result_set: Optional[AthenaArrowResultSet] = None + self._result_set: AthenaArrowResultSet | None = None @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]: + ) -> DefaultArrowTypeConverter | DefaultArrowUnloadTypeConverter | Any: if unload: return DefaultArrowUnloadTypeConverter() return DefaultArrowTypeConverter() @@ -81,16 +80,16 @@ def get_default_converter( async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> "AioArrowCursor": + ) -> AioArrowCursor: """Execute a SQL query asynchronously and return results as Arrow Tables. Args: @@ -143,7 +142,7 @@ async def execute( # type: ignore[override] async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -161,8 +160,8 @@ async def fetchone( # type: ignore[override] return await asyncio.to_thread(result_set.fetchone) async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -184,7 +183,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -207,7 +206,7 @@ async def __anext__(self): raise StopAsyncIteration return row - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Returns: @@ -218,7 +217,7 @@ def as_arrow(self) -> "Table": result_set = cast(AthenaArrowResultSet, self.result_set) return result_set.as_arrow() - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Returns: diff --git a/pyathena/aio/common.py b/pyathena/aio/common.py index 1ec92d9a..ec593297 100644 --- a/pyathena/aio/common.py +++ b/pyathena/aio/common.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging import sys from datetime import datetime, timedelta, timezone -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena.aio.util import async_retry_api_call from pyathena.common import BaseCursor, CursorIterator @@ -13,7 +12,7 @@ from pyathena.model import AthenaDatabase, AthenaQueryExecution, AthenaTableMetadata from pyathena.result_set import AthenaResultSet, WithResultSet -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioBaseCursor(BaseCursor): @@ -27,14 +26,14 @@ class AioBaseCursor(BaseCursor): async def _execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, ) -> str: query, execution_parameters = self._prepare_query(operation, parameters, paramstyle) @@ -118,8 +117,8 @@ async def _cancel(self, query_id: str) -> None: # type: ignore[override] raise OperationalError(*e.args) from e async def _batch_get_query_execution( # type: ignore[override] - self, query_ids: List[str] - ) -> List[AthenaQueryExecution]: + self, query_ids: list[str] + ) -> list[AthenaQueryExecution]: try: response = await async_retry_api_call( self.connection._client.batch_get_query_execution, @@ -138,10 +137,10 @@ async def _batch_get_query_execution( # type: ignore[override] async def _list_query_executions( # type: ignore[override] self, - work_group: Optional[str] = None, - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaQueryExecution]]: + work_group: str | None = None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaQueryExecution]]: request = self._build_list_query_executions_request( work_group=work_group, next_token=next_token, max_results=max_results ) @@ -165,10 +164,10 @@ async def _list_query_executions( # type: ignore[override] async def _find_previous_query_id( # type: ignore[override] self, query: str, - work_group: Optional[str], + work_group: str | None, cache_size: int = 0, cache_expiration_time: int = 0, - ) -> Optional[str]: + ) -> str | None: query_id = None if cache_size == 0 and cache_expiration_time > 0: cache_size = sys.maxsize @@ -191,7 +190,7 @@ async def _find_previous_query_id( # type: ignore[override] if e.state == AthenaQueryExecution.STATE_SUCCEEDED and e.statement_type == AthenaQueryExecution.STATEMENT_TYPE_DML ), - key=lambda e: e.completion_date_time, # type: ignore + key=lambda e: e.completion_date_time, # type: ignore[arg-type, return-value] reverse=True, ): if ( @@ -213,10 +212,10 @@ async def _find_previous_query_id( # type: ignore[override] async def _list_databases( # type: ignore[override] self, - catalog_name: Optional[str], - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaDatabase]]: + catalog_name: str | None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaDatabase]]: request = self._build_list_databases_request( catalog_name=catalog_name, next_token=next_token, @@ -239,10 +238,10 @@ async def _list_databases( # type: ignore[override] async def list_databases( # type: ignore[override] self, - catalog_name: Optional[str], - max_results: Optional[int] = None, - ) -> List[AthenaDatabase]: - databases: List[AthenaDatabase] = [] + catalog_name: str | None, + max_results: int | None = None, + ) -> list[AthenaDatabase]: + databases: list[AthenaDatabase] = [] next_token = None while True: next_token, response = await self._list_databases( @@ -258,8 +257,8 @@ async def list_databases( # type: ignore[override] async def _get_table_metadata( # type: ignore[override] self, table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, + catalog_name: str | None = None, + schema_name: str | None = None, logging_: bool = True, ) -> AthenaTableMetadata: request = self._build_get_table_metadata_request( @@ -284,8 +283,8 @@ async def _get_table_metadata( # type: ignore[override] async def get_table_metadata( # type: ignore[override] self, table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, + catalog_name: str | None = None, + schema_name: str | None = None, logging_: bool = True, ) -> AthenaTableMetadata: return await self._get_table_metadata( @@ -297,12 +296,12 @@ async def get_table_metadata( # type: ignore[override] async def _list_table_metadata( # type: ignore[override] self, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expression: Optional[str] = None, - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaTableMetadata]]: + catalog_name: str | None = None, + schema_name: str | None = None, + expression: str | None = None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaTableMetadata]]: request = self._build_list_table_metadata_request( catalog_name=catalog_name, schema_name=schema_name, @@ -328,12 +327,12 @@ async def _list_table_metadata( # type: ignore[override] async def list_table_metadata( # type: ignore[override] self, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expression: Optional[str] = None, - max_results: Optional[int] = None, - ) -> List[AthenaTableMetadata]: - metadata: List[AthenaTableMetadata] = [] + catalog_name: str | None = None, + schema_name: str | None = None, + expression: str | None = None, + max_results: int | None = None, + ) -> list[AthenaTableMetadata]: + metadata: list[AthenaTableMetadata] = [] next_token = None while True: next_token, response = await self._list_table_metadata( @@ -363,8 +362,8 @@ class WithAsyncFetch(AioBaseCursor, CursorIterator, WithResultSet): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self._query_id: Optional[str] = None - self._result_set: Optional[AthenaResultSet] = None + self._query_id: str | None = None + self._result_set: AthenaResultSet | None = None @property def arraysize(self) -> int: @@ -376,8 +375,8 @@ def arraysize(self, value: int) -> None: raise ProgrammingError("arraysize must be a positive integer value.") self._arraysize = value - @property # type: ignore - def result_set(self) -> Optional[AthenaResultSet]: + @property # type: ignore[override] + def result_set(self) -> AthenaResultSet | None: return self._result_set @result_set.setter @@ -385,7 +384,7 @@ def result_set(self, val) -> None: self._result_set = val @property - def query_id(self) -> Optional[str]: + def query_id(self) -> str | None: return self._query_id @query_id.setter @@ -393,7 +392,7 @@ def query_id(self, val) -> None: self._query_id = val @property - def rownumber(self) -> Optional[int]: + def rownumber(self) -> int | None: return self.result_set.rownumber if self.result_set else None @property @@ -408,7 +407,7 @@ def close(self) -> None: async def executemany( # type: ignore[override] self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: """Execute a SQL query multiple times with different parameters. @@ -435,7 +434,7 @@ async def cancel(self) -> None: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Returns: @@ -450,8 +449,8 @@ def fetchone( return result_set.fetchone() def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Args: @@ -470,7 +469,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Returns: diff --git a/pyathena/aio/connection.py b/pyathena/aio/connection.py index bbfba721..1a25dc8b 100644 --- a/pyathena/aio/connection.py +++ b/pyathena/aio/connection.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio @@ -33,7 +32,7 @@ def __init__(self, **kwargs: Any) -> None: async def create( cls, **kwargs: Any, - ) -> "AioConnection": + ) -> AioConnection: """Async factory for creating an ``AioConnection``. Runs the (potentially blocking) ``__init__`` in a thread so that @@ -47,7 +46,7 @@ async def create( """ return await asyncio.to_thread(cls, **kwargs) - async def __aenter__(self) -> "AioConnection": + async def __aenter__(self) -> AioConnection: return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: diff --git a/pyathena/aio/cursor.py b/pyathena/aio/cursor.py index 792758be..30738f8f 100644 --- a/pyathena/aio/cursor.py +++ b/pyathena/aio/cursor.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast from pyathena.aio.common import WithAsyncFetch from pyathena.aio.result_set import AthenaAioDictResultSet, AthenaAioResultSet @@ -10,7 +9,7 @@ from pyathena.error import OperationalError, ProgrammingError from pyathena.model import AthenaQueryExecution -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioCursor(WithAsyncFetch): @@ -29,13 +28,13 @@ class AioCursor(WithAsyncFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, @@ -54,7 +53,7 @@ def __init__( result_reuse_minutes=result_reuse_minutes, **kwargs, ) - self._result_set: Optional[AthenaAioResultSet] = None + self._result_set: AthenaAioResultSet | None = None self._result_set_class = AthenaAioResultSet @property @@ -72,16 +71,16 @@ def arraysize(self, value: int) -> None: async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, cache_size: int = 0, cache_expiration_time: int = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> "AioCursor": + ) -> AioCursor: """Execute a SQL query asynchronously. Args: @@ -127,7 +126,7 @@ async def execute( # type: ignore[override] async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Any, Dict[Any, Optional[Any]]]]: + ) -> Any | dict[Any, Any | None] | None: """Fetch the next row of a query result set. Returns: @@ -143,8 +142,8 @@ async def fetchone( # type: ignore[override] return await result_set.fetchone() async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Any, Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[Any | dict[Any, Any | None]]: """Fetch multiple rows from a query result set. Args: @@ -164,7 +163,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Any, Dict[Any, Optional[Any]]]]: + ) -> list[Any | dict[Any, Any | None]]: """Fetch all remaining rows from a query result set. Returns: diff --git a/pyathena/aio/pandas/__init__.py b/pyathena/aio/pandas/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/pandas/__init__.py +++ b/pyathena/aio/pandas/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/pandas/cursor.py b/pyathena/aio/pandas/cursor.py index 7a6de9a2..88b0f1c6 100644 --- a/pyathena/aio/pandas/cursor.py +++ b/pyathena/aio/pandas/cursor.py @@ -1,18 +1,12 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging +from collections.abc import Iterable from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, cast, ) @@ -29,7 +23,7 @@ if TYPE_CHECKING: from pandas import DataFrame -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioPandasCursor(WithAsyncFetch): @@ -48,19 +42,19 @@ class AioPandasCursor(WithAsyncFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, engine: str = "auto", - chunksize: Optional[int] = None, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + chunksize: int | None = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, @@ -87,12 +81,12 @@ def __init__( self._cache_type = cache_type self._max_workers = max_workers self._auto_optimize_chunksize = auto_optimize_chunksize - self._result_set: Optional[AthenaPandasResultSet] = None + self._result_set: AthenaPandasResultSet | None = None @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPandasTypeConverter, Any]: + ) -> DefaultPandasTypeConverter | Any: if unload: return DefaultPandasUnloadTypeConverter() return DefaultPandasTypeConverter() @@ -100,19 +94,19 @@ def get_default_converter( async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, keep_default_na: bool = False, - na_values: Optional[Iterable[str]] = ("",), + na_values: Iterable[str] | None = ("",), quoting: int = 1, **kwargs, - ) -> "AioPandasCursor": + ) -> AioPandasCursor: """Execute a SQL query asynchronously and return results as pandas DataFrames. Args: @@ -175,7 +169,7 @@ async def execute( # type: ignore[override] async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -193,8 +187,8 @@ async def fetchone( # type: ignore[override] return await asyncio.to_thread(result_set.fetchone) async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -216,7 +210,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -239,7 +233,7 @@ async def __anext__(self): raise StopAsyncIteration return row - def as_pandas(self) -> Union["DataFrame", PandasDataFrameIterator]: + def as_pandas(self) -> DataFrame | PandasDataFrameIterator: """Return DataFrame or PandasDataFrameIterator based on chunksize setting. Returns: diff --git a/pyathena/aio/polars/__init__.py b/pyathena/aio/polars/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/polars/__init__.py +++ b/pyathena/aio/polars/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/polars/cursor.py b/pyathena/aio/polars/cursor.py index 538839f4..484c413f 100644 --- a/pyathena/aio/polars/cursor.py +++ b/pyathena/aio/polars/cursor.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast from pyathena.aio.common import WithAsyncFetch from pyathena.common import CursorIterator @@ -20,7 +19,7 @@ import polars as pl from pyarrow import Table -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioPolarsCursor(WithAsyncFetch): @@ -39,21 +38,21 @@ class AioPolarsCursor(WithAsyncFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, - chunksize: Optional[int] = None, + chunksize: int | None = None, **kwargs, ) -> None: super().__init__( @@ -74,12 +73,12 @@ def __init__( self._cache_type = cache_type self._max_workers = max_workers self._chunksize = chunksize - self._result_set: Optional[AthenaPolarsResultSet] = None + self._result_set: AthenaPolarsResultSet | None = None @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, Any]: + ) -> DefaultPolarsTypeConverter | DefaultPolarsUnloadTypeConverter | Any: if unload: return DefaultPolarsUnloadTypeConverter() return DefaultPolarsTypeConverter() @@ -87,16 +86,16 @@ def get_default_converter( async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> "AioPolarsCursor": + ) -> AioPolarsCursor: """Execute a SQL query asynchronously and return results as Polars DataFrames. Args: @@ -151,7 +150,7 @@ async def execute( # type: ignore[override] async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -169,8 +168,8 @@ async def fetchone( # type: ignore[override] return await asyncio.to_thread(result_set.fetchone) async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -192,7 +191,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` to avoid @@ -215,7 +214,7 @@ async def __anext__(self): raise StopAsyncIteration return row - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Returns: @@ -226,7 +225,7 @@ def as_polars(self) -> "pl.DataFrame": result_set = cast(AthenaPolarsResultSet, self.result_set) return result_set.as_polars() - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Returns: diff --git a/pyathena/aio/result_set.py b/pyathena/aio/result_set.py index 776ac6d2..000337bd 100644 --- a/pyathena/aio/result_set.py +++ b/pyathena/aio/result_set.py @@ -1,15 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Optional, - Tuple, - Union, cast, ) @@ -23,7 +17,7 @@ if TYPE_CHECKING: from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AthenaAioResultSet(AthenaResultSet): @@ -36,7 +30,7 @@ class AthenaAioResultSet(AthenaResultSet): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, @@ -54,12 +48,12 @@ def __init__( @classmethod async def create( cls, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, - ) -> "AthenaAioResultSet": + ) -> AthenaAioResultSet: """Async factory method. Creates an ``AthenaAioResultSet`` and awaits the initial data fetch. @@ -80,15 +74,15 @@ async def create( return result_set async def __async_get_query_results( - self, max_results: int, next_token: Optional[str] = None - ) -> Dict[str, Any]: + self, max_results: int, next_token: str | None = None + ) -> dict[str, Any]: if not self.query_id: raise ProgrammingError("QueryExecutionId is none or empty.") if self.state != AthenaQueryExecution.STATE_SUCCEEDED: raise ProgrammingError("QueryExecutionState is not SUCCEEDED.") if self.is_closed: raise ProgrammingError("AthenaAioResultSet is closed.") - request: Dict[str, Any] = { + request: dict[str, Any] = { "QueryExecutionId": self.query_id, "MaxResults": max_results, } @@ -105,9 +99,9 @@ async def __async_get_query_results( _logger.exception("Failed to fetch result set.") raise OperationalError(*e.args) from e else: - return cast(Dict[str, Any], response) + return cast(dict[str, Any], response) - async def __async_fetch(self, next_token: Optional[str] = None) -> Dict[str, Any]: + async def __async_fetch(self, next_token: str | None = None) -> dict[str, Any]: return await self.__async_get_query_results(self._arraysize, next_token) async def _async_fetch(self) -> None: @@ -127,7 +121,7 @@ async def _async_pre_fetch(self) -> None: async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Automatically fetches the next page from Athena when the current @@ -146,8 +140,8 @@ async def fetchone( # type: ignore[override] return self._rows.popleft() async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Args: @@ -170,7 +164,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Returns: diff --git a/pyathena/aio/s3fs/__init__.py b/pyathena/aio/s3fs/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/s3fs/__init__.py +++ b/pyathena/aio/s3fs/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/s3fs/cursor.py b/pyathena/aio/s3fs/cursor.py index 40ebcd78..cf9c92a8 100644 --- a/pyathena/aio/s3fs/cursor.py +++ b/pyathena/aio/s3fs/cursor.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena.aio.common import WithAsyncFetch from pyathena.common import CursorIterator @@ -13,7 +12,7 @@ from pyathena.s3fs.converter import DefaultS3FSTypeConverter from pyathena.s3fs.result_set import AthenaS3FSResultSet, CSVReaderType -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioS3FSCursor(WithAsyncFetch): @@ -33,17 +32,17 @@ class AioS3FSCursor(WithAsyncFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - csv_reader: Optional[CSVReaderType] = None, + csv_reader: CSVReaderType | None = None, **kwargs, ) -> None: super().__init__( @@ -60,11 +59,11 @@ def __init__( **kwargs, ) self._csv_reader = csv_reader - self._result_set: Optional[AthenaS3FSResultSet] = None + self._result_set: AthenaS3FSResultSet | None = None @staticmethod def get_default_converter( - unload: bool = False, # noqa: ARG004 + unload: bool = False, ) -> DefaultS3FSTypeConverter: """Get the default type converter for S3FS cursor. @@ -79,16 +78,16 @@ def get_default_converter( async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> "AioS3FSCursor": + ) -> AioS3FSCursor: """Execute a SQL query asynchronously via S3FileSystem CSV reader. Args: @@ -138,7 +137,7 @@ async def execute( # type: ignore[override] async def fetchone( # type: ignore[override] self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` because @@ -156,8 +155,8 @@ async def fetchone( # type: ignore[override] return await asyncio.to_thread(result_set.fetchone) async def fetchmany( # type: ignore[override] - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` because @@ -179,7 +178,7 @@ async def fetchmany( # type: ignore[override] async def fetchall( # type: ignore[override] self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Wraps the synchronous fetch in ``asyncio.to_thread`` because diff --git a/pyathena/aio/spark/__init__.py b/pyathena/aio/spark/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/spark/__init__.py +++ b/pyathena/aio/spark/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/spark/cursor.py b/pyathena/aio/spark/cursor.py index ab9c9865..c414eba1 100644 --- a/pyathena/aio/spark/cursor.py +++ b/pyathena/aio/spark/cursor.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast from pyathena.aio.util import async_retry_api_call from pyathena.error import DatabaseError, NotSupportedError, OperationalError, ProgrammingError @@ -15,7 +14,7 @@ from pyathena.spark.common import SparkBaseCursor, WithCalculationExecution from pyathena.util import parse_output_location -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AioSparkCursor(SparkBaseCursor, WithCalculationExecution): @@ -44,11 +43,11 @@ class AioSparkCursor(SparkBaseCursor, WithCalculationExecution): def __init__( self, - session_id: Optional[str] = None, - description: Optional[str] = None, - engine_configuration: Optional[Dict[str, Any]] = None, - notebook_version: Optional[str] = None, - session_idle_timeout_minutes: Optional[int] = None, + session_id: str | None = None, + description: str | None = None, + engine_configuration: dict[str, Any] | None = None, + notebook_version: str | None = None, + session_idle_timeout_minutes: int | None = None, **kwargs, ) -> None: super().__init__( @@ -61,7 +60,7 @@ def __init__( ) @property - def calculation_execution(self) -> Optional[AthenaCalculationExecution]: + def calculation_execution(self) -> AthenaCalculationExecution | None: return self._calculation_execution # --- async overrides of SparkBaseCursor I/O methods --- @@ -69,7 +68,7 @@ def calculation_execution(self) -> Optional[AthenaCalculationExecution]: async def _get_calculation_execution_status( # type: ignore[override] self, query_id: str ) -> AthenaCalculationExecutionStatus: - request: Dict[str, Any] = {"CalculationExecutionId": query_id} + request: dict[str, Any] = {"CalculationExecutionId": query_id} try: response = await async_retry_api_call( self._connection.client.get_calculation_execution_status, @@ -86,7 +85,7 @@ async def _get_calculation_execution_status( # type: ignore[override] async def _get_calculation_execution( # type: ignore[override] self, query_id: str ) -> AthenaCalculationExecution: - request: Dict[str, Any] = {"CalculationExecutionId": query_id} + request: dict[str, Any] = {"CalculationExecutionId": query_id} try: response = await async_retry_api_call( self._connection.client.get_calculation_execution, @@ -104,8 +103,8 @@ async def _calculate( # type: ignore[override] self, session_id: str, code_block: str, - description: Optional[str] = None, - client_request_token: Optional[str] = None, + description: str | None = None, + client_request_token: str | None = None, ) -> str: request = self._build_start_calculation_execution_request( session_id=session_id, @@ -126,9 +125,7 @@ async def _calculate( # type: ignore[override] raise DatabaseError(*e.args) from e return cast(str, calculation_id) - async def __poll( - self, query_id: str - ) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + async def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: calculation_status = await self._get_calculation_execution_status(query_id) if calculation_status.state in [ @@ -141,7 +138,7 @@ async def __poll( async def _poll( # type: ignore[override] self, query_id: str - ) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + ) -> AthenaQueryExecution | AthenaCalculationExecution: try: query_execution = await self.__poll(query_id) except asyncio.CancelledError: @@ -154,7 +151,7 @@ async def _poll( # type: ignore[override] return query_execution async def _cancel(self, query_id: str) -> None: # type: ignore[override] - request: Dict[str, Any] = {"CalculationExecutionId": query_id} + request: dict[str, Any] = {"CalculationExecutionId": query_id} try: await async_retry_api_call( self._connection.client.stop_calculation_execution, @@ -167,7 +164,7 @@ async def _cancel(self, query_id: str) -> None: # type: ignore[override] raise OperationalError(*e.args) from e async def _terminate_session(self) -> None: # type: ignore[override] - request: Dict[str, Any] = {"SessionId": self._session_id} + request: dict[str, Any] = {"SessionId": self._session_id} try: await async_retry_api_call( self._connection.client.terminate_session, @@ -192,7 +189,7 @@ async def _read_s3_file_as_text(self, uri) -> str: # type: ignore[override] # --- public API --- - async def get_std_out(self) -> Optional[str]: + async def get_std_out(self) -> str | None: """Get the standard output from the Spark calculation execution. Returns: @@ -202,7 +199,7 @@ async def get_std_out(self) -> Optional[str]: return None return await self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri) - async def get_std_error(self) -> Optional[str]: + async def get_std_error(self) -> str | None: """Get the standard error from the Spark calculation execution. Returns: @@ -215,13 +212,13 @@ async def get_std_error(self) -> Optional[str]: async def execute( # type: ignore[override] self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - session_id: Optional[str] = None, - description: Optional[str] = None, - client_request_token: Optional[str] = None, - work_group: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + session_id: str | None = None, + description: str | None = None, + client_request_token: str | None = None, + work_group: str | None = None, **kwargs, - ) -> "AioSparkCursor": + ) -> AioSparkCursor: """Execute PySpark code asynchronously. Args: @@ -267,7 +264,7 @@ async def close(self) -> None: # type: ignore[override] async def executemany( # type: ignore[override] self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: raise NotSupportedError diff --git a/pyathena/aio/sqlalchemy/__init__.py b/pyathena/aio/sqlalchemy/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/aio/sqlalchemy/__init__.py +++ b/pyathena/aio/sqlalchemy/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/aio/sqlalchemy/arrow.py b/pyathena/aio/sqlalchemy/arrow.py index 91c9f0d5..8338fdff 100644 --- a/pyathena/aio/sqlalchemy/arrow.py +++ b/pyathena/aio/sqlalchemy/arrow.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.aio.sqlalchemy.base import AthenaAioDialect diff --git a/pyathena/aio/sqlalchemy/base.py b/pyathena/aio/sqlalchemy/base.py index 42d69e4c..a0308937 100644 --- a/pyathena/aio/sqlalchemy/base.py +++ b/pyathena/aio/sqlalchemy/base.py @@ -1,8 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations from collections import deque -from typing import TYPE_CHECKING, Any, Dict, List, MutableMapping, Optional, Tuple, Union, cast +from collections.abc import MutableMapping +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import pool from sqlalchemy.engine import AdaptedConnection @@ -29,7 +29,7 @@ from sqlalchemy import URL -class AsyncAdapt_pyathena_cursor: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi) +class AsyncAdapt_pyathena_cursor: """Wraps any async PyAthena cursor with a sync DBAPI interface. SQLAlchemy's async engine uses greenlet-based ``await_only()`` to call @@ -68,7 +68,7 @@ def execute(self, operation: str, parameters: Any = None, **kwargs: Any) -> Any: def executemany( self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs: Any, ) -> None: for parameters in seq_of_parameters: @@ -80,7 +80,7 @@ def fetchone(self) -> Any: return self._rows.popleft() return None - def fetchmany(self, size: Optional[int] = None) -> Any: + def fetchmany(self, size: int | None = None) -> Any: if size is None: size = self._cursor.arraysize if hasattr(self._cursor, "arraysize") else 1 return [self._rows.popleft() for _ in range(min(size, len(self._rows)))] @@ -106,14 +106,14 @@ def get_table_metadata(self, *args: Any, **kwargs: Any) -> Any: def list_table_metadata(self, *args: Any, **kwargs: Any) -> Any: return await_only(self._cursor.list_table_metadata(*args, **kwargs)) - def __enter__(self) -> "AsyncAdapt_pyathena_cursor": + def __enter__(self) -> AsyncAdapt_pyathena_cursor: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi) +class AsyncAdapt_pyathena_connection(AdaptedConnection): """Wraps ``AioConnection`` with a sync DBAPI interface. This adapted connection delegates ``cursor()`` to the underlying @@ -121,9 +121,9 @@ class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows ``AsyncAdapt_pyathena_cursor``. """ - __slots__ = ("dbapi", "_connection") + __slots__ = ("_connection", "dbapi") - def __init__(self, dbapi: "AsyncAdapt_pyathena_dbapi", connection: AioConnection) -> None: + def __init__(self, dbapi: AsyncAdapt_pyathena_dbapi, connection: AioConnection) -> None: self.dbapi = dbapi self._connection = connection # type: ignore[assignment] @@ -132,11 +132,11 @@ def driver_connection(self) -> AioConnection: return self._connection # type: ignore[return-value] @property - def catalog_name(self) -> Optional[str]: + def catalog_name(self) -> str | None: return self._connection.catalog_name # type: ignore[no-any-return] @property - def schema_name(self) -> Optional[str]: + def schema_name(self) -> str | None: return self._connection.schema_name # type: ignore[no-any-return] def cursor(self) -> AsyncAdapt_pyathena_cursor: @@ -153,7 +153,7 @@ def rollback(self) -> None: pass -class AsyncAdapt_pyathena_dbapi: # noqa: N801 - follows SQLAlchemy's internal async adapter naming convention (e.g. AsyncAdapt_asyncpg_dbapi) +class AsyncAdapt_pyathena_dbapi: """Fake DBAPI module for the async SQLAlchemy engine. SQLAlchemy expects ``import_dbapi()`` to return a module-like object @@ -201,21 +201,21 @@ class AthenaAioDialect(AthenaDialect): supports_statement_cache = True @classmethod - def get_pool_class(cls, url: "URL") -> type: + def get_pool_class(cls, url: URL) -> type: return pool.AsyncAdaptedQueuePool @classmethod - def import_dbapi(cls) -> "ModuleType": + def import_dbapi(cls) -> ModuleType: return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value] @classmethod - def dbapi(cls) -> "ModuleType": # type: ignore[override] + def dbapi(cls) -> ModuleType: # type: ignore[override] return AsyncAdapt_pyathena_dbapi() # type: ignore[return-value] - def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[str, Any]]: + def create_connect_args(self, url: URL) -> tuple[tuple[str], MutableMapping[str, Any]]: opts = self._create_connect_args(url) self._connect_options = opts - return cast(Tuple[str], ()), opts + return cast(tuple[str], ()), opts def get_driver_connection(self, connection: Any) -> Any: return connection diff --git a/pyathena/aio/sqlalchemy/pandas.py b/pyathena/aio/sqlalchemy/pandas.py index bd6bad71..6d60e122 100644 --- a/pyathena/aio/sqlalchemy/pandas.py +++ b/pyathena/aio/sqlalchemy/pandas.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.aio.sqlalchemy.base import AthenaAioDialect @@ -50,7 +49,7 @@ def create_connect_args(self, url): if "engine" in opts: cursor_kwargs.update({"engine": opts.pop("engine")}) if "chunksize" in opts: - cursor_kwargs.update({"chunksize": int(opts.pop("chunksize"))}) # type: ignore + cursor_kwargs.update({"chunksize": int(opts.pop("chunksize"))}) # type: ignore[dict-item] if cursor_kwargs: opts.update({"cursor_kwargs": cursor_kwargs}) self._connect_options = opts diff --git a/pyathena/aio/sqlalchemy/polars.py b/pyathena/aio/sqlalchemy/polars.py index d7daa7e9..938c2d54 100644 --- a/pyathena/aio/sqlalchemy/polars.py +++ b/pyathena/aio/sqlalchemy/polars.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.aio.sqlalchemy.base import AthenaAioDialect diff --git a/pyathena/aio/sqlalchemy/rest.py b/pyathena/aio/sqlalchemy/rest.py index 86c93f5c..ce47f406 100644 --- a/pyathena/aio/sqlalchemy/rest.py +++ b/pyathena/aio/sqlalchemy/rest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.aio.sqlalchemy.base import AthenaAioDialect diff --git a/pyathena/aio/sqlalchemy/s3fs.py b/pyathena/aio/sqlalchemy/s3fs.py index 945b016e..dd6f4e5d 100644 --- a/pyathena/aio/sqlalchemy/s3fs.py +++ b/pyathena/aio/sqlalchemy/s3fs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.aio.sqlalchemy.base import AthenaAioDialect diff --git a/pyathena/aio/util.py b/pyathena/aio/util.py index e46f1333..7ff88d68 100644 --- a/pyathena/aio/util.py +++ b/pyathena/aio/util.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from pyathena.util import RetryConfig, retry_api_call @@ -11,7 +11,7 @@ async def async_retry_api_call( func: Callable[..., Any], config: RetryConfig, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, *args: Any, **kwargs: Any, ) -> Any: diff --git a/pyathena/arrow/__init__.py b/pyathena/arrow/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/arrow/__init__.py +++ b/pyathena/arrow/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 32b2215c..39d2950b 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from concurrent.futures import Future from multiprocessing import cpu_count -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena import ProgrammingError from pyathena.arrow.converter import ( @@ -16,7 +15,7 @@ from pyathena.common import CursorIterator from pyathena.model import AthenaQueryExecution -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AsyncArrowCursor(AsyncCursor): @@ -61,21 +60,21 @@ class AsyncArrowCursor(AsyncCursor): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, max_workers: int = (cpu_count() or 1) * 5, arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - connect_timeout: Optional[float] = None, - request_timeout: Optional[float] = None, + connect_timeout: float | None = None, + request_timeout: float | None = None, **kwargs, ) -> None: """Initialize an AsyncArrowCursor. @@ -132,7 +131,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]: + ) -> DefaultArrowTypeConverter | DefaultArrowUnloadTypeConverter | Any: if unload: return DefaultArrowUnloadTypeConverter() return DefaultArrowTypeConverter() @@ -150,8 +149,8 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - unload_location: Optional[str] = None, - kwargs: Optional[Dict[str, Any]] = None, + unload_location: str | None = None, + kwargs: dict[str, Any] | None = None, ) -> AthenaArrowResultSet: if kwargs is None: kwargs = {} @@ -172,16 +171,16 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]: + ) -> tuple[str, Future[AthenaArrowResultSet | Any]]: operation, unload_location = self._prepare_unload(operation, s3_staging_dir) query_id = self._execute( operation, diff --git a/pyathena/arrow/converter.py b/pyathena/arrow/converter.py index f26dbeb3..1e5d5d91 100644 --- a/pyathena/arrow/converter.py +++ b/pyathena/arrow/converter.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Type +from typing import Any from pyathena.converter import ( Converter, @@ -15,10 +15,10 @@ _to_time, ) -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) -_DEFAULT_ARROW_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { +_DEFAULT_ARROW_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "date": _to_date, "time": _to_time, "decimal": _to_decimal, @@ -62,7 +62,7 @@ def __init__(self) -> None: ) @property - def _dtypes(self) -> Dict[str, Type[Any]]: + def _dtypes(self) -> dict[str, type[Any]]: if not hasattr(self, "__dtypes"): import pyarrow as pa @@ -90,7 +90,7 @@ def _dtypes(self) -> Dict[str, Type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: converter = self.get(type_) return converter(value) @@ -114,5 +114,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: pass diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 026e0951..8f831831 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -1,8 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, cast from pyathena.arrow.converter import ( DefaultArrowTypeConverter, @@ -18,7 +18,7 @@ import polars as pl from pyarrow import Table -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class ArrowCursor(WithFetch): @@ -52,20 +52,20 @@ class ArrowCursor(WithFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Optional[Callable[[str], None]] = None, - connect_timeout: Optional[float] = None, - request_timeout: Optional[float] = None, + on_start_query_execution: Callable[[str], None] | None = None, + connect_timeout: float | None = None, + request_timeout: float | None = None, **kwargs, ) -> None: """Initialize an ArrowCursor. @@ -120,7 +120,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]: + ) -> DefaultArrowTypeConverter | DefaultArrowUnloadTypeConverter | Any: if unload: return DefaultArrowUnloadTypeConverter() return DefaultArrowTypeConverter() @@ -128,15 +128,15 @@ def get_default_converter( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, - on_start_query_execution: Optional[Callable[[str], None]] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> ArrowCursor: """Execute a SQL query and return results as Apache Arrow Tables. @@ -203,7 +203,7 @@ def execute( raise OperationalError(query_execution.state_change_reason) return self - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Converts the entire result set into an Apache Arrow Table for efficient @@ -227,7 +227,7 @@ def as_arrow(self) -> "Table": result_set = cast(AthenaArrowResultSet, self.result_set) return result_set.as_arrow() - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Converts the Apache Arrow Table to a Polars DataFrame for diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index ea68ff9e..12fab1ac 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -1,17 +1,11 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - Union, + ClassVar, ) from pyathena import OperationalError @@ -28,7 +22,7 @@ from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AthenaArrowResultSet(AthenaResultSet): @@ -69,7 +63,7 @@ class AthenaArrowResultSet(AthenaResultSet): DEFAULT_BLOCK_SIZE = 1024 * 1024 * 128 - _timestamp_parsers: List[str] = [ + _timestamp_parsers: ClassVar[list[str]] = [ "%Y-%m-%d", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d %H:%M:%S %Z", @@ -87,16 +81,16 @@ class AthenaArrowResultSet(AthenaResultSet): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, - block_size: Optional[int] = None, + block_size: int | None = None, unload: bool = False, - unload_location: Optional[str] = None, - connect_timeout: Optional[float] = None, - request_timeout: Optional[float] = None, + unload_location: str | None = None, + connect_timeout: float | None = None, + request_timeout: float | None = None, **kwargs, ) -> None: super().__init__( @@ -137,7 +131,7 @@ def __s3_file_system(self): if self._request_timeout is not None: timeout_kwargs["request_timeout"] = self._request_timeout - if "role_arn" in connection._kwargs and connection._kwargs["role_arn"]: + if connection._kwargs.get("role_arn"): external_id = connection._kwargs.get("external_id") fs = fs.S3FileSystem( role_arn=connection._kwargs["role_arn"], @@ -193,13 +187,13 @@ def __s3_file_system(self): return fs @property - def timestamp_parsers(self) -> List[str]: + def timestamp_parsers(self) -> list[str]: from pyarrow.csv import ISO8601 - return [ISO8601] + self._timestamp_parsers + return [ISO8601, *self._timestamp_parsers] @property - def column_types(self) -> Dict[str, Type[Any]]: + def column_types(self) -> dict[str, type[Any]]: description = self.description if self.description else [] return { d[0]: dtype @@ -208,7 +202,7 @@ def column_types(self) -> Dict[str, Type[Any]]: } @property - def converters(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: + def converters(self) -> dict[str, Callable[[str | None], Any | None]]: description = self.description if self.description else [] return {d[0]: self._converter.get(d[1]) for d in description} @@ -228,7 +222,7 @@ def _fetch(self) -> None: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: if not self._rows: self._fetch() if not self._rows: @@ -239,8 +233,8 @@ def fetchone( return self._rows.popleft() def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: if not size or size <= 0: size = self._arraysize rows = [] @@ -254,7 +248,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: rows = [] while True: row = self.fetchone() @@ -264,7 +258,7 @@ def fetchall( break return rows - def _read_csv(self) -> "Table": + def _read_csv(self) -> Table: import pyarrow as pa from pyarrow import csv @@ -319,10 +313,10 @@ def _read_csv(self) -> "Table": ), ) except Exception as e: - _logger.exception(f"Failed to read {bucket}/{key}.") + _logger.exception("Failed to read %s/%s.", bucket, key) raise OperationalError(*e.args) from e - def _read_parquet(self) -> "Table": + def _read_parquet(self) -> Table: import pyarrow as pa from pyarrow import parquet @@ -337,10 +331,10 @@ def _read_parquet(self) -> "Table": dataset = parquet.ParquetDataset(f"{bucket}/{key}", filesystem=self._fs) return dataset.read(use_threads=True) except Exception as e: - _logger.exception(f"Failed to read {bucket}/{key}.") + _logger.exception("Failed to read %s/%s.", bucket, key) raise OperationalError(*e.args) from e - def _as_arrow(self) -> "Table": + def _as_arrow(self) -> Table: if self.is_unload: table = self._read_parquet() self._metadata = to_column_info(table.schema) @@ -348,7 +342,7 @@ def _as_arrow(self) -> "Table": table = self._read_csv() return table - def _as_arrow_from_api(self, converter: Optional[Converter] = None) -> "Table": + def _as_arrow_from_api(self, converter: Converter | None = None) -> Table: """Build an Arrow Table from GetQueryResults API. Used as a fallback when ``output_location`` is not available @@ -367,10 +361,10 @@ def _as_arrow_from_api(self, converter: Optional[Converter] = None) -> "Table": columns = [d[0] for d in description] return pa.table(self._rows_to_columnar(rows, columns)) - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: return self._table - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Converts the Apache Arrow Table to a Polars DataFrame for diff --git a/pyathena/arrow/util.py b/pyathena/arrow/util.py index 4bd66843..9497858d 100644 --- a/pyathena/arrow/util.py +++ b/pyathena/arrow/util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Utilities for converting PyArrow types to Athena metadata. This module provides functions to convert PyArrow schema and type information @@ -8,14 +7,14 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Tuple, cast +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from pyarrow import Schema from pyarrow.lib import DataType -def to_column_info(schema: "Schema") -> Tuple[Dict[str, Any], ...]: +def to_column_info(schema: Schema) -> tuple[dict[str, Any], ...]: """Convert a PyArrow schema to Athena column information. Iterates through all fields in the schema and converts each field's @@ -47,7 +46,7 @@ def to_column_info(schema: "Schema") -> Tuple[Dict[str, Any], ...]: return tuple(columns) -def get_athena_type(type_: "DataType") -> Tuple[str, int, int]: +def get_athena_type(type_: DataType) -> tuple[str, int, int]: """Map a PyArrow data type to an Athena SQL type. Converts PyArrow type identifiers to corresponding Athena SQL type names diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 4e8a28ea..1716cc3c 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -1,18 +1,17 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from concurrent.futures import Future from concurrent.futures.thread import ThreadPoolExecutor from multiprocessing import cpu_count -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena.common import BaseCursor, CursorIterator from pyathena.error import NotSupportedError, ProgrammingError from pyathena.model import AthenaQueryExecution from pyathena.result_set import AthenaDictResultSet, AthenaResultSet -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AsyncCursor(BaseCursor): @@ -53,13 +52,13 @@ class AsyncCursor(BaseCursor): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, max_workers: int = (cpu_count() or 1) * 5, arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, @@ -103,16 +102,16 @@ def close(self, wait: bool = False) -> None: def _description( self, query_id: str - ) -> Optional[List[Tuple[str, str, None, None, int, int, str]]]: + ) -> list[tuple[str, str, None, None, int, int, str]] | None: result_set = self._collect_result_set(query_id) return result_set.description def description( self, query_id: str - ) -> "Future[Optional[List[Tuple[str, str, None, None, int, int, str]]]]": + ) -> Future[list[tuple[str, str, None, None, int, int, str]] | None]: return self._executor.submit(self._description, query_id) - def query_execution(self, query_id: str) -> "Future[AthenaQueryExecution]": + def query_execution(self, query_id: str) -> Future[AthenaQueryExecution]: """Get query execution details asynchronously. Retrieves the current execution status and metadata for a query. @@ -126,7 +125,7 @@ def query_execution(self, query_id: str) -> "Future[AthenaQueryExecution]": """ return self._executor.submit(self._get_query_execution, query_id) - def poll(self, query_id: str) -> "Future[AthenaQueryExecution]": + def poll(self, query_id: str) -> Future[AthenaQueryExecution]: """Poll for query completion asynchronously. Waits for the query to complete (succeed, fail, or be cancelled) and @@ -158,16 +157,16 @@ def _collect_result_set(self, query_id: str) -> AthenaResultSet: def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]: + ) -> tuple[str, Future[AthenaResultSet | Any]]: """Execute a SQL query asynchronously. Starts query execution on Amazon Athena and returns immediately without @@ -213,7 +212,7 @@ def execute( def executemany( self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: """Execute multiple queries asynchronously (not supported). @@ -235,7 +234,7 @@ def executemany( """ raise NotSupportedError - def cancel(self, query_id: str) -> "Future[None]": + def cancel(self, query_id: str) -> Future[None]: """Cancel a running query asynchronously. Submits a cancellation request for the specified query. The cancellation diff --git a/pyathena/common.py b/pyathena/common.py index f0688220..66ac4fba 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging @@ -6,7 +5,7 @@ import time from abc import ABCMeta, abstractmethod from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast import pyathena from pyathena.converter import Converter, DefaultTypeConverter @@ -26,7 +25,7 @@ if TYPE_CHECKING: from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class CursorIterator(metaclass=ABCMeta): @@ -57,7 +56,7 @@ class CursorIterator(metaclass=ABCMeta): def __init__(self, **kwargs) -> None: super().__init__() self.arraysize: int = kwargs.get("arraysize", self.DEFAULT_FETCH_SIZE) - self._rownumber: Optional[int] = None + self._rownumber: int | None = None self._rowcount: int = -1 # By default, return -1 to indicate that this is not supported. @property @@ -73,7 +72,7 @@ def arraysize(self, value: int) -> None: self._arraysize = value @property - def rownumber(self) -> Optional[int]: + def rownumber(self) -> int | None: return self._rownumber @property @@ -151,17 +150,17 @@ class BaseCursor(metaclass=ABCMeta): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, formatter: Formatter, retry_config: RetryConfig, - s3_staging_dir: Optional[str], - schema_name: Optional[str], - catalog_name: Optional[str], - work_group: Optional[str], + s3_staging_dir: str | None, + schema_name: str | None, + catalog_name: str | None, + work_group: str | None, poll_interval: float, - encryption_option: Optional[str], - kms_key: Optional[str], + encryption_option: str | None, + kms_key: str | None, kill_on_interrupt: bool, result_reuse_enable: bool, result_reuse_minutes: int, @@ -184,7 +183,7 @@ def __init__( self._result_reuse_minutes = result_reuse_minutes @staticmethod - def get_default_converter(unload: bool = False) -> Union[DefaultTypeConverter, Any]: + def get_default_converter(unload: bool = False) -> DefaultTypeConverter | Any: """Get the default type converter for this cursor class. Args: @@ -197,19 +196,19 @@ def get_default_converter(unload: bool = False) -> Union[DefaultTypeConverter, A return DefaultTypeConverter() @property - def connection(self) -> "Connection[Any]": + def connection(self) -> Connection[Any]: return self._connection def _build_start_query_execution_request( self, query: str, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - execution_parameters: Optional[List[str]] = None, - ) -> Dict[str, Any]: - request: Dict[str, Any] = { + work_group: str | None = None, + s3_staging_dir: str | None = None, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + execution_parameters: list[str] | None = None, + ) -> dict[str, Any]: + request: dict[str, Any] = { "QueryString": query, "QueryExecutionContext": {}, } @@ -217,7 +216,7 @@ def _build_start_query_execution_request( request["QueryExecutionContext"].update({"Database": self._schema_name}) if self._catalog_name: request["QueryExecutionContext"].update({"Catalog": self._catalog_name}) - result_configuration: Dict[str, Any] = {} + result_configuration: dict[str, Any] = {} if self._s3_staging_dir or s3_staging_dir: result_configuration["OutputLocation"] = ( s3_staging_dir if s3_staging_dir else self._s3_staging_dir @@ -251,10 +250,10 @@ def _build_start_calculation_execution_request( self, session_id: str, code_block: str, - description: Optional[str] = None, - client_request_token: Optional[str] = None, + description: str | None = None, + client_request_token: str | None = None, ): - request: Dict[str, Any] = { + request: dict[str, Any] = { "SessionId": session_id, "CodeBlock": code_block, } @@ -266,11 +265,11 @@ def _build_start_calculation_execution_request( def _build_list_query_executions_request( self, - work_group: Optional[str], - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Dict[str, Any]: - request: Dict[str, Any] = { + work_group: str | None, + next_token: str | None = None, + max_results: int | None = None, + ) -> dict[str, Any]: + request: dict[str, Any] = { "MaxResults": max_results if max_results else self.LIST_QUERY_EXECUTIONS_MAX_RESULTS } if self._work_group or work_group: @@ -281,13 +280,13 @@ def _build_list_query_executions_request( def _build_list_table_metadata_request( self, - catalog_name: Optional[str], - schema_name: Optional[str], - expression: Optional[str] = None, - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Dict[str, Any]: - request: Dict[str, Any] = { + catalog_name: str | None, + schema_name: str | None, + expression: str | None = None, + next_token: str | None = None, + max_results: int | None = None, + ) -> dict[str, Any]: + request: dict[str, Any] = { "CatalogName": catalog_name if catalog_name else self._catalog_name, "DatabaseName": schema_name if schema_name else self._schema_name, "MaxResults": max_results if max_results else self.LIST_TABLE_METADATA_MAX_RESULTS, @@ -302,11 +301,11 @@ def _build_list_table_metadata_request( def _build_list_databases_request( self, - catalog_name: Optional[str], - next_token: Optional[str] = None, - max_results: Optional[int] = None, + catalog_name: str | None, + next_token: str | None = None, + max_results: int | None = None, ): - request: Dict[str, Any] = { + request: dict[str, Any] = { "CatalogName": catalog_name if catalog_name else self._catalog_name, "MaxResults": max_results if max_results else self.LIST_DATABASES_MAX_RESULTS, } @@ -318,10 +317,10 @@ def _build_list_databases_request( def _list_databases( self, - catalog_name: Optional[str], - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaDatabase]]: + catalog_name: str | None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaDatabase]]: request = self._build_list_databases_request( catalog_name=catalog_name, next_token=next_token, @@ -344,9 +343,9 @@ def _list_databases( def list_databases( self, - catalog_name: Optional[str], - max_results: Optional[int] = None, - ) -> List[AthenaDatabase]: + catalog_name: str | None, + max_results: int | None = None, + ) -> list[AthenaDatabase]: databases = [] next_token = None while True: @@ -363,10 +362,10 @@ def list_databases( def _build_get_table_metadata_request( self, table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - ) -> Dict[str, Any]: - request: Dict[str, Any] = { + catalog_name: str | None = None, + schema_name: str | None = None, + ) -> dict[str, Any]: + request: dict[str, Any] = { "CatalogName": catalog_name if catalog_name else self._catalog_name, "DatabaseName": schema_name if schema_name else self._schema_name, "TableName": table_name, @@ -378,8 +377,8 @@ def _build_get_table_metadata_request( def _get_table_metadata( self, table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, + catalog_name: str | None = None, + schema_name: str | None = None, logging_: bool = True, ) -> AthenaTableMetadata: request = self._build_get_table_metadata_request( @@ -404,8 +403,8 @@ def _get_table_metadata( def get_table_metadata( self, table_name: str, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, + catalog_name: str | None = None, + schema_name: str | None = None, logging_: bool = True, ) -> AthenaTableMetadata: return self._get_table_metadata( @@ -417,12 +416,12 @@ def get_table_metadata( def _list_table_metadata( self, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expression: Optional[str] = None, - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaTableMetadata]]: + catalog_name: str | None = None, + schema_name: str | None = None, + expression: str | None = None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaTableMetadata]]: request = self._build_list_table_metadata_request( catalog_name=catalog_name, schema_name=schema_name, @@ -448,11 +447,11 @@ def _list_table_metadata( def list_table_metadata( self, - catalog_name: Optional[str] = None, - schema_name: Optional[str] = None, - expression: Optional[str] = None, - max_results: Optional[int] = None, - ) -> List[AthenaTableMetadata]: + catalog_name: str | None = None, + schema_name: str | None = None, + expression: str | None = None, + max_results: int | None = None, + ) -> list[AthenaTableMetadata]: metadata = [] next_token = None while True: @@ -513,7 +512,7 @@ def _get_calculation_execution(self, query_id: str) -> AthenaCalculationExecutio else: return AthenaCalculationExecution(response) - def _batch_get_query_execution(self, query_ids: List[str]) -> List[AthenaQueryExecution]: + def _batch_get_query_execution(self, query_ids: list[str]) -> list[AthenaQueryExecution]: try: response = retry_api_call( self.connection._client.batch_get_query_execution, @@ -532,10 +531,10 @@ def _batch_get_query_execution(self, query_ids: List[str]) -> List[AthenaQueryEx def _list_query_executions( self, - work_group: Optional[str] = None, - next_token: Optional[str] = None, - max_results: Optional[int] = None, - ) -> Tuple[Optional[str], List[AthenaQueryExecution]]: + work_group: str | None = None, + next_token: str | None = None, + max_results: int | None = None, + ) -> tuple[str | None, list[AthenaQueryExecution]]: request = self._build_list_query_executions_request( work_group=work_group, next_token=next_token, max_results=max_results ) @@ -556,7 +555,7 @@ def _list_query_executions( return next_token, [] return next_token, self._batch_get_query_execution(query_ids) - def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: query_execution = self._get_query_execution(query_id) if query_execution.state in [ @@ -567,7 +566,7 @@ def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculation return query_execution time.sleep(self._poll_interval) - def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + def _poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: try: query_execution = self.__poll(query_id) except KeyboardInterrupt as e: @@ -582,10 +581,10 @@ def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationE def _find_previous_query_id( self, query: str, - work_group: Optional[str], + work_group: str | None, cache_size: int = 0, cache_expiration_time: int = 0, - ) -> Optional[str]: + ) -> str | None: query_id = None if cache_size == 0 and cache_expiration_time > 0: cache_size = sys.maxsize @@ -609,7 +608,7 @@ def _find_previous_query_id( and e.statement_type == AthenaQueryExecution.STATEMENT_TYPE_DML ), # https://github.com/python/mypy/issues/9656 - key=lambda e: e.completion_date_time, # type: ignore + key=lambda e: e.completion_date_time, # type: ignore[arg-type, return-value] reverse=True, ): if ( @@ -632,9 +631,9 @@ def _find_previous_query_id( def _prepare_query( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - paramstyle: Optional[str] = None, - ) -> Tuple[str, Optional[List[str]]]: + parameters: dict[str, Any] | list[str] | None = None, + paramstyle: str | None = None, + ) -> tuple[str, list[str] | None]: """Format query and build execution parameters. No I/O. Args: @@ -647,9 +646,9 @@ def _prepare_query( """ if pyathena.paramstyle == "qmark" or paramstyle == "qmark": query = operation - execution_parameters = cast(Optional[List[str]], parameters) + execution_parameters = cast(list[str] | None, parameters) else: - query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters)) + query = self._formatter.format(operation, cast(dict[str, Any] | None, parameters)) execution_parameters = None _logger.debug(query) return query, execution_parameters @@ -657,8 +656,8 @@ def _prepare_query( def _prepare_unload( self, operation: str, - s3_staging_dir: Optional[str], - ) -> Tuple[str, Optional[str]]: + s3_staging_dir: str | None, + ) -> tuple[str, str | None]: """Wrap operation with UNLOAD if enabled. Args: @@ -683,14 +682,14 @@ def _prepare_unload( def _execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, ) -> str: query, execution_parameters = self._prepare_query(operation, parameters, paramstyle) @@ -725,8 +724,8 @@ def _calculate( self, session_id: str, code_block: str, - description: Optional[str] = None, - client_request_token: Optional[str] = None, + description: str | None = None, + client_request_token: str | None = None, ) -> str: request = self._build_start_calculation_execution_request( session_id=session_id, @@ -750,7 +749,7 @@ def _calculate( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, + parameters: dict[str, Any] | list[str] | None = None, **kwargs, ): raise NotImplementedError # pragma: no cover @@ -759,7 +758,7 @@ def execute( def executemany( self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: raise NotImplementedError # pragma: no cover @@ -783,11 +782,9 @@ def _cancel(self, query_id: str) -> None: def setinputsizes(self, sizes): # noqa: B027 """Does nothing by default""" - pass def setoutputsize(self, size, column=None): # noqa: B027 """Does nothing by default""" - pass def __enter__(self): return self diff --git a/pyathena/connection.py b/pyathena/connection.py index d36da8f4..65a927f5 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -1,20 +1,15 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import os import time +from collections.abc import Callable from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, + ClassVar, Generic, - List, - Optional, - Type, TypeVar, - Union, cast, overload, ) @@ -33,7 +28,7 @@ if TYPE_CHECKING: from botocore.client import BaseClient -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) ConnectionCursor = TypeVar("ConnectionCursor", bound=BaseCursor) @@ -84,7 +79,7 @@ class Connection(Generic[ConnectionCursor]): _ENV_S3_STAGING_DIR: str = "AWS_ATHENA_S3_STAGING_DIR" _ENV_WORK_GROUP: str = "AWS_ATHENA_WORK_GROUP" - _SESSION_PASSING_ARGS: List[str] = [ + _SESSION_PASSING_ARGS: ClassVar[list[str]] = [ "aws_access_key_id", "aws_secret_access_key", "aws_session_token", @@ -92,7 +87,7 @@ class Connection(Generic[ConnectionCursor]): "botocore_session", "profile_name", ] - _CLIENT_PASSING_ARGS: List[str] = [ + _CLIENT_PASSING_ARGS: ClassVar[list[str]] = [ "aws_access_key_id", "aws_secret_access_key", "aws_session_token", @@ -107,92 +102,92 @@ class Connection(Generic[ConnectionCursor]): @overload def __init__( self: Connection[Cursor], - s3_staging_dir: Optional[str] = ..., - region_name: Optional[str] = ..., - schema_name: Optional[str] = ..., - catalog_name: Optional[str] = ..., - work_group: Optional[str] = ..., + s3_staging_dir: str | None = ..., + region_name: str | None = ..., + schema_name: str | None = ..., + catalog_name: str | None = ..., + work_group: str | None = ..., poll_interval: float = ..., - encryption_option: Optional[str] = ..., - kms_key: Optional[str] = ..., - profile_name: Optional[str] = ..., - role_arn: Optional[str] = ..., + encryption_option: str | None = ..., + kms_key: str | None = ..., + profile_name: str | None = ..., + role_arn: str | None = ..., role_session_name: str = ..., - external_id: Optional[str] = ..., - serial_number: Optional[str] = ..., + external_id: str | None = ..., + serial_number: str | None = ..., duration_seconds: int = ..., - converter: Optional[Converter] = ..., - formatter: Optional[Formatter] = ..., - retry_config: Optional[RetryConfig] = ..., + converter: Converter | None = ..., + formatter: Formatter | None = ..., + retry_config: RetryConfig | None = ..., cursor_class: None = ..., - cursor_kwargs: Optional[Dict[str, Any]] = ..., + cursor_kwargs: dict[str, Any] | None = ..., kill_on_interrupt: bool = ..., - session: Optional[Session] = ..., - config: Optional[Config] = ..., + session: Session | None = ..., + config: Config | None = ..., result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., - on_start_query_execution: Optional[Callable[[str], None]] = ..., + on_start_query_execution: Callable[[str], None] | None = ..., **kwargs, ) -> None: ... @overload def __init__( self: Connection[ConnectionCursor], - s3_staging_dir: Optional[str] = ..., - region_name: Optional[str] = ..., - schema_name: Optional[str] = ..., - catalog_name: Optional[str] = ..., - work_group: Optional[str] = ..., + s3_staging_dir: str | None = ..., + region_name: str | None = ..., + schema_name: str | None = ..., + catalog_name: str | None = ..., + work_group: str | None = ..., poll_interval: float = ..., - encryption_option: Optional[str] = ..., - kms_key: Optional[str] = ..., - profile_name: Optional[str] = ..., - role_arn: Optional[str] = ..., + encryption_option: str | None = ..., + kms_key: str | None = ..., + profile_name: str | None = ..., + role_arn: str | None = ..., role_session_name: str = ..., - external_id: Optional[str] = ..., - serial_number: Optional[str] = ..., + external_id: str | None = ..., + serial_number: str | None = ..., duration_seconds: int = ..., - converter: Optional[Converter] = ..., - formatter: Optional[Formatter] = ..., - retry_config: Optional[RetryConfig] = ..., - cursor_class: Type[ConnectionCursor] = ..., - cursor_kwargs: Optional[Dict[str, Any]] = ..., + converter: Converter | None = ..., + formatter: Formatter | None = ..., + retry_config: RetryConfig | None = ..., + cursor_class: type[ConnectionCursor] = ..., + cursor_kwargs: dict[str, Any] | None = ..., kill_on_interrupt: bool = ..., - session: Optional[Session] = ..., - config: Optional[Config] = ..., + session: Session | None = ..., + config: Config | None = ..., result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., - on_start_query_execution: Optional[Callable[[str], None]] = ..., + on_start_query_execution: Callable[[str], None] | None = ..., **kwargs, ) -> None: ... def __init__( self, - s3_staging_dir: Optional[str] = None, - region_name: Optional[str] = None, - schema_name: Optional[str] = "default", - catalog_name: Optional[str] = "awsdatacatalog", - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + region_name: str | None = None, + schema_name: str | None = "default", + catalog_name: str | None = "awsdatacatalog", + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, - profile_name: Optional[str] = None, - role_arn: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, + profile_name: str | None = None, + role_arn: str | None = None, role_session_name: str = f"PyAthena-session-{int(time.time())}", - external_id: Optional[str] = None, - serial_number: Optional[str] = None, + external_id: str | None = None, + serial_number: str | None = None, duration_seconds: int = 3600, - converter: Optional[Converter] = None, - formatter: Optional[Formatter] = None, - retry_config: Optional[RetryConfig] = None, - cursor_class: Optional[Type[ConnectionCursor]] = None, - cursor_kwargs: Optional[Dict[str, Any]] = None, + converter: Converter | None = None, + formatter: Formatter | None = None, + retry_config: RetryConfig | None = None, + cursor_class: type[ConnectionCursor] | None = None, + cursor_kwargs: dict[str, Any] | None = None, kill_on_interrupt: bool = True, - session: Optional[Session] = None, - config: Optional[Config] = None, + session: Session | None = None, + config: Config | None = None, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Optional[Callable[[str], None]] = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> None: """Initialize a new Athena database connection. @@ -253,21 +248,21 @@ def __init__( "duration_seconds": duration_seconds, } if s3_staging_dir is not None: - self.s3_staging_dir: Optional[str] = s3_staging_dir or None + self.s3_staging_dir: str | None = s3_staging_dir or None else: self.s3_staging_dir = os.getenv(self._ENV_S3_STAGING_DIR) self.region_name = region_name self.schema_name = schema_name self.catalog_name = catalog_name if work_group: - self.work_group: Optional[str] = work_group + self.work_group: str | None = work_group else: self.work_group = os.getenv(self._ENV_WORK_GROUP) self.poll_interval = poll_interval self.encryption_option = encryption_option self.kms_key = kms_key self.profile_name = profile_name - self.config: Optional[Config] = config if config else Config() + self.config: Config | None = config if config else Config() if not self.s3_staging_dir and not self.work_group: raise ProgrammingError("Required argument `s3_staging_dir` or `work_group` not found.") @@ -330,7 +325,7 @@ def __init__( self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() - self.cursor_class = cursor_class if cursor_class else cast(Type[ConnectionCursor], Cursor) + self.cursor_class = cursor_class if cursor_class else cast(type[ConnectionCursor], Cursor) self.cursor_kwargs = cursor_kwargs if cursor_kwargs else {} self.kill_on_interrupt = kill_on_interrupt self.result_reuse_enable = result_reuse_enable @@ -339,14 +334,14 @@ def __init__( def _assume_role( self, - profile_name: Optional[str], - region_name: Optional[str], + profile_name: str | None, + region_name: str | None, role_arn: str, role_session_name: str, - external_id: Optional[str], - serial_number: Optional[str], + external_id: str | None, + serial_number: str | None, duration_seconds: int, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Assume an IAM role and return temporary credentials. Uses AWS STS to assume the specified IAM role and obtain temporary @@ -396,16 +391,16 @@ def _assume_role( } ) response = client.assume_role(**request) - creds: Dict[str, Any] = response["Credentials"] + creds: dict[str, Any] = response["Credentials"] return creds def _get_session_token( self, - profile_name: Optional[str], - region_name: Optional[str], - serial_number: Optional[str], + profile_name: str | None, + region_name: str | None, + serial_number: str | None, duration_seconds: int, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get session token using MFA authentication. Obtains temporary security credentials by providing MFA authentication. @@ -435,11 +430,11 @@ def _get_session_token( "TokenCode": token_code, } response = client.get_session_token(**request) - creds: Dict[str, Any] = response["Credentials"] + creds: dict[str, Any] = response["Credentials"] return creds @property - def _session_kwargs(self) -> Dict[str, Any]: + def _session_kwargs(self) -> dict[str, Any]: """Get session keyword arguments for AWS Session creation. Returns: @@ -449,7 +444,7 @@ def _session_kwargs(self) -> Dict[str, Any]: return {k: v for k, v in self._kwargs.items() if k in self._SESSION_PASSING_ARGS} @property - def _client_kwargs(self) -> Dict[str, Any]: + def _client_kwargs(self) -> dict[str, Any]: """Get client keyword arguments for AWS client creation. Returns: @@ -468,7 +463,7 @@ def session(self) -> Session: return self._session @property - def client(self) -> "BaseClient": + def client(self) -> BaseClient: """Get the boto3 Athena client used for query operations. Returns: @@ -507,11 +502,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): def cursor(self, cursor: None = ..., **kwargs) -> ConnectionCursor: ... @overload - def cursor(self, cursor: Type[FunctionalCursor], **kwargs) -> FunctionalCursor: ... + def cursor(self, cursor: type[FunctionalCursor], **kwargs) -> FunctionalCursor: ... def cursor( - self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs - ) -> Union[FunctionalCursor, ConnectionCursor]: + self, cursor: type[FunctionalCursor] | None = None, **kwargs + ) -> FunctionalCursor | ConnectionCursor: """Create a new cursor object for executing queries. Creates and returns a cursor object that can be used to execute SQL @@ -574,7 +569,6 @@ def close(self) -> None: This method is called automatically when using the connection as a context manager (with statement). """ - pass def commit(self) -> None: """Commit any pending transaction. @@ -585,7 +579,6 @@ def commit(self) -> None: Note: Athena queries are auto-committed and cannot be rolled back. """ - pass def rollback(self) -> None: """Rollback any pending transaction. diff --git a/pyathena/converter.py b/pyathena/converter.py index bc39419f..ce664c81 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -1,23 +1,23 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import binascii import json import logging from abc import ABCMeta, abstractmethod +from collections.abc import Callable from copy import deepcopy from datetime import date, datetime, time from decimal import Decimal -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any from dateutil.tz import gettz from pyathena.util import strtobool -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) -def _to_date(value: Optional[Union[str, datetime, date]]) -> Optional[date]: +def _to_date(value: str | datetime | date | None) -> date | None: if value is None: return None if isinstance(value, datetime): @@ -27,62 +27,62 @@ def _to_date(value: Optional[Union[str, datetime, date]]) -> Optional[date]: return datetime.strptime(value, "%Y-%m-%d").date() -def _to_datetime(varchar_value: Optional[str]) -> Optional[datetime]: +def _to_datetime(varchar_value: str | None) -> datetime | None: if varchar_value is None: return None return datetime.strptime(varchar_value, "%Y-%m-%d %H:%M:%S.%f") -def _to_datetime_with_tz(varchar_value: Optional[str]) -> Optional[datetime]: +def _to_datetime_with_tz(varchar_value: str | None) -> datetime | None: if varchar_value is None: return None datetime_, _, tz = varchar_value.rpartition(" ") return datetime.strptime(datetime_, "%Y-%m-%d %H:%M:%S.%f").replace(tzinfo=gettz(tz)) -def _to_time(varchar_value: Optional[str]) -> Optional[time]: +def _to_time(varchar_value: str | None) -> time | None: if varchar_value is None: return None return datetime.strptime(varchar_value, "%H:%M:%S.%f").time() -def _to_float(varchar_value: Optional[str]) -> Optional[float]: +def _to_float(varchar_value: str | None) -> float | None: if varchar_value is None: return None return float(varchar_value) -def _to_int(varchar_value: Optional[str]) -> Optional[int]: +def _to_int(varchar_value: str | None) -> int | None: if varchar_value is None: return None return int(varchar_value) -def _to_decimal(varchar_value: Optional[str]) -> Optional[Decimal]: +def _to_decimal(varchar_value: str | None) -> Decimal | None: if not varchar_value: return None return Decimal(varchar_value) -def _to_boolean(varchar_value: Optional[str]) -> Optional[bool]: +def _to_boolean(varchar_value: str | None) -> bool | None: if not varchar_value: return None return bool(strtobool(varchar_value)) -def _to_binary(varchar_value: Optional[str]) -> Optional[bytes]: +def _to_binary(varchar_value: str | None) -> bytes | None: if varchar_value is None: return None return binascii.a2b_hex("".join(varchar_value.split(" "))) -def _to_json(varchar_value: Optional[str]) -> Optional[Any]: +def _to_json(varchar_value: str | None) -> Any | None: if varchar_value is None: return None return json.loads(varchar_value) -def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]: +def _to_array(varchar_value: str | None) -> list[Any] | None: """Convert array data to Python list. Supports two formats: @@ -128,7 +128,7 @@ def _to_array(varchar_value: Optional[str]) -> Optional[List[Any]]: return None -def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: +def _to_map(varchar_value: str | None) -> dict[str, Any] | None: """Convert map data to Python dictionary. Supports two formats: @@ -179,7 +179,7 @@ def _to_map(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: return None -def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: +def _to_struct(varchar_value: str | None) -> dict[str, Any] | None: """Convert struct data to Python dictionary. Supports two formats: @@ -229,7 +229,7 @@ def _to_struct(varchar_value: Optional[str]) -> Optional[Dict[str, Any]]: return None -def _parse_array_native(inner: str) -> Optional[List[Any]]: +def _parse_array_native(inner: str) -> list[Any] | None: """Parse array native format: 1, 2, 3 or {a, b}, {c, d}. Args: @@ -266,7 +266,7 @@ def _parse_array_native(inner: str) -> Optional[List[Any]]: return result if result else None -def _split_array_items(inner: str) -> List[str]: +def _split_array_items(inner: str) -> list[str]: """Split array items by comma, respecting brace and bracket groupings. Args: @@ -304,7 +304,7 @@ def _split_array_items(inner: str) -> List[str]: return items -def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]: +def _parse_map_native(inner: str) -> dict[str, Any] | None: """Parse map native format: key1=value1, key2=value2. Args: @@ -339,7 +339,7 @@ def _parse_map_native(inner: str) -> Optional[Dict[str, Any]]: return result if result else None -def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]: +def _parse_named_struct(inner: str) -> dict[str, Any] | None: """Parse named struct format: key1=value1, key2=value2. Supports nested structs: outer={inner_key=inner_value}, field=value. @@ -381,7 +381,7 @@ def _parse_named_struct(inner: str) -> Optional[Dict[str, Any]]: return result if result else None -def _parse_unnamed_struct(inner: str) -> Dict[str, Any]: +def _parse_unnamed_struct(inner: str) -> dict[str, Any]: """Parse unnamed struct format: Alice, 25. Args: @@ -409,18 +409,18 @@ def _convert_value(value: str) -> Any: return True if value.lower() == "false": return False - if value.isdigit() or value.startswith("-") and value[1:].isdigit(): + if value.isdigit() or (value.startswith("-") and value[1:].isdigit()): return int(value) if "." in value and value.replace(".", "", 1).replace("-", "", 1).isdigit(): return float(value) return value -def _to_default(varchar_value: Optional[str]) -> Optional[str]: +def _to_default(varchar_value: str | None) -> str | None: return varchar_value -_DEFAULT_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { +_DEFAULT_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "boolean": _to_boolean, "tinyint": _to_int, "smallint": _to_int, @@ -464,9 +464,9 @@ class Converter(metaclass=ABCMeta): def __init__( self, - mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]], - default: Callable[[Optional[str]], Optional[Any]] = _to_default, - types: Optional[Dict[str, Type[Any]]] = None, + mappings: dict[str, Callable[[str | None], Any | None]], + default: Callable[[str | None], Any | None] = _to_default, + types: dict[str, type[Any]] | None = None, ) -> None: if mappings: self._mappings = mappings @@ -479,7 +479,7 @@ def __init__( self._types = {} @property - def mappings(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: + def mappings(self) -> dict[str, Callable[[str | None], Any | None]]: """Get the current type conversion mappings. Returns: @@ -488,7 +488,7 @@ def mappings(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: return self._mappings @property - def types(self) -> Dict[str, Type[Any]]: + def types(self) -> dict[str, type[Any]]: """Get the current type mappings for result set descriptions. Returns: @@ -496,7 +496,7 @@ def types(self) -> Dict[str, Type[Any]]: """ return self._types - def get(self, type_: str) -> Callable[[Optional[str]], Optional[Any]]: + def get(self, type_: str) -> Callable[[str | None], Any | None]: """Get the conversion function for a specific Athena data type. Args: @@ -507,7 +507,7 @@ def get(self, type_: str) -> Callable[[Optional[str]], Optional[Any]]: """ return self.mappings.get(type_, self._default) - def set(self, type_: str, converter: Callable[[Optional[str]], Optional[Any]]) -> None: + def set(self, type_: str, converter: Callable[[str | None], Any | None]) -> None: """Set a custom conversion function for an Athena data type. Args: @@ -524,7 +524,7 @@ def remove(self, type_: str) -> None: """ self.mappings.pop(type_, None) - def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Optional[Type[Any]]: + def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> type[Any] | None: """Get the data type for a given Athena type. Subclasses may override this to provide custom type handling @@ -540,7 +540,7 @@ def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Optional[ """ return self._types.get(type_) - def update(self, mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]]) -> None: + def update(self, mappings: dict[str, Callable[[str | None], Any | None]]) -> None: """Update multiple conversion functions at once. Args: @@ -549,7 +549,7 @@ def update(self, mappings: Dict[str, Callable[[Optional[str]], Optional[Any]]]) self.mappings.update(mappings) @abstractmethod - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: raise NotImplementedError # pragma: no cover @@ -580,6 +580,6 @@ class DefaultTypeConverter(Converter): def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_CONVERTERS), default=_to_default) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: converter = self.get(type_) return converter(value) diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 46bf967b..9557dc7f 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -1,15 +1,15 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional, Union, cast +from collections.abc import Callable +from typing import Any, cast from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError from pyathena.model import AthenaQueryExecution from pyathena.result_set import AthenaDictResultSet, AthenaResultSet, WithFetch -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class Cursor(WithFetch): @@ -42,17 +42,17 @@ class Cursor(WithFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Optional[Callable[[str], None]] = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> None: super().__init__( @@ -86,15 +86,15 @@ def arraysize(self, value: int) -> None: def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, cache_size: int = 0, cache_expiration_time: int = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, - on_start_query_execution: Optional[Callable[[str], None]] = None, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> Cursor: """Execute a SQL query. diff --git a/pyathena/error.py b/pyathena/error.py index f29da893..41da74aa 100644 --- a/pyathena/error.py +++ b/pyathena/error.py @@ -1,14 +1,13 @@ -# -*- coding: utf-8 -*- __all__ = [ + "DataError", + "DatabaseError", "Error", - "Warning", "InterfaceError", - "DatabaseError", "InternalError", + "NotSupportedError", "OperationalError", "ProgrammingError", - "DataError", - "NotSupportedError", + "Warning", ] @@ -20,8 +19,6 @@ class Error(Exception): Python Database API Specification v2.0 (PEP 249). """ - pass - class Warning(Exception): # noqa: N818 """Exception for non-fatal warnings. @@ -31,8 +28,6 @@ class Warning(Exception): # noqa: N818 but follows the DB API 2.0 specification. """ - pass - class InterfaceError(Error): """Exception for errors related to the database interface. @@ -41,8 +36,6 @@ class InterfaceError(Error): such as connection problems or interface misuse. """ - pass - class DatabaseError(Error): """Base exception for database-related errors. @@ -52,8 +45,6 @@ class DatabaseError(Error): error types inherit from this class. """ - pass - class InternalError(DatabaseError): """Exception for internal database errors. @@ -62,8 +53,6 @@ class InternalError(DatabaseError): that is not due to user actions or programming errors. """ - pass - class OperationalError(DatabaseError): """Exception for errors during database operation processing. @@ -73,8 +62,6 @@ class OperationalError(DatabaseError): invalid query syntax that wasn't caught at the programming level. """ - pass - class ProgrammingError(DatabaseError): """Exception for programming errors in database operations. @@ -84,8 +71,6 @@ class ProgrammingError(DatabaseError): invalid parameters, or attempting operations on closed connections. """ - pass - class IntegrityError(DatabaseError): """Exception for data integrity constraint violations. @@ -95,8 +80,6 @@ class IntegrityError(DatabaseError): constraint failures. """ - pass - class DataError(DatabaseError): """Exception for errors due to invalid data. @@ -106,8 +89,6 @@ class DataError(DatabaseError): or malformed data structures. """ - pass - class NotSupportedError(DatabaseError): """Exception for unsupported database operations. @@ -116,5 +97,3 @@ class NotSupportedError(DatabaseError): by Athena, such as transactions (commit/rollback) or certain SQL features that are not available in the Athena query engine. """ - - pass diff --git a/pyathena/filesystem/__init__.py b/pyathena/filesystem/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/filesystem/__init__.py +++ b/pyathena/filesystem/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index 57f4bed8..7885514c 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -1,15 +1,16 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import mimetypes import os.path import re +from collections.abc import Callable from concurrent.futures import Future, as_completed from copy import deepcopy from datetime import datetime from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Pattern, Tuple, Union, cast +from re import Pattern +from typing import TYPE_CHECKING, Any, cast import botocore.exceptions from boto3 import Session @@ -36,7 +37,7 @@ if TYPE_CHECKING: from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class S3FileSystem(AbstractFileSystem): @@ -102,9 +103,9 @@ class S3FileSystem(AbstractFileSystem): def __init__( self, - connection: Optional["Connection[Any]"] = None, - default_block_size: Optional[int] = None, - default_cache_type: Optional[str] = None, + connection: Connection[Any] | None = None, + default_block_size: int | None = None, + default_cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, s3_additional_kwargs=None, *args, @@ -186,13 +187,13 @@ def _get_client_compatible_with_s3fs(self, **kwargs) -> BaseClient: ) @staticmethod - def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]: + def parse_path(path: str) -> tuple[str, str | None, str | None]: match = S3FileSystem.PATTERN_PATH.search(path) if match: return match.group("bucket"), match.group("key"), match.group("version_id") raise ValueError(f"Invalid S3 path format {path}.") - def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: + def _head_bucket(self, bucket, refresh: bool = False) -> S3Object | None: if bucket not in self.dircache or refresh: try: self._call( @@ -222,8 +223,8 @@ def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: return file def _head_object( - self, path: str, version_id: Optional[str] = None, refresh: bool = False - ) -> Optional[S3Object]: + self, path: str, version_id: str | None = None, refresh: bool = False + ) -> S3Object | None: bucket, key, path_version_id = self.parse_path(path) version_id = path_version_id if path_version_id else version_id if path not in self.dircache or refresh: @@ -254,7 +255,7 @@ def _head_object( file = self.dircache[path] return file - def _ls_buckets(self, refresh: bool = False) -> List[S3Object]: + def _ls_buckets(self, refresh: bool = False) -> list[S3Object]: if "" not in self.dircache or refresh: response = self._call( self._client.list_buckets, @@ -285,10 +286,10 @@ def _ls_dirs( path: str, prefix: str = "", delimiter: str = "/", - next_token: Optional[str] = None, - max_keys: Optional[int] = None, + next_token: str | None = None, + max_keys: int | None = None, refresh: bool = False, - ) -> List[S3Object]: + ) -> list[S3Object]: bucket, key, version_id = self.parse_path(path) if key: prefix = f"{key}/{prefix if prefix else ''}" @@ -296,11 +297,11 @@ def _ls_dirs( # Create a cache key that includes the delimiter cache_key = (path, delimiter) if cache_key in self.dircache and not refresh: - return cast(List[S3Object], self.dircache[cache_key]) + return cast(list[S3Object], self.dircache[cache_key]) - files: List[S3Object] = [] + files: list[S3Object] = [] while True: - request: Dict[Any, Any] = { + request: dict[Any, Any] = { "Bucket": bucket, "Prefix": prefix, "Delimiter": delimiter, @@ -347,7 +348,7 @@ def _ls_dirs( def ls( self, path: str, detail: bool = False, refresh: bool = False, **kwargs - ) -> Union[List[S3Object], List[str]]: + ) -> list[S3Object] | list[str]: """List contents of an S3 path. Lists buckets (when path is root) or objects within a bucket/prefix. @@ -398,7 +399,7 @@ def info(self, path: str, **kwargs) -> S3Object: version_id=None, ) if not refresh: - caches: Union[List[S3Object], S3Object] = self._ls_from_cache(path) + caches: list[S3Object] | S3Object = self._ls_from_cache(path) if caches is not None: if isinstance(caches, list): cache = next((c for c in caches if c.name == path), None) @@ -460,8 +461,8 @@ def info(self, path: str, **kwargs) -> S3Object: raise FileNotFoundError(path) def _extract_parent_directories( - self, files: List[S3Object], bucket: str, base_key: Optional[str] - ) -> List[S3Object]: + self, files: list[S3Object], bucket: str, base_key: str | None + ) -> list[S3Object]: """Extract parent directory objects from file paths. When listing files without delimiter, S3 doesn't return directory entries. @@ -521,10 +522,10 @@ def _extract_parent_directories( def _find( self, path: str, - maxdepth: Optional[int] = None, - withdirs: Optional[bool] = None, + maxdepth: int | None = None, + withdirs: bool | None = None, **kwargs, - ) -> List[S3Object]: + ) -> list[S3Object]: path = self._strip_protocol(path) if path in ["", "/"]: raise ValueError("Cannot traverse all files in S3.") @@ -533,7 +534,7 @@ def _find( # When maxdepth is specified, use a recursive approach with delimiter if maxdepth is not None: - result: List[S3Object] = [] + result: list[S3Object] = [] # List files and directories at current level current_items = self._ls_dirs(path, prefix=prefix, delimiter="/") @@ -578,11 +579,11 @@ def _find( def find( self, path: str, - maxdepth: Optional[int] = None, - withdirs: Optional[bool] = None, + maxdepth: int | None = None, + withdirs: bool | None = None, detail: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: + ) -> dict[str, S3Object] | list[str]: """Find all files below a given S3 path. Recursively searches for files under the specified path, with optional @@ -670,7 +671,7 @@ def rm(self, path, recursive=False, maxdepth=None, **kwargs) -> None: self.invalidate_cache(p) def _delete_object( - self, bucket: str, key: str, version_id: Optional[str] = None, **kwargs + self, bucket: str, key: str, version_id: str | None = None, **kwargs ) -> None: request = { "Bucket": bucket, @@ -679,7 +680,7 @@ def _delete_object( if version_id: request.update({"VersionId": version_id}) - _logger.debug(f"Delete object: s3://{bucket}/{key}?versionId={version_id}") + _logger.debug("Delete object: s3://%s/%s?versionId=%s", bucket, key, version_id) self._call( self._client.delete_object, **request, @@ -700,7 +701,7 @@ def _create_executor(self, max_workers: int) -> S3Executor: return S3ThreadPoolExecutor(max_workers=max_workers) def _delete_objects( - self, bucket: str, paths: List[str], max_workers: Optional[int] = None, **kwargs + self, bucket: str, paths: list[str], max_workers: int | None = None, **kwargs ) -> None: if not paths: return @@ -735,7 +736,7 @@ def _delete_objects( for f in as_completed(fs): f.result() - def touch(self, path: str, truncate: bool = True, **kwargs) -> Dict[str, Any]: + def touch(self, path: str, truncate: bool = True, **kwargs) -> dict[str, Any]: bucket, key, version_id = self.parse_path(path) if version_id: raise ValueError("Cannot touch the file with the version specified.") @@ -812,7 +813,7 @@ def _copy_object( self, bucket1: str, key1: str, - version_id1: Optional[str], + version_id1: str | None, bucket2: str, key2: str, **kwargs, @@ -830,8 +831,12 @@ def _copy_object( } _logger.debug( - f"Copy object from s3://{bucket1}/{key1}?versionId={version_id1} " - f"to s3://{bucket2}/{key2}." + "Copy object from s3://%s/%s?versionId=%s to s3://%s/%s.", + bucket1, + key1, + version_id1, + bucket2, + key2, ) self._call(self._client.copy_object, **request, **kwargs) @@ -842,9 +847,9 @@ def _copy_object_with_multipart_upload( size1: int, bucket2: str, key2: str, - max_workers: Optional[int] = None, - block_size: Optional[int] = None, - version_id1: Optional[str] = None, + max_workers: int | None = None, + block_size: int | None = None, + version_id1: str | None = None, **kwargs, ) -> None: max_workers = max_workers if max_workers else self.max_workers @@ -896,7 +901,7 @@ def _copy_object_with_multipart_upload( } ) - parts.sort(key=lambda x: x["PartNumber"]) # type: ignore + parts.sort(key=lambda x: x["PartNumber"]) # type: ignore[arg-type, return-value] self._complete_multipart_upload( bucket=bucket2, key=key2, @@ -905,7 +910,7 @@ def _copy_object_with_multipart_upload( ) def cat_file( - self, path: str, start: Optional[int] = None, end: Optional[int] = None, **kwargs + self, path: str, start: int | None = None, end: int | None = None, **kwargs ) -> bytes: bucket, key, version_id = self.parse_path(path) if start is not None or end is not None: @@ -1073,7 +1078,7 @@ def sign(self, path: str, expiration: int = 3600, **kwargs): "ExpiresIn": expiration, } - _logger.debug(f"Generate signed url: s3://{bucket}/{key}?versionId={version_id}") + _logger.debug("Generate signed url: s3://%s/%s?versionId=%s", bucket, key, version_id) return self._call( self._client.generate_presigned_url, **request, @@ -1086,7 +1091,7 @@ def modified(self, path: str) -> datetime: info = self.info(path) return cast(datetime, info.get("last_modified")) - def invalidate_cache(self, path: Optional[str] = None) -> None: + def invalidate_cache(self, path: str | None = None) -> None: if path is None: self.dircache.clear() else: @@ -1099,10 +1104,10 @@ def _open( self, path: str, mode: str = "rb", - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + block_size: int | None = None, + cache_type: str | None = None, autocommit: bool = True, - cache_options: Optional[Dict[Any, Any]] = None, + cache_options: dict[Any, Any] | None = None, **kwargs, ) -> S3File: if block_size is None: @@ -1132,10 +1137,10 @@ def _get_object( self, bucket: str, key: str, - ranges: Optional[Tuple[int, int]] = None, - version_id: Optional[str] = None, + ranges: tuple[int, int] | None = None, + version_id: str | None = None, **kwargs, - ) -> Tuple[int, bytes]: + ) -> tuple[int, bytes]: request = {"Bucket": bucket, "Key": key} if ranges: range_ = S3File._format_ranges(ranges) @@ -1146,7 +1151,13 @@ def _get_object( if version_id: request.update({"VersionId": version_id}) - _logger.debug(f"Get object: s3://{bucket}/{key}?versionId={version_id}&range={range_}") + _logger.debug( + "Get object: s3://%s/%s?versionId=%s&range=%s", + bucket, + key, + version_id, + range_, + ) response = self._call( self._client.get_object, **request, @@ -1154,12 +1165,12 @@ def _get_object( ) return ranges[0], cast(bytes, response["Body"].read()) - def _put_object(self, bucket: str, key: str, body: Optional[bytes], **kwargs) -> S3PutObject: - request: Dict[str, Any] = {"Bucket": bucket, "Key": key} + def _put_object(self, bucket: str, key: str, body: bytes | None, **kwargs) -> S3PutObject: + request: dict[str, Any] = {"Bucket": bucket, "Key": key} if body: request.update({"Body": body}) - _logger.debug(f"Put object: s3://{bucket}/{key}") + _logger.debug("Put object: s3://%s/%s", bucket, key) response = self._call( self._client.put_object, **request, @@ -1173,7 +1184,7 @@ def _create_multipart_upload(self, bucket: str, key: str, **kwargs) -> S3Multipa "Key": key, } - _logger.debug(f"Create multipart upload to s3://{bucket}/{key}.") + _logger.debug("Create multipart upload to s3://%s/%s.", bucket, key) response = self._call( self._client.create_multipart_upload, **request, @@ -1185,10 +1196,10 @@ def _upload_part_copy( self, bucket: str, key: str, - copy_source: Union[str, Dict[str, Any]], + copy_source: str | dict[str, Any], upload_id: str, part_number: int, - copy_source_ranges: Optional[Tuple[int, int]] = None, + copy_source_ranges: tuple[int, int] | None = None, **kwargs, ) -> S3MultipartUploadPart: request = { @@ -1202,7 +1213,11 @@ def _upload_part_copy( range_ = S3File._format_ranges(copy_source_ranges) request.update({"CopySourceRange": range_}) _logger.debug( - f"Upload part copy from {copy_source} to s3://{bucket}/{key} as part {part_number}." + "Upload part copy from %s to s3://%s/%s as part %s.", + copy_source, + bucket, + key, + part_number, ) response = self._call( self._client.upload_part_copy, @@ -1228,7 +1243,13 @@ def _upload_part( "Body": body, } - _logger.debug(f"Upload part of {upload_id} to s3://{bucket}/{key} as part {part_number}.") + _logger.debug( + "Upload part of %s to s3://%s/%s as part %s.", + upload_id, + bucket, + key, + part_number, + ) response = self._call( self._client.upload_part, **request, @@ -1237,7 +1258,7 @@ def _upload_part( return S3MultipartUploadPart(part_number, response) def _complete_multipart_upload( - self, bucket: str, key: str, upload_id: str, parts: List[Dict[str, Any]], **kwargs + self, bucket: str, key: str, upload_id: str, parts: list[dict[str, Any]], **kwargs ) -> S3CompleteMultipartUpload: request = { "Bucket": bucket, @@ -1246,7 +1267,7 @@ def _complete_multipart_upload( "MultipartUpload": {"Parts": parts}, } - _logger.debug(f"Complete multipart upload {upload_id} to s3://{bucket}/{key}.") + _logger.debug("Complete multipart upload %s to s3://%s/%s.", upload_id, bucket, key) response = self._call( self._client.complete_multipart_upload, **request, @@ -1254,12 +1275,12 @@ def _complete_multipart_upload( ) return S3CompleteMultipartUpload(response) - def _call(self, method: Union[str, Callable[..., Any]], **kwargs) -> Dict[str, Any]: + def _call(self, method: str | Callable[..., Any], **kwargs) -> dict[str, Any]: func = getattr(self._client, method) if isinstance(method, str) else method response = retry_api_call( func, config=self._retry_config, logger=_logger, **kwargs, **self.request_kwargs ) - return cast(Dict[str, Any], response) + return cast(dict[str, Any], response) class S3File(AbstractBufferedFile): @@ -1268,15 +1289,15 @@ def __init__( fs: S3FileSystem, path: str, mode: str = "rb", - version_id: Optional[str] = None, + version_id: str | None = None, max_workers: int = (cpu_count() or 1) * 5, - executor: Optional[S3Executor] = None, + executor: S3Executor | None = None, block_size: int = S3FileSystem.DEFAULT_BLOCK_SIZE, cache_type: str = "bytes", autocommit: bool = True, - cache_options: Optional[Dict[Any, Any]] = None, - size: Optional[int] = None, - s3_additional_kwargs: Optional[Dict[str, Any]] = None, + cache_options: dict[Any, Any] | None = None, + size: int | None = None, + s3_additional_kwargs: dict[str, Any] | None = None, **kwargs, ) -> None: self.max_workers = max_workers @@ -1304,7 +1325,7 @@ def __init__( f"The version_id: {version_id} specified in the argument and " f"the version_id: {path_version_id} specified in the path do not match." ) - self.version_id: Optional[str] = version_id + self.version_id: str | None = version_id elif path_version_id: self.version_id = path_version_id else: @@ -1332,8 +1353,8 @@ def __init__( else: self._details = {} - self.multipart_upload: Optional[S3MultipartUpload] = None - self.multipart_upload_parts: List[Future[S3MultipartUploadPart]] = [] + self.multipart_upload: S3MultipartUpload | None = None + self.multipart_upload_parts: list[Future[S3MultipartUploadPart]] = [] def close(self) -> None: super().close() @@ -1457,7 +1478,7 @@ def commit(self) -> None: if not self.multipart_upload: raise RuntimeError("Multipart upload is not initialized.") - parts: List[Dict[str, Any]] = [] + parts: list[dict[str, Any]] = [] for f in as_completed(self.multipart_upload_parts): result = f.result() parts.append( @@ -1519,13 +1540,13 @@ def _fetch_range(self, start: int, end: int) -> bytes: return object_ @staticmethod - def _format_ranges(ranges: Tuple[int, int]): + def _format_ranges(ranges: tuple[int, int]): return f"bytes={ranges[0]}-{ranges[1] - 1}" @staticmethod def _get_ranges( start: int, end: int, max_workers: int, worker_block_size: int - ) -> List[Tuple[int, int]]: + ) -> list[tuple[int, int]]: ranges = [] range_size = end - start if max_workers > 1 and range_size > worker_block_size: @@ -1542,6 +1563,6 @@ def _get_ranges( return ranges @staticmethod - def _merge_objects(objects: List[Tuple[int, bytes]]) -> bytes: + def _merge_objects(objects: list[tuple[int, bytes]]) -> bytes: objects.sort(key=lambda x: x[0]) return b"".join([obj for start, obj in objects]) diff --git a/pyathena/filesystem/s3_async.py b/pyathena/filesystem/s3_async.py index 44aeb834..c85c2394 100644 --- a/pyathena/filesystem/s3_async.py +++ b/pyathena/filesystem/s3_async.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio import logging from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast from fsspec.asyn import AsyncFileSystem from fsspec.callbacks import _DEFAULT_CALLBACK @@ -60,14 +59,14 @@ class AioS3FileSystem(AsyncFileSystem): def __init__( self, - connection: Optional["Connection[Any]"] = None, - default_block_size: Optional[int] = None, - default_cache_type: Optional[str] = None, + connection: Connection[Any] | None = None, + default_block_size: int | None = None, + default_cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, - s3_additional_kwargs: Optional[Dict[str, Any]] = None, + s3_additional_kwargs: dict[str, Any] | None = None, asynchronous: bool = False, - loop: Optional[Any] = None, - batch_size: Optional[int] = None, + loop: Any | None = None, + batch_size: int | None = None, **kwargs, ) -> None: super().__init__( @@ -88,19 +87,17 @@ def __init__( self.dircache = self._sync_fs.dircache @staticmethod - def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]: + def parse_path(path: str) -> tuple[str, str | None, str | None]: return S3FileSystem.parse_path(path) async def _info(self, path: str, **kwargs) -> S3Object: return await asyncio.to_thread(self._sync_fs.info, path, **kwargs) - async def _ls( - self, path: str, detail: bool = False, **kwargs - ) -> Union[List[S3Object], List[str]]: + async def _ls(self, path: str, detail: bool = False, **kwargs) -> list[S3Object] | list[str]: return await asyncio.to_thread(self._sync_fs.ls, path, detail=detail, **kwargs) async def _cat_file( - self, path: str, start: Optional[int] = None, end: Optional[int] = None, **kwargs + self, path: str, start: int | None = None, end: int | None = None, **kwargs ) -> bytes: return await asyncio.to_thread(self._sync_fs.cat_file, path, start=start, end=end, **kwargs) @@ -125,7 +122,7 @@ async def _mkdir(self, path: str, create_parents: bool = True, **kwargs) -> None async def _makedirs(self, path: str, exist_ok: bool = False) -> None: await asyncio.to_thread(self._sync_fs.makedirs, path, exist_ok=exist_ok) - async def _rm(self, path: Union[str, List[str]], recursive: bool = False, **kwargs) -> None: + async def _rm(self, path: str | list[str], recursive: bool = False, **kwargs) -> None: """Remove files or directories using async parallel batch deletion. For multiple paths, chunks into batches of 1000 (S3 API limit) and uses @@ -136,7 +133,7 @@ async def _rm(self, path: Union[str, List[str]], recursive: bool = False, **kwar bucket, _, _ = self.parse_path(path[0]) - expand_paths: List[str] = [] + expand_paths: list[str] = [] for p in path: expanded = await asyncio.to_thread(self._sync_fs.expand_path, p, recursive=recursive) expand_paths.extend(expanded) @@ -145,11 +142,11 @@ async def _rm(self, path: Union[str, List[str]], recursive: bool = False, **kwar return quiet = kwargs.pop("Quiet", True) - delete_objects: List[Dict[str, Any]] = [] + delete_objects: list[dict[str, Any]] = [] for p in expand_paths: _, key, version_id = self.parse_path(p) if key: - object_: Dict[str, Any] = {"Key": key} + object_: dict[str, Any] = {"Key": key} if version_id: object_["VersionId"] = version_id delete_objects.append(object_) @@ -162,7 +159,7 @@ async def _rm(self, path: Union[str, List[str]], recursive: bool = False, **kwar for i in range(0, len(delete_objects), self.DELETE_OBJECTS_MAX_KEYS) ] - async def _delete_chunk(chunk: List[Dict[str, Any]]) -> None: + async def _delete_chunk(chunk: list[dict[str, Any]]) -> None: request = { "Bucket": bucket, "Delete": { @@ -220,8 +217,8 @@ async def _copy_object_with_multipart_upload( size1: int, bucket2: str, key2: str, - block_size: Optional[int] = None, - version_id1: Optional[str] = None, + block_size: int | None = None, + version_id1: str | None = None, **kwargs, ) -> None: block_size = block_size if block_size else S3FileSystem.MULTIPART_UPLOAD_MAX_PART_SIZE @@ -231,7 +228,7 @@ async def _copy_object_with_multipart_upload( ): raise ValueError("Block size must be greater than 5MiB and less than 5GiB.") - copy_source: Dict[str, Any] = { + copy_source: dict[str, Any] = { "Bucket": bucket1, "Key": key1, } @@ -251,7 +248,7 @@ async def _copy_object_with_multipart_upload( **kwargs, ) - async def _upload_part(i: int, range_: Tuple[int, int]) -> Dict[str, Any]: + async def _upload_part(i: int, range_: tuple[int, int]) -> dict[str, Any]: result = await asyncio.to_thread( self._sync_fs._upload_part_copy, bucket=bucket2, @@ -280,10 +277,10 @@ async def _upload_part(i: int, range_: Tuple[int, int]) -> Dict[str, Any]: async def _find( self, path: str, - maxdepth: Optional[int] = None, + maxdepth: int | None = None, withdirs: bool = False, **kwargs, - ) -> Union[Dict[str, S3Object], List[str]]: + ) -> dict[str, S3Object] | list[str]: detail = kwargs.pop("detail", False) files = await asyncio.to_thread( self._sync_fs._find, path, maxdepth=maxdepth, withdirs=withdirs, **kwargs @@ -296,12 +293,12 @@ def _open( self, path: str, mode: str = "rb", - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + block_size: int | None = None, + cache_type: str | None = None, autocommit: bool = True, - cache_options: Optional[Dict[Any, Any]] = None, + cache_options: dict[Any, Any] | None = None, **kwargs, - ) -> "AioS3File": + ) -> AioS3File: if block_size is None: block_size = self._sync_fs.default_block_size if cache_type is None: @@ -331,13 +328,13 @@ def sign(self, path: str, expiration: int = 3600, **kwargs) -> str: def checksum(self, path: str, **kwargs) -> int: return cast(int, self._sync_fs.checksum(path, **kwargs)) - def created(self, path: str) -> "datetime": + def created(self, path: str) -> datetime: return self._sync_fs.created(path) - def modified(self, path: str) -> "datetime": + def modified(self, path: str) -> datetime: return self._sync_fs.modified(path) - def invalidate_cache(self, path: Optional[str] = None) -> None: + def invalidate_cache(self, path: str | None = None) -> None: self._sync_fs.invalidate_cache(path) async def _touch(self, path: str, truncate: bool = True, **kwargs) -> None: @@ -353,5 +350,3 @@ class AioS3File(S3File): through the ``S3Executor`` interface — the ``S3AioExecutor`` provided by ``AioS3FileSystem`` uses the event loop instead of threads. """ - - pass diff --git a/pyathena/filesystem/s3_executor.py b/pyathena/filesystem/s3_executor.py index 7c0bcee0..8e0717e8 100644 --- a/pyathena/filesystem/s3_executor.py +++ b/pyathena/filesystem/s3_executor.py @@ -1,11 +1,11 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import asyncio from abc import ABCMeta, abstractmethod +from collections.abc import Callable from concurrent.futures import Future from concurrent.futures.thread import ThreadPoolExecutor -from typing import Any, Callable, Optional, TypeVar +from typing import Any, TypeVar T = TypeVar("T") @@ -29,7 +29,7 @@ def shutdown(self, wait: bool = True) -> None: """Shut down the executor, freeing any resources.""" ... - def __enter__(self) -> "S3Executor": + def __enter__(self) -> S3Executor: return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -71,7 +71,7 @@ class S3AioExecutor(S3Executor): RuntimeError: If the event loop is not running when ``submit`` is called. """ - def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + def __init__(self, loop: asyncio.AbstractEventLoop | None = None) -> None: self._loop = loop def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]: diff --git a/pyathena/filesystem/s3_object.py b/pyathena/filesystem/s3_object.py index 24016d7d..bff60076 100644 --- a/pyathena/filesystem/s3_object.py +++ b/pyathena/filesystem/s3_object.py @@ -1,12 +1,12 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import copy import logging +from collections.abc import Iterator, MutableMapping from datetime import datetime -from typing import Any, Dict, Iterator, MutableMapping, Optional +from typing import Any -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) _API_FIELD_TO_S3_OBJECT_PROPERTY = { "ETag": "etag", @@ -104,7 +104,7 @@ class S3Object(MutableMapping[str, Any]): def __init__( self, - init: Dict[str, Any], + init: dict[str, Any], **kwargs, ) -> None: if init: @@ -162,7 +162,7 @@ def __len__(self) -> int: def __str__(self): return str(self.__dict__) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Convert S3Object to dictionary representation. Returns: @@ -170,7 +170,7 @@ def to_dict(self) -> Dict[str, Any]: """ return copy.deepcopy(self.__dict__) - def to_api_repr(self) -> Dict[str, Any]: + def to_api_repr(self) -> dict[str, Any]: fields = {} for k, v in _API_FIELD_TO_S3_OBJECT_PROPERTY.items(): if k in ["ETag", "ContentLength", "LastModified"]: @@ -201,14 +201,14 @@ class S3PutObject: typically not instantiated directly by users. """ - def __init__(self, response: Dict[str, Any]) -> None: - self._expiration: Optional[str] = response.get("Expiration") - self._version_id: Optional[str] = response.get("VersionId") - self._etag: Optional[str] = response.get("ETag") - self._checksum_crc32: Optional[str] = response.get("ChecksumCRC32") - self._checksum_crc32c: Optional[str] = response.get("ChecksumCRC32C") - self._checksum_sha1: Optional[str] = response.get("ChecksumSHA1") - self._checksum_sha256: Optional[str] = response.get("ChecksumSHA256") + def __init__(self, response: dict[str, Any]) -> None: + self._expiration: str | None = response.get("Expiration") + self._version_id: str | None = response.get("VersionId") + self._etag: str | None = response.get("ETag") + self._checksum_crc32: str | None = response.get("ChecksumCRC32") + self._checksum_crc32c: str | None = response.get("ChecksumCRC32C") + self._checksum_sha1: str | None = response.get("ChecksumSHA1") + self._checksum_sha256: str | None = response.get("ChecksumSHA256") self._server_side_encryption = response.get("ServerSideEncryption") self._sse_customer_algorithm = response.get("SSECustomerAlgorithm") self._sse_customer_key_md5 = response.get("SSECustomerKeyMD5") @@ -218,62 +218,62 @@ def __init__(self, response: Dict[str, Any]) -> None: self._request_charged = response.get("RequestCharged") @property - def expiration(self) -> Optional[str]: + def expiration(self) -> str | None: return self._expiration @property - def version_id(self) -> Optional[str]: + def version_id(self) -> str | None: return self._version_id @property - def etag(self) -> Optional[str]: + def etag(self) -> str | None: return self._etag @property - def checksum_crc32(self) -> Optional[str]: + def checksum_crc32(self) -> str | None: return self._checksum_crc32 @property - def checksum_crc32c(self) -> Optional[str]: + def checksum_crc32c(self) -> str | None: return self._checksum_crc32c @property - def checksum_sha1(self) -> Optional[str]: + def checksum_sha1(self) -> str | None: return self._checksum_sha1 @property - def checksum_sha256(self) -> Optional[str]: + def checksum_sha256(self) -> str | None: return self._checksum_sha256 @property - def server_side_encryption(self) -> Optional[str]: + def server_side_encryption(self) -> str | None: return self._server_side_encryption @property - def sse_customer_algorithm(self) -> Optional[str]: + def sse_customer_algorithm(self) -> str | None: return self._sse_customer_algorithm @property - def sse_customer_key_md5(self) -> Optional[str]: + def sse_customer_key_md5(self) -> str | None: return self._sse_customer_key_md5 @property - def sse_kms_key_id(self) -> Optional[str]: + def sse_kms_key_id(self) -> str | None: return self._sse_kms_key_id @property - def sse_kms_encryption_context(self) -> Optional[str]: + def sse_kms_encryption_context(self) -> str | None: return self._sse_kms_encryption_context @property - def bucket_key_enabled(self) -> Optional[bool]: + def bucket_key_enabled(self) -> bool | None: return self._bucket_key_enabled @property - def request_charged(self) -> Optional[str]: + def request_charged(self) -> str | None: return self._request_charged - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return copy.deepcopy(self.__dict__) @@ -295,7 +295,7 @@ class S3MultipartUpload: Used internally by S3FileSystem for large file upload operations. """ - def __init__(self, response: Dict[str, Any]) -> None: + def __init__(self, response: dict[str, Any]) -> None: self._abort_date = response.get("AbortDate") self._abort_rule_id = response.get("AbortRuleId") self._bucket = response.get("Bucket") @@ -311,55 +311,55 @@ def __init__(self, response: Dict[str, Any]) -> None: self._checksum_algorithm = response.get("ChecksumAlgorithm") @property - def abort_date(self) -> Optional[datetime]: + def abort_date(self) -> datetime | None: return self._abort_date @property - def abort_rule_id(self) -> Optional[str]: + def abort_rule_id(self) -> str | None: return self._abort_rule_id @property - def bucket(self) -> Optional[str]: + def bucket(self) -> str | None: return self._bucket @property - def key(self) -> Optional[str]: + def key(self) -> str | None: return self._key @property - def upload_id(self) -> Optional[str]: + def upload_id(self) -> str | None: return self._upload_id @property - def server_side_encryption(self) -> Optional[str]: + def server_side_encryption(self) -> str | None: return self._server_side_encryption @property - def sse_customer_algorithm(self) -> Optional[str]: + def sse_customer_algorithm(self) -> str | None: return self._sse_customer_algorithm @property - def sse_customer_key_md5(self) -> Optional[str]: + def sse_customer_key_md5(self) -> str | None: return self._sse_customer_key_md5 @property - def sse_kms_key_id(self) -> Optional[str]: + def sse_kms_key_id(self) -> str | None: return self._sse_kms_key_id @property - def sse_kms_encryption_context(self) -> Optional[str]: + def sse_kms_encryption_context(self) -> str | None: return self._sse_kms_encryption_context @property - def bucket_key_enabled(self) -> Optional[bool]: + def bucket_key_enabled(self) -> bool | None: return self._bucket_key_enabled @property - def request_charged(self) -> Optional[str]: + def request_charged(self) -> str | None: return self._request_charged @property - def checksum_algorithm(self) -> Optional[str]: + def checksum_algorithm(self) -> str | None: return self._checksum_algorithm @@ -381,17 +381,17 @@ class S3MultipartUploadPart: by S3FileSystem for chunked upload operations. """ - def __init__(self, part_number: int, response: Dict[str, Any]) -> None: + def __init__(self, part_number: int, response: dict[str, Any]) -> None: self._part_number = part_number - self._copy_source_version_id: Optional[str] = response.get("CopySourceVersionId") + self._copy_source_version_id: str | None = response.get("CopySourceVersionId") copy_part_result = response.get("CopyPartResult") if copy_part_result: - self._last_modified: Optional[datetime] = copy_part_result.get("LastModified") - self._etag: Optional[str] = copy_part_result.get("ETag") - self._checksum_crc32: Optional[str] = copy_part_result.get("ChecksumCRC32") - self._checksum_crc32c: Optional[str] = copy_part_result.get("ChecksumCRC32C") - self._checksum_sha1: Optional[str] = copy_part_result.get("ChecksumSHA1") - self._checksum_sha256: Optional[str] = copy_part_result.get("ChecksumSHA256") + self._last_modified: datetime | None = copy_part_result.get("LastModified") + self._etag: str | None = copy_part_result.get("ETag") + self._checksum_crc32: str | None = copy_part_result.get("ChecksumCRC32") + self._checksum_crc32c: str | None = copy_part_result.get("ChecksumCRC32C") + self._checksum_sha1: str | None = copy_part_result.get("ChecksumSHA1") + self._checksum_sha256: str | None = copy_part_result.get("ChecksumSHA256") else: self._last_modified = None self._etag = response.get("ETag") @@ -399,70 +399,70 @@ def __init__(self, part_number: int, response: Dict[str, Any]) -> None: self._checksum_crc32c = response.get("ChecksumCRC32C") self._checksum_sha1 = response.get("ChecksumSHA1") self._checksum_sha256 = response.get("ChecksumSHA256") - self._server_side_encryption: Optional[str] = response.get("ServerSideEncryption") - self._sse_customer_algorithm: Optional[str] = response.get("SSECustomerAlgorithm") - self._sse_customer_key_md5: Optional[str] = response.get("SSECustomerKeyMD5") - self._sse_kms_key_id: Optional[str] = response.get("SSEKMSKeyId") - self._bucket_key_enabled: Optional[bool] = response.get("BucketKeyEnabled") - self._request_charged: Optional[str] = response.get("RequestCharged") + self._server_side_encryption: str | None = response.get("ServerSideEncryption") + self._sse_customer_algorithm: str | None = response.get("SSECustomerAlgorithm") + self._sse_customer_key_md5: str | None = response.get("SSECustomerKeyMD5") + self._sse_kms_key_id: str | None = response.get("SSEKMSKeyId") + self._bucket_key_enabled: bool | None = response.get("BucketKeyEnabled") + self._request_charged: str | None = response.get("RequestCharged") @property def part_number(self) -> int: return self._part_number @property - def copy_source_version_id(self) -> Optional[str]: + def copy_source_version_id(self) -> str | None: return self._copy_source_version_id @property - def last_modified(self) -> Optional[datetime]: + def last_modified(self) -> datetime | None: return self._last_modified @property - def etag(self) -> Optional[str]: + def etag(self) -> str | None: return self._etag @property - def checksum_crc32(self) -> Optional[str]: + def checksum_crc32(self) -> str | None: return self._checksum_crc32 @property - def checksum_crc32c(self) -> Optional[str]: + def checksum_crc32c(self) -> str | None: return self._checksum_crc32c @property - def checksum_sha1(self) -> Optional[str]: + def checksum_sha1(self) -> str | None: return self._checksum_sha1 @property - def checksum_sha256(self) -> Optional[str]: + def checksum_sha256(self) -> str | None: return self._checksum_sha256 @property - def server_side_encryption(self) -> Optional[str]: + def server_side_encryption(self) -> str | None: return self._server_side_encryption @property - def sse_customer_algorithm(self) -> Optional[str]: + def sse_customer_algorithm(self) -> str | None: return self._sse_customer_algorithm @property - def sse_customer_key_md5(self) -> Optional[str]: + def sse_customer_key_md5(self) -> str | None: return self._sse_customer_key_md5 @property - def sse_kms_key_id(self) -> Optional[str]: + def sse_kms_key_id(self) -> str | None: return self._sse_kms_key_id @property - def bucket_key_enabled(self) -> Optional[bool]: + def bucket_key_enabled(self) -> bool | None: return self._bucket_key_enabled @property - def request_charged(self) -> Optional[str]: + def request_charged(self) -> str | None: return self._request_charged - def to_api_repr(self) -> Dict[str, Any]: + def to_api_repr(self) -> dict[str, Any]: return { "ETag": self.etag, "ChecksumCRC32": self.checksum_crc32, @@ -493,76 +493,76 @@ class S3CompleteMultipartUpload: Used internally by S3FileSystem operations. """ - def __init__(self, response: Dict[str, Any]) -> None: - self._location: Optional[str] = response.get("Location") - self._bucket: Optional[str] = response.get("Bucket") - self._key: Optional[str] = response.get("Key") - self._expiration: Optional[str] = response.get("Expiration") - self._version_id: Optional[str] = response.get("VersionId") - self._etag: Optional[str] = response.get("ETag") - self._checksum_crc32: Optional[str] = response.get("ChecksumCRC32") - self._checksum_crc32c: Optional[str] = response.get("ChecksumCRC32C") - self._checksum_sha1: Optional[str] = response.get("ChecksumSHA1") - self._checksum_sha256: Optional[str] = response.get("ChecksumSHA256") + def __init__(self, response: dict[str, Any]) -> None: + self._location: str | None = response.get("Location") + self._bucket: str | None = response.get("Bucket") + self._key: str | None = response.get("Key") + self._expiration: str | None = response.get("Expiration") + self._version_id: str | None = response.get("VersionId") + self._etag: str | None = response.get("ETag") + self._checksum_crc32: str | None = response.get("ChecksumCRC32") + self._checksum_crc32c: str | None = response.get("ChecksumCRC32C") + self._checksum_sha1: str | None = response.get("ChecksumSHA1") + self._checksum_sha256: str | None = response.get("ChecksumSHA256") self._server_side_encryption = response.get("ServerSideEncryption") self._sse_kms_key_id = response.get("SSEKMSKeyId") self._bucket_key_enabled = response.get("BucketKeyEnabled") self._request_charged = response.get("RequestCharged") @property - def location(self) -> Optional[str]: + def location(self) -> str | None: return self._location @property - def bucket(self) -> Optional[str]: + def bucket(self) -> str | None: return self._bucket @property - def key(self) -> Optional[str]: + def key(self) -> str | None: return self._key @property - def expiration(self) -> Optional[str]: + def expiration(self) -> str | None: return self._expiration @property - def version_id(self) -> Optional[str]: + def version_id(self) -> str | None: return self._version_id @property - def etag(self) -> Optional[str]: + def etag(self) -> str | None: return self._etag @property - def checksum_crc32(self) -> Optional[str]: + def checksum_crc32(self) -> str | None: return self._checksum_crc32 @property - def checksum_crc32c(self) -> Optional[str]: + def checksum_crc32c(self) -> str | None: return self._checksum_crc32c @property - def checksum_sha1(self) -> Optional[str]: + def checksum_sha1(self) -> str | None: return self._checksum_sha1 @property - def checksum_sha256(self) -> Optional[str]: + def checksum_sha256(self) -> str | None: return self._checksum_sha256 @property - def server_side_encryption(self) -> Optional[str]: + def server_side_encryption(self) -> str | None: return self._server_side_encryption @property - def sse_kms_key_id(self) -> Optional[str]: + def sse_kms_key_id(self) -> str | None: return self._sse_kms_key_id @property - def bucket_key_enabled(self) -> Optional[bool]: + def bucket_key_enabled(self) -> bool | None: return self._bucket_key_enabled @property - def request_charged(self) -> Optional[str]: + def request_charged(self) -> str | None: return self._request_charged def to_dict(self): diff --git a/pyathena/formatter.py b/pyathena/formatter.py index 4053aead..694b6a9a 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -1,19 +1,19 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import textwrap import uuid from abc import ABCMeta, abstractmethod +from collections.abc import Callable from copy import deepcopy from datetime import date, datetime, timezone from decimal import Decimal -from typing import Any, Callable, Dict, Optional, Tuple, Type +from typing import Any from pyathena.error import ProgrammingError from pyathena.model import AthenaCompression, AthenaFileFormat -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class Formatter(metaclass=ABCMeta): @@ -33,8 +33,8 @@ class Formatter(metaclass=ABCMeta): def __init__( self, - mappings: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]], - default: Optional[Callable[[Formatter, Callable[[str], str], Any], Any]] = None, + mappings: dict[type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]], + default: Callable[[Formatter, Callable[[str], str], Any], Any] | None = None, ) -> None: self._mappings = mappings self._default = default @@ -42,7 +42,7 @@ def __init__( @property def mappings( self, - ) -> Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]]: + ) -> dict[type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]]: """Get the current parameter formatting mappings. Returns: @@ -50,7 +50,7 @@ def mappings( """ return self._mappings - def get(self, type_) -> Optional[Callable[[Formatter, Callable[[str], str], Any], Any]]: + def get(self, type_) -> Callable[[Formatter, Callable[[str], str], Any], Any] | None: """Get the formatting function for a specific Python type. Args: @@ -63,21 +63,21 @@ def get(self, type_) -> Optional[Callable[[Formatter, Callable[[str], str], Any] def set( self, - type_: Type[Any], + type_: type[Any], formatter: Callable[[Formatter, Callable[[str], str], Any], Any], ) -> None: self.mappings[type_] = formatter - def remove(self, type_: Type[Any]) -> None: + def remove(self, type_: type[Any]) -> None: self.mappings.pop(type_, None) def update( - self, mappings: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]] + self, mappings: dict[type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]] ) -> None: self.mappings.update(mappings) @abstractmethod - def format(self, operation: str, parameters: Optional[Dict[str, Any]] = None) -> str: + def format(self, operation: str, parameters: dict[str, Any] | None = None) -> str: raise NotImplementedError # pragma: no cover @staticmethod @@ -86,7 +86,7 @@ def wrap_unload( s3_staging_dir: str, format_: str = AthenaFileFormat.FILE_FORMAT_PARQUET, compression: str = AthenaCompression.COMPRESSION_SNAPPY, - ) -> Tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Wrap a SELECT query with UNLOAD statement for high-performance result retrieval. Transforms SELECT or WITH queries into UNLOAD statements that export results @@ -129,9 +129,9 @@ def wrap_unload( raise ProgrammingError("Query is none or empty.") operation_upper = operation.strip().upper() - if operation_upper.startswith("SELECT") or operation_upper.startswith("WITH"): + if operation_upper.startswith(("SELECT", "WITH")): now = datetime.now(timezone.utc).strftime("%Y%m%d") - location = f"{s3_staging_dir}unload/{now}/{str(uuid.uuid4())}/" + location = f"{s3_staging_dir}unload/{now}/{uuid.uuid4()!s}/" operation = textwrap.dedent( f""" UNLOAD ( @@ -220,7 +220,7 @@ def _format_decimal(formatter: Formatter, escaper: Callable[[str], str], val: An return f"DECIMAL {escaped}" -_DEFAULT_FORMATTERS: Dict[Type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]] = { +_DEFAULT_FORMATTERS: dict[type[Any], Callable[[Formatter, Callable[[str], str], Any], Any]] = { type(None): _format_none, date: _format_date, datetime: _format_datetime, @@ -263,24 +263,18 @@ class DefaultParameterFormatter(Formatter): def __init__(self) -> None: super().__init__(mappings=deepcopy(_DEFAULT_FORMATTERS), default=None) - def format(self, operation: str, parameters: Optional[Dict[str, Any]] = None) -> str: + def format(self, operation: str, parameters: dict[str, Any] | None = None) -> str: if not operation or not operation.strip(): raise ProgrammingError("Query is none or empty.") operation = operation.strip() operation_upper = operation.upper() - if ( - operation_upper.startswith("SELECT") - or operation_upper.startswith("WITH") - or operation_upper.startswith("INSERT") - or operation_upper.startswith("UPDATE") - or operation_upper.startswith("MERGE") - ): + if operation_upper.startswith(("SELECT", "WITH", "INSERT", "UPDATE", "MERGE")): escaper = _escape_presto else: escaper = _escape_hive - kwargs: Optional[Dict[str, Any]] = None + kwargs: dict[str, Any] | None = None if parameters is not None: kwargs = {} if not parameters: diff --git a/pyathena/model.py b/pyathena/model.py index 9b90b068..ef5a1539 100644 --- a/pyathena/model.py +++ b/pyathena/model.py @@ -1,14 +1,14 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern +from re import Pattern +from typing import Any from pyathena.error import DataError -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AthenaQueryExecution: @@ -64,212 +64,212 @@ class AthenaQueryExecution: S3_ACL_OPTION_BUCKET_OWNER_FULL_CONTROL = "BUCKET_OWNER_FULL_CONTROL" - def __init__(self, response: Dict[str, Any]) -> None: + def __init__(self, response: dict[str, Any]) -> None: query_execution = response.get("QueryExecution") if not query_execution: raise DataError("KeyError `QueryExecution`") query_execution_context = query_execution.get("QueryExecutionContext", {}) - self._database: Optional[str] = query_execution_context.get("Database") - self._catalog: Optional[str] = query_execution_context.get("Catalog") + self._database: str | None = query_execution_context.get("Database") + self._catalog: str | None = query_execution_context.get("Catalog") - self._query_id: Optional[str] = query_execution.get("QueryExecutionId") + self._query_id: str | None = query_execution.get("QueryExecutionId") if not self._query_id: raise DataError("KeyError `QueryExecutionId`") - self._query: Optional[str] = query_execution.get("Query") + self._query: str | None = query_execution.get("Query") if not self._query: raise DataError("KeyError `Query`") - self._statement_type: Optional[str] = query_execution.get("StatementType") - self._substatement_type: Optional[str] = query_execution.get("SubstatementType") - self._work_group: Optional[str] = query_execution.get("WorkGroup") - self._execution_parameters: List[str] = query_execution.get("ExecutionParameters", []) + self._statement_type: str | None = query_execution.get("StatementType") + self._substatement_type: str | None = query_execution.get("SubstatementType") + self._work_group: str | None = query_execution.get("WorkGroup") + self._execution_parameters: list[str] = query_execution.get("ExecutionParameters", []) status = query_execution.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State") - self._state_change_reason: Optional[str] = status.get("StateChangeReason") - self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime") - self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime") + self._state: str | None = status.get("State") + self._state_change_reason: str | None = status.get("StateChangeReason") + self._submission_date_time: datetime | None = status.get("SubmissionDateTime") + self._completion_date_time: datetime | None = status.get("CompletionDateTime") athena_error = status.get("AthenaError", {}) - self._error_category: Optional[int] = athena_error.get("ErrorCategory") - self._error_type: Optional[int] = athena_error.get("ErrorType") - self._retryable: Optional[bool] = athena_error.get("Retryable") - self._error_message: Optional[str] = athena_error.get("ErrorMessage") + self._error_category: int | None = athena_error.get("ErrorCategory") + self._error_type: int | None = athena_error.get("ErrorType") + self._retryable: bool | None = athena_error.get("Retryable") + self._error_message: str | None = athena_error.get("ErrorMessage") statistics = query_execution.get("Statistics", {}) - self._data_scanned_in_bytes: Optional[int] = statistics.get("DataScannedInBytes") - self._engine_execution_time_in_millis: Optional[int] = statistics.get( + self._data_scanned_in_bytes: int | None = statistics.get("DataScannedInBytes") + self._engine_execution_time_in_millis: int | None = statistics.get( "EngineExecutionTimeInMillis", None ) - self._query_queue_time_in_millis: Optional[int] = statistics.get( + self._query_queue_time_in_millis: int | None = statistics.get( "QueryQueueTimeInMillis", None ) - self._total_execution_time_in_millis: Optional[int] = statistics.get( + self._total_execution_time_in_millis: int | None = statistics.get( "TotalExecutionTimeInMillis", None ) - self._query_planning_time_in_millis: Optional[int] = statistics.get( + self._query_planning_time_in_millis: int | None = statistics.get( "QueryPlanningTimeInMillis", None ) - self._service_processing_time_in_millis: Optional[int] = statistics.get( + self._service_processing_time_in_millis: int | None = statistics.get( "ServiceProcessingTimeInMillis", None ) - self._data_manifest_location: Optional[str] = statistics.get("DataManifestLocation") + self._data_manifest_location: str | None = statistics.get("DataManifestLocation") reuse_info = statistics.get("ResultReuseInformation", {}) - self._reused_previous_result: Optional[bool] = reuse_info.get("ReusedPreviousResult") + self._reused_previous_result: bool | None = reuse_info.get("ReusedPreviousResult") result_conf = query_execution.get("ResultConfiguration", {}) - self._output_location: Optional[str] = result_conf.get("OutputLocation") + self._output_location: str | None = result_conf.get("OutputLocation") encryption_conf = result_conf.get("EncryptionConfiguration", {}) - self._encryption_option: Optional[str] = encryption_conf.get("EncryptionOption") - self._kms_key: Optional[str] = encryption_conf.get("KmsKey") - self._expected_bucket_owner: Optional[str] = result_conf.get("ExpectedBucketOwner") + self._encryption_option: str | None = encryption_conf.get("EncryptionOption") + self._kms_key: str | None = encryption_conf.get("KmsKey") + self._expected_bucket_owner: str | None = result_conf.get("ExpectedBucketOwner") acl_conf = result_conf.get("AclConfiguration", {}) - self._s3_acl_option: Optional[str] = acl_conf.get("S3AclOption") + self._s3_acl_option: str | None = acl_conf.get("S3AclOption") engine_version = query_execution.get("EngineVersion", {}) - self._selected_engine_version: Optional[str] = engine_version.get( + self._selected_engine_version: str | None = engine_version.get( "SelectedEngineVersion", None ) - self._effective_engine_version: Optional[str] = engine_version.get( + self._effective_engine_version: str | None = engine_version.get( "EffectiveEngineVersion", None ) reuse_conf = query_execution.get("ResultReuseConfiguration", {}) reuse_age_conf = reuse_conf.get("ResultReuseByAgeConfiguration", {}) - self._result_reuse_enabled: Optional[bool] = reuse_age_conf.get("Enabled") - self._result_reuse_minutes: Optional[int] = reuse_age_conf.get("MaxAgeInMinutes") + self._result_reuse_enabled: bool | None = reuse_age_conf.get("Enabled") + self._result_reuse_minutes: int | None = reuse_age_conf.get("MaxAgeInMinutes") @property - def database(self) -> Optional[str]: + def database(self) -> str | None: return self._database @property - def catalog(self) -> Optional[str]: + def catalog(self) -> str | None: return self._catalog @property - def query_id(self) -> Optional[str]: + def query_id(self) -> str | None: return self._query_id @property - def query(self) -> Optional[str]: + def query(self) -> str | None: return self._query @property - def statement_type(self) -> Optional[str]: + def statement_type(self) -> str | None: return self._statement_type @property - def substatement_type(self) -> Optional[str]: + def substatement_type(self) -> str | None: return self._substatement_type @property - def work_group(self) -> Optional[str]: + def work_group(self) -> str | None: return self._work_group @property - def execution_parameters(self) -> List[str]: + def execution_parameters(self) -> list[str]: return self._execution_parameters @property - def state(self) -> Optional[str]: + def state(self) -> str | None: return self._state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: return self._state_change_reason @property - def submission_date_time(self) -> Optional[datetime]: + def submission_date_time(self) -> datetime | None: return self._submission_date_time @property - def completion_date_time(self) -> Optional[datetime]: + def completion_date_time(self) -> datetime | None: return self._completion_date_time @property - def error_category(self) -> Optional[int]: + def error_category(self) -> int | None: return self._error_category @property - def error_type(self) -> Optional[int]: + def error_type(self) -> int | None: return self._error_type @property - def retryable(self) -> Optional[bool]: + def retryable(self) -> bool | None: return self._retryable @property - def error_message(self) -> Optional[str]: + def error_message(self) -> str | None: return self._error_message @property - def data_scanned_in_bytes(self) -> Optional[int]: + def data_scanned_in_bytes(self) -> int | None: return self._data_scanned_in_bytes @property - def engine_execution_time_in_millis(self) -> Optional[int]: + def engine_execution_time_in_millis(self) -> int | None: return self._engine_execution_time_in_millis @property - def query_queue_time_in_millis(self) -> Optional[int]: + def query_queue_time_in_millis(self) -> int | None: return self._query_queue_time_in_millis @property - def total_execution_time_in_millis(self) -> Optional[int]: + def total_execution_time_in_millis(self) -> int | None: return self._total_execution_time_in_millis @property - def query_planning_time_in_millis(self) -> Optional[int]: + def query_planning_time_in_millis(self) -> int | None: return self._query_planning_time_in_millis @property - def service_processing_time_in_millis(self) -> Optional[int]: + def service_processing_time_in_millis(self) -> int | None: return self._service_processing_time_in_millis @property - def output_location(self) -> Optional[str]: + def output_location(self) -> str | None: return self._output_location @property - def data_manifest_location(self) -> Optional[str]: + def data_manifest_location(self) -> str | None: return self._data_manifest_location @property - def reused_previous_result(self) -> Optional[bool]: + def reused_previous_result(self) -> bool | None: return self._reused_previous_result @property - def encryption_option(self) -> Optional[str]: + def encryption_option(self) -> str | None: return self._encryption_option @property - def kms_key(self) -> Optional[str]: + def kms_key(self) -> str | None: return self._kms_key @property - def expected_bucket_owner(self) -> Optional[str]: + def expected_bucket_owner(self) -> str | None: return self._expected_bucket_owner @property - def s3_acl_option(self) -> Optional[str]: + def s3_acl_option(self) -> str | None: return self._s3_acl_option @property - def selected_engine_version(self) -> Optional[str]: + def selected_engine_version(self) -> str | None: return self._selected_engine_version @property - def effective_engine_version(self) -> Optional[str]: + def effective_engine_version(self) -> str | None: return self._effective_engine_version @property - def result_reuse_enabled(self) -> Optional[bool]: + def result_reuse_enabled(self) -> bool | None: return self._result_reuse_enabled @property - def result_reuse_minutes(self) -> Optional[int]: + def result_reuse_minutes(self) -> int | None: return self._result_reuse_minutes @@ -304,43 +304,43 @@ class AthenaCalculationExecutionStatus: STATE_COMPLETED: str = "COMPLETED" STATE_FAILED: str = "FAILED" - def __init__(self, response: Dict[str, Any]) -> None: + def __init__(self, response: dict[str, Any]) -> None: status = response.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State") - self._state_change_reason: Optional[str] = status.get("StateChangeReason") - self._submission_date_time: Optional[datetime] = status.get("SubmissionDateTime") - self._completion_date_time: Optional[datetime] = status.get("CompletionDateTime") + self._state: str | None = status.get("State") + self._state_change_reason: str | None = status.get("StateChangeReason") + self._submission_date_time: datetime | None = status.get("SubmissionDateTime") + self._completion_date_time: datetime | None = status.get("CompletionDateTime") statistics = response.get("Statistics") if not statistics: raise DataError("KeyError `Statistics`") - self._dpu_execution_in_millis: Optional[int] = statistics.get("DpuExecutionInMillis") - self._progress: Optional[str] = statistics.get("Progress") + self._dpu_execution_in_millis: int | None = statistics.get("DpuExecutionInMillis") + self._progress: str | None = statistics.get("Progress") @property - def state(self) -> Optional[str]: + def state(self) -> str | None: return self._state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: return self._state_change_reason @property - def submission_date_time(self) -> Optional[datetime]: + def submission_date_time(self) -> datetime | None: return self._submission_date_time @property - def completion_date_time(self) -> Optional[datetime]: + def completion_date_time(self) -> datetime | None: return self._completion_date_time @property - def dpu_execution_in_millis(self) -> Optional[int]: + def dpu_execution_in_millis(self) -> int | None: return self._dpu_execution_in_millis @property - def progress(self) -> Optional[str]: + def progress(self) -> str | None: return self._progress @@ -359,55 +359,55 @@ class AthenaCalculationExecution(AthenaCalculationExecutionStatus): https://docs.aws.amazon.com/athena/latest/APIReference/API_CalculationSummary.html """ - def __init__(self, response: Dict[str, Any]) -> None: + def __init__(self, response: dict[str, Any]) -> None: super().__init__(response) - self._calculation_id: Optional[str] = response.get("CalculationExecutionId") + self._calculation_id: str | None = response.get("CalculationExecutionId") if not self._calculation_id: raise DataError("KeyError `CalculationExecutionId`") - self._session_id: Optional[str] = response.get("SessionId") + self._session_id: str | None = response.get("SessionId") if not self._session_id: raise DataError("KeyError `SessionId`") - self._description: Optional[str] = response.get("Description") - self._working_directory: Optional[str] = response.get("WorkingDirectory") + self._description: str | None = response.get("Description") + self._working_directory: str | None = response.get("WorkingDirectory") # If cancelled, the result does not exist. result = response.get("Result", {}) - self._std_out_s3_uri: Optional[str] = result.get("StdOutS3Uri") - self._std_error_s3_uri: Optional[str] = result.get("StdErrorS3Uri") - self._result_s3_uri: Optional[str] = result.get("ResultS3Uri") - self._result_type: Optional[str] = result.get("ResultType") + self._std_out_s3_uri: str | None = result.get("StdOutS3Uri") + self._std_error_s3_uri: str | None = result.get("StdErrorS3Uri") + self._result_s3_uri: str | None = result.get("ResultS3Uri") + self._result_type: str | None = result.get("ResultType") @property - def calculation_id(self) -> Optional[str]: + def calculation_id(self) -> str | None: return self._calculation_id @property - def session_id(self) -> Optional[str]: + def session_id(self) -> str | None: return self._session_id @property - def description(self) -> Optional[str]: + def description(self) -> str | None: return self._description @property - def working_directory(self) -> Optional[str]: + def working_directory(self) -> str | None: return self._working_directory @property - def std_out_s3_uri(self) -> Optional[str]: + def std_out_s3_uri(self) -> str | None: return self._std_out_s3_uri @property - def std_error_s3_uri(self) -> Optional[str]: + def std_error_s3_uri(self) -> str | None: return self._std_error_s3_uri @property - def result_s3_uri(self) -> Optional[str]: + def result_s3_uri(self) -> str | None: return self._result_s3_uri @property - def result_type(self) -> Optional[str]: + def result_type(self) -> str | None: return self._result_type @@ -442,45 +442,45 @@ class AthenaSessionStatus: STATE_DEGRADED: str = "DEGRADED" STATE_FAILED: str = "FAILED" - def __init__(self, response: Dict[str, Any]) -> None: - self._session_id: Optional[str] = response.get("SessionId") + def __init__(self, response: dict[str, Any]) -> None: + self._session_id: str | None = response.get("SessionId") status = response.get("Status") if not status: raise DataError("KeyError `Status`") - self._state: Optional[str] = status.get("State") - self._state_change_reason: Optional[str] = status.get("StateChangeReason") - self._start_date_time: Optional[datetime] = status.get("StartDateTime") - self._last_modified_date_time: Optional[datetime] = status.get("LastModifiedDateTime") - self._end_date_time: Optional[datetime] = status.get("EndDateTime") - self._idle_since_date_time: Optional[datetime] = status.get("IdleSinceDateTime") + self._state: str | None = status.get("State") + self._state_change_reason: str | None = status.get("StateChangeReason") + self._start_date_time: datetime | None = status.get("StartDateTime") + self._last_modified_date_time: datetime | None = status.get("LastModifiedDateTime") + self._end_date_time: datetime | None = status.get("EndDateTime") + self._idle_since_date_time: datetime | None = status.get("IdleSinceDateTime") @property - def session_id(self) -> Optional[str]: + def session_id(self) -> str | None: return self._session_id @property - def state(self) -> Optional[str]: + def state(self) -> str | None: return self._state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: return self._state_change_reason @property - def start_date_time(self) -> Optional[datetime]: + def start_date_time(self) -> datetime | None: return self._start_date_time @property - def last_modified_date_time(self) -> Optional[datetime]: + def last_modified_date_time(self) -> datetime | None: return self._last_modified_date_time @property - def end_date_time(self) -> Optional[datetime]: + def end_date_time(self) -> datetime | None: return self._end_date_time @property - def idle_since_date_time(self) -> Optional[datetime]: + def idle_since_date_time(self) -> datetime | None: return self._idle_since_date_time @@ -501,20 +501,20 @@ def __init__(self, response): if not database: raise DataError("KeyError `Database`") - self._name: Optional[str] = database.get("Name") - self._description: Optional[str] = database.get("Description") - self._parameters: Dict[str, str] = database.get("Parameters", {}) + self._name: str | None = database.get("Name") + self._description: str | None = database.get("Description") + self._parameters: dict[str, str] = database.get("Parameters", {}) @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self._name @property - def description(self) -> Optional[str]: + def description(self) -> str | None: return self._description @property - def parameters(self) -> Dict[str, str]: + def parameters(self) -> dict[str, str]: return self._parameters @@ -530,20 +530,20 @@ class AthenaTableMetadataColumn: """ def __init__(self, response): - self._name: Optional[str] = response.get("Name") - self._type: Optional[str] = response.get("Type") - self._comment: Optional[str] = response.get("Comment") + self._name: str | None = response.get("Name") + self._type: str | None = response.get("Type") + self._comment: str | None = response.get("Comment") @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self._name @property - def type(self) -> Optional[str]: + def type(self) -> str | None: return self._type @property - def comment(self) -> Optional[str]: + def comment(self) -> str | None: return self._comment @@ -560,20 +560,20 @@ class AthenaTableMetadataPartitionKey: """ def __init__(self, response): - self._name: Optional[str] = response.get("Name") - self._type: Optional[str] = response.get("Type") - self._comment: Optional[str] = response.get("Comment") + self._name: str | None = response.get("Name") + self._type: str | None = response.get("Type") + self._comment: str | None = response.get("Comment") @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self._name @property - def type(self) -> Optional[str]: + def type(self) -> str | None: return self._type @property - def comment(self) -> Optional[str]: + def comment(self) -> str | None: return self._comment @@ -597,76 +597,76 @@ def __init__(self, response): if not table_metadata: raise DataError("KeyError `TableMetadata`") - self._name: Optional[str] = table_metadata.get("Name") - self._create_time: Optional[datetime] = table_metadata.get("CreateTime") - self._last_access_time: Optional[datetime] = table_metadata.get("LastAccessTime") - self._table_type: Optional[str] = table_metadata.get("TableType") + self._name: str | None = table_metadata.get("Name") + self._create_time: datetime | None = table_metadata.get("CreateTime") + self._last_access_time: datetime | None = table_metadata.get("LastAccessTime") + self._table_type: str | None = table_metadata.get("TableType") columns = table_metadata.get("Columns", []) - self._columns: List[AthenaTableMetadataColumn] = [] + self._columns: list[AthenaTableMetadataColumn] = [] for column in columns: self._columns.append(AthenaTableMetadataColumn(column)) partition_keys = table_metadata.get("PartitionKeys", []) - self._partition_keys: List[AthenaTableMetadataPartitionKey] = [] + self._partition_keys: list[AthenaTableMetadataPartitionKey] = [] for key in partition_keys: self._partition_keys.append(AthenaTableMetadataPartitionKey(key)) - self._parameters: Dict[str, str] = table_metadata.get("Parameters", {}) + self._parameters: dict[str, str] = table_metadata.get("Parameters", {}) @property - def name(self) -> Optional[str]: + def name(self) -> str | None: return self._name @property - def create_time(self) -> Optional[datetime]: + def create_time(self) -> datetime | None: return self._create_time @property - def last_access_time(self) -> Optional[datetime]: + def last_access_time(self) -> datetime | None: return self._last_access_time @property - def table_type(self) -> Optional[str]: + def table_type(self) -> str | None: return self._table_type @property - def columns(self) -> List[AthenaTableMetadataColumn]: + def columns(self) -> list[AthenaTableMetadataColumn]: return self._columns @property - def partition_keys(self) -> List[AthenaTableMetadataPartitionKey]: + def partition_keys(self) -> list[AthenaTableMetadataPartitionKey]: return self._partition_keys @property - def parameters(self) -> Dict[str, str]: + def parameters(self) -> dict[str, str]: return self._parameters @property - def comment(self) -> Optional[str]: + def comment(self) -> str | None: return self._parameters.get("comment") @property - def location(self) -> Optional[str]: + def location(self) -> str | None: return self._parameters.get("location") @property - def input_format(self) -> Optional[str]: + def input_format(self) -> str | None: return self._parameters.get("inputformat") @property - def output_format(self) -> Optional[str]: + def output_format(self) -> str | None: return self._parameters.get("outputformat") @property - def row_format(self) -> Optional[str]: + def row_format(self) -> str | None: serde = self.serde_serialization_lib if serde: return f"SERDE '{serde}'" return None @property - def file_format(self) -> Optional[str]: + def file_format(self) -> str | None: input = self.input_format output = self.output_format if input and output: @@ -674,11 +674,11 @@ def file_format(self) -> Optional[str]: return None @property - def serde_serialization_lib(self) -> Optional[str]: + def serde_serialization_lib(self) -> str | None: return self._parameters.get("serde.serialization.lib") @property - def compression(self) -> Optional[str]: + def compression(self) -> str | None: if "write.compression" in self._parameters: # text or json return self._parameters["write.compression"] if "serde.param.write.compression" in self._parameters: # text or json @@ -690,7 +690,7 @@ def compression(self) -> Optional[str]: return None @property - def serde_properties(self) -> Dict[str, str]: + def serde_properties(self) -> dict[str, str]: return { k.replace("serde.param.", ""): v for k, v in self._parameters.items() @@ -698,7 +698,7 @@ def serde_properties(self) -> Dict[str, str]: } @property - def table_properties(self) -> Dict[str, str]: + def table_properties(self) -> dict[str, str]: return {k: v for k, v in self._parameters.items() if not k.startswith("serde.param.")} diff --git a/pyathena/pandas/__init__.py b/pyathena/pandas/__init__.py index 20a12efb..d2a6d61c 100644 --- a/pyathena/pandas/__init__.py +++ b/pyathena/pandas/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import fsspec fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True) diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 5df82c1c..48d3d217 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -1,10 +1,10 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Iterable from concurrent.futures import Future from multiprocessing import cpu_count -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena import ProgrammingError from pyathena.async_cursor import AsyncCursor @@ -16,7 +16,7 @@ ) from pyathena.pandas.result_set import AthenaPandasResultSet -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AsyncPandasCursor(AsyncCursor): @@ -61,19 +61,19 @@ class AsyncPandasCursor(AsyncCursor): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, max_workers: int = (cpu_count() or 1) * 5, arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, unload: bool = False, engine: str = "auto", - chunksize: Optional[int] = None, + chunksize: int | None = None, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, **kwargs, @@ -100,7 +100,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPandasTypeConverter, Any]: + ) -> DefaultPandasTypeConverter | Any: if unload: return DefaultPandasUnloadTypeConverter() return DefaultPandasTypeConverter() @@ -119,10 +119,10 @@ def _collect_result_set( self, query_id: str, keep_default_na: bool = False, - na_values: Optional[Iterable[str]] = ("",), + na_values: Iterable[str] | None = ("",), quoting: int = 1, - unload_location: Optional[str] = None, - kwargs: Optional[Dict[str, Any]] = None, + unload_location: str | None = None, + kwargs: dict[str, Any] | None = None, ) -> AthenaPandasResultSet: if kwargs is None: kwargs = {} @@ -146,19 +146,19 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, keep_default_na: bool = False, - na_values: Optional[Iterable[str]] = ("",), + na_values: Iterable[str] | None = ("",), quoting: int = 1, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaPandasResultSet, Any]]"]: + ) -> tuple[str, Future[AthenaPandasResultSet | Any]]: operation, unload_location = self._prepare_unload(operation, s3_staging_dir) query_id = self._execute( operation, diff --git a/pyathena/pandas/converter.py b/pyathena/pandas/converter.py index cf438ace..576d2f53 100644 --- a/pyathena/pandas/converter.py +++ b/pyathena/pandas/converter.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Dict, Optional, Type +from typing import Any from pyathena.converter import ( Converter, @@ -14,10 +14,10 @@ _to_json, ) -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) -_DEFAULT_PANDAS_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { +_DEFAULT_PANDAS_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "boolean": _to_boolean, "decimal": _to_decimal, "varbinary": _to_binary, @@ -59,7 +59,7 @@ def __init__(self) -> None: ) @property - def _dtypes(self) -> Dict[str, Type[Any]]: + def _dtypes(self) -> dict[str, type[Any]]: if not hasattr(self, "__dtypes"): import pandas as pd @@ -80,7 +80,7 @@ def _dtypes(self) -> Dict[str, Type[Any]]: } return self.__dtypes - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: pass @@ -103,5 +103,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: pass diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index b87ecd50..39b1cf32 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -1,18 +1,11 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable, Generator, Iterable from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Union, cast, ) @@ -29,7 +22,7 @@ if TYPE_CHECKING: from pandas import DataFrame -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class PandasCursor(WithFetch): @@ -69,24 +62,24 @@ class PandasCursor(WithFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, engine: str = "auto", - chunksize: Optional[int] = None, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + chunksize: int | None = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, auto_optimize_chunksize: bool = False, - on_start_query_execution: Optional[Callable[[str], None]] = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> None: """Initialize PandasCursor with configuration options. @@ -140,7 +133,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPandasTypeConverter, Any]: + ) -> DefaultPandasTypeConverter | Any: if unload: return DefaultPandasUnloadTypeConverter() return DefaultPandasTypeConverter() @@ -148,18 +141,18 @@ def get_default_converter( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, keep_default_na: bool = False, - na_values: Optional[Iterable[str]] = ("",), + na_values: Iterable[str] | None = ("",), quoting: int = 1, - on_start_query_execution: Optional[Callable[[str], None]] = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> PandasCursor: """Execute a SQL query and return results as pandas DataFrames. @@ -238,7 +231,7 @@ def execute( return self - def as_pandas(self) -> Union["DataFrame", PandasDataFrameIterator]: + def as_pandas(self) -> DataFrame | PandasDataFrameIterator: """Return DataFrame or PandasDataFrameIterator based on chunksize setting. Returns: @@ -249,7 +242,7 @@ def as_pandas(self) -> Union["DataFrame", PandasDataFrameIterator]: result_set = cast(AthenaPandasResultSet, self.result_set) return result_set.as_pandas() - def iter_chunks(self) -> Generator["DataFrame", None, None]: + def iter_chunks(self) -> Generator[DataFrame, None, None]: """Iterate over DataFrame chunks for memory-efficient processing. This method provides an iterator interface for processing large result sets diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index d8e3cdf7..1486b45d 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -1,21 +1,13 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from collections import abc +from collections.abc import Callable, Iterable, Iterator from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Tuple, - Type, - Union, + ClassVar, ) from pyathena import OperationalError @@ -31,14 +23,14 @@ from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) -def _no_trunc_date(df: "DataFrame") -> "DataFrame": +def _no_trunc_date(df: DataFrame) -> DataFrame: return df -class PandasDataFrameIterator(abc.Iterator): # type: ignore +class PandasDataFrameIterator(abc.Iterator): # type: ignore[type-arg] """Iterator for chunked DataFrame results from Athena queries. This class wraps either a pandas TextFileReader (for chunked reading) or @@ -65,8 +57,8 @@ class PandasDataFrameIterator(abc.Iterator): # type: ignore def __init__( self, - reader: Union["TextFileReader", "DataFrame"], - trunc_date: Callable[["DataFrame"], "DataFrame"], + reader: TextFileReader | DataFrame, + trunc_date: Callable[[DataFrame], DataFrame], ) -> None: """Initialize the iterator. @@ -82,7 +74,7 @@ def __init__( self._reader = reader self._trunc_date = trunc_date - def __next__(self) -> "DataFrame": + def __next__(self) -> DataFrame: """Get the next DataFrame chunk. Returns: @@ -98,11 +90,11 @@ def __next__(self) -> "DataFrame": self.close() raise - def __iter__(self) -> "PandasDataFrameIterator": + def __iter__(self) -> PandasDataFrameIterator: """Return self as iterator.""" return self - def __enter__(self) -> "PandasDataFrameIterator": + def __enter__(self) -> PandasDataFrameIterator: """Context manager entry.""" return self @@ -117,7 +109,7 @@ def close(self) -> None: if isinstance(self._reader, TextFileReader): self._reader.close() - def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]: + def iterrows(self) -> Iterator[tuple[int, dict[str, Any]]]: """Iterate over rows as (index, row_dict) tuples. Row indices are continuous across all chunks, starting from 0. @@ -134,7 +126,7 @@ def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]: yield (row_num, dict(zip(columns, row, strict=True))) row_num += 1 - def get_chunk(self, size: Optional[int] = None) -> "DataFrame": + def get_chunk(self, size: int | None = None) -> DataFrame: """Get a chunk of specified size. Args: @@ -149,7 +141,7 @@ def get_chunk(self, size: Optional[int] = None) -> "DataFrame": return self._reader.get_chunk(size) return next(self._reader) - def as_pandas(self) -> "DataFrame": + def as_pandas(self) -> DataFrame: """Collect all chunks into a single DataFrame. Returns: @@ -157,7 +149,7 @@ def as_pandas(self) -> "DataFrame": """ import pandas as pd - dfs: List["DataFrame"] = list(self) + dfs: list[DataFrame] = list(self) if not dfs: return pd.DataFrame() if len(dfs) == 1: @@ -211,7 +203,7 @@ class AthenaPandasResultSet(AthenaResultSet): AUTO_CHUNK_SIZE_LARGE: int = 100_000 AUTO_CHUNK_SIZE_MEDIUM: int = 50_000 - _PARSE_DATES: List[str] = [ + _PARSE_DATES: ClassVar[list[str]] = [ "date", "time", "time with time zone", @@ -221,20 +213,20 @@ class AthenaPandasResultSet(AthenaResultSet): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, keep_default_na: bool = False, - na_values: Optional[Iterable[str]] = ("",), + na_values: Iterable[str] | None = ("",), quoting: int = 1, unload: bool = False, - unload_location: Optional[str] = None, + unload_location: str | None = None, engine: str = "auto", - chunksize: Optional[int] = None, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + chunksize: int | None = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, auto_optimize_chunksize: bool = False, **kwargs, @@ -282,13 +274,13 @@ def __init__( self._cache_type = cache_type self._max_workers = max_workers self._auto_optimize_chunksize = auto_optimize_chunksize - self._data_manifest: List[str] = [] + self._data_manifest: list[str] = [] self._kwargs = kwargs self._fs = self.__s3_file_system() # Cache time column names for efficient _trunc_date processing description = self.description if self.description else [] - self._time_columns: List[str] = [ + self._time_columns: list[str] = [ d[0] for d in description if d[1] in ("time", "time with time zone") ] @@ -319,7 +311,7 @@ def _get_parquet_engine(self) -> str: return self._engine def _get_csv_engine( - self, file_size_bytes: Optional[int] = None, chunksize: Optional[int] = None + self, file_size_bytes: int | None = None, chunksize: int | None = None ) -> str: """Determine the appropriate CSV engine based on configuration and compatibility. @@ -341,7 +333,7 @@ def _get_csv_engine( # Auto-selection for "auto" or unknown engine values return self._get_optimal_csv_engine(file_size_bytes) - def _get_pyarrow_engine(self, file_size_bytes: Optional[int], chunksize: Optional[int]) -> str: + def _get_pyarrow_engine(self, file_size_bytes: int | None, chunksize: int | None) -> str: """Get PyArrow engine if compatible, otherwise return optimal engine.""" # Check parameter compatibility if chunksize is not None or self._quoting != 1 or self.converters: @@ -357,7 +349,7 @@ def _get_pyarrow_engine(self, file_size_bytes: Optional[int], chunksize: Optiona except ImportError: return self._get_optimal_csv_engine(file_size_bytes) - def _get_available_engine(self, engine_candidates: List[str]) -> str: + def _get_available_engine(self, engine_candidates: list[str]) -> str: """Get the first available engine from a list of candidates. Args: @@ -376,8 +368,8 @@ def _get_available_engine(self, engine_candidates: List[str]) -> str: try: module = importlib.import_module(engine) return module.__name__ - except ImportError as e: - error_msgs += f"\n - {str(e)}" + except ImportError as e: # noqa: PERF203 + error_msgs += f"\n - {e!s}" available_engines = ", ".join(f"'{e}'" for e in engine_candidates) raise ImportError( @@ -386,7 +378,7 @@ def _get_available_engine(self, engine_candidates: List[str]) -> str: f"{error_msgs}" ) - def _get_optimal_csv_engine(self, file_size_bytes: Optional[int] = None) -> str: + def _get_optimal_csv_engine(self, file_size_bytes: int | None = None) -> str: """Get the optimal CSV engine based on file size. Args: @@ -399,7 +391,7 @@ def _get_optimal_csv_engine(self, file_size_bytes: Optional[int] = None) -> str: return "python" return "c" - def _auto_determine_chunksize(self, file_size_bytes: int) -> Optional[int]: + def _auto_determine_chunksize(self, file_size_bytes: int) -> int | None: """Determine appropriate chunksize for large files based on file size. This method provides a simple file-size-based chunksize determination. @@ -435,7 +427,7 @@ def __s3_file_system(self): ) @property - def dtypes(self) -> Dict[str, Type[Any]]: + def dtypes(self) -> dict[str, type[Any]]: """Get pandas-compatible data types for result columns. Returns: @@ -452,18 +444,18 @@ def dtypes(self) -> Dict[str, Type[Any]]: @property def converters( self, - ) -> Dict[Optional[Any], Callable[[Optional[str]], Optional[Any]]]: + ) -> dict[Any | None, Callable[[str | None], Any | None]]: description = self.description if self.description else [] return { d[0]: self._converter.get(d[1]) for d in description if d[1] in self._converter.mappings } @property - def parse_dates(self) -> List[Optional[Any]]: + def parse_dates(self) -> list[Any | None]: description = self.description if self.description else [] return [d[0] for d in description if d[1] in self._PARSE_DATES] - def _trunc_date(self, df: "DataFrame") -> "DataFrame": + def _trunc_date(self, df: DataFrame) -> DataFrame: if self._time_columns: truncated = df.loc[:, self._time_columns].apply(lambda r: r.dt.time) for time_col in self._time_columns: @@ -472,7 +464,7 @@ def _trunc_date(self, df: "DataFrame") -> "DataFrame": def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: try: row = next(self._iterrows) except StopIteration: @@ -483,8 +475,8 @@ def fetchone( return tuple([row[1][d[0]] for d in description]) def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: if not size or size <= 0: size = self._arraysize rows = [] @@ -498,7 +490,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: rows = [] while True: row = self.fetchone() @@ -508,7 +500,7 @@ def fetchall( break return rows - def _read_csv(self) -> Union["TextFileReader", "DataFrame"]: + def _read_csv(self) -> TextFileReader | DataFrame: import pandas as pd if not self.output_location: @@ -546,8 +538,9 @@ def _read_csv(self) -> Union["TextFileReader", "DataFrame"]: effective_chunksize = self._auto_determine_chunksize(length) if effective_chunksize: _logger.debug( - f"Auto-determined chunksize: {effective_chunksize} " - f"for file size: {length} bytes" + "Auto-determined chunksize: %s for file size: %s bytes", + effective_chunksize, + length, ) csv_engine = self._get_csv_engine(length, effective_chunksize) @@ -586,15 +579,17 @@ def _read_csv(self) -> Union["TextFileReader", "DataFrame"]: # Log performance information for large files if length > self.LARGE_FILE_THRESHOLD_BYTES: mode = "chunked" if effective_chunksize else "full" - _logger.info( - f"Reading {length} bytes from S3 in {mode} mode using {csv_engine} engine" - + (f" with chunksize={effective_chunksize}" if effective_chunksize else "") - ) + msg = "Reading %s bytes from S3 in %s mode using %s engine" + args: tuple[object, ...] = (length, mode, csv_engine) + if effective_chunksize: + msg += " with chunksize=%s" + args = (*args, effective_chunksize) + _logger.info(msg, *args) return result except Exception as e: - _logger.exception(f"Failed to read {self.output_location}.") + _logger.exception("Failed to read %s.", self.output_location) error_msg = str(e).lower() if any( phrase in error_msg @@ -617,7 +612,7 @@ def _read_csv(self) -> Union["TextFileReader", "DataFrame"]: raise OperationalError(detailed_msg) from e raise OperationalError(*e.args) from e - def _read_parquet(self, engine) -> "DataFrame": + def _read_parquet(self, engine) -> DataFrame: import pandas as pd self._data_manifest = self._read_data_manifest() @@ -648,10 +643,10 @@ def _read_parquet(self, engine) -> "DataFrame": **kwargs, ) except Exception as e: - _logger.exception(f"Failed to read {self.output_location}.") + _logger.exception("Failed to read %s.", self.output_location) raise OperationalError(*e.args) from e - def _read_parquet_schema(self, engine) -> Tuple[Dict[str, Any], ...]: + def _read_parquet_schema(self, engine) -> tuple[dict[str, Any], ...]: if engine == "pyarrow": from pyarrow import parquet @@ -664,12 +659,12 @@ def _read_parquet_schema(self, engine) -> Tuple[Dict[str, Any], ...]: dataset = parquet.ParquetDataset(f"{bucket}/{key}", filesystem=self._fs) return to_column_info(dataset.schema) except Exception as e: - _logger.exception(f"Failed to read schema {bucket}/{key}.") + _logger.exception("Failed to read schema %s/%s.", bucket, key) raise OperationalError(*e.args) from e else: raise ProgrammingError("Engine must be `pyarrow`.") - def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]: + def _as_pandas(self) -> TextFileReader | DataFrame: if self.is_unload: engine = self._get_parquet_engine() df = self._read_parquet(engine) @@ -681,7 +676,7 @@ def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]: df = self._read_csv() return df - def _as_pandas_from_api(self, converter: Optional[Converter] = None) -> "DataFrame": + def _as_pandas_from_api(self, converter: Converter | None = None) -> DataFrame: """Build a DataFrame from GetQueryResults API. Used as a fallback when ``output_location`` is not available @@ -700,7 +695,7 @@ def _as_pandas_from_api(self, converter: Optional[Converter] = None) -> "DataFra columns = [d[0] for d in description] return pd.DataFrame(self._rows_to_columnar(rows, columns)) - def as_pandas(self) -> Union[PandasDataFrameIterator, "DataFrame"]: + def as_pandas(self) -> PandasDataFrameIterator | DataFrame: if self._chunksize is None: return next(self._df_iter) return self._df_iter diff --git a/pyathena/pandas/util.py b/pyathena/pandas/util.py index 3eb756c5..25e3657a 100644 --- a/pyathena/pandas/util.py +++ b/pyathena/pandas/util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import concurrent @@ -6,6 +5,7 @@ import textwrap import uuid from collections import OrderedDict +from collections.abc import Callable, Iterator from concurrent.futures.process import ProcessPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor from copy import deepcopy @@ -13,13 +13,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterator, - List, - Optional, - Type, - Union, ) from boto3 import Session @@ -34,10 +27,10 @@ from pyathena.connection import Connection from pyathena.cursor import Cursor -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) -def get_chunks(df: "DataFrame", chunksize: Optional[int] = None) -> Iterator["DataFrame"]: +def get_chunks(df: DataFrame, chunksize: int | None = None) -> Iterator[DataFrame]: """Split a DataFrame into chunks of specified size. Args: @@ -67,7 +60,7 @@ def get_chunks(df: "DataFrame", chunksize: Optional[int] = None) -> Iterator["Da yield df[start_i:end_i] -def reset_index(df: "DataFrame", index_label: Optional[str] = None) -> None: +def reset_index(df: DataFrame, index_label: str | None = None) -> None: """Reset the DataFrame index and add it as a column. Args: @@ -84,7 +77,7 @@ def reset_index(df: "DataFrame", index_label: Optional[str] = None) -> None: raise ValueError("Duplicate name in index/columns") from e -def as_pandas(cursor: "Cursor", coerce_float: bool = False) -> "DataFrame": +def as_pandas(cursor: Cursor, coerce_float: bool = False) -> DataFrame: """Convert cursor results to a pandas DataFrame. Fetches all remaining rows from the cursor and converts them to a @@ -107,7 +100,7 @@ def as_pandas(cursor: "Cursor", coerce_float: bool = False) -> "DataFrame": return DataFrame.from_records(cursor.fetchall(), columns=names, coerce_float=coerce_float) -def to_sql_type_mappings(col: "Series") -> str: +def to_sql_type_mappings(col: Series) -> str: """Map a pandas Series data type to an Athena SQL type. Infers the appropriate Athena SQL type based on the pandas Series dtype. @@ -151,13 +144,13 @@ def to_sql_type_mappings(col: "Series") -> str: def to_parquet( - df: "DataFrame", + df: DataFrame, bucket_name: str, prefix: str, retry_config: RetryConfig, - session_kwargs: Dict[str, Any], - client_kwargs: Dict[str, Any], - compression: Optional[str] = None, + session_kwargs: dict[str, Any], + client_kwargs: dict[str, Any], + compression: str | None = None, flavor: str = "spark", ) -> str: """Write a DataFrame to S3 as a Parquet file. @@ -197,20 +190,20 @@ def to_parquet( def to_sql( - df: "DataFrame", + df: DataFrame, name: str, - conn: "Connection[Any]", + conn: Connection[Any], location: str, schema: str = "default", index: bool = False, - index_label: Optional[str] = None, - partitions: Optional[List[str]] = None, - chunksize: Optional[int] = None, + index_label: str | None = None, + partitions: list[str] | None = None, + chunksize: int | None = None, if_exists: str = "fail", - compression: Optional[str] = None, + compression: str | None = None, flavor: str = "spark", - type_mappings: Callable[["Series"], str] = to_sql_type_mappings, - executor_class: Type[Union[ThreadPoolExecutor, ProcessPoolExecutor]] = ThreadPoolExecutor, + type_mappings: Callable[[Series], str] = to_sql_type_mappings, + executor_class: type[ThreadPoolExecutor | ProcessPoolExecutor] = ThreadPoolExecutor, max_workers: int = (cpu_count() or 1) * 5, repair_table=True, ) -> None: @@ -296,7 +289,7 @@ def to_sql( if index: reset_index(df, index_label) with executor_class(max_workers=max_workers) as e: - futures = [] + futures: list[concurrent.futures.Future[Any]] = [] session_kwargs = deepcopy(conn._session_kwargs) session_kwargs.update({"profile_name": conn.profile_name}) client_kwargs = deepcopy(conn._client_kwargs) @@ -318,38 +311,38 @@ def to_sql( f"{location}{partition_prefix}/", ) ) - for chunk in get_chunks(group, chunksize): - futures.append( - e.submit( - to_parquet, - chunk, - bucket_name, - f"{key_prefix}{partition_prefix}/", - conn._retry_config, - session_kwargs, - client_kwargs, - compression, - flavor, - ) - ) - else: - for chunk in get_chunks(df, chunksize): - futures.append( + futures.extend( e.submit( to_parquet, chunk, bucket_name, - key_prefix, + f"{key_prefix}{partition_prefix}/", conn._retry_config, session_kwargs, client_kwargs, compression, flavor, ) + for chunk in get_chunks(group, chunksize) ) + else: + futures.extend( + e.submit( + to_parquet, + chunk, + bucket_name, + key_prefix, + conn._retry_config, + session_kwargs, + client_kwargs, + compression, + flavor, + ) + for chunk in get_chunks(df, chunksize) + ) for future in concurrent.futures.as_completed(futures): result = future.result() - _logger.info(f"to_parquet: {result}") + _logger.info("to_parquet: %s", result) ddl = generate_ddl( df=df, @@ -374,7 +367,7 @@ def to_sql( cursor.execute(add_partition) -def get_column_names_and_types(df: "DataFrame", type_mappings) -> "OrderedDict[str, str]": +def get_column_names_and_types(df: DataFrame, type_mappings) -> OrderedDict[str, str]: """Extract column names and their SQL types from a DataFrame. Args: @@ -385,18 +378,18 @@ def get_column_names_and_types(df: "DataFrame", type_mappings) -> "OrderedDict[s An OrderedDict mapping column names to their SQL type strings. """ return OrderedDict( - ((str(df.columns[i]), type_mappings(df.iloc[:, i])) for i in range(len(df.columns))) + (str(df.columns[i]), type_mappings(df.iloc[:, i])) for i in range(len(df.columns)) ) def generate_ddl( - df: "DataFrame", + df: DataFrame, name: str, location: str, schema: str = "default", - partitions: Optional[List[str]] = None, - compression: Optional[str] = None, - type_mappings: Callable[["Series"], str] = to_sql_type_mappings, + partitions: list[str] | None = None, + compression: str | None = None, + type_mappings: Callable[[Series], str] = to_sql_type_mappings, ) -> str: """Generate CREATE EXTERNAL TABLE DDL for a DataFrame. diff --git a/pyathena/polars/__init__.py b/pyathena/polars/__init__.py index 20a12efb..d2a6d61c 100644 --- a/pyathena/polars/__init__.py +++ b/pyathena/polars/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import fsspec fsspec.register_implementation("s3", "pyathena.filesystem.s3.S3FileSystem", clobber=True) diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index 81f3625a..9349f61a 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from concurrent.futures import Future from multiprocessing import cpu_count -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena import ProgrammingError from pyathena.async_cursor import AsyncCursor @@ -58,22 +57,22 @@ class AsyncPolarsCursor(AsyncCursor): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, max_workers: int = (cpu_count() or 1) * 5, arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, - chunksize: Optional[int] = None, + block_size: int | None = None, + cache_type: str | None = None, + chunksize: int | None = None, **kwargs, ) -> None: """Initialize an AsyncPolarsCursor. @@ -127,7 +126,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, Any]: + ) -> DefaultPolarsTypeConverter | DefaultPolarsUnloadTypeConverter | Any: """Get the default type converter for Polars results. Args: @@ -162,8 +161,8 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - unload_location: Optional[str] = None, - kwargs: Optional[Dict[str, Any]] = None, + unload_location: str | None = None, + kwargs: dict[str, Any] | None = None, ) -> AthenaPolarsResultSet: if kwargs is None: kwargs = {} @@ -186,16 +185,16 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaPolarsResultSet, Any]]"]: + ) -> tuple[str, Future[AthenaPolarsResultSet | Any]]: """Execute a SQL query asynchronously and return results as Polars DataFrames. Executes the SQL query on Amazon Athena asynchronously and returns a diff --git a/pyathena/polars/converter.py b/pyathena/polars/converter.py index be3f7cdd..627a4ffd 100644 --- a/pyathena/polars/converter.py +++ b/pyathena/polars/converter.py @@ -1,9 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable, Dict, Optional +from typing import Any from pyathena.converter import ( Converter, @@ -17,7 +17,7 @@ _logger = logging.getLogger(__name__) -_DEFAULT_POLARS_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { +_DEFAULT_POLARS_CONVERTERS: dict[str, Callable[[str | None], Any | None]] = { "date": _to_date, "time": _to_time, "varbinary": _to_binary, @@ -58,7 +58,7 @@ def __init__(self) -> None: ) @property - def _dtypes(self) -> Dict[str, Any]: + def _dtypes(self) -> dict[str, Any]: import polars as pl if not hasattr(self, "__dtypes"): @@ -103,7 +103,7 @@ def get_dtype(self, type_: str, precision: int = 0, scale: int = 0) -> Any: return pl.Decimal(precision=precision, scale=scale) return self._types.get(type_) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: converter = self.get(type_) return converter(value) @@ -127,5 +127,5 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: pass diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index f3cb342a..b788feac 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -1,17 +1,11 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging +from collections.abc import Callable, Iterator from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterator, - List, - Optional, - Union, cast, ) @@ -69,22 +63,22 @@ class PolarsCursor(WithFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Optional[Callable[[str], None]] = None, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + on_start_query_execution: Callable[[str], None] | None = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, - chunksize: Optional[int] = None, + chunksize: int | None = None, **kwargs, ) -> None: """Initialize a PolarsCursor. @@ -138,7 +132,7 @@ def __init__( @staticmethod def get_default_converter( unload: bool = False, - ) -> Union[DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, Any]: + ) -> DefaultPolarsTypeConverter | DefaultPolarsUnloadTypeConverter | Any: """Get the default type converter for Polars results. Args: @@ -154,17 +148,17 @@ def get_default_converter( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, - on_start_query_execution: Optional[Callable[[str], None]] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, - ) -> "PolarsCursor": + ) -> PolarsCursor: """Execute a SQL query and return results as Polars DataFrames. Executes the SQL query on Amazon Athena and configures the result set @@ -230,7 +224,7 @@ def execute( raise OperationalError(query_execution.state_change_reason) return self - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Returns the query results as a Polars DataFrame. This is the primary @@ -254,7 +248,7 @@ def as_polars(self) -> "pl.DataFrame": result_set = cast(AthenaPolarsResultSet, self.result_set) return result_set.as_polars() - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Converts the Polars DataFrame to an Apache Arrow Table for @@ -278,7 +272,7 @@ def as_arrow(self) -> "Table": result_set = cast(AthenaPolarsResultSet, self.result_set) return result_set.as_arrow() - def iter_chunks(self) -> Iterator["pl.DataFrame"]: + def iter_chunks(self) -> Iterator[pl.DataFrame]: """Iterate over result chunks as Polars DataFrames. This method provides an iterator interface for processing result sets. diff --git a/pyathena/polars/result_set.py b/pyathena/polars/result_set.py index 07f28845..6818b84e 100644 --- a/pyathena/polars/result_set.py +++ b/pyathena/polars/result_set.py @@ -1,19 +1,12 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from collections import abc +from collections.abc import Callable, Iterator from multiprocessing import cpu_count from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterator, - List, - Optional, - Tuple, - Union, cast, ) @@ -39,7 +32,7 @@ def _identity(x: Any) -> Any: return x -class PolarsDataFrameIterator(abc.Iterator): # type: ignore +class PolarsDataFrameIterator(abc.Iterator): # type: ignore[type-arg] """Iterator for chunked DataFrame results from Athena queries. This class wraps either a Polars DataFrame iterator (for chunked reading) or @@ -66,9 +59,9 @@ class PolarsDataFrameIterator(abc.Iterator): # type: ignore def __init__( self, - reader: Union[Iterator["pl.DataFrame"], "pl.DataFrame"], - converters: Dict[str, Callable[[Optional[str]], Optional[Any]]], - column_names: List[str], + reader: Iterator[pl.DataFrame] | pl.DataFrame, + converters: dict[str, Callable[[str | None], Any | None]], + column_names: list[str], ) -> None: """Initialize the iterator. @@ -80,13 +73,13 @@ def __init__( import polars as pl if isinstance(reader, pl.DataFrame): - self._reader: Iterator["pl.DataFrame"] = iter([reader]) + self._reader: Iterator[pl.DataFrame] = iter([reader]) else: self._reader = reader self._converters = converters self._column_names = column_names - def __next__(self) -> "pl.DataFrame": + def __next__(self) -> pl.DataFrame: """Get the next DataFrame chunk. Returns: @@ -101,11 +94,11 @@ def __next__(self) -> "pl.DataFrame": self.close() raise - def __iter__(self) -> "PolarsDataFrameIterator": + def __iter__(self) -> PolarsDataFrameIterator: """Return self as iterator.""" return self - def __enter__(self) -> "PolarsDataFrameIterator": + def __enter__(self) -> PolarsDataFrameIterator: """Context manager entry.""" return self @@ -120,7 +113,7 @@ def close(self) -> None: if isinstance(self._reader, GeneratorType): self._reader.close() - def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]: + def iterrows(self) -> Iterator[tuple[int, dict[str, Any]]]: """Iterate over rows as (index, row_dict) tuples. Yields: @@ -137,7 +130,7 @@ def iterrows(self) -> Iterator[Tuple[int, Dict[str, Any]]]: yield (row_num, processed_row) row_num += 1 - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Collect all chunks into a single DataFrame. Returns: @@ -145,7 +138,7 @@ def as_polars(self) -> "pl.DataFrame": """ import polars as pl - dfs = cast(List["pl.DataFrame"], list(self)) + dfs = cast(list["pl.DataFrame"], list(self)) if not dfs: return pl.DataFrame() if len(dfs) == 1: @@ -198,17 +191,17 @@ class AthenaPolarsResultSet(AthenaResultSet): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, unload: bool = False, - unload_location: Optional[str] = None, - block_size: Optional[int] = None, - cache_type: Optional[str] = None, + unload_location: str | None = None, + block_size: int | None = None, + cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, - chunksize: Optional[int] = None, + chunksize: int | None = None, **kwargs, ) -> None: """Initialize the Polars result set. @@ -263,11 +256,11 @@ def __init__( # Cache column names for efficient access in fetchone() # Must be after _create_dataframe_iterator() which updates _metadata for unload - self._column_names_cache: List[str] = self._get_column_names() + self._column_names_cache: list[str] = self._get_column_names() self._iterrows = self._df_iter.iterrows() @property - def _csv_storage_options(self) -> Dict[str, Any]: + def _csv_storage_options(self) -> dict[str, Any]: """Get storage options for Polars CSV reading via fsspec. Polars read_csv uses fsspec for cloud storage access, which works @@ -284,7 +277,7 @@ def _csv_storage_options(self) -> Dict[str, Any]: } @property - def _parquet_storage_options(self) -> Dict[str, Any]: + def _parquet_storage_options(self) -> dict[str, Any]: """Get storage options for Polars Parquet reading via native object_store. Polars read_parquet uses Rust's native object_store crate, which requires @@ -294,7 +287,7 @@ def _parquet_storage_options(self) -> Dict[str, Any]: Dictionary with AWS credentials and region for S3 access. """ credentials = self.connection.session.get_credentials() - options: Dict[str, Any] = {} + options: dict[str, Any] = {} if credentials: frozen_credentials = credentials.get_frozen_credentials() options["aws_access_key_id"] = frozen_credentials.access_key @@ -306,7 +299,7 @@ def _parquet_storage_options(self) -> Dict[str, Any]: return options @property - def dtypes(self) -> Dict[str, Any]: + def dtypes(self) -> dict[str, Any]: """Get Polars-compatible data types for result columns.""" description = self.description if self.description else [] return { @@ -316,7 +309,7 @@ def dtypes(self) -> Dict[str, Any]: } @property - def converters(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: + def converters(self) -> dict[str, Callable[[str | None], Any | None]]: """Get converter functions for each column. Returns: @@ -325,7 +318,7 @@ def converters(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: description = self.description if self.description else [] return {d[0]: self._converter.get(d[1]) for d in description} - def _get_column_names(self) -> List[str]: + def _get_column_names(self) -> list[str]: """Get column names from description. Returns: @@ -342,7 +335,7 @@ def _create_dataframe_iterator(self) -> PolarsDataFrameIterator: """ if self._chunksize is not None: # Chunked mode: create lazy iterator - reader: Union[Iterator["pl.DataFrame"], "pl.DataFrame"] = ( + reader: Iterator[pl.DataFrame] | pl.DataFrame = ( self._iter_parquet_chunks() if self.is_unload else self._iter_csv_chunks() ) else: @@ -353,7 +346,7 @@ def _create_dataframe_iterator(self) -> PolarsDataFrameIterator: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the query result. Returns: @@ -368,8 +361,8 @@ def fetchone( return tuple([row[1][col] for col in self._column_names_cache]) def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch the next set of rows of the query result. Args: @@ -391,7 +384,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows of the query result. Returns: @@ -442,7 +435,7 @@ def _prepare_parquet_location(self) -> bool: self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" return True - def _read_csv(self) -> "pl.DataFrame": + def _read_csv(self) -> pl.DataFrame: """Read query results from CSV file in S3. Returns: @@ -475,10 +468,10 @@ def _read_csv(self) -> "pl.DataFrame": df.columns = new_columns return df except Exception as e: - _logger.exception(f"Failed to read {self.output_location}.") + _logger.exception("Failed to read %s.", self.output_location) raise OperationalError(*e.args) from e - def _read_parquet(self) -> "pl.DataFrame": + def _read_parquet(self) -> pl.DataFrame: """Read query results from Parquet files in S3. Returns: @@ -502,10 +495,10 @@ def _read_parquet(self) -> "pl.DataFrame": **self._kwargs, ) except Exception as e: - _logger.exception(f"Failed to read {self._unload_location}.") + _logger.exception("Failed to read %s.", self._unload_location) raise OperationalError(*e.args) from e - def _read_parquet_schema(self) -> Tuple[Dict[str, Any], ...]: + def _read_parquet_schema(self) -> tuple[dict[str, Any], ...]: """Read schema from Parquet files for metadata.""" import polars as pl @@ -521,10 +514,10 @@ def _read_parquet_schema(self) -> Tuple[Dict[str, Any], ...]: schema = lazy_df.collect_schema() return to_column_info(schema) except Exception as e: - _logger.exception(f"Failed to read schema from {self._unload_location}.") + _logger.exception("Failed to read schema from %s.", self._unload_location) raise OperationalError(*e.args) from e - def _as_polars(self) -> "pl.DataFrame": + def _as_polars(self) -> pl.DataFrame: """Load query results as a Polars DataFrame. Reads from Parquet for UNLOAD queries, otherwise from CSV. @@ -542,7 +535,7 @@ def _as_polars(self) -> "pl.DataFrame": df = self._read_csv() return df - def _as_polars_from_api(self, converter: Optional[Converter] = None) -> "pl.DataFrame": + def _as_polars_from_api(self, converter: Converter | None = None) -> pl.DataFrame: """Build a Polars DataFrame from GetQueryResults API. Used as a fallback when ``output_location`` is not available @@ -561,7 +554,7 @@ def _as_polars_from_api(self, converter: Optional[Converter] = None) -> "pl.Data columns = [d[0] for d in description] return pl.DataFrame(self._rows_to_columnar(rows, columns)) - def as_polars(self) -> "pl.DataFrame": + def as_polars(self) -> pl.DataFrame: """Return query results as a Polars DataFrame. Returns the query results as a Polars DataFrame. This is the primary @@ -584,7 +577,7 @@ def as_polars(self) -> "pl.DataFrame": """ return self._df_iter.as_polars() - def as_arrow(self) -> "Table": + def as_arrow(self) -> Table: """Return query results as an Apache Arrow Table. Converts the Polars DataFrame to an Apache Arrow Table for @@ -609,7 +602,7 @@ def as_arrow(self) -> "Table": "pyarrow is required for as_arrow(). Install it with: pip install pyarrow" ) from e - def _get_csv_params(self) -> Tuple[str, bool, Optional[List[str]]]: + def _get_csv_params(self) -> tuple[str, bool, list[str] | None]: """Get CSV parsing parameters based on file type. Returns: @@ -618,14 +611,14 @@ def _get_csv_params(self) -> Tuple[str, bool, Optional[List[str]]]: if self.output_location and self.output_location.endswith(".txt"): separator = "\t" has_header = False - new_columns: Optional[List[str]] = self._get_column_names() + new_columns: list[str] | None = self._get_column_names() else: separator = "," has_header = True new_columns = None return separator, has_header, new_columns - def _iter_csv_chunks(self) -> Iterator["pl.DataFrame"]: + def _iter_csv_chunks(self) -> Iterator[pl.DataFrame]: """Iterate over CSV data in chunks using lazy evaluation. Yields: @@ -661,10 +654,10 @@ def _iter_csv_chunks(self) -> Iterator["pl.DataFrame"]: batch.columns = new_columns yield batch except Exception as e: - _logger.exception(f"Failed to read {self.output_location}.") + _logger.exception("Failed to read %s.", self.output_location) raise OperationalError(*e.args) from e - def _iter_parquet_chunks(self) -> Iterator["pl.DataFrame"]: + def _iter_parquet_chunks(self) -> Iterator[pl.DataFrame]: """Iterate over Parquet data in chunks using lazy evaluation. Yields: @@ -687,10 +680,9 @@ def _iter_parquet_chunks(self) -> Iterator["pl.DataFrame"]: storage_options=self._parquet_storage_options, **self._kwargs, ) - for batch in lazy_df.collect_batches(chunk_size=self._chunksize): - yield batch + yield from lazy_df.collect_batches(chunk_size=self._chunksize) except Exception as e: - _logger.exception(f"Failed to read {self._unload_location}.") + _logger.exception("Failed to read %s.", self._unload_location) raise OperationalError(*e.args) from e def iter_chunks(self) -> PolarsDataFrameIterator: diff --git a/pyathena/polars/util.py b/pyathena/polars/util.py index afc64b6b..91edd06d 100644 --- a/pyathena/polars/util.py +++ b/pyathena/polars/util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Utilities for converting Polars types to Athena metadata. This module provides functions to convert Polars schema and type information @@ -8,13 +7,13 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Tuple +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import polars as pl -def to_column_info(schema: "pl.Schema") -> Tuple[Dict[str, Any], ...]: +def to_column_info(schema: pl.Schema) -> tuple[dict[str, Any], ...]: """Convert a Polars schema to Athena column information. Iterates through all fields in the schema and converts each field's @@ -46,7 +45,7 @@ def to_column_info(schema: "pl.Schema") -> Tuple[Dict[str, Any], ...]: return tuple(columns) -def get_athena_type(dtype: Any) -> Tuple[str, int, int]: +def get_athena_type(dtype: Any) -> tuple[str, int, int]: """Map a Polars data type to an Athena SQL type. Converts Polars type identifiers to corresponding Athena SQL type names @@ -73,7 +72,7 @@ def get_athena_type(dtype: Any) -> Tuple[str, int, int]: base_dtype = dtype.base_type() if hasattr(dtype, "base_type") else dtype # Type mapping: Polars type -> (Athena type, precision, scale) - type_mapping: Dict[Any, Tuple[str, int, int]] = { + type_mapping: dict[Any, tuple[str, int, int]] = { pl.Boolean: ("boolean", 0, 0), pl.Int8: ("tinyint", 3, 0), pl.Int16: ("smallint", 5, 0), diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 103a8a22..7f579d06 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import collections @@ -8,13 +7,6 @@ from typing import ( TYPE_CHECKING, Any, - Deque, - Dict, - List, - Optional, - Tuple, - Type, - Union, cast, ) @@ -27,7 +19,7 @@ if TYPE_CHECKING: from pyathena.connection import Connection -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AthenaResultSet(CursorIterator): @@ -63,7 +55,7 @@ class AthenaResultSet(CursorIterator): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, @@ -71,9 +63,9 @@ def __init__( _pre_fetch: bool = True, ) -> None: super().__init__(arraysize=arraysize) - self._connection: Optional["Connection[Any]"] = connection + self._connection: Connection[Any] | None = connection self._converter = converter - self._query_execution: Optional[AthenaQueryExecution] = query_execution + self._query_execution: AthenaQueryExecution | None = query_execution if not self._query_execution: raise ProgrammingError("Required argument `query_execution` not found.") self._retry_config = retry_config @@ -84,11 +76,11 @@ def __init__( **connection._client_kwargs, ) - self._metadata: Optional[Tuple[Dict[str, Any], ...]] = None - self._rows: Deque[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]] = ( + self._metadata: tuple[dict[str, Any], ...] | None = None + self._rows: collections.deque[tuple[Any | None, ...] | dict[Any, Any | None]] = ( collections.deque() ) - self._next_token: Optional[str] = None + self._next_token: str | None = None if self.state == AthenaQueryExecution.STATE_SUCCEEDED: self._rownumber = 0 @@ -96,151 +88,151 @@ def __init__( self._pre_fetch() @property - def database(self) -> Optional[str]: + def database(self) -> str | None: if not self._query_execution: return None return self._query_execution.database @property - def catalog(self) -> Optional[str]: + def catalog(self) -> str | None: if not self._query_execution: return None return self._query_execution.catalog @property - def query_id(self) -> Optional[str]: + def query_id(self) -> str | None: if not self._query_execution: return None return self._query_execution.query_id @property - def query(self) -> Optional[str]: + def query(self) -> str | None: if not self._query_execution: return None return self._query_execution.query @property - def statement_type(self) -> Optional[str]: + def statement_type(self) -> str | None: if not self._query_execution: return None return self._query_execution.statement_type @property - def substatement_type(self) -> Optional[str]: + def substatement_type(self) -> str | None: if not self._query_execution: return None return self._query_execution.substatement_type @property - def work_group(self) -> Optional[str]: + def work_group(self) -> str | None: if not self._query_execution: return None return self._query_execution.work_group @property - def execution_parameters(self) -> List[str]: + def execution_parameters(self) -> list[str]: if not self._query_execution: return [] return self._query_execution.execution_parameters @property - def state(self) -> Optional[str]: + def state(self) -> str | None: if not self._query_execution: return None return self._query_execution.state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: if not self._query_execution: return None return self._query_execution.state_change_reason @property - def submission_date_time(self) -> Optional[datetime]: + def submission_date_time(self) -> datetime | None: if not self._query_execution: return None return self._query_execution.submission_date_time @property - def completion_date_time(self) -> Optional[datetime]: + def completion_date_time(self) -> datetime | None: if not self._query_execution: return None return self._query_execution.completion_date_time @property - def error_category(self) -> Optional[int]: + def error_category(self) -> int | None: if not self._query_execution: return None return self._query_execution.error_category @property - def error_type(self) -> Optional[int]: + def error_type(self) -> int | None: if not self._query_execution: return None return self._query_execution.error_type @property - def retryable(self) -> Optional[bool]: + def retryable(self) -> bool | None: if not self._query_execution: return None return self._query_execution.retryable @property - def error_message(self) -> Optional[str]: + def error_message(self) -> str | None: if not self._query_execution: return None return self._query_execution.error_message @property - def data_scanned_in_bytes(self) -> Optional[int]: + def data_scanned_in_bytes(self) -> int | None: if not self._query_execution: return None return self._query_execution.data_scanned_in_bytes @property - def engine_execution_time_in_millis(self) -> Optional[int]: + def engine_execution_time_in_millis(self) -> int | None: if not self._query_execution: return None return self._query_execution.engine_execution_time_in_millis @property - def query_queue_time_in_millis(self) -> Optional[int]: + def query_queue_time_in_millis(self) -> int | None: if not self._query_execution: return None return self._query_execution.query_queue_time_in_millis @property - def total_execution_time_in_millis(self) -> Optional[int]: + def total_execution_time_in_millis(self) -> int | None: if not self._query_execution: return None return self._query_execution.total_execution_time_in_millis @property - def query_planning_time_in_millis(self) -> Optional[int]: + def query_planning_time_in_millis(self) -> int | None: if not self._query_execution: return None return self._query_execution.query_planning_time_in_millis @property - def service_processing_time_in_millis(self) -> Optional[int]: + def service_processing_time_in_millis(self) -> int | None: if not self._query_execution: return None return self._query_execution.service_processing_time_in_millis @property - def output_location(self) -> Optional[str]: + def output_location(self) -> str | None: if not self._query_execution: return None return self._query_execution.output_location @property - def data_manifest_location(self) -> Optional[str]: + def data_manifest_location(self) -> str | None: if not self._query_execution: return None return self._query_execution.data_manifest_location @property - def reused_previous_result(self) -> Optional[bool]: + def reused_previous_result(self) -> bool | None: if not self._query_execution: return None return self._query_execution.reused_previous_result @@ -259,49 +251,49 @@ def is_unload(self) -> bool: ) @property - def encryption_option(self) -> Optional[str]: + def encryption_option(self) -> str | None: if not self._query_execution: return None return self._query_execution.encryption_option @property - def kms_key(self) -> Optional[str]: + def kms_key(self) -> str | None: if not self._query_execution: return None return self._query_execution.kms_key @property - def expected_bucket_owner(self) -> Optional[str]: + def expected_bucket_owner(self) -> str | None: if not self._query_execution: return None return self._query_execution.expected_bucket_owner @property - def s3_acl_option(self) -> Optional[str]: + def s3_acl_option(self) -> str | None: if not self._query_execution: return None return self._query_execution.s3_acl_option @property - def selected_engine_version(self) -> Optional[str]: + def selected_engine_version(self) -> str | None: if not self._query_execution: return None return self._query_execution.selected_engine_version @property - def effective_engine_version(self) -> Optional[str]: + def effective_engine_version(self) -> str | None: if not self._query_execution: return None return self._query_execution.effective_engine_version @property - def result_reuse_enabled(self) -> Optional[bool]: + def result_reuse_enabled(self) -> bool | None: if not self._query_execution: return None return self._query_execution.result_reuse_enabled @property - def result_reuse_minutes(self) -> Optional[int]: + def result_reuse_minutes(self) -> int | None: if not self._query_execution: return None return self._query_execution.result_reuse_minutes @@ -309,7 +301,7 @@ def result_reuse_minutes(self) -> Optional[int]: @property def description( self, - ) -> Optional[List[Tuple[str, str, None, None, int, int, str]]]: + ) -> list[tuple[str, str, None, None, int, int, str]] | None: if self._metadata is None: return None return [ @@ -326,21 +318,21 @@ def description( ] @property - def connection(self) -> "Connection[Any]": + def connection(self) -> Connection[Any]: if self.is_closed: raise ProgrammingError("AthenaResultSet is closed.") return cast("Connection[Any]", self._connection) def __get_query_results( - self, max_results: int, next_token: Optional[str] = None - ) -> Dict[str, Any]: + self, max_results: int, next_token: str | None = None + ) -> dict[str, Any]: if not self.query_id: raise ProgrammingError("QueryExecutionId is none or empty.") if self.state != AthenaQueryExecution.STATE_SUCCEEDED: raise ProgrammingError("QueryExecutionState is not SUCCEEDED.") if self.is_closed: raise ProgrammingError("AthenaResultSet is closed.") - request: Dict[str, Any] = { + request: dict[str, Any] = { "QueryExecutionId": self.query_id, "MaxResults": max_results, } @@ -357,9 +349,9 @@ def __get_query_results( _logger.exception("Failed to fetch result set.") raise OperationalError(*e.args) from e else: - return cast(Dict[str, Any], response) + return cast(dict[str, Any], response) - def __fetch(self, next_token: Optional[str] = None) -> Dict[str, Any]: + def __fetch(self, next_token: str | None = None) -> dict[str, Any]: return self.__get_query_results(self._arraysize, next_token) def _fetch(self) -> None: @@ -379,7 +371,7 @@ def _pre_fetch(self) -> None: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: if not self._rows and self._next_token: self._fetch() if not self._rows: @@ -390,8 +382,8 @@ def fetchone( return self._rows.popleft() def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: if not size or size <= 0: size = self._arraysize rows = [] @@ -405,7 +397,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: rows = [] while True: row = self.fetchone() @@ -415,7 +407,7 @@ def fetchall( break return rows - def _process_metadata(self, response: Dict[str, Any]) -> None: + def _process_metadata(self, response: dict[str, Any]) -> None: result_set = response.get("ResultSet") if not result_set: raise DataError("KeyError `ResultSet`") @@ -427,7 +419,7 @@ def _process_metadata(self, response: Dict[str, Any]) -> None: raise DataError("KeyError `ColumnInfo`") self._metadata = tuple(column_info) - def _process_update_count(self, response: Dict[str, Any]) -> None: + def _process_update_count(self, response: dict[str, Any]) -> None: update_count = response.get("UpdateCount") if ( update_count is not None @@ -446,10 +438,10 @@ def _process_update_count(self, response: Dict[str, Any]) -> None: def _get_rows( self, offset: int, - metadata: Tuple[Any, ...], - rows: List[Dict[str, Any]], - converter: Optional[Converter] = None, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + metadata: tuple[Any, ...], + rows: list[dict[str, Any]], + converter: Converter | None = None, + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter return [ tuple( @@ -462,8 +454,8 @@ def _get_rows( ] def _parse_result_rows( - self, response: Dict[str, Any] - ) -> Tuple[List[Dict[str, Any]], Optional[str]]: + self, response: dict[str, Any] + ) -> tuple[list[dict[str, Any]], str | None]: """Parse a GetQueryResults response into raw rows and next token. Handles response validation and pagination token extraction. @@ -485,12 +477,12 @@ def _parse_result_rows( next_token = response.get("NextToken") return rows, next_token - def _process_rows(self, rows: List[Dict[str, Any]], offset: int = 0) -> None: + def _process_rows(self, rows: list[dict[str, Any]], offset: int = 0) -> None: if rows and self._metadata: processed_rows = self._get_rows(offset, self._metadata, rows) self._rows.extend(processed_rows) - def _is_first_row_column_labels(self, rows: List[Dict[str, Any]]) -> bool: + def _is_first_row_column_labels(self, rows: list[dict[str, Any]]) -> bool: first_row_data = rows[0].get("Data", []) for meta, data in zip(self._metadata or (), first_row_data, strict=False): if meta.get("Name") != data.get("VarCharValue"): @@ -499,8 +491,8 @@ def _is_first_row_column_labels(self, rows: List[Dict[str, Any]]) -> bool: def _fetch_all_rows( self, - converter: Optional[Converter] = None, - ) -> List[Tuple[Optional[Any], ...]]: + converter: Converter | None = None, + ) -> list[tuple[Any | None, ...]]: """Fetch all rows via GetQueryResults API with type conversion. Paginates through all results from the beginning using MaxResults=1000. @@ -529,8 +521,8 @@ def _fetch_all_rows( ) converter = converter or DefaultTypeConverter() - all_rows: List[Tuple[Optional[Any], ...]] = [] - next_token: Optional[str] = None + all_rows: list[tuple[Any | None, ...]] = [] + next_token: str | None = None while True: response = self.__get_query_results(self.DEFAULT_FETCH_SIZE, next_token) @@ -539,7 +531,7 @@ def _fetch_all_rows( offset = 1 if rows and self._is_first_row_column_labels(rows) else 0 all_rows.extend( cast( - List[Tuple[Optional[Any], ...]], + list[tuple[Any | None, ...]], self._get_rows(offset, self._metadata, rows, converter), ) ) @@ -551,9 +543,9 @@ def _fetch_all_rows( @staticmethod def _rows_to_columnar( - rows: List[Tuple[Optional[Any], ...]], - columns: List[str], - ) -> Dict[str, List[Any]]: + rows: list[tuple[Any | None, ...]], + columns: list[str], + ) -> dict[str, list[Any]]: """Convert row-oriented data to columnar format. Args: @@ -563,7 +555,7 @@ def _rows_to_columnar( Returns: Dictionary mapping column names to lists of values. """ - columnar: Dict[str, List[Any]] = {col: [] for col in columns} + columnar: dict[str, list[Any]] = {col: [] for col in columns} for row in rows: for col, val in zip(columns, row, strict=False): columnar[col].append(val) @@ -587,7 +579,7 @@ def _get_content_length(self) -> int: else: return cast(int, response["ContentLength"]) - def _read_data_manifest(self) -> List[str]: + def _read_data_manifest(self) -> list[str]: if not self.data_manifest_location: raise ProgrammingError("DataManifestLocation is none or empty.") bucket, key = parse_output_location(self.data_manifest_location) @@ -600,7 +592,7 @@ def _read_data_manifest(self) -> List[str]: Key=key, ) except Exception as e: - _logger.exception(f"Failed to read {bucket}/{key}.") + _logger.exception("Failed to read %s/%s.", bucket, key) raise OperationalError(*e.args) from e else: manifest: str = response["Body"].read().decode("utf-8").strip() @@ -628,15 +620,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): class AthenaDictResultSet(AthenaResultSet): # You can override this to use OrderedDict or other dict-like types. - dict_type: Type[Any] = dict + dict_type: type[Any] = dict def _get_rows( self, offset: int, - metadata: Tuple[Any, ...], - rows: List[Dict[str, Any]], - converter: Optional[Converter] = None, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + metadata: tuple[Any, ...], + rows: list[dict[str, Any]], + converter: Converter | None = None, + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: conv = converter or self._converter return [ self.dict_type( @@ -657,19 +649,19 @@ def __init__(self): super().__init__() def _reset_state(self) -> None: - self.query_id = None # type: ignore + self.query_id = None if self.result_set and not self.result_set.is_closed: self.result_set.close() - self.result_set = None # type: ignore + self.result_set = None - @property # type: ignore + @property @abstractmethod - def result_set(self) -> Optional[AthenaResultSet]: + def result_set(self) -> AthenaResultSet | None: raise NotImplementedError # pragma: no cover - @result_set.setter # type: ignore + @result_set.setter @abstractmethod - def result_set(self, val: Optional[AthenaResultSet]) -> None: + def result_set(self, val: AthenaResultSet | None) -> None: raise NotImplementedError # pragma: no cover @property @@ -679,209 +671,209 @@ def has_result_set(self) -> bool: @property def description( self, - ) -> Optional[List[Tuple[str, str, None, None, int, int, str]]]: + ) -> list[tuple[str, str, None, None, int, int, str]] | None: if not self.result_set: return None return self.result_set.description @property - def database(self) -> Optional[str]: + def database(self) -> str | None: if not self.result_set: return None return self.result_set.database @property - def catalog(self) -> Optional[str]: + def catalog(self) -> str | None: if not self.result_set: return None return self.result_set.catalog - @property # type: ignore + @property @abstractmethod - def query_id(self) -> Optional[str]: + def query_id(self) -> str | None: raise NotImplementedError # pragma: no cover - @query_id.setter # type: ignore + @query_id.setter @abstractmethod - def query_id(self, val: Optional[str]) -> None: + def query_id(self, val: str | None) -> None: raise NotImplementedError # pragma: no cover @property - def query(self) -> Optional[str]: + def query(self) -> str | None: if not self.result_set: return None return self.result_set.query @property - def statement_type(self) -> Optional[str]: + def statement_type(self) -> str | None: if not self.result_set: return None return self.result_set.statement_type @property - def substatement_type(self) -> Optional[str]: + def substatement_type(self) -> str | None: if not self.result_set: return None return self.result_set.substatement_type @property - def work_group(self) -> Optional[str]: + def work_group(self) -> str | None: if not self.result_set: return None return self.result_set.work_group @property - def execution_parameters(self) -> List[str]: + def execution_parameters(self) -> list[str]: if not self.result_set: return [] return self.result_set.execution_parameters @property - def state(self) -> Optional[str]: + def state(self) -> str | None: if not self.result_set: return None return self.result_set.state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: if not self.result_set: return None return self.result_set.state_change_reason @property - def submission_date_time(self) -> Optional[datetime]: + def submission_date_time(self) -> datetime | None: if not self.result_set: return None return self.result_set.submission_date_time @property - def completion_date_time(self) -> Optional[datetime]: + def completion_date_time(self) -> datetime | None: if not self.result_set: return None return self.result_set.completion_date_time @property - def error_category(self) -> Optional[int]: + def error_category(self) -> int | None: if not self.result_set: return None return self.result_set.error_category @property - def error_type(self) -> Optional[int]: + def error_type(self) -> int | None: if not self.result_set: return None return self.result_set.error_type @property - def retryable(self) -> Optional[bool]: + def retryable(self) -> bool | None: if not self.result_set: return None return self.result_set.retryable @property - def error_message(self) -> Optional[str]: + def error_message(self) -> str | None: if not self.result_set: return None return self.result_set.error_message @property - def data_scanned_in_bytes(self) -> Optional[int]: + def data_scanned_in_bytes(self) -> int | None: if not self.result_set: return None return self.result_set.data_scanned_in_bytes @property - def engine_execution_time_in_millis(self) -> Optional[int]: + def engine_execution_time_in_millis(self) -> int | None: if not self.result_set: return None return self.result_set.engine_execution_time_in_millis @property - def query_queue_time_in_millis(self) -> Optional[int]: + def query_queue_time_in_millis(self) -> int | None: if not self.result_set: return None return self.result_set.query_queue_time_in_millis @property - def total_execution_time_in_millis(self) -> Optional[int]: + def total_execution_time_in_millis(self) -> int | None: if not self.result_set: return None return self.result_set.total_execution_time_in_millis @property - def query_planning_time_in_millis(self) -> Optional[int]: + def query_planning_time_in_millis(self) -> int | None: if not self.result_set: return None return self.result_set.query_planning_time_in_millis @property - def service_processing_time_in_millis(self) -> Optional[int]: + def service_processing_time_in_millis(self) -> int | None: if not self.result_set: return None return self.result_set.service_processing_time_in_millis @property - def output_location(self) -> Optional[str]: + def output_location(self) -> str | None: if not self.result_set: return None return self.result_set.output_location @property - def data_manifest_location(self) -> Optional[str]: + def data_manifest_location(self) -> str | None: if not self.result_set: return None return self.result_set.data_manifest_location @property - def reused_previous_result(self) -> Optional[bool]: + def reused_previous_result(self) -> bool | None: if not self.result_set: return None return self.result_set.reused_previous_result @property - def encryption_option(self) -> Optional[str]: + def encryption_option(self) -> str | None: if not self.result_set: return None return self.result_set.encryption_option @property - def kms_key(self) -> Optional[str]: + def kms_key(self) -> str | None: if not self.result_set: return None return self.result_set.kms_key @property - def expected_bucket_owner(self) -> Optional[str]: + def expected_bucket_owner(self) -> str | None: if not self.result_set: return None return self.result_set.expected_bucket_owner @property - def s3_acl_option(self) -> Optional[str]: + def s3_acl_option(self) -> str | None: if not self.result_set: return None return self.result_set.s3_acl_option @property - def selected_engine_version(self) -> Optional[str]: + def selected_engine_version(self) -> str | None: if not self.result_set: return None return self.result_set.selected_engine_version @property - def effective_engine_version(self) -> Optional[str]: + def effective_engine_version(self) -> str | None: if not self.result_set: return None return self.result_set.effective_engine_version @property - def result_reuse_enabled(self) -> Optional[bool]: + def result_reuse_enabled(self) -> bool | None: if not self.result_set: return None return self.result_set.result_reuse_enabled @property - def result_reuse_minutes(self) -> Optional[int]: + def result_reuse_minutes(self) -> int | None: if not self.result_set: return None return self.result_set.result_reuse_minutes @@ -914,8 +906,8 @@ class WithFetch(BaseCursor, CursorIterator, WithResultSet): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - self._query_id: Optional[str] = None - self._result_set: Optional[AthenaResultSet] = None + self._query_id: str | None = None + self._result_set: AthenaResultSet | None = None @property def arraysize(self) -> int: @@ -927,8 +919,8 @@ def arraysize(self, value: int) -> None: raise ProgrammingError("arraysize must be a positive integer value.") self._arraysize = value - @property # type: ignore - def result_set(self) -> Optional[AthenaResultSet]: + @property + def result_set(self) -> AthenaResultSet | None: return self._result_set @result_set.setter @@ -936,7 +928,7 @@ def result_set(self, val) -> None: self._result_set = val @property - def query_id(self) -> Optional[str]: + def query_id(self) -> str | None: return self._query_id @query_id.setter @@ -944,7 +936,7 @@ def query_id(self, val) -> None: self._query_id = val @property - def rownumber(self) -> Optional[int]: + def rownumber(self) -> int | None: return self.result_set.rownumber if self.result_set else None @property @@ -959,7 +951,7 @@ def close(self) -> None: def executemany( self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: """Execute a SQL query multiple times with different parameters. @@ -986,7 +978,7 @@ def cancel(self) -> None: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Returns: @@ -1001,8 +993,8 @@ def fetchone( return result_set.fetchone() def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch multiple rows from the result set. Args: @@ -1021,7 +1013,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows from the result set. Returns: diff --git a/pyathena/s3fs/__init__.py b/pyathena/s3fs/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/s3fs/__init__.py +++ b/pyathena/s3fs/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/s3fs/async_cursor.py b/pyathena/s3fs/async_cursor.py index 89073b3d..a98c4558 100644 --- a/pyathena/s3fs/async_cursor.py +++ b/pyathena/s3fs/async_cursor.py @@ -1,10 +1,9 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from concurrent.futures import Future from multiprocessing import cpu_count -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, cast from pyathena.async_cursor import AsyncCursor from pyathena.common import CursorIterator @@ -49,19 +48,19 @@ class AsyncS3FSCursor(AsyncCursor): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, max_workers: int = (cpu_count() or 1) * 5, arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - csv_reader: Optional[CSVReaderType] = None, + csv_reader: CSVReaderType | None = None, **kwargs, ) -> None: """Initialize an AsyncS3FSCursor. @@ -109,7 +108,7 @@ def __init__( @staticmethod def get_default_converter( - unload: bool = False, # noqa: ARG004 + unload: bool = False, ) -> DefaultS3FSTypeConverter: """Get the default type converter for S3FS cursor. @@ -143,7 +142,7 @@ def arraysize(self, value: int) -> None: def _collect_result_set( self, query_id: str, - kwargs: Optional[Dict[str, Any]] = None, + kwargs: dict[str, Any] | None = None, ) -> AthenaS3FSResultSet: """Collect result set after query execution. @@ -170,16 +169,16 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaS3FSResultSet, Any]]"]: + ) -> tuple[str, Future[AthenaS3FSResultSet | Any]]: """Execute a SQL query asynchronously. Submits the query to Athena and returns immediately with a query ID diff --git a/pyathena/s3fs/converter.py b/pyathena/s3fs/converter.py index 4396d4ac..853f7a7d 100644 --- a/pyathena/s3fs/converter.py +++ b/pyathena/s3fs/converter.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from copy import deepcopy -from typing import Any, Optional +from typing import Any from pyathena.converter import ( _DEFAULT_CONVERTERS, @@ -45,7 +44,7 @@ def __init__(self) -> None: default=_to_default, ) - def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: + def convert(self, type_: str, value: str | None) -> Any | None: """Convert a string value to the appropriate Python type. Looks up the converter function for the given Athena type and applies diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index 138e7b71..fbb85fab 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -1,8 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional, Union, cast +from collections.abc import Callable +from typing import Any, cast from pyathena.common import CursorIterator from pyathena.error import OperationalError @@ -48,18 +48,18 @@ class S3FSCursor(WithFetch): def __init__( self, - s3_staging_dir: Optional[str] = None, - schema_name: Optional[str] = None, - catalog_name: Optional[str] = None, - work_group: Optional[str] = None, + s3_staging_dir: str | None = None, + schema_name: str | None = None, + catalog_name: str | None = None, + work_group: str | None = None, poll_interval: float = 1, - encryption_option: Optional[str] = None, - kms_key: Optional[str] = None, + encryption_option: str | None = None, + kms_key: str | None = None, kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Optional[Callable[[str], None]] = None, - csv_reader: Optional[CSVReaderType] = None, + on_start_query_execution: Callable[[str], None] | None = None, + csv_reader: CSVReaderType | None = None, **kwargs, ) -> None: """Initialize an S3FSCursor. @@ -109,7 +109,7 @@ def __init__( @staticmethod def get_default_converter( - unload: bool = False, # noqa: ARG004 + unload: bool = False, ) -> DefaultS3FSTypeConverter: """Get the default type converter for S3FS cursor. @@ -124,17 +124,17 @@ def get_default_converter( def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - work_group: Optional[str] = None, - s3_staging_dir: Optional[str] = None, - cache_size: Optional[int] = 0, - cache_expiration_time: Optional[int] = 0, - result_reuse_enable: Optional[bool] = None, - result_reuse_minutes: Optional[int] = None, - paramstyle: Optional[str] = None, - on_start_query_execution: Optional[Callable[[str], None]] = None, + parameters: dict[str, Any] | list[str] | None = None, + work_group: str | None = None, + s3_staging_dir: str | None = None, + cache_size: int | None = 0, + cache_expiration_time: int | None = 0, + result_reuse_enable: bool | None = None, + result_reuse_minutes: int | None = None, + paramstyle: str | None = None, + on_start_query_execution: Callable[[str], None] | None = None, **kwargs, - ) -> "S3FSCursor": + ) -> S3FSCursor: """Execute a SQL query and return results. Executes the SQL query on Amazon Athena and configures the result set diff --git a/pyathena/s3fs/reader.py b/pyathena/s3fs/reader.py index 11f29b15..a48524d8 100644 --- a/pyathena/s3fs/reader.py +++ b/pyathena/s3fs/reader.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import csv from collections.abc import Iterator -from typing import Any, List, Optional, Tuple +from typing import Any -class DefaultCSVReader(Iterator[List[str]]): +class DefaultCSVReader(Iterator[list[str]]): """CSV reader using Python's standard csv module. This reader wraps Python's standard csv.reader and treats empty fields @@ -34,14 +33,14 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None: file_obj: File-like object to read from. delimiter: Field delimiter character. """ - self._file: Optional[Any] = file_obj + self._file: Any | None = file_obj self._reader = csv.reader(file_obj, delimiter=delimiter) - def __iter__(self) -> "DefaultCSVReader": + def __iter__(self) -> DefaultCSVReader: """Iterate over rows in the CSV file.""" return self - def __next__(self) -> List[str]: + def __next__(self) -> list[str]: """Read and parse the next line. Returns: @@ -65,7 +64,7 @@ def close(self) -> None: self._file.close() self._file = None - def __enter__(self) -> "DefaultCSVReader": + def __enter__(self) -> DefaultCSVReader: """Enter context manager.""" return self @@ -74,7 +73,7 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() -class AthenaCSVReader(Iterator[List[Optional[str]]]): +class AthenaCSVReader(Iterator[list[str | None]]): """CSV reader that distinguishes between NULL and empty string. This is the default reader for S3FSCursor. @@ -105,14 +104,14 @@ def __init__(self, file_obj: Any, delimiter: str = ",") -> None: file_obj: File-like object to read from. delimiter: Field delimiter character. """ - self._file: Optional[Any] = file_obj + self._file: Any | None = file_obj self._delimiter = delimiter - def __iter__(self) -> "AthenaCSVReader": + def __iter__(self) -> AthenaCSVReader: """Iterate over rows in the CSV file.""" return self - def __next__(self) -> List[Optional[str]]: + def __next__(self) -> list[str | None]: """Read and parse the next line. Returns: @@ -163,7 +162,7 @@ def _check_quote_state(self, text: str, starting_state: bool = False) -> bool: i += 1 return in_quotes - def _parse_line(self, line: str) -> List[Optional[str]]: + def _parse_line(self, line: str) -> list[str | None]: """Parse a single CSV line preserving NULL vs empty string distinction. Args: @@ -176,7 +175,7 @@ def _parse_line(self, line: str) -> List[Optional[str]]: if not line: return [None] - fields: List[Optional[str]] = [] + fields: list[str | None] = [] pos = 0 length = len(line) @@ -197,7 +196,7 @@ def _parse_line(self, line: str) -> List[Optional[str]]: return fields - def _parse_quoted_field(self, line: str, pos: int) -> Tuple[str, int]: + def _parse_quoted_field(self, line: str, pos: int) -> tuple[str, int]: """Parse a quoted field starting at pos. Args: @@ -231,7 +230,7 @@ def _parse_quoted_field(self, line: str, pos: int) -> Tuple[str, int]: return "".join(value_parts), pos - def _parse_unquoted_field(self, line: str, pos: int) -> Tuple[str, int]: + def _parse_unquoted_field(self, line: str, pos: int) -> tuple[str, int]: """Parse an unquoted field starting at pos. Args: @@ -261,7 +260,7 @@ def close(self) -> None: self._file.close() self._file = None - def __enter__(self) -> "AthenaCSVReader": + def __enter__(self) -> AthenaCSVReader: """Enter context manager.""" return self diff --git a/pyathena/s3fs/result_set.py b/pyathena/s3fs/result_set.py index 1068aaa0..4173f25b 100644 --- a/pyathena/s3fs/result_set.py +++ b/pyathena/s3fs/result_set.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging from io import TextIOWrapper -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any from fsspec import AbstractFileSystem @@ -18,7 +17,7 @@ if TYPE_CHECKING: from pyathena.connection import Connection -CSVReaderType = Union[Type[DefaultCSVReader], Type[AthenaCSVReader]] +CSVReaderType = type[DefaultCSVReader] | type[AthenaCSVReader] _logger = logging.getLogger(__name__) @@ -57,14 +56,14 @@ class AthenaS3FSResultSet(AthenaResultSet): def __init__( self, - connection: "Connection[Any]", + connection: Connection[Any], converter: Converter, query_execution: AthenaQueryExecution, arraysize: int, retry_config: RetryConfig, - block_size: Optional[int] = None, - csv_reader: Optional[CSVReaderType] = None, - filesystem_class: Optional[Type[AbstractFileSystem]] = None, + block_size: int | None = None, + csv_reader: CSVReaderType | None = None, + filesystem_class: type[AbstractFileSystem] | None = None, **kwargs, ) -> None: super().__init__( @@ -80,9 +79,9 @@ def __init__( self._arraysize = arraysize self._block_size = block_size if block_size else self.DEFAULT_BLOCK_SIZE self._csv_reader_class: CSVReaderType = csv_reader or AthenaCSVReader - self._filesystem_class: Type[AbstractFileSystem] = filesystem_class or S3FileSystem + self._filesystem_class: type[AbstractFileSystem] = filesystem_class or S3FileSystem self._fs = self._create_s3_file_system() - self._csv_reader: Optional[Any] = None + self._csv_reader: Any | None = None if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: self._init_csv_reader() @@ -140,7 +139,7 @@ def _init_csv_reader(self) -> None: next(self._csv_reader) except Exception as e: - _logger.exception(f"Failed to open {path}.") + _logger.exception("Failed to open %s.", path) raise OperationalError(*e.args) from e def _fetch(self) -> None: @@ -176,7 +175,7 @@ def _fetch(self) -> None: def fetchone( self, - ) -> Optional[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> tuple[Any | None, ...] | dict[Any, Any | None] | None: """Fetch the next row of the result set. Returns: @@ -192,8 +191,8 @@ def fetchone( return self._rows.popleft() def fetchmany( - self, size: Optional[int] = None - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + self, size: int | None = None + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch the next set of rows of the result set. Args: @@ -215,7 +214,7 @@ def fetchmany( def fetchall( self, - ) -> List[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]]: + ) -> list[tuple[Any | None, ...] | dict[Any, Any | None]]: """Fetch all remaining rows of the result set. Returns: diff --git a/pyathena/spark/__init__.py b/pyathena/spark/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/spark/__init__.py +++ b/pyathena/spark/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/spark/async_cursor.py b/pyathena/spark/async_cursor.py index 634feca3..1a476f20 100644 --- a/pyathena/spark/async_cursor.py +++ b/pyathena/spark/async_cursor.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- import logging from concurrent.futures import Future, ThreadPoolExecutor from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, cast from pyathena.model import AthenaCalculationExecution from pyathena.spark.common import SparkBaseCursor @@ -10,7 +9,7 @@ if TYPE_CHECKING: from pyathena.model import AthenaQueryExecution -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class AsyncSparkCursor(SparkBaseCursor): @@ -68,11 +67,11 @@ class AsyncSparkCursor(SparkBaseCursor): def __init__( self, - session_id: Optional[str] = None, - description: Optional[str] = None, - engine_configuration: Optional[Dict[str, Any]] = None, - notebook_version: Optional[str] = None, - session_idle_timeout_minutes: Optional[int] = None, + session_id: str | None = None, + description: str | None = None, + engine_configuration: dict[str, Any] | None = None, + notebook_version: str | None = None, + session_idle_timeout_minutes: int | None = None, max_workers: int = (cpu_count() or 1) * 5, **kwargs, ): @@ -96,7 +95,7 @@ def calculation_execution(self, query_id: str) -> "Future[AthenaCalculationExecu def get_std_out( self, calculation_execution: AthenaCalculationExecution - ) -> "Optional[Future[str]]": + ) -> "Future[str] | None": if not calculation_execution.std_out_s3_uri: return None return self._executor.submit( @@ -105,7 +104,7 @@ def get_std_out( def get_std_error( self, calculation_execution: AthenaCalculationExecution - ) -> "Optional[Future[str]]": + ) -> "Future[str] | None": if not calculation_execution.std_error_s3_uri: return None return self._executor.submit( @@ -120,13 +119,13 @@ def poll(self, query_id: str) -> "Future[AthenaCalculationExecution]": def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - session_id: Optional[str] = None, - description: Optional[str] = None, - client_request_token: Optional[str] = None, - work_group: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + session_id: str | None = None, + description: str | None = None, + client_request_token: str | None = None, + work_group: str | None = None, **kwargs, - ) -> Tuple[str, "Future[Union[AthenaQueryExecution, AthenaCalculationExecution]]"]: + ) -> tuple[str, "Future[AthenaQueryExecution | AthenaCalculationExecution]"]: calculation_id = self._calculate( session_id=session_id if session_id else self._session_id, code_block=operation, diff --git a/pyathena/spark/common.py b/pyathena/spark/common.py index ef599af2..959ad014 100644 --- a/pyathena/spark/common.py +++ b/pyathena/spark/common.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import time from abc import ABCMeta, abstractmethod from datetime import datetime -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast import botocore @@ -19,7 +18,7 @@ ) from pyathena.util import parse_output_location, retry_api_call -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class SparkBaseCursor(BaseCursor, metaclass=ABCMeta): @@ -50,11 +49,11 @@ class SparkBaseCursor(BaseCursor, metaclass=ABCMeta): def __init__( self, - session_id: Optional[str] = None, - description: Optional[str] = None, - engine_configuration: Optional[Dict[str, Any]] = None, - notebook_version: Optional[str] = None, - session_idle_timeout_minutes: Optional[int] = None, + session_id: str | None = None, + description: str | None = None, + engine_configuration: dict[str, Any] | None = None, + notebook_version: str | None = None, + session_idle_timeout_minutes: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -75,8 +74,8 @@ def __init__( else: self._session_id = self._start_session() - self._calculation_id: Optional[str] = None - self._calculation_execution: Optional[AthenaCalculationExecution] = None + self._calculation_id: str | None = None + self._calculation_execution: AthenaCalculationExecution | None = None self._client = self.connection.session.client( "s3", @@ -90,11 +89,11 @@ def session_id(self) -> str: return self._session_id @property - def calculation_id(self) -> Optional[str]: + def calculation_id(self) -> str | None: return self._calculation_id @staticmethod - def get_default_engine_configuration() -> Dict[str, Any]: + def get_default_engine_configuration() -> dict[str, Any]: return { "CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, @@ -113,7 +112,7 @@ def _read_s3_file_as_text(self, uri) -> str: return cast(str, response["Body"].read().decode("utf-8").strip()) def _get_session_status(self, session_id: str): - request: Dict[str, Any] = {"SessionId": session_id} + request: dict[str, Any] = {"SessionId": session_id} try: response = retry_api_call( self._connection.client.get_session_status, @@ -154,7 +153,7 @@ def _exists_session(self, session_id: str) -> bool: isinstance(e, botocore.exceptions.ClientError) and e.response["Error"]["Code"] == "InvalidRequestException" ): - _logger.exception(f"Session: {session_id} not found.") + _logger.exception("Session: %s not found.", session_id) return False raise OperationalError(*e.args) from e else: @@ -162,7 +161,7 @@ def _exists_session(self, session_id: str) -> bool: return True def _start_session(self) -> str: - request: Dict[str, Any] = { + request: dict[str, Any] = { "WorkGroup": self._work_group, "EngineConfiguration": self._engine_configuration, } @@ -199,7 +198,7 @@ def _terminate_session(self) -> None: _logger.exception("Failed to terminate session.") raise OperationalError(*e.args) from e - def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: calculation_status = self._get_calculation_execution_status(query_id) if calculation_status.state in [ @@ -210,7 +209,7 @@ def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculation return self._get_calculation_execution(query_id) time.sleep(self._poll_interval) - def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: + def _poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: try: query_execution = self.__poll(query_id) except KeyboardInterrupt as e: @@ -241,7 +240,7 @@ def close(self) -> None: def executemany( self, operation: str, - seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + seq_of_parameters: list[dict[str, Any] | list[str] | None], **kwargs, ) -> None: raise NotSupportedError @@ -278,7 +277,7 @@ def __init__(self): @property @abstractmethod - def calculation_execution(self) -> Optional[AthenaCalculationExecution]: + def calculation_execution(self) -> AthenaCalculationExecution | None: raise NotImplementedError # pragma: no cover @property @@ -288,77 +287,77 @@ def session_id(self) -> str: @property @abstractmethod - def calculation_id(self) -> Optional[str]: + def calculation_id(self) -> str | None: raise NotImplementedError # pragma: no cover @property - def description(self) -> Optional[str]: + def description(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.description @property - def working_directory(self) -> Optional[str]: + def working_directory(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.working_directory @property - def state(self) -> Optional[str]: + def state(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.state @property - def state_change_reason(self) -> Optional[str]: + def state_change_reason(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.state_change_reason @property - def submission_date_time(self) -> Optional[datetime]: + def submission_date_time(self) -> datetime | None: if not self.calculation_execution: return None return self.calculation_execution.submission_date_time @property - def completion_date_time(self) -> Optional[datetime]: + def completion_date_time(self) -> datetime | None: if not self.calculation_execution: return None return self.calculation_execution.completion_date_time @property - def dpu_execution_in_millis(self) -> Optional[int]: + def dpu_execution_in_millis(self) -> int | None: if not self.calculation_execution: return None return self.calculation_execution.dpu_execution_in_millis @property - def progress(self) -> Optional[str]: + def progress(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.progress @property - def std_out_s3_uri(self) -> Optional[str]: + def std_out_s3_uri(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.std_out_s3_uri @property - def std_error_s3_uri(self) -> Optional[str]: + def std_error_s3_uri(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.std_error_s3_uri @property - def result_s3_uri(self) -> Optional[str]: + def result_s3_uri(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.result_s3_uri @property - def result_type(self) -> Optional[str]: + def result_type(self) -> str | None: if not self.calculation_execution: return None return self.calculation_execution.result_type diff --git a/pyathena/spark/cursor.py b/pyathena/spark/cursor.py index 42d09593..c484731d 100644 --- a/pyathena/spark/cursor.py +++ b/pyathena/spark/cursor.py @@ -1,14 +1,13 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast from pyathena import OperationalError, ProgrammingError from pyathena.model import AthenaCalculationExecution, AthenaCalculationExecutionStatus from pyathena.spark.common import SparkBaseCursor, WithCalculationExecution -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) class SparkCursor(SparkBaseCursor, WithCalculationExecution): @@ -59,11 +58,11 @@ class SparkCursor(SparkBaseCursor, WithCalculationExecution): def __init__( self, - session_id: Optional[str] = None, - description: Optional[str] = None, - engine_configuration: Optional[Dict[str, Any]] = None, - notebook_version: Optional[str] = None, - session_idle_timeout_minutes: Optional[int] = None, + session_id: str | None = None, + description: str | None = None, + engine_configuration: dict[str, Any] | None = None, + notebook_version: str | None = None, + session_idle_timeout_minutes: int | None = None, **kwargs, ) -> None: super().__init__( @@ -76,10 +75,10 @@ def __init__( ) @property - def calculation_execution(self) -> Optional[AthenaCalculationExecution]: + def calculation_execution(self) -> AthenaCalculationExecution | None: return self._calculation_execution - def get_std_out(self) -> Optional[str]: + def get_std_out(self) -> str | None: """Get the standard output from the Spark calculation execution. Retrieves and returns the contents of the standard output generated @@ -93,7 +92,7 @@ def get_std_out(self) -> Optional[str]: return None return self._read_s3_file_as_text(self._calculation_execution.std_out_s3_uri) - def get_std_error(self) -> Optional[str]: + def get_std_error(self) -> str | None: """Get the standard error from the Spark calculation execution. Retrieves and returns the contents of the standard error generated @@ -111,11 +110,11 @@ def get_std_error(self) -> Optional[str]: def execute( self, operation: str, - parameters: Optional[Union[Dict[str, Any], List[str]]] = None, - session_id: Optional[str] = None, - description: Optional[str] = None, - client_request_token: Optional[str] = None, - work_group: Optional[str] = None, + parameters: dict[str, Any] | list[str] | None = None, + session_id: str | None = None, + description: str | None = None, + client_request_token: str | None = None, + work_group: str | None = None, **kwargs, ) -> SparkCursor: self._calculation_id = self._calculate( diff --git a/pyathena/sqlalchemy/__init__.py b/pyathena/sqlalchemy/__init__.py index 40a96afc..e69de29b 100644 --- a/pyathena/sqlalchemy/__init__.py +++ b/pyathena/sqlalchemy/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/pyathena/sqlalchemy/arrow.py b/pyathena/sqlalchemy/arrow.py index 34dde6da..da314859 100644 --- a/pyathena/sqlalchemy/arrow.py +++ b/pyathena/sqlalchemy/arrow.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.sqlalchemy.base import AthenaDialect diff --git a/pyathena/sqlalchemy/base.py b/pyathena/sqlalchemy/base.py index e57052c1..917f9f53 100644 --- a/pyathena/sqlalchemy/base.py +++ b/pyathena/sqlalchemy/base.py @@ -1,20 +1,12 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import contextlib import re +from collections.abc import Mapping, MutableMapping +from re import Pattern from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, - MutableMapping, - Optional, - Pattern, - Tuple, - Type, - Union, cast, ) @@ -64,7 +56,7 @@ from sqlalchemy.sql.schema import SchemaItem -ischema_names: Dict[str, Type[Any]] = { +ischema_names: dict[str, type[Any]] = { "boolean": types.BOOLEAN, "float": types.FLOAT, "double": get_double_type(), @@ -137,28 +129,26 @@ class AthenaDialect(DefaultDialect): """ name: str = "awsathena" - preparer: Type[IdentifierPreparer] = AthenaDMLIdentifierPreparer - statement_compiler: Type[SQLCompiler] = AthenaStatementCompiler - ddl_compiler: Type[DDLCompiler] = AthenaDDLCompiler - type_compiler: Type[GenericTypeCompiler] = AthenaTypeCompiler + preparer: type[IdentifierPreparer] = AthenaDMLIdentifierPreparer + statement_compiler: type[SQLCompiler] = AthenaStatementCompiler + ddl_compiler: type[DDLCompiler] = AthenaDDLCompiler + type_compiler: type[GenericTypeCompiler] = AthenaTypeCompiler default_paramstyle: str = pyathena.paramstyle cte_follows_insert: bool = True supports_alter: bool = False - supports_pk_autoincrement: Optional[bool] = False + supports_pk_autoincrement: bool | None = False supports_default_values: bool = False supports_empty_insert: bool = False supports_multivalues_insert: bool = True supports_native_decimal: bool = True supports_native_boolean: bool = True - supports_unicode_statements: Optional[bool] = True - supports_unicode_binds: Optional[bool] = True + supports_unicode_statements: bool | None = True + supports_unicode_binds: bool | None = True supports_statement_cache: bool = True - returns_unicode_strings: Optional[bool] = True - description_encoding: Optional[bool] = None + returns_unicode_strings: bool | None = True + description_encoding: bool | None = None postfetch_lastrowid: bool = False - construct_arguments: Optional[ - List[Tuple[Type[Union["SchemaItem", "ClauseElement"]], Mapping[str, Any]]] - ] = [ + construct_arguments: list[tuple[type[SchemaItem | ClauseElement], Mapping[str, Any]]] | None = [ # noqa: RUF012 ( schema.Table, { @@ -183,15 +173,15 @@ class AthenaDialect(DefaultDialect): ), ] - colspecs = { + colspecs: dict[type[Any], type[Any]] = { # noqa: RUF012 types.DATE: AthenaDate, types.DATETIME: AthenaTimestamp, types.TIMESTAMP: AthenaTimestamp, } - ischema_names: Dict[str, Type[Any]] = ischema_names + ischema_names: dict[str, type[Any]] = ischema_names - _connect_options: Dict[str, Any] = {} # type: ignore + _connect_options: dict[str, Any] = {} # type: ignore[override] # noqa: RUF012 _pattern_column_type: Pattern[str] = re.compile(r"^([a-zA-Z]+)(?:$|[\(|<](.+)[\)|>]$)") def __init__(self, json_deserializer=None, json_serializer=None, **kwargs): @@ -200,28 +190,28 @@ def __init__(self, json_deserializer=None, json_serializer=None, **kwargs): self._json_serializer = json_serializer @classmethod - def import_dbapi(cls) -> "ModuleType": + def import_dbapi(cls) -> ModuleType: return pyathena @classmethod - def dbapi(cls) -> "ModuleType": # type: ignore + def dbapi(cls) -> ModuleType: # type: ignore[override] return pyathena - def _raw_connection(self, connection: Union[Engine, "Connection"]) -> "PoolProxiedConnection": + def _raw_connection(self, connection: Engine | Connection) -> PoolProxiedConnection: if isinstance(connection, Engine): return connection.raw_connection() return connection.connection - def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[str, Any]]: + def create_connect_args(self, url: URL) -> tuple[tuple[str], MutableMapping[str, Any]]: # Connection string format: # awsathena+rest:// # {aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/ # {schema_name}?s3_staging_dir={s3_staging_dir}&... self._connect_options = self._create_connect_args(url) - return cast(Tuple[str], ()), self._connect_options + return cast(tuple[str], ()), self._connect_options - def _create_connect_args(self, url: "URL") -> Dict[str, Any]: - opts: Dict[str, Any] = { + def _create_connect_args(self, url: URL) -> dict[str, Any]: + opts: dict[str, Any] = { "aws_access_key_id": url.username if url.username else None, "aws_secret_access_key": url.password if url.password else None, "region_name": re.sub( @@ -253,8 +243,8 @@ def _create_connect_args(self, url: "URL") -> Dict[str, Any]: @reflection.cache def _get_schemas(self, connection, **kw): raw_connection = self._raw_connection(connection) - catalog = raw_connection.catalog_name # type: ignore - with raw_connection.driver_connection.cursor() as cursor: # type: ignore + catalog = raw_connection.catalog_name # type: ignore[union-attr] + with raw_connection.driver_connection.cursor() as cursor: # type: ignore[union-attr] try: return cursor.list_databases(catalog) except pyathena.error.OperationalError as e: @@ -267,10 +257,10 @@ def _get_schemas(self, connection, **kw): raise @reflection.cache - def _get_table(self, connection, table_name: str, schema: Optional[str] = None, **kw): + def _get_table(self, connection, table_name: str, schema: str | None = None, **kw): raw_connection = self._raw_connection(connection) - schema = schema if schema else raw_connection.schema_name # type: ignore - with raw_connection.driver_connection.cursor() as cursor: # type: ignore + schema = schema if schema else raw_connection.schema_name # type: ignore[union-attr] + with raw_connection.driver_connection.cursor() as cursor: # type: ignore[union-attr] try: return cursor.get_table_metadata(table_name, schema_name=schema, logging_=False) except pyathena.error.OperationalError as e: @@ -283,17 +273,17 @@ def _get_table(self, connection, table_name: str, schema: Optional[str] = None, raise @reflection.cache - def _get_tables(self, connection, schema: Optional[str] = None, **kw): + def _get_tables(self, connection, schema: str | None = None, **kw): raw_connection = self._raw_connection(connection) - schema = schema if schema else raw_connection.schema_name # type: ignore - with raw_connection.driver_connection.cursor() as cursor: # type: ignore + schema = schema if schema else raw_connection.schema_name # type: ignore[union-attr] + with raw_connection.driver_connection.cursor() as cursor: # type: ignore[union-attr] return cursor.list_table_metadata(schema_name=schema) def get_schema_names(self, connection, **kw): schemas = self._get_schemas(connection, **kw) return [s.name for s in schemas] - def get_table_names(self, connection: "Connection", schema: Optional[str] = None, **kw): + def get_table_names(self, connection: Connection, schema: str | None = None, **kw): # Tables created by Athena are always classified as `EXTERNAL_TABLE`, # but Athena can also query tables classified as `MANAGED_TABLE`, `EXTERNAL`, or `customer`. # Managed Tables are created by default when creating tables via Spark when @@ -307,18 +297,18 @@ def get_table_names(self, connection: "Connection", schema: Optional[str] = None if t.table_type in ["EXTERNAL_TABLE", "MANAGED_TABLE", "EXTERNAL", "customer"] ] - def get_view_names(self, connection: "Connection", schema: Optional[str] = None, **kw): + def get_view_names(self, connection: Connection, schema: str | None = None, **kw): tables = self._get_tables(connection, schema, **kw) return [t.name for t in tables if t.table_type == "VIRTUAL_VIEW"] def get_table_comment( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw + self, connection: Connection, table_name: str, schema: str | None = None, **kw ): metadata = self._get_table(connection, table_name, schema=schema, **kw) return {"text": metadata.comment} def get_table_options( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw + self, connection: Connection, table_name: str, schema: str | None = None, **kw ): metadata = self._get_table(connection, table_name, schema=schema, **kw) # TODO The metadata retrieved from the API does not seem to include bucketing information. @@ -331,9 +321,7 @@ def get_table_options( "awsathena_tblproperties": _HashableDict(metadata.table_properties), } - def has_table( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw - ): + def has_table(self, connection: Connection, table_name: str, schema: str | None = None, **kw): try: columns = self.get_columns(connection, table_name, schema) return bool(columns) @@ -342,10 +330,10 @@ def has_table( @reflection.cache def get_view_definition( - self, connection: Connection, view_name: str, schema: Optional[str] = None, **kw + self, connection: Connection, view_name: str, schema: str | None = None, **kw ): raw_connection = self._raw_connection(connection) - schema = schema if schema else raw_connection.schema_name # type: ignore + schema = schema if schema else raw_connection.schema_name # type: ignore[union-attr] query = f"""SHOW CREATE VIEW "{schema}"."{view_name}";""" try: res = connection.scalars(text(query)) @@ -355,9 +343,7 @@ def get_view_definition( return "\n".join(res) @reflection.cache - def get_columns( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw - ): + def get_columns(self, connection: Connection, table_name: str, schema: str | None = None, **kw): metadata = self._get_table(connection, table_name, schema=schema, **kw) columns = [ { @@ -411,20 +397,20 @@ def _get_column_type(self, type_: str): return col_type(*args) def get_foreign_keys( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw - ) -> List["ReflectedForeignKeyConstraint"]: + self, connection: Connection, table_name: str, schema: str | None = None, **kw + ) -> list[ReflectedForeignKeyConstraint]: # Athena has no support for foreign keys. return [] # pragma: no cover def get_pk_constraint( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw - ) -> "ReflectedPrimaryKeyConstraint": + self, connection: Connection, table_name: str, schema: str | None = None, **kw + ) -> ReflectedPrimaryKeyConstraint: # Athena has no support for primary keys. return {"name": None, "constrained_columns": []} # pragma: no cover def get_indexes( - self, connection: "Connection", table_name: str, schema: Optional[str] = None, **kw - ) -> List["ReflectedIndex"]: + self, connection: Connection, table_name: str, schema: str | None = None, **kw + ) -> list[ReflectedIndex]: # Athena has no support for indexes. return [] # pragma: no cover @@ -440,16 +426,16 @@ def do_execute(self, cursor, statement, parameters, context=None): else: cursor.execute(statement, parameters) - def do_rollback(self, dbapi_connection: "PoolProxiedConnection") -> None: + def do_rollback(self, dbapi_connection: PoolProxiedConnection) -> None: # No transactions for Athena pass # pragma: no cover def _check_unicode_returns( - self, connection: "Connection", additional_tests: Optional[List[Any]] = None + self, connection: Connection, additional_tests: list[Any] | None = None ) -> bool: # Requests gives back Unicode strings return True # pragma: no cover - def _check_unicode_description(self, connection: "Connection") -> bool: + def _check_unicode_description(self, connection: Connection) -> bool: # Requests gives back Unicode strings return True # pragma: no cover diff --git a/pyathena/sqlalchemy/compiler.py b/pyathena/sqlalchemy/compiler.py index d3168fdb..f45ab310 100644 --- a/pyathena/sqlalchemy/compiler.py +++ b/pyathena/sqlalchemy/compiler.py @@ -1,7 +1,7 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from collections.abc import Mapping +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import exc, types, util from sqlalchemy.sql.compiler import ( @@ -61,111 +61,111 @@ class AthenaTypeCompiler(GenericTypeCompiler): https://docs.aws.amazon.com/athena/latest/ug/data-types.html """ - def visit_FLOAT(self, type_: types.Float[Any], **kw: Any) -> str: # noqa: N802 + def visit_FLOAT(self, type_: types.Float[Any], **kw: Any) -> str: return self.visit_REAL(type_, **kw) # type: ignore[arg-type] - def visit_REAL(self, type_: types.REAL[Any], **kw: Any) -> str: # noqa: N802 + def visit_REAL(self, type_: types.REAL[Any], **kw: Any) -> str: return "FLOAT" - def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802 + def visit_DOUBLE(self, type_, **kw) -> str: return "DOUBLE" - def visit_DOUBLE_PRECISION(self, type_, **kw) -> str: # noqa: N802 + def visit_DOUBLE_PRECISION(self, type_, **kw) -> str: return "DOUBLE" - def visit_NUMERIC(self, type_: types.Numeric[Any], **kw: Any) -> str: # noqa: N802 + def visit_NUMERIC(self, type_: types.Numeric[Any], **kw: Any) -> str: return self.visit_DECIMAL(type_, **kw) # type: ignore[arg-type] - def visit_DECIMAL(self, type_: types.DECIMAL[Any], **kw: Any) -> str: # noqa: N802 + def visit_DECIMAL(self, type_: types.DECIMAL[Any], **kw: Any) -> str: if type_.precision is None: return "DECIMAL" if type_.scale is None: return f"DECIMAL({type_.precision})" return f"DECIMAL({type_.precision}, {type_.scale})" - def visit_TINYINT(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802 + def visit_TINYINT(self, type_: types.Integer, **kw: Any) -> str: return "TINYINT" - def visit_INTEGER(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802 + def visit_INTEGER(self, type_: types.Integer, **kw: Any) -> str: return "INTEGER" - def visit_SMALLINT(self, type_: types.SmallInteger, **kw: Any) -> str: # noqa: N802 + def visit_SMALLINT(self, type_: types.SmallInteger, **kw: Any) -> str: return "SMALLINT" - def visit_BIGINT(self, type_: types.BigInteger, **kw: Any) -> str: # noqa: N802 + def visit_BIGINT(self, type_: types.BigInteger, **kw: Any) -> str: return "BIGINT" - def visit_TIMESTAMP(self, type_: types.TIMESTAMP, **kw: Any) -> str: # noqa: N802 + def visit_TIMESTAMP(self, type_: types.TIMESTAMP, **kw: Any) -> str: return "TIMESTAMP" - def visit_DATETIME(self, type_: types.DateTime, **kw: Any) -> str: # noqa: N802 + def visit_DATETIME(self, type_: types.DateTime, **kw: Any) -> str: return self.visit_TIMESTAMP(type_, **kw) # type: ignore[arg-type] - def visit_DATE(self, type_: types.Date, **kw: Any) -> str: # noqa: N802 + def visit_DATE(self, type_: types.Date, **kw: Any) -> str: return "DATE" - def visit_TIME(self, type_: types.Time, **kw: Any) -> str: # noqa: N802 + def visit_TIME(self, type_: types.Time, **kw: Any) -> str: raise exc.CompileError(f"Data type `{type_}` is not supported") - def visit_CLOB(self, type_: types.CLOB, **kw: Any) -> str: # noqa: N802 + def visit_CLOB(self, type_: types.CLOB, **kw: Any) -> str: return self.visit_BINARY(type_, **kw) # type: ignore[arg-type] - def visit_NCLOB(self, type_: types.Text, **kw: Any) -> str: # noqa: N802 + def visit_NCLOB(self, type_: types.Text, **kw: Any) -> str: return self.visit_BINARY(type_, **kw) # type: ignore[arg-type] - def visit_CHAR(self, type_: types.CHAR, **kw: Any) -> str: # noqa: N802 + def visit_CHAR(self, type_: types.CHAR, **kw: Any) -> str: if type_.length: return self._render_string_type("CHAR", type_.length, type_.collation) return "STRING" - def visit_NCHAR(self, type_: types.NCHAR, **kw: Any) -> str: # noqa: N802 + def visit_NCHAR(self, type_: types.NCHAR, **kw: Any) -> str: return self.visit_CHAR(type_, **kw) # type: ignore[arg-type] - def visit_VARCHAR(self, type_: types.String, **kw: Any) -> str: # noqa: N802 + def visit_VARCHAR(self, type_: types.String, **kw: Any) -> str: if type_.length: return self._render_string_type("VARCHAR", type_.length, type_.collation) return "STRING" - def visit_NVARCHAR(self, type_: types.NVARCHAR, **kw: Any) -> str: # noqa: N802 + def visit_NVARCHAR(self, type_: types.NVARCHAR, **kw: Any) -> str: return self.visit_VARCHAR(type_, **kw) # type: ignore[arg-type] - def visit_TEXT(self, type_: types.Text, **kw: Any) -> str: # noqa: N802 + def visit_TEXT(self, type_: types.Text, **kw: Any) -> str: return "STRING" - def visit_BLOB(self, type_: types.LargeBinary, **kw: Any) -> str: # noqa: N802 + def visit_BLOB(self, type_: types.LargeBinary, **kw: Any) -> str: return self.visit_BINARY(type_, **kw) # type: ignore[arg-type] - def visit_BINARY(self, type_: types.BINARY, **kw: Any) -> str: # noqa: N802 + def visit_BINARY(self, type_: types.BINARY, **kw: Any) -> str: return "BINARY" - def visit_VARBINARY(self, type_: types.VARBINARY, **kw: Any) -> str: # noqa: N802 + def visit_VARBINARY(self, type_: types.VARBINARY, **kw: Any) -> str: return self.visit_BINARY(type_, **kw) # type: ignore[arg-type] - def visit_BOOLEAN(self, type_: types.Boolean, **kw: Any) -> str: # noqa: N802 + def visit_BOOLEAN(self, type_: types.Boolean, **kw: Any) -> str: return "BOOLEAN" - def visit_JSON(self, type_: types.JSON, **kw: Any) -> str: # noqa: N802 + def visit_JSON(self, type_: types.JSON, **kw: Any) -> str: return "JSON" - def visit_string(self, type_, **kw): # noqa: N802 + def visit_string(self, type_, **kw): return "STRING" - def visit_unicode(self, type_, **kw): # noqa: N802 + def visit_unicode(self, type_, **kw): return "STRING" - def visit_unicode_text(self, type_, **kw): # noqa: N802 + def visit_unicode_text(self, type_, **kw): return "STRING" - def visit_null(self, type_, **kw): # noqa: N802 + def visit_null(self, type_, **kw): return "NULL" - def visit_tinyint(self, type_, **kw): # noqa: N802 + def visit_tinyint(self, type_, **kw): return self.visit_TINYINT(type_, **kw) def visit_enum(self, type_, **kw): return self.visit_string(type_, **kw) - def visit_struct(self, type_, **kw): # noqa: N802 + def visit_struct(self, type_, **kw): if isinstance(type_, AthenaStruct): if type_.fields: field_specs = [] @@ -176,26 +176,26 @@ def visit_struct(self, type_, **kw): # noqa: N802 return "ROW()" return "ROW()" - def visit_STRUCT(self, type_, **kw): # noqa: N802 + def visit_STRUCT(self, type_, **kw): return self.visit_struct(type_, **kw) - def visit_map(self, type_, **kw): # noqa: N802 + def visit_map(self, type_, **kw): if isinstance(type_, AthenaMap): key_type_str = self.process(type_.key_type, **kw) value_type_str = self.process(type_.value_type, **kw) return f"MAP<{key_type_str}, {value_type_str}>" return "MAP" - def visit_MAP(self, type_, **kw): # noqa: N802 + def visit_MAP(self, type_, **kw): return self.visit_map(type_, **kw) - def visit_array(self, type_, **kw): # noqa: N802 + def visit_array(self, type_, **kw): if isinstance(type_, AthenaArray): item_type_str = self.process(type_.item_type, **kw) return f"ARRAY<{item_type_str}>" return "ARRAY" - def visit_ARRAY(self, type_, **kw): # noqa: N802 + def visit_ARRAY(self, type_, **kw): return self.visit_array(type_, **kw) @@ -219,10 +219,10 @@ class AthenaStatementCompiler(SQLCompiler): https://docs.aws.amazon.com/athena/latest/ug/ddl-sql-reference.html """ - def visit_char_length_func(self, fn: "Function[Any]", **kw: Any) -> str: + def visit_char_length_func(self, fn: Function[Any], **kw: Any) -> str: return f"length{self.function_argspec(fn, **kw)}" - def visit_filter_func(self, fn: "Function[Any]", **kw: Any) -> str: + def visit_filter_func(self, fn: Function[Any], **kw: Any) -> str: """Compile Athena filter() function with lambda expressions. Supports syntax: filter(array_expr, lambda_expr) @@ -249,7 +249,7 @@ def visit_filter_func(self, fn: "Function[Any]", **kw: Any) -> str: return f"filter({array_sql}, {lambda_sql})" - def visit_cast(self, cast: "Cast[Any]", **kwargs): + def visit_cast(self, cast: Cast[Any], **kwargs): if (isinstance(cast.type, types.VARCHAR) and cast.type.length is None) or isinstance( cast.type, types.String ): @@ -267,7 +267,7 @@ def visit_cast(self, cast: "Cast[Any]", **kwargs): type_clause = cast.typeclause._compiler_dispatch(self, **kwargs) return f"CAST({cast.clause._compiler_dispatch(self, **kwargs)} AS {type_clause})" - def limit_clause(self, select: "GenerativeSelect", **kw): + def limit_clause(self, select: GenerativeSelect, **kw): text = [] if select._offset_clause is not None: text.append(" OFFSET " + self.process(select._offset_clause, **kw)) @@ -344,11 +344,11 @@ def preparer(self, value: IdentifierPreparer): def __init__( self, - dialect: "AthenaDialect", - statement: "CreateTable", - schema_translate_map: Optional[Dict[Optional[str], Optional[str]]] = None, + dialect: AthenaDialect, + statement: CreateTable, + schema_translate_map: dict[str | None, str | None] | None = None, render_schema_translate: bool = False, - compile_kwargs: Optional[Dict[str, Any]] = None, + compile_kwargs: dict[str, Any] | None = None, ): self._preparer = AthenaDDLIdentifierPreparer(dialect) super().__init__( @@ -370,8 +370,8 @@ def _get_comment_specification(self, comment: str) -> str: return f"COMMENT {self._escape_comment(comment)}" def _get_bucket_count( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[str]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | None: if dialect_opts["bucket_count"]: bucket_count = dialect_opts["bucket_count"] elif connect_opts: @@ -381,18 +381,18 @@ def _get_bucket_count( return cast(str, bucket_count) if bucket_count is not None else None def _get_file_format( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[str]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | None: if dialect_opts["file_format"]: file_format = dialect_opts["file_format"] elif connect_opts: file_format = connect_opts.get("file_format") else: file_format = None - return cast(Optional[str], file_format) + return cast(str | None, file_format) def _get_file_format_specification( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] ) -> str: file_format = self._get_file_format(dialect_opts, connect_opts) text = [] @@ -401,18 +401,18 @@ def _get_file_format_specification( return "\n".join(text) def _get_row_format( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[str]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | None: if dialect_opts["row_format"]: row_format = dialect_opts["row_format"] elif connect_opts: row_format = connect_opts.get("row_format") else: row_format = None - return cast(Optional[str], row_format) + return cast(str | None, row_format) def _get_row_format_specification( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] ) -> str: row_format = self._get_row_format(dialect_opts, connect_opts) text = [] @@ -421,18 +421,18 @@ def _get_row_format_specification( return "\n".join(text) def _get_serde_properties( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[Union[str, Dict[str, Any]]]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | dict[str, Any] | None: if dialect_opts["serdeproperties"]: serde_properties = dialect_opts["serdeproperties"] elif connect_opts: serde_properties = connect_opts.get("serdeproperties") else: serde_properties = None - return cast(Optional[str], serde_properties) + return cast(str | None, serde_properties) def _get_serde_properties_specification( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] ) -> str: serde_properties = self._get_serde_properties(dialect_opts, connect_opts) text = [] @@ -446,8 +446,8 @@ def _get_serde_properties_specification( return "\n".join(text) def _get_table_location( - self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[str]: + self, table: Table, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | None: if dialect_opts["location"]: location = cast(str, dialect_opts["location"]) location += "/" if not location.endswith("/") else "" @@ -464,7 +464,7 @@ def _get_table_location( return location def _get_table_location_specification( - self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] + self, table: Table, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] ) -> str: location = self._get_table_location(table, dialect_opts, connect_opts) text = [] @@ -482,8 +482,8 @@ def _get_table_location_specification( return "\n".join(text) def _get_table_properties( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[Union[Dict[str, str], str]]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> dict[str, str] | str | None: if dialect_opts["tblproperties"]: table_properties = cast(str, dialect_opts["tblproperties"]) elif connect_opts: @@ -493,8 +493,8 @@ def _get_table_properties( return table_properties def _get_compression( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] - ) -> Optional[str]: + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] + ) -> str | None: if dialect_opts["compression"]: compression = cast(str, dialect_opts["compression"]) elif connect_opts: @@ -504,7 +504,7 @@ def _get_compression( return compression def _get_table_properties_specification( - self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any] + self, dialect_opts: _DialectArgDict, connect_opts: Mapping[str, Any] ) -> str: properties = self._get_table_properties(dialect_opts, connect_opts) if properties: @@ -541,7 +541,7 @@ def _get_table_properties_specification( text.append(")") return "\n".join(text) - def get_column_specification(self, column: "Column[Any]", **kwargs) -> str: + def get_column_specification(self, column: Column[Any], **kwargs) -> str: if type(column.type) in [types.Integer, types.INTEGER, types.INT]: # https://docs.aws.amazon.com/athena/latest/ug/create-table.html # In Data Definition Language (DDL) queries like CREATE TABLE, @@ -554,22 +554,22 @@ def get_column_specification(self, column: "Column[Any]", **kwargs) -> str: text.append(f"{self._get_comment_specification(column.comment)}") return " ".join(text) - def visit_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str: + def visit_check_constraint(self, constraint: CheckConstraint, **kw: Any) -> str: return "" - def visit_column_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str: + def visit_column_check_constraint(self, constraint: CheckConstraint, **kw: Any) -> str: return "" - def visit_foreign_key_constraint(self, constraint: "ForeignKeyConstraint", **kw: Any) -> str: + def visit_foreign_key_constraint(self, constraint: ForeignKeyConstraint, **kw: Any) -> str: return "" - def visit_primary_key_constraint(self, constraint: "PrimaryKeyConstraint", **kw: Any) -> str: + def visit_primary_key_constraint(self, constraint: PrimaryKeyConstraint, **kw: Any) -> str: return "" - def visit_unique_constraint(self, constraint: "UniqueConstraint", **kw: Any) -> str: + def visit_unique_constraint(self, constraint: UniqueConstraint, **kw: Any) -> str: return "" - def _get_connect_option_partitions(self, connect_opts: Mapping[str, Any]) -> List[str]: + def _get_connect_option_partitions(self, connect_opts: Mapping[str, Any]) -> list[str]: if connect_opts: partition = cast(str, connect_opts.get("partition")) partitions = partition.split(",") if partition else [] @@ -577,7 +577,7 @@ def _get_connect_option_partitions(self, connect_opts: Mapping[str, Any]) -> Lis partitions = [] return partitions - def _get_connect_option_buckets(self, connect_opts: Mapping[str, Any]) -> List[str]: + def _get_connect_option_buckets(self, connect_opts: Mapping[str, Any]) -> list[str]: if connect_opts: bucket = cast(str, connect_opts.get("cluster")) buckets = bucket.split(",") if bucket else [] @@ -617,11 +617,11 @@ def _prepared_partitions(self, column: Column[Any]): def _prepared_columns( self, - table: "Table", + table: Table, is_iceberg: bool, - create_columns: List["CreateColumn"], + create_columns: list[CreateColumn], connect_opts: Mapping[str, Any], - ) -> Tuple[List[str], List[str], List[str]]: + ) -> tuple[list[str], list[str], list[str]]: columns, partitions, buckets = [], [], [] conn_partitions = self._get_connect_option_partitions(connect_opts) conn_buckets = self._get_connect_option_buckets(connect_opts) @@ -656,7 +656,7 @@ def _prepared_columns( ) from e return columns, partitions, buckets - def visit_create_table(self, create: "CreateTable", **kwargs) -> str: + def visit_create_table(self, create: CreateTable, **kwargs) -> str: table = create.element dialect_opts = table.dialect_options["awsathena"] dialect = cast("AthenaDialect", self.dialect) @@ -701,8 +701,8 @@ def visit_create_table(self, create: "CreateTable", **kwargs) -> str: text.append(f"{self.post_create_table(table)}\n") return "\n".join(text) - def post_create_table(self, table: "Table") -> str: - dialect_opts: "_DialectArgDict" = table.dialect_options["awsathena"] + def post_create_table(self, table: Table) -> str: + dialect_opts: _DialectArgDict = table.dialect_options["awsathena"] dialect = cast("AthenaDialect", self.dialect) connect_opts = dialect._connect_options text = [ diff --git a/pyathena/sqlalchemy/constants.py b/pyathena/sqlalchemy/constants.py index 5ce4fa26..f0a9229e 100644 --- a/pyathena/sqlalchemy/constants.py +++ b/pyathena/sqlalchemy/constants.py @@ -1,12 +1,9 @@ -# -*- coding: utf-8 -*- """Constants for PyAthena SQLAlchemy dialect.""" from __future__ import annotations -from typing import Set - # https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html#list-of-ddl-reserved-words -DDL_RESERVED_WORDS: Set[str] = { +DDL_RESERVED_WORDS: set[str] = { "all", "alter", "and", @@ -147,7 +144,7 @@ } # https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html#list-of-reserved-words-sql-select -SELECT_STATEMENT_RESERVED_WORDS: Set[str] = { +SELECT_STATEMENT_RESERVED_WORDS: set[str] = { "all", "and", "any", @@ -258,4 +255,4 @@ "with", } -RESERVED_WORDS: Set[str] = set(DDL_RESERVED_WORDS | SELECT_STATEMENT_RESERVED_WORDS) +RESERVED_WORDS: set[str] = set(DDL_RESERVED_WORDS | SELECT_STATEMENT_RESERVED_WORDS) diff --git a/pyathena/sqlalchemy/pandas.py b/pyathena/sqlalchemy/pandas.py index 30383eea..86a2da9f 100644 --- a/pyathena/sqlalchemy/pandas.py +++ b/pyathena/sqlalchemy/pandas.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.sqlalchemy.base import AthenaDialect @@ -53,7 +52,7 @@ def create_connect_args(self, url): if "engine" in opts: cursor_kwargs.update({"engine": opts.pop("engine")}) if "chunksize" in opts: - cursor_kwargs.update({"chunksize": int(opts.pop("chunksize"))}) # type: ignore + cursor_kwargs.update({"chunksize": int(opts.pop("chunksize"))}) # type: ignore[dict-item] if cursor_kwargs: opts.update({"cursor_kwargs": cursor_kwargs}) return [[], opts] diff --git a/pyathena/sqlalchemy/polars.py b/pyathena/sqlalchemy/polars.py index 702ae522..43544a6b 100644 --- a/pyathena/sqlalchemy/polars.py +++ b/pyathena/sqlalchemy/polars.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.sqlalchemy.base import AthenaDialect diff --git a/pyathena/sqlalchemy/preparer.py b/pyathena/sqlalchemy/preparer.py index 3664027c..f884ff6a 100644 --- a/pyathena/sqlalchemy/preparer.py +++ b/pyathena/sqlalchemy/preparer.py @@ -1,7 +1,6 @@ -# -*- coding: utf-8 -*- from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Set +from typing import TYPE_CHECKING from sqlalchemy.sql.compiler import ILLEGAL_INITIAL_CHARACTERS, IdentifierPreparer @@ -27,7 +26,7 @@ class AthenaDMLIdentifierPreparer(IdentifierPreparer): https://docs.aws.amazon.com/athena/latest/ug/reserved-words.html """ - reserved_words: Set[str] = SELECT_STATEMENT_RESERVED_WORDS + reserved_words: set[str] = SELECT_STATEMENT_RESERVED_WORDS class AthenaDDLIdentifierPreparer(IdentifierPreparer): @@ -53,9 +52,9 @@ class AthenaDDLIdentifierPreparer(IdentifierPreparer): def __init__( self, - dialect: "Dialect", + dialect: Dialect, initial_quote: str = "`", - final_quote: Optional[str] = None, + final_quote: str | None = None, escape_quote: str = "`", quote_case_sensitive_collations: bool = True, omit_schema: bool = False, diff --git a/pyathena/sqlalchemy/requirements.py b/pyathena/sqlalchemy/requirements.py index ed075b25..e600a949 100644 --- a/pyathena/sqlalchemy/requirements.py +++ b/pyathena/sqlalchemy/requirements.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from sqlalchemy.testing import exclusions from sqlalchemy.testing.requirements import SuiteRequirements diff --git a/pyathena/sqlalchemy/rest.py b/pyathena/sqlalchemy/rest.py index 67d56afc..d641accc 100644 --- a/pyathena/sqlalchemy/rest.py +++ b/pyathena/sqlalchemy/rest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.sqlalchemy.base import AthenaDialect diff --git a/pyathena/sqlalchemy/s3fs.py b/pyathena/sqlalchemy/s3fs.py index 3e35f514..5ea8bdea 100644 --- a/pyathena/sqlalchemy/s3fs.py +++ b/pyathena/sqlalchemy/s3fs.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import TYPE_CHECKING from pyathena.sqlalchemy.base import AthenaDialect diff --git a/pyathena/sqlalchemy/types.py b/pyathena/sqlalchemy/types.py index cba0425e..23cf0690 100644 --- a/pyathena/sqlalchemy/types.py +++ b/pyathena/sqlalchemy/types.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- from __future__ import annotations from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any from sqlalchemy import types from sqlalchemy.sql import sqltypes @@ -13,7 +12,7 @@ from sqlalchemy.sql.type_api import _LiteralProcessorType -def get_double_type() -> Type[Any]: +def get_double_type() -> type[Any]: """Get the appropriate type for DOUBLE based on SQLAlchemy version. SQLAlchemy 2.0+ provides a native DOUBLE type, while earlier versions @@ -51,12 +50,12 @@ class AthenaTimestamp(TypeEngine[datetime]): render_bind_cast = True @staticmethod - def process(value: Optional[Union[datetime, Any]]) -> str: + def process(value: datetime | Any | None) -> str: if isinstance(value, datetime): return f"""TIMESTAMP '{value.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]}'""" - return f"TIMESTAMP '{str(value)}'" + return f"TIMESTAMP '{value!s}'" - def literal_processor(self, dialect: "Dialect") -> Optional["_LiteralProcessorType[datetime]"]: + def literal_processor(self, dialect: Dialect) -> _LiteralProcessorType[datetime] | None: return self.process @@ -80,12 +79,12 @@ class AthenaDate(TypeEngine[date]): render_bind_cast = True @staticmethod - def process(value: Union[date, datetime, Any]) -> str: + def process(value: date | datetime | Any) -> str: if isinstance(value, (date, datetime)): f"DATE '{value:%Y-%m-%d}'" - return f"DATE '{str(value)}'" + return f"DATE '{value!s}'" - def literal_processor(self, dialect: "Dialect") -> Optional["_LiteralProcessorType[date]"]: + def literal_processor(self, dialect: Dialect) -> _LiteralProcessorType[date] | None: return self.process @@ -108,7 +107,7 @@ class TINYINT(Tinyint): __visit_name__ = "TINYINT" -class AthenaStruct(TypeEngine[Dict[str, Any]]): +class AthenaStruct(TypeEngine[dict[str, Any]]): """SQLAlchemy type for Athena STRUCT/ROW complex type. STRUCT represents a record with named fields, similar to a database row @@ -139,8 +138,8 @@ class AthenaStruct(TypeEngine[Dict[str, Any]]): __visit_name__ = "struct" - def __init__(self, *fields: Union[str, Tuple[str, Any]]) -> None: - self.fields: Dict[str, TypeEngine[Any]] = {} + def __init__(self, *fields: str | tuple[str, Any]) -> None: + self.fields: dict[str, TypeEngine[Any]] = {} for field in fields: if isinstance(field, str): @@ -169,7 +168,7 @@ class STRUCT(AthenaStruct): __visit_name__ = "STRUCT" -class AthenaMap(TypeEngine[Dict[str, Any]]): +class AthenaMap(TypeEngine[dict[str, Any]]): """SQLAlchemy type for Athena MAP complex type. MAP represents a collection of key-value pairs where all keys have the @@ -222,7 +221,7 @@ class MAP(AthenaMap): __visit_name__ = "MAP" -class AthenaArray(TypeEngine[List[Any]]): +class AthenaArray(TypeEngine[list[Any]]): """SQLAlchemy type for Athena ARRAY complex type. ARRAY represents an ordered collection of elements of the same type. diff --git a/pyathena/sqlalchemy/util.py b/pyathena/sqlalchemy/util.py index 984f125c..c82228d8 100644 --- a/pyathena/sqlalchemy/util.py +++ b/pyathena/sqlalchemy/util.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- """Utility classes for PyAthena SQLAlchemy dialect.""" -class _HashableDict(dict): # type: ignore +class _HashableDict(dict): # type: ignore[type-arg] """A dictionary subclass that can be used as a dictionary key. SQLAlchemy's reflection caching requires hashable objects. This class @@ -10,5 +9,5 @@ class _HashableDict(dict): # type: ignore making them hashable through tuple conversion. """ - def __hash__(self): # type: ignore + def __hash__(self): # type: ignore[override] return hash(tuple(sorted(self.items()))) diff --git a/pyathena/util.py b/pyathena/util.py index ce3328d6..d0e18760 100644 --- a/pyathena/util.py +++ b/pyathena/util.py @@ -1,23 +1,24 @@ -# -*- coding: utf-8 -*- from __future__ import annotations import logging import re -from typing import Any, Callable, Iterable, Optional, Pattern, Tuple, cast +from collections.abc import Callable, Iterable +from re import Pattern +from typing import Any, cast import tenacity from tenacity import after_log, retry_if_exception, stop_after_attempt, wait_exponential from pyathena import DataError -_logger = logging.getLogger(__name__) # type: ignore +_logger = logging.getLogger(__name__) PATTERN_OUTPUT_LOCATION: Pattern[str] = re.compile( r"^s3://(?P[a-zA-Z0-9.\-_]+)/(?P.+)$" ) -def parse_output_location(output_location: str) -> Tuple[str, str]: +def parse_output_location(output_location: str) -> tuple[str, str]: """Parse an S3 output location URL into bucket and key components. Args: @@ -135,7 +136,7 @@ def __init__( def retry_api_call( func: Callable[..., Any], config: RetryConfig, - logger: Optional[logging.Logger] = None, + logger: logging.Logger | None = None, *args, **kwargs, ) -> Any: @@ -173,10 +174,10 @@ def retry_api_call( Does not retry on client errors or non-AWS exceptions. """ - def _extract_code(ex: BaseException) -> Optional[str]: - resp = cast(Optional[dict[str, Any]], getattr(ex, "response", None)) - err = cast(Optional[dict[str, Any]], (resp or {}).get("Error")) - return cast(Optional[str], (err or {}).get("Code")) + def _extract_code(ex: BaseException) -> str | None: + resp = cast(dict[str, Any] | None, getattr(ex, "response", None)) + err = cast(dict[str, Any] | None, (resp or {}).get("Error")) + return cast(str | None, (err or {}).get("Code")) def _is_retryable(ex: BaseException) -> bool: code = _extract_code(ex) @@ -190,7 +191,7 @@ def _is_retryable(ex: BaseException) -> bool: max=config.max_delay, exp_base=config.exponential_base, ), - after=after_log(logger, logger.getEffectiveLevel()) if logger else None, # type: ignore + after=after_log(logger, logger.getEffectiveLevel()) if logger else None, # type: ignore[arg-type] reraise=True, ) return retry(func, *args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index cdb0a631..1322adb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -135,6 +135,7 @@ line-length = 100 exclude = [ ".venv", ".tox", + "benchmarks", ] target-version = "py310" @@ -150,7 +151,31 @@ select = [ "C4", # flake8-comprehensions "SIM", # flake8-simplify "RET", # flake8-return - # "UP", # pyupgrade + "UP", # pyupgrade + "PIE", # flake8-pie + "PERF", # perflint + "T20", # flake8-print + "FLY", # flynt + "ISC", # flake8-implicit-str-concat + "RSE", # flake8-raise + "RUF", # ruff-specific rules + "PGH", # pygrep-hooks + "G", # flake8-logging-format + "PT", # flake8-pytest-style +] +ignore = [ + "RUF059", # unused-unpacked-variable (too noisy for interface-heavy code) +] + +[tool.ruff.lint.per-file-ignores] +"tests/**" = [ + "RUF012", # mutable-class-default (test classes often use mutable defaults) +] +"pyathena/sqlalchemy/compiler.py" = [ + "N802", # SQLAlchemy TypeCompiler requires visit_UPPERCASE method names +] +"pyathena/aio/sqlalchemy/base.py" = [ + "N801", # SQLAlchemy async adapter naming convention (AsyncAdapt_*) ] [tool.mypy] diff --git a/tests/__init__.py b/tests/__init__.py index 4cb72813..bbe284c9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import random import string diff --git a/tests/pyathena/__init__.py b/tests/pyathena/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/__init__.py +++ b/tests/pyathena/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/__init__.py b/tests/pyathena/aio/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/__init__.py +++ b/tests/pyathena/aio/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/arrow/__init__.py b/tests/pyathena/aio/arrow/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/arrow/__init__.py +++ b/tests/pyathena/aio/arrow/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/arrow/test_cursor.py b/tests/pyathena/aio/arrow/test_cursor.py index efc8f6e6..5253dabe 100644 --- a/tests/pyathena/aio/arrow/test_cursor.py +++ b/tests/pyathena/aio/arrow/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from pyathena.arrow.result_set import AthenaArrowResultSet diff --git a/tests/pyathena/aio/conftest.py b/tests/pyathena/aio/conftest.py index e81b8b40..39da7709 100644 --- a/tests/pyathena/aio/conftest.py +++ b/tests/pyathena/aio/conftest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from tests import ENV @@ -17,7 +16,7 @@ async def aio_cursor(request): from pyathena.aio.cursor import AioCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioCursor, **request.param) try: async with conn.cursor() as cursor: @@ -31,7 +30,7 @@ async def aio_dict_cursor(request): from pyathena.aio.cursor import AioDictCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioDictCursor, **request.param) try: async with conn.cursor() as cursor: @@ -45,7 +44,7 @@ async def aio_pandas_cursor(request): from pyathena.aio.pandas.cursor import AioPandasCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioPandasCursor, **request.param) try: async with conn.cursor() as cursor: @@ -59,7 +58,7 @@ async def aio_arrow_cursor(request): from pyathena.aio.arrow.cursor import AioArrowCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioArrowCursor, **request.param) try: async with conn.cursor() as cursor: @@ -73,7 +72,7 @@ async def aio_polars_cursor(request): from pyathena.aio.polars.cursor import AioPolarsCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioPolarsCursor, **request.param) try: async with conn.cursor() as cursor: @@ -87,7 +86,7 @@ async def aio_s3fs_cursor(request): from pyathena.aio.s3fs.cursor import AioS3FSCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect(schema_name=ENV.schema, cursor_class=AioS3FSCursor, **request.param) try: async with conn.cursor() as cursor: @@ -103,7 +102,7 @@ async def aio_spark_cursor(request): from pyathena.aio.spark.cursor import AioSparkCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} conn = await _aio_connect( schema_name=ENV.schema, cursor_class=AioSparkCursor, diff --git a/tests/pyathena/aio/pandas/__init__.py b/tests/pyathena/aio/pandas/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/pandas/__init__.py +++ b/tests/pyathena/aio/pandas/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/pandas/test_cursor.py b/tests/pyathena/aio/pandas/test_cursor.py index 6e3794e7..06b02fc8 100644 --- a/tests/pyathena/aio/pandas/test_cursor.py +++ b/tests/pyathena/aio/pandas/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from pyathena.error import ProgrammingError diff --git a/tests/pyathena/aio/polars/__init__.py b/tests/pyathena/aio/polars/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/polars/__init__.py +++ b/tests/pyathena/aio/polars/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/polars/test_cursor.py b/tests/pyathena/aio/polars/test_cursor.py index 81309fe5..da9e25c4 100644 --- a/tests/pyathena/aio/polars/test_cursor.py +++ b/tests/pyathena/aio/polars/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from pyathena.error import ProgrammingError diff --git a/tests/pyathena/aio/s3fs/__init__.py b/tests/pyathena/aio/s3fs/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/s3fs/__init__.py +++ b/tests/pyathena/aio/s3fs/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/s3fs/test_cursor.py b/tests/pyathena/aio/s3fs/test_cursor.py index 720e63d4..4a1521fd 100644 --- a/tests/pyathena/aio/s3fs/test_cursor.py +++ b/tests/pyathena/aio/s3fs/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from pyathena.aio.s3fs.cursor import AioS3FSCursor @@ -61,9 +60,7 @@ async def test_invalid_arraysize(self, aio_s3fs_cursor): async def test_async_iterator(self, aio_s3fs_cursor): await aio_s3fs_cursor.execute("SELECT * FROM one_row") - rows = [] - async for row in aio_s3fs_cursor: - rows.append(row) + rows = [row async for row in aio_s3fs_cursor] assert rows == [(1,)] async def test_context_manager(self): diff --git a/tests/pyathena/aio/spark/__init__.py b/tests/pyathena/aio/spark/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/spark/__init__.py +++ b/tests/pyathena/aio/spark/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/spark/test_cursor.py b/tests/pyathena/aio/spark/test_cursor.py index 6004d567..369aa13d 100644 --- a/tests/pyathena/aio/spark/test_cursor.py +++ b/tests/pyathena/aio/spark/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import textwrap import pytest diff --git a/tests/pyathena/aio/sqlalchemy/__init__.py b/tests/pyathena/aio/sqlalchemy/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/aio/sqlalchemy/__init__.py +++ b/tests/pyathena/aio/sqlalchemy/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/aio/sqlalchemy/test_base.py b/tests/pyathena/aio/sqlalchemy/test_base.py index da2f6b2e..e65317e0 100644 --- a/tests/pyathena/aio/sqlalchemy/test_base.py +++ b/tests/pyathena/aio/sqlalchemy/test_base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest import sqlalchemy from sqlalchemy import text diff --git a/tests/pyathena/aio/test_cursor.py b/tests/pyathena/aio/test_cursor.py index e69a72c1..e8ea6cf7 100644 --- a/tests/pyathena/aio/test_cursor.py +++ b/tests/pyathena/aio/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import re from datetime import datetime @@ -59,9 +58,7 @@ async def test_fetchall(self, aio_cursor): async def test_async_iterator(self, aio_cursor): await aio_cursor.execute("SELECT * FROM one_row") - rows = [] - async for row in aio_cursor: - rows.append(row) + rows = [row async for row in aio_cursor] assert rows == [(1,)] async def test_execute_returns_self(self, aio_cursor): diff --git a/tests/pyathena/arrow/__init__.py b/tests/pyathena/arrow/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/arrow/__init__.py +++ b/tests/pyathena/arrow/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/arrow/test_async_cursor.py b/tests/pyathena/arrow/test_async_cursor.py index b29d3804..a8fb770b 100644 --- a/tests/pyathena/arrow/test_async_cursor.py +++ b/tests/pyathena/arrow/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import string diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index f2e218d3..417d5c16 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import string diff --git a/tests/pyathena/arrow/test_util.py b/tests/pyathena/arrow/test_util.py index dfb031aa..1060ecb5 100644 --- a/tests/pyathena/arrow/test_util.py +++ b/tests/pyathena/arrow/test_util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pyarrow as pa from pyathena.arrow.util import to_column_info diff --git a/tests/pyathena/conftest.py b/tests/pyathena/conftest.py index e9cabb02..6c0b07b5 100644 --- a/tests/pyathena/conftest.py +++ b/tests/pyathena/conftest.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib from io import BytesIO from pathlib import Path @@ -120,7 +119,7 @@ def create_async_engine(**kwargs): def _cursor(cursor_class, request): if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} with ( contextlib.closing( connect(schema_name=ENV.schema, cursor_class=cursor_class, **request.param) @@ -219,7 +218,7 @@ def spark_cursor(request): from pyathena.spark.cursor import SparkCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} request.param.update({"work_group": ENV.spark_work_group}) yield from _cursor(SparkCursor, request) @@ -229,7 +228,7 @@ def async_spark_cursor(request): from pyathena.spark.async_cursor import AsyncSparkCursor if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} request.param.update({"work_group": ENV.spark_work_group}) yield from _cursor(AsyncSparkCursor, request) @@ -237,7 +236,7 @@ def async_spark_cursor(request): @pytest.fixture def engine(request): if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} engine_ = create_engine(**request.param) try: with contextlib.closing(engine_.connect()) as conn: @@ -249,7 +248,7 @@ def engine(request): @pytest.fixture async def async_engine(request): if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} engine_ = create_async_engine(**request.param) try: async with engine_.connect() as conn: diff --git a/tests/pyathena/filesystem/__init__.py b/tests/pyathena/filesystem/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/filesystem/__init__.py +++ b/tests/pyathena/filesystem/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/filesystem/test_s3.py b/tests/pyathena/filesystem/test_s3.py index f737eb17..676060da 100644 --- a/tests/pyathena/filesystem/test_s3.py +++ b/tests/pyathena/filesystem/test_s3.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import tempfile import time @@ -104,35 +103,35 @@ def test_parse_path(self): assert actual[2] == "12345abcde" def test_parse_path_invalid(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("http://bucket") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3://bucket?") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3://bucket?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3://bucket/path/to/obj?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3a://bucket?") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3a://bucket?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): S3FileSystem.parse_path("s3a://bucket/path/to/obj?foo=bar") @pytest.fixture(scope="class") def fs(self, request): if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} return S3FileSystem(connect(), **request.param) @pytest.mark.parametrize( - ["fs", "start", "end", "target_data"], + ("fs", "start", "end", "target_data"), list( chain( *[ @@ -165,7 +164,7 @@ def test_read(self, fs, start, end, target_data): assert data == target_data, data @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -196,7 +195,7 @@ def test_write(self, fs, base, exp): assert actual == data @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -457,7 +456,7 @@ def test_glob(self, fs): assert fs._strip_protocol(path) in fs.glob(f"{dir_}/nested/test_*") assert fs._strip_protocol(path) in fs.glob(f"{dir_}/*/*") - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 fs.glob("*") def test_exists_bucket(self, fs): @@ -539,12 +538,12 @@ def test_touch(self, fs): with fs.open(path, "wb") as f: f.write(b"data") assert fs.size(path) == 4 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Cannot touch"): fs.touch(path, truncate=False) assert fs.size(path) == 4 @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -585,7 +584,7 @@ def test_cat_ranges(self, fs): assert fs.cat_file(path, start=-5) == data[-5:] @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -618,7 +617,7 @@ def test_put(self, fs, base, exp): assert fs.cat(rpath) == tmp.read() @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -654,7 +653,7 @@ def test_put_with_callback(self, fs, base, exp): assert callback.value == callback.size @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ # TODO: Comment out some test cases because of the high cost of AWS for testing. (1, 2**10), @@ -786,10 +785,10 @@ def test_pandas_read_csv(self): assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)] @pytest.mark.parametrize( - ["line_count"], + "line_count", [ # TODO: Comment out some test cases because of the high cost of AWS for testing. - (1 * (2**20),), # Generates files of about 2 MB. + 1 * 2**20, # Generates files of about 2 MB. # (2 * (2**20),), # 4MB # (3 * (2**20),), # 6MB # (4 * (2**20),), # 8MB @@ -803,7 +802,7 @@ def test_pandas_write_csv(self, line_count): with tempfile.NamedTemporaryFile("w+t") as tmp: tmp.write("col1") tmp.write("\n") - for _ in range(0, line_count): + for _ in range(line_count): tmp.write("a") tmp.write("\n") tmp.flush() @@ -822,7 +821,7 @@ def test_pandas_write_csv(self, line_count): class TestS3File: @pytest.mark.parametrize( - ["objects", "target"], + ("objects", "target"), [ ([(0, b"")], b""), ([(0, b"foo")], b"foo"), @@ -837,7 +836,7 @@ def test_merge_objects(self, objects, target): assert S3File._merge_objects(objects) == target @pytest.mark.parametrize( - ["start", "end", "max_workers", "worker_block_size", "ranges"], + ("start", "end", "max_workers", "worker_block_size", "ranges"), [ (42, 1337, 1, 999, [(42, 1337)]), # single worker (42, 1337, 2, 999, [(42, 42 + 999), (42 + 999, 1337)]), # more workers diff --git a/tests/pyathena/filesystem/test_s3_async.py b/tests/pyathena/filesystem/test_s3_async.py index 5bd312bd..ba2e07b8 100644 --- a/tests/pyathena/filesystem/test_s3_async.py +++ b/tests/pyathena/filesystem/test_s3_async.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import os import tempfile import time @@ -109,35 +108,35 @@ def test_parse_path(self): assert actual[2] == "12345abcde" def test_parse_path_invalid(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("http://bucket") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3://bucket?") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3://bucket?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3://bucket/path/to/obj?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3a://bucket?") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3a://bucket?foo=bar") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid S3 path format"): AioS3FileSystem.parse_path("s3a://bucket/path/to/obj?foo=bar") @pytest.fixture(scope="class") def fs(self, request): if not hasattr(request, "param"): - setattr(request, "param", {}) # noqa: B010 + request.param = {} return AioS3FileSystem(connection=connect(), **request.param) @pytest.mark.parametrize( - ["fs", "start", "end", "target_data"], + ("fs", "start", "end", "target_data"), list( chain( *[ @@ -170,7 +169,7 @@ def test_read(self, fs, start, end, target_data): assert data == target_data, data @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -190,7 +189,7 @@ def test_write(self, fs, base, exp): assert actual == data @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -472,7 +471,7 @@ async def test_glob(self, fs): assert fs._strip_protocol(path) in fs.glob(f"{dir_}/nested/test_*") assert fs._strip_protocol(path) in fs.glob(f"{dir_}/*/*") - with pytest.raises(ValueError): + with pytest.raises(ValueError): # noqa: PT011 fs.glob("*") @pytest.mark.asyncio @@ -568,14 +567,14 @@ async def test_touch(self, fs): f.write(b"data") info = await fs._info(path, refresh=True) assert info.size == 4 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Cannot touch"): await fs._touch(path, truncate=False) info = await fs._info(path, refresh=True) assert info.size == 4 @pytest.mark.asyncio @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -607,7 +606,7 @@ async def test_cat_ranges(self, fs): @pytest.mark.asyncio @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -629,7 +628,7 @@ async def test_put(self, fs, base, exp): @pytest.mark.asyncio @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -654,7 +653,7 @@ async def test_put_with_callback(self, fs, base, exp): @pytest.mark.asyncio @pytest.mark.parametrize( - ["base", "exp"], + ("base", "exp"), [ (1, 2**10), (1, 2**20), @@ -764,9 +763,9 @@ def test_pandas_read_csv(self): assert [(row["col"],) for _, row in df.iterrows()] == [(123456789,)] @pytest.mark.parametrize( - ["line_count"], + "line_count", [ - (1 * (2**20),), + 1 * 2**20, ], ) def test_pandas_write_csv(self, line_count): @@ -775,7 +774,7 @@ def test_pandas_write_csv(self, line_count): with tempfile.NamedTemporaryFile("w+t") as tmp: tmp.write("col1") tmp.write("\n") - for _ in range(0, line_count): + for _ in range(line_count): tmp.write("a") tmp.write("\n") tmp.flush() @@ -816,7 +815,7 @@ def test_invalidate_cache(self, fs): class TestAioS3File: @pytest.mark.parametrize( - ["objects", "target"], + ("objects", "target"), [ ([(0, b"")], b""), ([(0, b"foo")], b"foo"), @@ -831,7 +830,7 @@ def test_merge_objects(self, objects, target): assert S3File._merge_objects(objects) == target @pytest.mark.parametrize( - ["start", "end", "max_workers", "worker_block_size", "ranges"], + ("start", "end", "max_workers", "worker_block_size", "ranges"), [ (42, 1337, 1, 999, [(42, 1337)]), # single worker (42, 1337, 2, 999, [(42, 42 + 999), (42 + 999, 1337)]), # more workers diff --git a/tests/pyathena/filesystem/test_s3_object.py b/tests/pyathena/filesystem/test_s3_object.py index a551c830..54beb8a1 100644 --- a/tests/pyathena/filesystem/test_s3_object.py +++ b/tests/pyathena/filesystem/test_s3_object.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from datetime import datetime from pyathena.filesystem.s3_object import ( diff --git a/tests/pyathena/pandas/__init__.py b/tests/pyathena/pandas/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/pandas/__init__.py +++ b/tests/pyathena/pandas/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/pandas/test_async_cursor.py b/tests/pyathena/pandas/test_async_cursor.py index 8d954bf3..28aa0b95 100644 --- a/tests/pyathena/pandas/test_async_cursor.py +++ b/tests/pyathena/pandas/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import math import random @@ -21,7 +20,7 @@ class TestAsyncPandasCursor: @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -42,7 +41,7 @@ def test_fetchone(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.fetchone() is None @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -63,7 +62,7 @@ def test_fetchmany(self, async_pandas_cursor, parquet_engine, chunksize): assert len(result_set.fetchmany(10)) == 5 @pytest.mark.parametrize( - "async_pandas_cursor, chunksize", + ("async_pandas_cursor", "chunksize"), [ ({}, None), ({}, 1_000), @@ -84,7 +83,7 @@ def test_get_chunk(self, async_pandas_cursor, chunksize): assert len(df) == 15 @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -107,7 +106,7 @@ def test_fetchall(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.fetchall() == [(i,) for i in range(10000)] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -128,7 +127,7 @@ def test_iterator(self, async_pandas_cursor, parquet_engine, chunksize): pytest.raises(StopIteration, result_set.__next__) @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -158,7 +157,7 @@ def test_invalid_arraysize(self, async_pandas_cursor): async_pandas_cursor.arraysize = -1 @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -186,7 +185,7 @@ def test_description(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.description == description @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -270,7 +269,7 @@ def test_query_execution(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.effective_engine_version == query_execution.effective_engine_version @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -297,7 +296,7 @@ def test_poll(self, async_pandas_cursor, parquet_engine, chunksize): ] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -319,7 +318,7 @@ def test_bad_query(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.error_type is not None @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -343,7 +342,7 @@ def test_as_pandas(self, async_pandas_cursor, parquet_engine, chunksize): assert [(row["number_of_rows"],) for _, row in df.iterrows()] == [(1,)] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -393,7 +392,7 @@ def test_no_ops(self): conn.close() @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -412,7 +411,7 @@ def test_show_columns(self, async_pandas_cursor, parquet_engine, chunksize): assert result_set.fetchall() == [("number_of_rows ",)] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine, chunksize", + ("async_pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -444,7 +443,7 @@ def test_empty_result_ddl(self, async_pandas_cursor, parquet_engine, chunksize): assert df.shape[1] == 0 @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": True}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "pyarrow"), @@ -463,7 +462,7 @@ def test_empty_result_dml_unload(self, async_pandas_cursor, parquet_engine): assert df.shape[1] == 0 @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -494,7 +493,7 @@ def test_integer_na_values(self, async_pandas_cursor, parquet_engine): assert rows == [(1, 2), (1, pd.NA), (pd.NA, pd.NA)] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -514,7 +513,7 @@ def test_float_na_values(self, async_pandas_cursor, parquet_engine): np.testing.assert_equal(rows, [(0.33,), (np.nan,)]) @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -534,7 +533,7 @@ def test_boolean_na_values(self, async_pandas_cursor, parquet_engine): assert rows == [(True, False), (False, None), (None, None)] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -553,7 +552,7 @@ def test_not_skip_blank_lines(self, async_pandas_cursor, parquet_engine): assert len(result_set.fetchall()) == 2 @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -599,7 +598,7 @@ def test_empty_and_null_string(self, async_pandas_cursor, parquet_engine): ] @pytest.mark.parametrize( - "async_pandas_cursor, parquet_engine", + ("async_pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), diff --git a/tests/pyathena/pandas/test_cursor.py b/tests/pyathena/pandas/test_cursor.py index 403cd62e..944dc025 100644 --- a/tests/pyathena/pandas/test_cursor.py +++ b/tests/pyathena/pandas/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import math import random @@ -22,7 +21,7 @@ class TestPandasCursor: @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -40,7 +39,7 @@ def test_fetchone(self, pandas_cursor, parquet_engine, chunksize): assert pandas_cursor.fetchone() is None @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -60,7 +59,7 @@ def test_fetchmany(self, pandas_cursor, parquet_engine, chunksize): assert len(pandas_cursor.fetchmany(10)) == 5 @pytest.mark.parametrize( - "pandas_cursor, chunksize", + ("pandas_cursor", "chunksize"), [ ({}, None), ({}, 1_000), @@ -81,7 +80,7 @@ def test_get_chunk(self, pandas_cursor, chunksize): assert len(df) == 15 @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -98,7 +97,7 @@ def test_fetchall(self, pandas_cursor, parquet_engine, chunksize): assert pandas_cursor.fetchall() == [(i,) for i in range(10000)] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -114,7 +113,7 @@ def test_iterator(self, pandas_cursor, parquet_engine, chunksize): pytest.raises(StopIteration, pandas_cursor.__next__) @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -143,7 +142,7 @@ def test_invalid_arraysize(self, pandas_cursor): pandas_cursor.arraysize = -1 @pytest.mark.parametrize( - "pandas_cursor, chunksize", + ("pandas_cursor", "chunksize"), [ ({}, None), ({}, 1_000), @@ -224,7 +223,7 @@ def test_complex(self, pandas_cursor, chunksize): ] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": True}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "pyarrow"), @@ -341,7 +340,7 @@ def test_fetch_no_data(self, pandas_cursor): pytest.raises(ProgrammingError, pandas_cursor.as_pandas) @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -362,7 +361,7 @@ def test_as_pandas(self, pandas_cursor, parquet_engine, chunksize): assert [(row["number_of_rows"],) for _, row in df.iterrows()] == [(1,)] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -383,7 +382,7 @@ def test_many_as_pandas(self, pandas_cursor, parquet_engine, chunksize): assert [(row["a"],) for _, row in df.iterrows()] == [(i,) for i in range(10000)] @pytest.mark.parametrize( - "pandas_cursor, chunksize", + ("pandas_cursor", "chunksize"), [ ({}, None), ({}, 1_000), @@ -513,7 +512,7 @@ def test_complex_as_pandas(self, pandas_cursor, chunksize): ] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": True}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "pyarrow"), @@ -814,7 +813,7 @@ def test_get_csv_engine_explicit_specification(self): mock_opt.assert_called_once() @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -830,7 +829,7 @@ def test_show_columns(self, pandas_cursor, parquet_engine, chunksize): assert pandas_cursor.fetchall() == [("number_of_rows ",)] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine, chunksize", + ("pandas_cursor", "parquet_engine", "chunksize"), [ ({"cursor_kwargs": {"unload": False}}, "auto", None), ({"cursor_kwargs": {"unload": False}}, "auto", 1_000), @@ -861,7 +860,7 @@ def test_empty_result_ddl(self, pandas_cursor, parquet_engine, chunksize): assert df.shape[1] == 0 @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": True}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "pyarrow"), @@ -896,7 +895,7 @@ def test_ctas(self, pandas_cursor): assert pandas_cursor.fetchone() is None @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -926,7 +925,7 @@ def test_integer_na_values(self, pandas_cursor, parquet_engine): assert rows == [(1, 2), (1, pd.NA), (pd.NA, pd.NA)] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -945,7 +944,7 @@ def test_float_na_values(self, pandas_cursor, parquet_engine): np.testing.assert_equal(rows, [(0.33,), (np.nan,)]) @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -964,7 +963,7 @@ def test_boolean_na_values(self, pandas_cursor, parquet_engine): assert rows == [(True, False), (False, None), (None, None)] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -985,7 +984,7 @@ def test_executemany(self, pandas_cursor, parquet_engine): assert sorted(pandas_cursor.fetchall()) == list(rows) @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -1002,7 +1001,7 @@ def test_executemany_fetch(self, pandas_cursor, parquet_engine): pytest.raises(ProgrammingError, pandas_cursor.as_pandas) @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -1020,7 +1019,7 @@ def test_not_skip_blank_lines(self, pandas_cursor, parquet_engine): assert len(pandas_cursor.fetchall()) == 2 @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), @@ -1064,7 +1063,7 @@ def test_null_vs_empty_string(self, pandas_cursor, parquet_engine): ] @pytest.mark.parametrize( - "pandas_cursor, parquet_engine", + ("pandas_cursor", "parquet_engine"), [ ({"cursor_kwargs": {"unload": False}}, "auto"), ({"cursor_kwargs": {"unload": True}}, "auto"), diff --git a/tests/pyathena/pandas/test_util.py b/tests/pyathena/pandas/test_util.py index 972dd192..a34d2341 100644 --- a/tests/pyathena/pandas/test_util.py +++ b/tests/pyathena/pandas/test_util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import textwrap import uuid from datetime import date, datetime @@ -32,9 +31,9 @@ def test_get_chunks(): assert list(get_chunks(pd.DataFrame())) == [] # invalid - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Chunk size argument must be greater than zero"): list(get_chunks(df, chunksize=0)) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Chunk size argument must be greater than zero"): list(get_chunks(df, chunksize=-1)) @@ -48,7 +47,7 @@ def test_reset_index(): assert list(df.columns) == ["__index__", "a"] df = pd.DataFrame({"a": [1, 2, 3, 4, 5]}) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Duplicate name"): reset_index(df, index_label="a") @@ -167,7 +166,7 @@ def test_generate_ddl(): "col_timestamp": [datetime(2020, 1, 1, 0, 0, 0)], "col_date": [date(2020, 12, 31)], "col_timedelta": [np.timedelta64(1, "D")], - "col_binary": ["foobar".encode()], + "col_binary": [b"foobar"], } ) # Explicitly specify column order @@ -306,12 +305,12 @@ def test_generate_ddl(): # complex df = pd.DataFrame({"col_complex": np.complex128([1.0, 2.0, 3.0, 4.0, 5.0])}) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is not supported"): generate_ddl(df, "test_table", "s3://bucket/path/to/") # time df = pd.DataFrame({"col_time": [datetime(2020, 1, 1, 0, 0, 0).time()]}, index=["i"]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is not supported"): generate_ddl(df, "test_table", "s3://bucket/path/to/") @@ -326,7 +325,7 @@ def test_to_sql(cursor): "col_boolean": np.bool_([True]), "col_timestamp": [datetime(2020, 1, 1, 0, 0, 0)], "col_date": [date(2020, 12, 31)], - "col_binary": "foobar".encode(), + "col_binary": b"foobar", } ) # Explicitly specify column order @@ -387,7 +386,7 @@ def test_to_sql(cursor): True, datetime(2020, 1, 1, 0, 0, 0), date(2020, 12, 31), - "foobar".encode(), + b"foobar", ) ] assert [(d[0], d[1]) for d in cursor.description] == [ @@ -423,7 +422,7 @@ def test_to_sql(cursor): True, datetime(2020, 1, 1, 0, 0, 0), date(2020, 12, 31), - "foobar".encode(), + b"foobar", ), ( 1, @@ -434,7 +433,7 @@ def test_to_sql(cursor): True, datetime(2020, 1, 1, 0, 0, 0), date(2020, 12, 31), - "foobar".encode(), + b"foobar", ), ] @@ -521,7 +520,7 @@ def test_to_sql_invalid_args(cursor): table_name = f"""to_sql_{str(uuid.uuid4()).replace("-", "")}""" location = f"{ENV.s3_staging_dir}{ENV.schema}/{table_name}/" # invalid if_exists - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is not valid for if_exists"): to_sql( df, table_name, @@ -532,7 +531,7 @@ def test_to_sql_invalid_args(cursor): compression="snappy", ) # invalid compression - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="is not valid for compression"): to_sql( df, table_name, @@ -544,7 +543,7 @@ def test_to_sql_invalid_args(cursor): ) # invalid partition key (None) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="is None, no data will be written") as exc_info: to_sql( df, table_name, @@ -561,7 +560,7 @@ def test_to_sql_invalid_args(cursor): ) # invalid partition key value (None) df_with_none = pd.DataFrame({"col_int": np.int32([1]), "partition_key": [None]}) - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="contains None values") as exc_info: to_sql( df_with_none, table_name, diff --git a/tests/pyathena/polars/__init__.py b/tests/pyathena/polars/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/polars/__init__.py +++ b/tests/pyathena/polars/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/polars/test_async_cursor.py b/tests/pyathena/polars/test_async_cursor.py index fd9d61ba..dde6c431 100644 --- a/tests/pyathena/polars/test_async_cursor.py +++ b/tests/pyathena/polars/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import string diff --git a/tests/pyathena/polars/test_cursor.py b/tests/pyathena/polars/test_cursor.py index 7522120c..6e4b482e 100644 --- a/tests/pyathena/polars/test_cursor.py +++ b/tests/pyathena/polars/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import string diff --git a/tests/pyathena/s3fs/__init__.py b/tests/pyathena/s3fs/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/s3fs/__init__.py +++ b/tests/pyathena/s3fs/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/s3fs/test_async_cursor.py b/tests/pyathena/s3fs/test_async_cursor.py index b9b37053..f7046fa3 100644 --- a/tests/pyathena/s3fs/test_async_cursor.py +++ b/tests/pyathena/s3fs/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import time diff --git a/tests/pyathena/s3fs/test_cursor.py b/tests/pyathena/s3fs/test_cursor.py index c2c78d5e..fd8f6572 100644 --- a/tests/pyathena/s3fs/test_cursor.py +++ b/tests/pyathena/s3fs/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import random import string @@ -427,7 +426,7 @@ def test_empty_string_with_athena_reader(self): assert result == ("",) @pytest.mark.parametrize( - "csv_reader, expected_empty", + ("csv_reader", "expected_empty"), [ (DefaultCSVReader, None), # DefaultCSVReader: empty string becomes None (AthenaCSVReader, ""), # AthenaCSVReader: empty string is preserved diff --git a/tests/pyathena/s3fs/test_reader.py b/tests/pyathena/s3fs/test_reader.py index 58383525..8c3050ff 100644 --- a/tests/pyathena/s3fs/test_reader.py +++ b/tests/pyathena/s3fs/test_reader.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from collections.abc import Iterator from io import StringIO diff --git a/tests/pyathena/spark/__init__.py b/tests/pyathena/spark/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/spark/__init__.py +++ b/tests/pyathena/spark/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/spark/test_async_cursor.py b/tests/pyathena/spark/test_async_cursor.py index 440d7186..e921bdc3 100644 --- a/tests/pyathena/spark/test_async_cursor.py +++ b/tests/pyathena/spark/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import textwrap import time from random import randint diff --git a/tests/pyathena/spark/test_spark_cursor.py b/tests/pyathena/spark/test_spark_cursor.py index b00d1d29..12723ff1 100644 --- a/tests/pyathena/spark/test_spark_cursor.py +++ b/tests/pyathena/spark/test_spark_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import textwrap import time from concurrent.futures import ThreadPoolExecutor diff --git a/tests/pyathena/sqlalchemy/__init__.py b/tests/pyathena/sqlalchemy/__init__.py index 40a96afc..e69de29b 100644 --- a/tests/pyathena/sqlalchemy/__init__.py +++ b/tests/pyathena/sqlalchemy/__init__.py @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/tests/pyathena/sqlalchemy/test_base.py b/tests/pyathena/sqlalchemy/test_base.py index 73e3442c..7c4c44fc 100644 --- a/tests/pyathena/sqlalchemy/test_base.py +++ b/tests/pyathena/sqlalchemy/test_base.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- import re import textwrap import uuid from datetime import date, datetime from decimal import Decimal -from typing import List from urllib.parse import quote_plus import numpy as np @@ -2390,7 +2388,7 @@ def test_create_table_with_complex_nested_types(self, engine): def test_sqlalchemy_execute_with_execution_options_callback(self, engine): """Test callback functionality through SQLAlchemy execution_options.""" engine, conn = engine - query_ids: List[str] = [] + query_ids: list[str] = [] def callback_function(query_id: str) -> None: query_ids.append(query_id) @@ -2414,7 +2412,7 @@ def callback_function(query_id: str) -> None: def test_sqlalchemy_connection_level_callback(self, engine): """Test connection-level callback functionality through SQLAlchemy engine creation.""" - query_ids: List[str] = [] + query_ids: list[str] = [] def callback_function(query_id: str) -> None: query_ids.append(query_id) diff --git a/tests/pyathena/sqlalchemy/test_compiler.py b/tests/pyathena/sqlalchemy/test_compiler.py index c91b7554..5fc20a65 100644 --- a/tests/pyathena/sqlalchemy/test_compiler.py +++ b/tests/pyathena/sqlalchemy/test_compiler.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - from unittest.mock import Mock import pytest @@ -189,18 +187,18 @@ def test_visit_filter_func_nested_access(self): def test_visit_filter_func_wrong_argument_count(self): """Test filter() function with wrong number of arguments.""" # Test error when wrong number of arguments provided + stmt = select(func.filter(self.test_table.c.numbers)) with pytest.raises( exc.CompileError, match="filter\\(\\) function expects exactly 2 arguments" ): - stmt = select(func.filter(self.test_table.c.numbers)) stmt.compile(dialect=self.dialect) + stmt = select( + func.filter(self.test_table.c.numbers, literal("x -> x > 0"), literal("extra_arg")) + ) with pytest.raises( exc.CompileError, match="filter\\(\\) function expects exactly 2 arguments" ): - stmt = select( - func.filter(self.test_table.c.numbers, literal("x -> x > 0"), literal("extra_arg")) - ) stmt.compile(dialect=self.dialect) def test_visit_filter_func_integration_example(self): diff --git a/tests/pyathena/sqlalchemy/test_types.py b/tests/pyathena/sqlalchemy/test_types.py index 73d34d63..a3fb89cf 100644 --- a/tests/pyathena/sqlalchemy/test_types.py +++ b/tests/pyathena/sqlalchemy/test_types.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import pytest from sqlalchemy import Integer, String, types from sqlalchemy.sql import sqltypes @@ -46,7 +45,7 @@ def test_python_type(self): assert struct_type.python_type is dict def test_invalid_field_specification(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid field specification"): AthenaStruct(123) # Invalid field type def test_visit_name(self): diff --git a/tests/pyathena/test_async_cursor.py b/tests/pyathena/test_async_cursor.py index 9433ecf0..f04d15f6 100644 --- a/tests/pyathena/test_async_cursor.py +++ b/tests/pyathena/test_async_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import time from datetime import datetime diff --git a/tests/pyathena/test_converter.py b/tests/pyathena/test_converter.py index ede3f9a7..fbf78884 100644 --- a/tests/pyathena/test_converter.py +++ b/tests/pyathena/test_converter.py @@ -1,12 +1,10 @@ -# -*- coding: utf-8 -*- - import pytest from pyathena.converter import DefaultTypeConverter, _to_array, _to_struct @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ (None, None), ( @@ -28,7 +26,7 @@ def test_to_struct_json_formats(input_value, expected): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ ("{a=1, b=2}", {"a": 1, "b": 2}), ("{}", {}), @@ -47,7 +45,7 @@ def test_to_struct_athena_native_formats(input_value, expected): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ # Single level nesting (Issue #627) ( @@ -153,7 +151,7 @@ def test_to_array_athena_unnamed_struct_elements(): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ # Array with nested structs (Issue #627) ( @@ -196,7 +194,7 @@ def test_to_struct_non_dict_json(input_value): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ (None, None), ( @@ -227,7 +225,7 @@ def test_to_array_json_formats(input_value, expected): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ ("[1, 2, 3]", [1, 2, 3]), ("[]", []), @@ -248,7 +246,7 @@ def test_to_array_athena_native_formats(input_value, expected): @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ ("[ARRAY[1, 2], ARRAY[3, 4]]", None), # Nested arrays (native format) ("[[1, 2], [3, 4]]", [[1, 2], [3, 4]]), # Nested arrays (JSON format - parseable) @@ -292,7 +290,7 @@ def test_to_array_invalid_formats(input_value): class TestDefaultTypeConverter: @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ ('{"name": "Alice", "age": 25}', {"name": "Alice", "age": 25}), (None, None), @@ -308,7 +306,7 @@ def test_struct_conversion(self, input_value, expected): assert result == expected @pytest.mark.parametrize( - "input_value,expected", + ("input_value", "expected"), [ ("[1, 2, 3]", [1, 2, 3]), ('["a", "b", "c"]', ["a", "b", "c"]), diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 9964baa0..27c34950 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import contextlib import json import logging @@ -82,7 +81,7 @@ def test_iterator(self, cursor): def test_cache_size(self, cursor): # To test caching, we need to make sure the query is unique, otherwise # we might accidentally pick up the cache results from another CI run. - query = f"SELECT * FROM one_row -- {str(datetime.now(timezone.utc))}" + query = f"SELECT * FROM one_row -- {datetime.now(timezone.utc)!s}" cursor.execute(query) first_query_id = cursor.query_id @@ -115,7 +114,7 @@ def test_cache_size_with_work_group(self, cursor): assert third_query_id in [first_query_id, second_query_id] def test_cache_expiration_time(self, cursor): - query = f"SELECT * FROM one_row -- {str(datetime.now(timezone.utc))}" + query = f"SELECT * FROM one_row -- {datetime.now(timezone.utc)!s}" cursor.execute(query) query_id_1 = cursor.query_id @@ -131,7 +130,7 @@ def test_cache_expiration_time(self, cursor): def test_cache_expiration_time_with_cache_size(self, cursor): # Cache miss - query = f"SELECT * FROM one_row -- {str(datetime.now(timezone.utc))}" + query = f"SELECT * FROM one_row -- {datetime.now(timezone.utc)!s}" cursor.execute(query) query_id_1 = cursor.query_id @@ -148,7 +147,7 @@ def test_cache_expiration_time_with_cache_size(self, cursor): assert query_id_3 not in [query_id_1, query_id_2] # Cache miss - query = f"SELECT * FROM one_row -- {str(datetime.now(timezone.utc))}" + query = f"SELECT * FROM one_row -- {datetime.now(timezone.utc)!s}" cursor.execute(query) query_id_4 = cursor.query_id @@ -166,7 +165,7 @@ def test_cache_expiration_time_with_cache_size(self, cursor): assert query_id_6 not in [query_id_4, query_id_5] # Cache hit - query = f"SELECT * FROM one_row -- {str(datetime.now(timezone.utc))}" + query = f"SELECT * FROM one_row -- {datetime.now(timezone.utc)!s}" cursor.execute(query) query_id_7 = cursor.query_id @@ -978,7 +977,7 @@ class TestComplexDataTypes: """Test complex data types (STRUCT, ARRAY, MAP) with actual Athena queries.""" @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ("SELECT ROW('John', 30) AS simple_struct", "simple_struct"), ( @@ -997,11 +996,11 @@ class TestComplexDataTypes: ) def test_struct_types(self, cursor, query, description): """Test various STRUCT type scenarios to understand Athena's behavior.""" - _logger.info(f"=== STRUCT Type Test: {description} ===") + _logger.info("=== STRUCT Type Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() struct_value = result[0] - _logger.info(f"{description}: {struct_value!r} (type: {type(struct_value).__name__})") + _logger.info("%s: %r (type: %s)", description, struct_value, type(struct_value).__name__) # Validate struct value and converter behavior assert struct_value is not None, f"STRUCT value should not be None for {description}" @@ -1009,7 +1008,7 @@ def test_struct_types(self, cursor, query, description): # Test struct conversion behavior if isinstance(struct_value, str): converted = _to_struct(struct_value) - _logger.info(f"{description}: Converted {struct_value!r} -> {converted!r}") + _logger.info("%s: Converted %r -> %r", description, struct_value, converted) # For string structs, conversion should succeed or return None for complex cases if converted is not None: assert isinstance(converted, dict), ( @@ -1017,15 +1016,18 @@ def test_struct_types(self, cursor, query, description): ) elif isinstance(struct_value, dict): # Already converted by the cursor converter - _logger.info(f"{description}: Already converted to dict: {struct_value!r}") + _logger.info("%s: Already converted to dict: %r", description, struct_value) else: # Log unexpected types for debugging but don't fail _logger.warning( - f"{description}: Unexpected type {type(struct_value).__name__}: {struct_value!r}" + "%s: Unexpected type %s: %r", + description, + type(struct_value).__name__, + struct_value, ) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ("SELECT ARRAY[1, 2, 3, 4, 5] AS simple_array", "simple_array"), ("SELECT ARRAY['apple', 'banana', 'cherry'] AS string_array", "string_array"), @@ -1044,18 +1046,18 @@ def test_struct_types(self, cursor, query, description): ) def test_array_types(self, cursor, query, description): """Test various ARRAY type scenarios.""" - _logger.info(f"=== ARRAY Type Test: {description} ===") + _logger.info("=== ARRAY Type Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() array_value = result[0] - _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + _logger.info("%s: %r (type: %s)", description, array_value, type(array_value).__name__) # Validate array value assert array_value is not None, f"ARRAY value should not be None for {description}" - _logger.info(f"{description}: Array value type {type(array_value).__name__}") + _logger.info("%s: Array value type %s", description, type(array_value).__name__) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ( "SELECT MAP(ARRAY[1, 2, 3], ARRAY['one', 'two', 'three']) AS simple_map", @@ -1084,11 +1086,11 @@ def test_array_types(self, cursor, query, description): ) def test_map_types(self, cursor, query, description): """Test various MAP type scenarios.""" - _logger.info(f"=== MAP Type Test: {description} ===") + _logger.info("=== MAP Type Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() map_value = result[0] - _logger.info(f"{description}: {map_value!r} (type: {type(map_value).__name__})") + _logger.info("%s: %r (type: %s)", description, map_value, type(map_value).__name__) # Validate map value and converter behavior assert map_value is not None, f"MAP value should not be None for {description}" @@ -1098,26 +1100,29 @@ def test_map_types(self, cursor, query, description): # For complex MAP structures, string is expected (JSON or native format) if "ROW(" in map_value or "ARRAY[" in map_value: # Complex structure, expect string format - _logger.info(f"{description}: Complex MAP kept as string: {map_value!r}") + _logger.info("%s: Complex MAP kept as string: %r", description, map_value) else: # Simple MAP, try conversion converted = _to_map(map_value) - _logger.info(f"{description}: Converted {map_value!r} -> {converted!r}") + _logger.info("%s: Converted %r -> %r", description, map_value, converted) if converted is not None: assert isinstance(converted, dict), ( f"Converted map should be dict for {description}" ) elif isinstance(map_value, dict): # Already converted by the cursor converter - _logger.info(f"{description}: Already converted to dict: {map_value!r}") + _logger.info("%s: Already converted to dict: %r", description, map_value) else: # Log unexpected types for debugging but don't fail _logger.warning( - f"{description}: Unexpected type {type(map_value).__name__}: {map_value!r}" + "%s: Unexpected type %s: %r", + description, + type(map_value).__name__, + map_value, ) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ( "SELECT CAST(ROW(ARRAY[1, 2, 3], MAP(ARRAY['a', 'b'], ARRAY[1, 2])) AS JSON) " @@ -1138,28 +1143,28 @@ def test_map_types(self, cursor, query, description): ) def test_complex_combinations(self, cursor, query, description): """Test complex combinations of data types.""" - _logger.info(f"=== Complex Combination Test: {description} ===") + _logger.info("=== Complex Combination Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() complex_value = result[0] - _logger.info(f"{description}: {complex_value!r} (type: {type(complex_value).__name__})") + _logger.info("%s: %r (type: %s)", description, complex_value, type(complex_value).__name__) # For JSON cast results, expect string values that can be parsed as JSON if isinstance(complex_value, str): try: # Test that the JSON string can be parsed parsed = json.loads(complex_value) - _logger.info(f" Parsed JSON: {parsed!r}") + _logger.info(" Parsed JSON: %r", parsed) assert parsed is not None, f"Parsed JSON should not be None for {description}" except json.JSONDecodeError as e: raise AssertionError(f"JSON parsing failed for {description}: {e}") from e else: # If it's not a string, it should still be a valid value (not None) assert complex_value is not None, f"Complex value should not be None for {description}" - _logger.info(f"{description}: Complex value type {type(complex_value).__name__}") + _logger.info("%s: Complex value type %s", description, type(complex_value).__name__) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ("SELECT ARRAY[1, 2, 3, 4, 5] AS simple_array", "simple_array"), ("SELECT ARRAY['apple', 'banana', 'cherry'] AS string_array", "string_array"), @@ -1188,11 +1193,11 @@ def test_complex_combinations(self, cursor, query, description): ) def test_array_types_basic(self, cursor, query, description): """Test basic ARRAY type scenarios.""" - _logger.info(f"=== ARRAY Type Test: {description} ===") + _logger.info("=== ARRAY Type Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() array_value = result[0] - _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + _logger.info("%s: %r (type: %s)", description, array_value, type(array_value).__name__) # Validate array value assert array_value is not None or description == "empty_array", ( @@ -1202,10 +1207,10 @@ def test_array_types_basic(self, cursor, query, description): assert array_value == [], f"Empty array should be [] for {description}" else: assert isinstance(array_value, list), f"ARRAY value should be list for {description}" - _logger.info(f"{description}: Array value type {type(array_value).__name__}") + _logger.info("%s: Array value type %s", description, type(array_value).__name__) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ( "SELECT ARRAY[ROW(1, 'Alice'), ROW(2, 'Bob'), ROW(3, 'Charlie')] AS struct_array", @@ -1225,11 +1230,11 @@ def test_array_types_basic(self, cursor, query, description): ) def test_array_types_with_structs(self, cursor, query, description): """Test ARRAY types containing STRUCT elements.""" - _logger.info(f"=== ARRAY with STRUCT Test: {description} ===") + _logger.info("=== ARRAY with STRUCT Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() array_value = result[0] - _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + _logger.info("%s: %r (type: %s)", description, array_value, type(array_value).__name__) # Validate array value assert array_value is not None, f"ARRAY value should not be None for {description}" @@ -1241,10 +1246,10 @@ def test_array_types_with_structs(self, cursor, query, description): assert isinstance(first_element, dict), ( f"First array element should be dict (struct) for {description}" ) - _logger.info(f"{description}: First element: {first_element!r}") + _logger.info("%s: First element: %r", description, first_element) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ( "SELECT CAST(ARRAY[1, 2, 3] AS JSON) AS arr_json", @@ -1262,21 +1267,21 @@ def test_array_types_with_structs(self, cursor, query, description): ) def test_array_types_json_cast(self, cursor, query, description): """Test ARRAY types with JSON casting.""" - _logger.info(f"=== ARRAY JSON Cast Test: {description} ===") + _logger.info("=== ARRAY JSON Cast Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() array_value = result[0] - _logger.info(f"{description}: {array_value!r} (type: {type(array_value).__name__})") + _logger.info("%s: %r (type: %s)", description, array_value, type(array_value).__name__) # Validate array value assert array_value is not None, f"ARRAY value should not be None for {description}" assert isinstance(array_value, list), ( f"JSON cast ARRAY value should be list for {description}" ) - _logger.info(f"{description}: JSON cast array type {type(array_value).__name__}") + _logger.info("%s: JSON cast array type %s", description, type(array_value).__name__) @pytest.mark.parametrize( - "query,description", + ("query", "description"), [ ("SELECT CARDINALITY(ARRAY[1, 2, 3, 4, 5]) AS array_size", "array_size"), ("SELECT ARRAY[10, 20, 30, 40][2] AS array_element", "array_element"), @@ -1286,12 +1291,15 @@ def test_array_types_json_cast(self, cursor, query, description): ) def test_array_operations(self, cursor, query, description): """Test ARRAY operations and functions.""" - _logger.info(f"=== ARRAY Operation Test: {description} ===") + _logger.info("=== ARRAY Operation Test: %s ===", description) cursor.execute(query) result = cursor.fetchone() operation_result = result[0] _logger.info( - f"{description}: {operation_result!r} (type: {type(operation_result).__name__})" + "%s: %r (type: %s)", + description, + operation_result, + type(operation_result).__name__, ) # Validate operation result @@ -1319,7 +1327,7 @@ def test_array_converter_behavior(self, cursor): cursor.execute("SELECT ARRAY[1, 2, 3] AS simple") result = cursor.fetchone() simple_array = result[0] - _logger.info(f"Simple array: {simple_array!r}") + _logger.info("Simple array: %r", simple_array) assert simple_array == [1, 2, 3] # Test array with struct conversion @@ -1329,7 +1337,7 @@ def test_array_converter_behavior(self, cursor): ) result = cursor.fetchone() struct_array = result[0] - _logger.info(f"Struct array: {struct_array!r}") + _logger.info("Struct array: %r", struct_array) assert isinstance(struct_array, list) assert len(struct_array) == 2 assert isinstance(struct_array[0], dict) @@ -1346,7 +1354,7 @@ def test_array_converter_behavior(self, cursor): for test_input, expected in test_cases: result = _to_array(test_input) - _logger.info(f"Converter test: {test_input!r} -> {result!r}") + _logger.info("Converter test: %r -> %r", test_input, result) assert result == expected, ( f"Converter failed for {test_input}: expected {expected}, got {result}" ) diff --git a/tests/pyathena/test_formatter.py b/tests/pyathena/test_formatter.py index 13d7951b..631a1163 100644 --- a/tests/pyathena/test_formatter.py +++ b/tests/pyathena/test_formatter.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import textwrap from datetime import date, datetime from decimal import Decimal diff --git a/tests/pyathena/test_model.py b/tests/pyathena/test_model.py index a1c73188..628da672 100644 --- a/tests/pyathena/test_model.py +++ b/tests/pyathena/test_model.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- import copy from datetime import datetime diff --git a/tests/pyathena/test_util.py b/tests/pyathena/test_util.py index d8e285ef..409aa206 100644 --- a/tests/pyathena/test_util.py +++ b/tests/pyathena/test_util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from typing import Any import pytest diff --git a/tests/pyathena/util.py b/tests/pyathena/util.py index 7c8650fe..092301b9 100644 --- a/tests/pyathena/util.py +++ b/tests/pyathena/util.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from pathlib import Path from jinja2 import Environment, FileSystemLoader diff --git a/tests/sqlalchemy/__init__.py b/tests/sqlalchemy/__init__.py index fd74e829..9a9dea77 100644 --- a/tests/sqlalchemy/__init__.py +++ b/tests/sqlalchemy/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- from sqlalchemy.dialects import registry registry.register("awsathena", "pyathena.sqlalchemy.base", "AthenaDialect") diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index ca0f0b26..b7056916 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,8 +1,7 @@ -# -*- coding: utf-8 -*- import contextlib from urllib.parse import quote_plus -from sqlalchemy.testing.plugin.pytestplugin import * # noqa +from sqlalchemy.testing.plugin.pytestplugin import * # noqa: F403 from sqlalchemy.testing.plugin.pytestplugin import ( pytest_sessionstart as sqlalchemy_pytest_sessionstart, ) diff --git a/tests/sqlalchemy/test_suite.py b/tests/sqlalchemy/test_suite.py index 1e2c59bc..4ffdd9cc 100644 --- a/tests/sqlalchemy/test_suite.py +++ b/tests/sqlalchemy/test_suite.py @@ -1,6 +1,5 @@ -# -*- coding: utf-8 -*- import pytest -from sqlalchemy.testing.suite import * # noqa +from sqlalchemy.testing.suite import * # noqa: F403 from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest from sqlalchemy.testing.suite import HasTableTest as _HasTableTest from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest @@ -8,25 +7,25 @@ from sqlalchemy.testing.suite import StringTest as _StringTest from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest -del BinaryTest # noqa -del ComponentReflectionTest # noqa -del ComponentReflectionTestExtra # noqa -del CompositeKeyReflectionTest # noqa -del CTETest # noqa -del DateTimeMicrosecondsTest # noqa -del DifficultParametersTest # noqa -del DistinctOnTest # noqa -del HasIndexTest # noqa -del IdentityAutoincrementTest # noqa -del JoinTest # noqa -del LongNameBlowoutTest # noqa -del QuotedNameArgumentTest # noqa -del RowCountTest # noqa -del SimpleUpdateDeleteTest # noqa -del TimeMicrosecondsTest # noqa -del TimeTest # noqa -del TimestampMicrosecondsTest # noqa -del UuidTest # noqa +del BinaryTest # noqa: F821 +del ComponentReflectionTest # noqa: F821 +del ComponentReflectionTestExtra # noqa: F821 +del CompositeKeyReflectionTest # noqa: F821 +del CTETest # noqa: F821 +del DateTimeMicrosecondsTest # noqa: F821 +del DifficultParametersTest # noqa: F821 +del DistinctOnTest # noqa: F821 +del HasIndexTest # noqa: F821 +del IdentityAutoincrementTest # noqa: F821 +del JoinTest # noqa: F821 +del LongNameBlowoutTest # noqa: F821 +del QuotedNameArgumentTest # noqa: F821 +del RowCountTest # noqa: F821 +del SimpleUpdateDeleteTest # noqa: F821 +del TimeMicrosecondsTest # noqa: F821 +del TimeTest # noqa: F821 +del TimestampMicrosecondsTest # noqa: F821 +del UuidTest # noqa: F821 class HasTableTest(_HasTableTest): From 46d2feb660c22898b842ade426d8cfa25791d2d5 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sun, 22 Feb 2026 21:43:34 +0900 Subject: [PATCH 2/2] Rename make fmt to make format Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- Makefile | 4 ++-- docs/testing.md | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 54bc4452..c2725412 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -16,7 +16,7 @@ PyAthena is a Python DB API 2.0 (PEP 249) compliant client for Amazon Athena. Se ### Code Quality — Always Run Before Committing ```bash -make fmt # Auto-fix formatting and imports +make format # Auto-fix formatting and imports make lint # Lint + format check + mypy ``` diff --git a/Makefile b/Makefile index 465e8431..593bb0ae 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ RUFF_VERSION := 0.14.14 TOX_VERSION := 4.34.1 -.PHONY: fmt -fmt: +.PHONY: format +format: # TODO: https://github.com/astral-sh/uv/issues/5903 uvx ruff@$(RUFF_VERSION) check --select I --fix . uvx ruff@$(RUFF_VERSION) format . diff --git a/docs/testing.md b/docs/testing.md index 43fe5794..1bc2865f 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -51,7 +51,7 @@ The code formatting uses [ruff](https://github.com/astral-sh/ruff). ### Appy format ```bash -$ make fmt +$ make format ``` ### Lint and check format