diff --git a/pyathena/aio/arrow/cursor.py b/pyathena/aio/arrow/cursor.py index cb7b9c2f..b9c7d6f1 100644 --- a/pyathena/aio/arrow/cursor.py +++ b/pyathena/aio/arrow/cursor.py @@ -13,7 +13,7 @@ from pyathena.arrow.result_set import AthenaArrowResultSet from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution if TYPE_CHECKING: import polars as pl @@ -110,18 +110,7 @@ async def execute( # type: ignore[override] Self reference for method chaining. """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = await self._execute( operation, parameters=parameters, diff --git a/pyathena/aio/pandas/cursor.py b/pyathena/aio/pandas/cursor.py index fbee2625..f9686e6b 100644 --- a/pyathena/aio/pandas/cursor.py +++ b/pyathena/aio/pandas/cursor.py @@ -18,7 +18,7 @@ from pyathena.aio.common import WithAsyncFetch from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.pandas.converter import ( DefaultPandasTypeConverter, DefaultPandasUnloadTypeConverter, @@ -133,18 +133,7 @@ async def execute( # type: ignore[override] Self reference for method chaining. """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = await self._execute( operation, parameters=parameters, diff --git a/pyathena/aio/polars/cursor.py b/pyathena/aio/polars/cursor.py index 1f897f9d..5b048f2b 100644 --- a/pyathena/aio/polars/cursor.py +++ b/pyathena/aio/polars/cursor.py @@ -9,7 +9,7 @@ from pyathena.aio.common import WithAsyncFetch from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.polars.converter import ( DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, @@ -115,18 +115,7 @@ async def execute( # type: ignore[override] Self reference for method chaining. """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = await self._execute( operation, parameters=parameters, diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 875d3cc3..32b2215c 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -14,7 +14,7 @@ from pyathena.arrow.result_set import AthenaArrowResultSet from pyathena.async_cursor import AsyncCursor from pyathena.common import CursorIterator -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution _logger = logging.getLogger(__name__) # type: ignore @@ -182,18 +182,7 @@ def execute( paramstyle: Optional[str] = None, **kwargs, ) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]: - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) query_id = self._execute( operation, parameters=parameters, diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 514333be..026e0951 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -11,7 +11,7 @@ from pyathena.arrow.result_set import AthenaArrowResultSet from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.result_set import WithFetch if TYPE_CHECKING: @@ -166,18 +166,7 @@ def execute( >>> table = cursor.as_arrow() # Returns Apache Arrow Table """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = self._execute( operation, parameters=parameters, diff --git a/pyathena/common.py b/pyathena/common.py index 2fc36f5b..f0688220 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -15,7 +15,9 @@ from pyathena.model import ( AthenaCalculationExecution, AthenaCalculationExecutionStatus, + AthenaCompression, AthenaDatabase, + AthenaFileFormat, AthenaQueryExecution, AthenaTableMetadata, ) @@ -652,6 +654,32 @@ def _prepare_query( _logger.debug(query) return query, execution_parameters + def _prepare_unload( + self, + operation: str, + s3_staging_dir: Optional[str], + ) -> Tuple[str, Optional[str]]: + """Wrap operation with UNLOAD if enabled. + + Args: + operation: SQL query string. + s3_staging_dir: S3 location for query results. + + Returns: + Tuple of (possibly-wrapped operation, unload_location or None). + """ + if not getattr(self, "_unload", False): + return operation, None + s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir + if not s3_staging_dir: + raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") + return self._formatter.wrap_unload( + operation, + s3_staging_dir=s3_staging_dir, + format_=AthenaFileFormat.FILE_FORMAT_PARQUET, + compression=AthenaCompression.COMPRESSION_SNAPPY, + ) + def _execute( self, operation: str, diff --git a/pyathena/formatter.py b/pyathena/formatter.py index 1c1f6860..4053aead 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -8,7 +8,7 @@ from copy import deepcopy from datetime import date, datetime, timezone from decimal import Decimal -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Tuple, Type from pyathena.error import ProgrammingError from pyathena.model import AthenaCompression, AthenaFileFormat @@ -86,7 +86,7 @@ def wrap_unload( s3_staging_dir: str, format_: str = AthenaFileFormat.FILE_FORMAT_PARQUET, compression: str = AthenaCompression.COMPRESSION_SNAPPY, - ): + ) -> Tuple[str, Optional[str]]: """Wrap a SELECT query with UNLOAD statement for high-performance result retrieval. Transforms SELECT or WITH queries into UNLOAD statements that export results diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 7fb87211..5df82c1c 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -9,7 +9,7 @@ from pyathena import ProgrammingError from pyathena.async_cursor import AsyncCursor from pyathena.common import CursorIterator -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.pandas.converter import ( DefaultPandasTypeConverter, DefaultPandasUnloadTypeConverter, @@ -159,18 +159,7 @@ def execute( quoting: int = 1, **kwargs, ) -> Tuple[str, "Future[Union[AthenaPandasResultSet, Any]]"]: - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) query_id = self._execute( operation, parameters=parameters, diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index e7bcc455..b87ecd50 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -18,7 +18,7 @@ from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.pandas.converter import ( DefaultPandasTypeConverter, DefaultPandasUnloadTypeConverter, @@ -193,18 +193,7 @@ def execute( >>> df = cursor.fetchall() # Returns pandas DataFrame """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = self._execute( operation, parameters=parameters, diff --git a/pyathena/polars/async_cursor.py b/pyathena/polars/async_cursor.py index 61e5432c..81f3625a 100644 --- a/pyathena/polars/async_cursor.py +++ b/pyathena/polars/async_cursor.py @@ -9,7 +9,7 @@ from pyathena import ProgrammingError from pyathena.async_cursor import AsyncCursor from pyathena.common import CursorIterator -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.polars.converter import ( DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, @@ -221,18 +221,7 @@ def execute( >>> result_set = future.result() >>> df = result_set.as_polars() # Returns Polars DataFrame """ - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) query_id = self._execute( operation, parameters=parameters, diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index 5ad08010..f3cb342a 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -17,7 +17,7 @@ from pyathena.common import CursorIterator from pyathena.error import OperationalError, ProgrammingError -from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution +from pyathena.model import AthenaQueryExecution from pyathena.polars.converter import ( DefaultPolarsTypeConverter, DefaultPolarsUnloadTypeConverter, @@ -191,18 +191,7 @@ def execute( >>> df = cursor.as_polars() # Returns Polars DataFrame """ self._reset_state() - if self._unload: - s3_staging_dir = s3_staging_dir if s3_staging_dir else self._s3_staging_dir - if not s3_staging_dir: - raise ProgrammingError("If the unload option is used, s3_staging_dir is required.") - operation, unload_location = self._formatter.wrap_unload( - operation, - s3_staging_dir=s3_staging_dir, - format_=AthenaFileFormat.FILE_FORMAT_PARQUET, - compression=AthenaCompression.COMPRESSION_SNAPPY, - ) - else: - unload_location = None + operation, unload_location = self._prepare_unload(operation, s3_staging_dir) self.query_id = self._execute( operation, parameters=parameters,