From 2e0476dae412eb97f147e6c97e99d24f53be147f Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Sat, 17 Aug 2024 22:51:26 +0900 Subject: [PATCH 1/3] Support Athena parameterized queries when paramstyle is qmark (fix #545) --- pyathena/__init__.py | 6 ++---- pyathena/arrow/async_cursor.py | 6 ++++-- pyathena/arrow/cursor.py | 9 +++++++-- pyathena/async_cursor.py | 7 ++++++- pyathena/common.py | 22 ++++++++++++++++++---- pyathena/connection.py | 12 ++++-------- pyathena/cursor.py | 9 +++++++-- pyathena/filesystem/s3.py | 4 +++- pyathena/filesystem/s3_object.py | 6 +++--- pyathena/pandas/async_cursor.py | 4 +++- pyathena/pandas/cursor.py | 4 +++- pyathena/result_set.py | 6 +++--- pyathena/spark/async_cursor.py | 4 ++-- pyathena/spark/cursor.py | 2 +- 14 files changed, 66 insertions(+), 35 deletions(-) diff --git a/pyathena/__init__.py b/pyathena/__init__.py index 86c76111..2a0e2247 100644 --- a/pyathena/__init__.py +++ b/pyathena/__init__.py @@ -59,15 +59,13 @@ def __hash__(self): @overload -def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]": - ... +def connect(*args, cursor_class: None = ..., **kwargs) -> "Connection[Cursor]": ... @overload def connect( *args, cursor_class: Type[ConnectionCursor], **kwargs -) -> "Connection[ConnectionCursor]": - ... +) -> "Connection[ConnectionCursor]": ... def connect(*args, **kwargs) -> "Connection[Any]": diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index b0ab1002..797aeef2 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -4,7 +4,7 @@ import logging from concurrent.futures import Future from multiprocessing import cpu_count -from typing import Any, Dict, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Tuple, Union, cast from pyathena import ProgrammingError from pyathena.arrow.converter import ( @@ -96,13 +96,14 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Tuple[str, "Future[Union[AthenaArrowResultSet, Any]]"]: if self._unload: @@ -125,6 +126,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return ( query_id, diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 38dc691f..42474cfc 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -99,13 +99,14 @@ def close(self) -> None: def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> ArrowCursor: self._reset_state() @@ -129,6 +130,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: @@ -147,7 +149,10 @@ def execute( return self def executemany( - self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, ) -> None: for parameters in seq_of_parameters: self.execute(operation, parameters, **kwargs) diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 48fe74ca..d9f40f73 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -104,6 +104,7 @@ def execute( cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]: query_id = self._execute( @@ -115,11 +116,15 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return query_id, self._executor.submit(self._collect_result_set, query_id) def executemany( - self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, ) -> None: raise NotSupportedError diff --git a/pyathena/common.py b/pyathena/common.py index 64ce0b0e..3f600a98 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -8,6 +8,7 @@ from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +import pyathena from pyathena.converter import Converter, DefaultTypeConverter from pyathena.error import DatabaseError, OperationalError, ProgrammingError from pyathena.formatter import Formatter @@ -144,6 +145,7 @@ def _build_start_query_execution_request( s3_staging_dir: Optional[str] = None, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + execution_parameters: Optional[List[str]] = None, ) -> Dict[str, Any]: request: Dict[str, Any] = { "QueryString": query, @@ -177,6 +179,8 @@ def _build_start_query_execution_request( else self._result_reuse_minutes, } request["ResultReuseConfiguration"] = {"ResultReuseByAgeConfiguration": reuse_conf} + if execution_parameters: + request["ExecutionParameters"] = execution_parameters return request def _build_start_calculation_execution_request( @@ -546,15 +550,21 @@ def _find_previous_query_id( def _execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, ) -> str: - query = self._formatter.format(operation, parameters) + if pyathena.paramstyle == "qmark" or paramstyle == "qmark": + query = operation + execution_parameters = cast(Optional[List[str]], parameters) + else: + query = self._formatter.format(operation, cast(Optional[Dict[str, Any]], parameters)) + execution_parameters = None _logger.debug(query) request = self._build_start_query_execution_request( @@ -563,6 +573,7 @@ def _execute( s3_staging_dir=s3_staging_dir, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + execution_parameters=execution_parameters, ) query_id = self._find_previous_query_id( query, @@ -612,14 +623,17 @@ def _calculate( def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, **kwargs, ): raise NotImplementedError # pragma: no cover @abstractmethod def executemany( - self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, ) -> None: raise NotImplementedError # pragma: no cover diff --git a/pyathena/connection.py b/pyathena/connection.py index bfd56150..d10c052d 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -90,8 +90,7 @@ def __init__( result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., **kwargs, - ) -> None: - ... + ) -> None: ... @overload def __init__( @@ -121,8 +120,7 @@ def __init__( result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., **kwargs, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -329,12 +327,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close() @overload - def cursor(self, cursor: None = ..., **kwargs) -> ConnectionCursor: - ... + def cursor(self, cursor: None = ..., **kwargs) -> ConnectionCursor: ... @overload - def cursor(self, cursor: Type[FunctionalCursor], **kwargs) -> FunctionalCursor: - ... + def cursor(self, cursor: Type[FunctionalCursor], **kwargs) -> FunctionalCursor: ... def cursor( self, cursor: Optional[Type[FunctionalCursor]] = None, **kwargs diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 81d6fcc8..65bef0d0 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -75,13 +75,14 @@ def close(self) -> None: def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: int = 0, cache_expiration_time: int = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, **kwargs, ) -> Cursor: self._reset_state() @@ -94,6 +95,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: @@ -109,7 +111,10 @@ def execute( return self def executemany( - self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs + self, + operation: str, + seq_of_parameters: List[Optional[Union[Dict[str, Any], List[str]]]], + **kwargs, ) -> None: for parameters in seq_of_parameters: self.execute(operation, parameters, **kwargs) diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index f00e8399..bf9b8de5 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -542,7 +542,9 @@ def touch(self, path: str, truncate: bool = True, **kwargs) -> Dict[str, Any]: self.invalidate_cache(path) return object_.to_dict() - def cp_file(self, path1: str, path2: str, recursive=False, maxdepth=None, on_error=None, **kwargs): + def cp_file( + self, path1: str, path2: str, recursive=False, maxdepth=None, on_error=None, **kwargs + ): # TODO: Delete the value that seems to be a typo, onerror=false. # https://github.com/fsspec/filesystem_spec/commit/346a589fef9308550ffa3d0d510f2db67281bb05 # https://github.com/fsspec/filesystem_spec/blob/2024.10.0/fsspec/spec.py#L1185 diff --git a/pyathena/filesystem/s3_object.py b/pyathena/filesystem/s3_object.py index 14b94cf6..15deba0d 100644 --- a/pyathena/filesystem/s3_object.py +++ b/pyathena/filesystem/s3_object.py @@ -67,9 +67,9 @@ def __init__( # https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseSyntax # Amazon S3 returns this header for all objects except for # S3 Standard storage class objects. - filtered[ - _API_FIELD_TO_S3_OBJECT_PROPERTY["StorageClass"] - ] = S3StorageClass.S3_STORAGE_CLASS_STANDARD + filtered[_API_FIELD_TO_S3_OBJECT_PROPERTY["StorageClass"]] = ( + S3StorageClass.S3_STORAGE_CLASS_STANDARD + ) super().update(filtered) if "Size" in init: self.content_length = init["Size"] diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index 5668204b..bea6677f 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -108,13 +108,14 @@ def _collect_result_set( def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, keep_default_na: bool = False, na_values: Optional[Iterable[str]] = ("",), quoting: int = 1, @@ -140,6 +141,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) return ( query_id, diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index b1bb1bbd..5c273ab3 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -121,13 +121,14 @@ def close(self) -> None: def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, work_group: Optional[str] = None, s3_staging_dir: Optional[str] = None, cache_size: Optional[int] = 0, cache_expiration_time: Optional[int] = 0, result_reuse_enable: Optional[bool] = None, result_reuse_minutes: Optional[int] = None, + paramstyle: Optional[str] = None, keep_default_na: bool = False, na_values: Optional[Iterable[str]] = ("",), quoting: int = 1, @@ -154,6 +155,7 @@ def execute( cache_expiration_time=cache_expiration_time, result_reuse_enable=result_reuse_enable, result_reuse_minutes=result_reuse_minutes, + paramstyle=paramstyle, ) query_execution = cast(AthenaQueryExecution, self._poll(self.query_id)) if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED: diff --git a/pyathena/result_set.py b/pyathena/result_set.py index b96a8901..4c9891aa 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -53,9 +53,9 @@ def __init__( ) self._metadata: Optional[Tuple[Dict[str, Any], ...]] = None - self._rows: Deque[ - Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]] - ] = collections.deque() + self._rows: Deque[Union[Tuple[Optional[Any], ...], Dict[Any, Optional[Any]]]] = ( + collections.deque() + ) self._next_token: Optional[str] = None if self.state == AthenaQueryExecution.STATE_SUCCEEDED: diff --git a/pyathena/spark/async_cursor.py b/pyathena/spark/async_cursor.py index ce91d06e..eb273a8b 100644 --- a/pyathena/spark/async_cursor.py +++ b/pyathena/spark/async_cursor.py @@ -2,7 +2,7 @@ import logging from concurrent.futures import Future, ThreadPoolExecutor from multiprocessing import cpu_count -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast from pyathena.model import AthenaCalculationExecution from pyathena.spark.common import SparkBaseCursor @@ -68,7 +68,7 @@ def poll(self, query_id: str) -> "Future[AthenaCalculationExecution]": def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, session_id: Optional[str] = None, description: Optional[str] = None, client_request_token: Optional[str] = None, diff --git a/pyathena/spark/cursor.py b/pyathena/spark/cursor.py index 3ec6e813..63bfe546 100644 --- a/pyathena/spark/cursor.py +++ b/pyathena/spark/cursor.py @@ -47,7 +47,7 @@ def get_std_error(self) -> Optional[str]: def execute( self, operation: str, - parameters: Optional[Dict[str, Any]] = None, + parameters: Optional[Union[Dict[str, Any], List[str]]] = None, session_id: Optional[str] = None, description: Optional[str] = None, client_request_token: Optional[str] = None, From 975810b4312c9955cc188e7210969614fc959788 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 26 Dec 2024 15:31:48 +0900 Subject: [PATCH 2/3] Add test cases for qmark parameter --- tests/pyathena/test_cursor.py | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index e1fad86b..a3de67a3 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -248,6 +248,46 @@ def test_fetch_no_data(self, cursor): pytest.raises(ProgrammingError, cursor.fetchmany) pytest.raises(ProgrammingError, cursor.fetchall) + def test_query_with_parameter(self, cursor): + cursor.execute( + """ + SELECT * FROM many_rows + WHERE a < %(param)d + """, + {"param": 10}, + ) + assert cursor.fetchall() == [(i,) for i in range(10)] + + cursor.execute( + """ + SELECT col_string FROM one_row_complex + WHERE col_string = %(param)s + """, + {"param": "a string"}, + ) + assert cursor.fetchall() == [("a string",)] + + def test_query_with_parameter_qmark(self, cursor): + cursor.execute( + """ + SELECT * FROM many_rows + WHERE a < ? + """, + ["10"], + paramstyle="qmark", + ) + assert cursor.fetchall() == [(i,) for i in range(10)] + + cursor.execute( + """ + SELECT col_string FROM one_row_complex + WHERE col_string = ? + """, + ["'a string'"], + paramstyle="qmark", + ) + assert cursor.fetchall() == [("a string",)] + def test_null_param(self, cursor): cursor.execute("SELECT %(param)s FROM one_row", {"param": None}) assert cursor.fetchall() == [(None,)] From 13565b5a76fe18a4183c154a8021b94453af03ed Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Thu, 26 Dec 2024 16:05:42 +0900 Subject: [PATCH 3/3] Update docs --- docs/usage.rst | 41 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/docs/usage.rst b/docs/usage.rst index fda041b1..ef3e99e6 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -44,7 +44,8 @@ Supported `DB API paramstyle`_ is only ``PyFormat``. cursor.execute(""" SELECT col_string FROM one_row_complex WHERE col_string = %(param)s - """, {"param": "a string"}) + """, + {"param": "a string"}) print(cursor.fetchall()) if ``%`` character is contained in your query, it must be escaped with ``%%`` like the following: @@ -54,6 +55,43 @@ if ``%`` character is contained in your query, it must be escaped with ``%%`` li SELECT col_string FROM one_row_complex WHERE col_string = %(param)s OR col_string LIKE 'a%%' +Use parameterized queries +~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you want to use Athena's parameterized queries, you can do so by changing the ``paramstyle`` to ``qmark`` as follows. + +.. code:: python + + from pyathena import connect + + pyathena.paramstyle = "qmark" + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor() + cursor.execute(""" + SELECT col_string FROM one_row_complex + WHERE col_string = ? + """, + ["'a string'"]) + print(cursor.fetchall()) + +You can also specify the ``paramstyle`` using the execute method when executing a query. + +.. code:: python + + from pyathena import connect + + cursor = connect(s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2").cursor() + cursor.execute(""" + SELECT col_string FROM one_row_complex + WHERE col_string = ? + """, + ["'a string'"], + paramstyle="qmark") + print(cursor.fetchall()) + +You can find more information about the `considerations and limitations of parameterized queries`_ in the official documentation. + Quickly re-run queries ---------------------- @@ -270,3 +308,4 @@ No need to specify credential information. .. _`reuse the results of previous queries`: https://docs.aws.amazon.com/athena/latest/ug/reusing-query-results.html .. _`Boto3 environment variables`: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables .. _`Boto3 credentials`: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html +.. _`considerations and limitations of parameterized queries`: https://docs.aws.amazon.com/athena/latest/ug/querying-with-prepared-statements.html