diff --git a/pyathena/__init__.py b/pyathena/__init__.py index b17dc4be..f58d39f9 100644 --- a/pyathena/__init__.py +++ b/pyathena/__init__.py @@ -28,14 +28,12 @@ class DBAPITypeObject(FrozenSet[str]): def __eq__(self, other: object): if isinstance(other, frozenset): return frozenset.__eq__(self, other) - else: - return other in self + return other in self def __ne__(self, other: object): if isinstance(other, frozenset): return frozenset.__ne__(self, other) - else: - return other not in self + return other not in self def __hash__(self): return frozenset.__hash__(self) diff --git a/pyathena/arrow/async_cursor.py b/pyathena/arrow/async_cursor.py index 797aeef2..a58f256d 100644 --- a/pyathena/arrow/async_cursor.py +++ b/pyathena/arrow/async_cursor.py @@ -60,8 +60,7 @@ def get_default_converter( ) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]: if unload: return DefaultArrowUnloadTypeConverter() - else: - return DefaultArrowTypeConverter() + return DefaultArrowTypeConverter() @property def arraysize(self) -> int: @@ -80,7 +79,7 @@ def _collect_result_set( kwargs: Optional[Dict[str, Any]] = None, ) -> AthenaArrowResultSet: if kwargs is None: - kwargs = dict() + kwargs = {} query_execution = cast(AthenaQueryExecution, self._poll(query_id)) return AthenaArrowResultSet( connection=self._connection, diff --git a/pyathena/arrow/converter.py b/pyathena/arrow/converter.py index fa22633c..49d20faf 100644 --- a/pyathena/arrow/converter.py +++ b/pyathena/arrow/converter.py @@ -22,10 +22,9 @@ def _to_date(value: Optional[Union[str, datetime]]) -> Optional[date]: if value is None: return None - elif isinstance(value, datetime): + if isinstance(value, datetime): return value.date() - else: - return datetime.strptime(value, "%Y-%m-%d").date() + return datetime.strptime(value, "%Y-%m-%d").date() _DEFAULT_ARROW_CONVERTERS: Dict[str, Callable[[Optional[str]], Optional[Any]]] = { @@ -82,7 +81,7 @@ def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: class DefaultArrowUnloadTypeConverter(Converter): def __init__(self) -> None: super().__init__( - mappings=dict(), + mappings={}, default=_to_default, ) diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 42474cfc..9b3458be 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -59,8 +59,7 @@ def get_default_converter( ) -> Union[DefaultArrowTypeConverter, DefaultArrowUnloadTypeConverter, Any]: if unload: return DefaultArrowUnloadTypeConverter() - else: - return DefaultArrowTypeConverter() + return DefaultArrowTypeConverter() @property def arraysize(self) -> int: diff --git a/pyathena/arrow/result_set.py b/pyathena/arrow/result_set.py index c7d695a0..c42518f1 100644 --- a/pyathena/arrow/result_set.py +++ b/pyathena/arrow/result_set.py @@ -80,7 +80,7 @@ def __init__( else: import pyarrow as pa - self._table = pa.Table.from_pydict(dict()) + self._table = pa.Table.from_pydict({}) self._batches = iter(self._table.to_batches(arraysize)) def __s3_file_system(self): @@ -199,14 +199,14 @@ def _read_csv(self) -> "Table": if not self.output_location: raise ProgrammingError("OutputLocation is none or empty.") if not self.output_location.endswith((".csv", ".txt")): - return pa.Table.from_pydict(dict()) + return pa.Table.from_pydict({}) if self.substatement_type and self.substatement_type.upper() in ( "UPDATE", "DELETE", "MERGE", "VACUUM_TABLE", ): - return pa.Table.from_pydict(dict()) + return pa.Table.from_pydict({}) length = self._get_content_length() if length and self.output_location.endswith(".txt"): description = self.description if self.description else [] @@ -232,7 +232,7 @@ def _read_csv(self) -> "Table": escape_char=False, ) else: - return pa.Table.from_pydict(dict()) + return pa.Table.from_pydict({}) bucket, key = parse_output_location(self.output_location) try: @@ -256,7 +256,7 @@ def _read_parquet(self) -> "Table": manifests = self._read_data_manifest() if not manifests: - return pa.Table.from_pydict(dict()) + return pa.Table.from_pydict({}) if not self._unload_location: self._unload_location = "/".join(manifests[0].split("/")[:-1]) + "/" @@ -283,5 +283,5 @@ def close(self) -> None: import pyarrow as pa super().close() - self._table = pa.Table.from_pydict(dict()) + self._table = pa.Table.from_pydict({}) self._batches = [] diff --git a/pyathena/arrow/util.py b/pyathena/arrow/util.py index 962ea754..6b24b7eb 100644 --- a/pyathena/arrow/util.py +++ b/pyathena/arrow/util.py @@ -29,42 +29,41 @@ def get_athena_type(type_: "DataType") -> Tuple[str, int, int]: if type_.id in [types.Type_BOOL]: # 1 return "boolean", 0, 0 - elif type_.id in [types.Type_UINT8, types.Type_INT8]: # 2, 3 + if type_.id in [types.Type_UINT8, types.Type_INT8]: # 2, 3 return "tinyint", 3, 0 - elif type_.id in [types.Type_UINT16, types.Type_INT16]: # 4, 5 + if type_.id in [types.Type_UINT16, types.Type_INT16]: # 4, 5 return "smallint", 5, 0 - elif type_.id in [types.Type_UINT32, types.Type_INT32]: # 6, 7 + if type_.id in [types.Type_UINT32, types.Type_INT32]: # 6, 7 return "integer", 10, 0 - elif type_.id in [types.Type_UINT64, types.Type_INT64]: # 8, 9 + if type_.id in [types.Type_UINT64, types.Type_INT64]: # 8, 9 return "bigint", 19, 0 - elif type_.id in [types.Type_HALF_FLOAT, types.Type_FLOAT]: # 10, 11 + if type_.id in [types.Type_HALF_FLOAT, types.Type_FLOAT]: # 10, 11 return "float", 17, 0 - elif type_.id in [types.Type_DOUBLE]: # 12 + if type_.id in [types.Type_DOUBLE]: # 12 return "double", 17, 0 - elif type_.id in [types.Type_STRING, types.Type_LARGE_STRING]: # 13, 34 + if type_.id in [types.Type_STRING, types.Type_LARGE_STRING]: # 13, 34 return "varchar", 2147483647, 0 - elif type_.id in [ + if type_.id in [ types.Type_BINARY, types.Type_FIXED_SIZE_BINARY, types.Type_LARGE_BINARY, ]: # 14, 15, 35 return "varbinary", 1073741824, 0 - elif type_.id in [types.Type_DATE32, types.Type_DATE64]: # 16, 17 + if type_.id in [types.Type_DATE32, types.Type_DATE64]: # 16, 17 return "date", 0, 0 - elif type_.id == types.Type_TIMESTAMP: # 18 + if type_.id == types.Type_TIMESTAMP: # 18 return "timestamp", 3, 0 - elif type_.id in [types.Type_DECIMAL128, types.Decimal256Type]: # 23, 24 + if type_.id in [types.Type_DECIMAL128, types.Decimal256Type]: # 23, 24 type_ = cast(types.Decimal128Type, type_) return "decimal", type_.precision, type_.scale - elif type_.id in [ + if type_.id in [ types.Type_LIST, types.Type_FIXED_SIZE_LIST, types.Type_LARGE_LIST, ]: # 25, 32, 36 return "array", 0, 0 - elif type_.id in [types.Type_STRUCT]: # 26 + if type_.id in [types.Type_STRUCT]: # 26 return "row", 0, 0 - elif type_.id in [types.Type_MAP]: # 30 + if type_.id in [types.Type_MAP]: # 30 return "map", 0, 0 - else: - return "string", 2147483647, 0 + return "string", 2147483647, 0 diff --git a/pyathena/common.py b/pyathena/common.py index 3f600a98..59488a55 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -78,8 +78,7 @@ def __next__(self): row = self.fetchone() if row is None: raise StopIteration - else: - return row + return row def __iter__(self): return self @@ -482,8 +481,7 @@ def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculation AthenaQueryExecution.STATE_CANCELLED, ]: return query_execution - else: - time.sleep(self._poll_interval) + time.sleep(self._poll_interval) def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: try: @@ -654,11 +652,11 @@ def _cancel(self, query_id: str) -> None: _logger.exception("Failed to cancel query.") raise OperationalError(*e.args) from e - def setinputsizes(self, sizes): + def setinputsizes(self, sizes): # noqa: B027 """Does nothing by default""" pass - def setoutputsize(self, size, column=None): + def setoutputsize(self, size, column=None): # noqa: B027 """Does nothing by default""" pass diff --git a/pyathena/connection.py b/pyathena/connection.py index d10c052d..5d4cd0d2 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -141,7 +141,7 @@ def __init__( converter: Optional[Converter] = None, formatter: Optional[Formatter] = None, retry_config: Optional[RetryConfig] = None, - cursor_class: Optional[Type[ConnectionCursor]] = cast(Type[ConnectionCursor], Cursor), + cursor_class: Optional[Type[ConnectionCursor]] = None, cursor_kwargs: Optional[Dict[str, Any]] = None, kill_on_interrupt: bool = True, session: Optional[Session] = None, @@ -234,8 +234,8 @@ def __init__( self._converter = converter self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() - self.cursor_class = cast(Type[ConnectionCursor], cursor_class) - self.cursor_kwargs = cursor_kwargs if cursor_kwargs else dict() + self.cursor_class = cursor_class if cursor_class else cast(Type[ConnectionCursor], Cursor) + self.cursor_kwargs = cursor_kwargs if cursor_kwargs else {} self.kill_on_interrupt = kill_on_interrupt self.result_reuse_enable = result_reuse_enable self.result_reuse_minutes = result_reuse_minutes diff --git a/pyathena/converter.py b/pyathena/converter.py index bbe6f823..23baebd9 100644 --- a/pyathena/converter.py +++ b/pyathena/converter.py @@ -117,12 +117,12 @@ def __init__( if mappings: self._mappings = mappings else: - self._mappings = dict() + self._mappings = {} self._default = default if types: self._types = types else: - self._types = dict() + self._types = {} @property def mappings(self) -> Dict[str, Callable[[Optional[str]], Optional[Any]]]: diff --git a/pyathena/fastparquet/util.py b/pyathena/fastparquet/util.py index 1e1390b0..45d53db3 100644 --- a/pyathena/fastparquet/util.py +++ b/pyathena/fastparquet/util.py @@ -37,34 +37,29 @@ def get_athena_type(type_: "SchemaElement") -> Tuple[str, int, int]: if type_.type in [Type.BOOLEAN]: return "boolean", 0, 0 - elif type_.type in [Type.INT32]: + if type_.type in [Type.INT32]: if type_.converted_type == ConvertedType.DATE: return "date", 0, 0 - else: - return "integer", 10, 0 - elif type_.type in [Type.INT64]: + return "integer", 10, 0 + if type_.type in [Type.INT64]: return "bigint", 19, 0 - elif type_.type in [Type.INT96]: + if type_.type in [Type.INT96]: return "timestamp", 3, 0 - elif type_.type in [Type.FLOAT]: + if type_.type in [Type.FLOAT]: return "float", 17, 0 - elif type_.type in [Type.DOUBLE]: + if type_.type in [Type.DOUBLE]: return "double", 17, 0 - elif type_.type in [Type.BYTE_ARRAY, Type.FIXED_LEN_BYTE_ARRAY]: + if type_.type in [Type.BYTE_ARRAY, Type.FIXED_LEN_BYTE_ARRAY]: if type_.converted_type == ConvertedType.UTF8: return "varchar", 2147483647, 0 - elif type_.converted_type == ConvertedType.DECIMAL: + if type_.converted_type == ConvertedType.DECIMAL: return "decimal", type_.precision, type_.scale - else: - return "varbinary", 1073741824, 0 - else: - if type_.converted_type == ConvertedType.LIST: - return "array", 0, 0 - elif type_.converted_type == ConvertedType.MAP: - return "map", 0, 0 - else: - children = getattr(type_, "children", []) - if type_.type is None and type_.converted_type is None and children: - return "row", 0, 0 - else: - return "string", 2147483647, 0 + return "varbinary", 1073741824, 0 + if type_.converted_type == ConvertedType.LIST: + return "array", 0, 0 + if type_.converted_type == ConvertedType.MAP: + return "map", 0, 0 + children = getattr(type_, "children", []) + if type_.type is None and type_.converted_type is None and children: + return "row", 0, 0 + return "string", 2147483647, 0 diff --git a/pyathena/filesystem/s3.py b/pyathena/filesystem/s3.py index bf9b8de5..8a062570 100644 --- a/pyathena/filesystem/s3.py +++ b/pyathena/filesystem/s3.py @@ -125,11 +125,11 @@ def _get_client_compatible_with_s3fs(self, **kwargs) -> BaseClient: if anon: config_kwargs.update({"signature_version": UNSIGNED}) else: - creds = dict( - aws_access_key_id=kwargs.pop("key", kwargs.pop("username", None)), - aws_secret_access_key=kwargs.pop("secret", kwargs.pop("password", None)), - aws_session_token=kwargs.pop("token", None), - ) + creds = { + "aws_access_key_id": kwargs.pop("key", kwargs.pop("username", None)), + "aws_secret_access_key": kwargs.pop("secret", kwargs.pop("password", None)), + "aws_session_token": kwargs.pop("token", None), + } kwargs.update(**creds) client_kwargs.update(**creds) @@ -148,8 +148,7 @@ def parse_path(path: str) -> Tuple[str, Optional[str], Optional[str]]: match = S3FileSystem.PATTERN_PATH.search(path) if match: return match.group("bucket"), match.group("key"), match.group("version_id") - else: - raise ValueError(f"Invalid S3 path format {path}.") + raise ValueError(f"Invalid S3 path format {path}.") def _head_bucket(self, bucket, refresh: bool = False) -> Optional[S3Object]: if bucket not in self.dircache or refresh: @@ -299,10 +298,7 @@ def _ls_dirs( self.dircache[path] = files else: cache = self.dircache[path] - if not isinstance(cache, list): - files = [cache] - else: - files = cache + files = cache if isinstance(cache, list) else [cache] return files def ls( @@ -317,7 +313,7 @@ def ls( file = self._head_object(path, refresh=refresh) if file: files = [file] - return [f for f in files] if detail else [f.name for f in files] + return list(files) if detail else [f.name for f in files] def info(self, path: str, **kwargs) -> S3Object: refresh = kwargs.pop("refresh", False) @@ -350,20 +346,19 @@ def info(self, path: str, **kwargs) -> S3Object: if cache: return cache - else: - return S3Object( - init={ - "ContentLength": 0, - "ContentType": None, - "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, - "ETag": None, - "LastModified": None, - }, - type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, - bucket=bucket, - key=key.rstrip("/") if key else None, - version_id=version_id, - ) + return S3Object( + init={ + "ContentLength": 0, + "ContentType": None, + "StorageClass": S3StorageClass.S3_STORAGE_CLASS_DIRECTORY, + "ETag": None, + "LastModified": None, + }, + type=S3ObjectType.S3_OBJECT_TYPE_DIRECTORY, + bucket=bucket, + key=key.rstrip("/") if key else None, + version_id=version_id, + ) if key: object_info = self._head_object(path, refresh=refresh, version_id=version_id) if object_info: @@ -372,8 +367,7 @@ def info(self, path: str, **kwargs) -> S3Object: bucket_info = self._head_bucket(path, refresh=refresh) if bucket_info: return bucket_info - else: - raise FileNotFoundError(path) + raise FileNotFoundError(path) response = self._call( self._client.list_objects_v2, @@ -400,8 +394,7 @@ def info(self, path: str, **kwargs) -> S3Object: key=key.rstrip("/") if key else None, version_id=version_id, ) - else: - raise FileNotFoundError(path) + raise FileNotFoundError(path) def find( self, @@ -426,8 +419,7 @@ def find( files = [] if detail: return {f.name: f for f in files} - else: - return [f.name for f in files] + return [f.name for f in files] def exists(self, path: str, **kwargs) -> bool: path = self._strip_protocol(path) @@ -440,10 +432,7 @@ def exists(self, path: str, **kwargs) -> bool: if self._ls_from_cache(path): return True info = self.info(path) - if info: - return True - else: - return False + return bool(info) except FileNotFoundError: return False elif self.dircache.get(bucket, False): @@ -455,10 +444,7 @@ def exists(self, path: str, **kwargs) -> bool: except FileNotFoundError: pass file = self._head_bucket(bucket) - if file: - return True - else: - return False + return bool(file) def rm_file(self, path: str, **kwargs) -> None: bucket, key, version_id = self.parse_path(path) @@ -725,11 +711,13 @@ def put_file(self, lpath: str, rpath: str, callback=_DEFAULT_CALLBACK, **kwargs) if content_type is not None: kwargs["ContentType"] = content_type - with self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote: - with open(lpath, "rb") as local: - while data := local.read(remote.blocksize): - remote.write(data) - callback.relative_update(len(data)) + with ( + self.open(rpath, "wb", s3_additional_kwargs=kwargs) as remote, + open(lpath, "rb") as local, + ): + while data := local.read(remote.blocksize): + remote.write(data) + callback.relative_update(len(data)) self.invalidate_cache(rpath) @@ -737,20 +725,18 @@ def get_file(self, rpath: str, lpath: str, callback=_DEFAULT_CALLBACK, outfile=N if os.path.isdir(lpath): return - with open(lpath, "wb") as local: - with self.open(rpath, "rb", **kwargs) as remote: - callback.set_size(remote.size) - while data := remote.read(remote.blocksize): - local.write(data) - callback.relative_update(len(data)) + with open(lpath, "wb") as local, self.open(rpath, "rb", **kwargs) as remote: + callback.set_size(remote.size) + while data := remote.read(remote.blocksize): + local.write(data) + callback.relative_update(len(data)) def checksum(self, path: str, **kwargs): refresh = kwargs.pop("refresh", False) info = self.info(path, refresh=refresh) if info.get("type") != S3ObjectType.S3_OBJECT_TYPE_DIRECTORY: return int(info.get("etag").strip('"').split("-")[0], 16) - else: - return int(tokenize(info), 16) + return int(tokenize(info), 16) def sign(self, path: str, expiration: int = 3600, **kwargs): bucket, key, version_id = self.parse_path(path) @@ -947,10 +933,7 @@ def _complete_multipart_upload( return S3CompleteMultipartUpload(response) def _call(self, method: Union[str, Callable[..., Any]], **kwargs) -> Dict[str, Any]: - if isinstance(method, str): - func = getattr(self._client, method) - else: - func = method + func = getattr(self._client, method) if isinstance(method, str) else method response = retry_api_call( func, config=self._retry_config, logger=_logger, **kwargs, **self.request_kwargs ) @@ -1235,9 +1218,8 @@ def _get_ranges( if range_end > end: ranges.append((range_start, end)) break - else: - ranges.append((range_start, range_end)) - range_start += worker_block_size + ranges.append((range_start, range_end)) + range_start += worker_block_size else: ranges.append((start, end)) return ranges diff --git a/pyathena/formatter.py b/pyathena/formatter.py index a3074e9f..cf6f77d4 100644 --- a/pyathena/formatter.py +++ b/pyathena/formatter.py @@ -193,7 +193,7 @@ def format(self, operation: str, parameters: Optional[Dict[str, Any]] = None) -> kwargs: Optional[Dict[str, Any]] = None if parameters is not None: - kwargs = dict() + kwargs = {} if not parameters: pass elif isinstance(parameters, dict): diff --git a/pyathena/model.py b/pyathena/model.py index 78b6ef2f..365d12d6 100644 --- a/pyathena/model.py +++ b/pyathena/model.py @@ -547,14 +547,13 @@ def serde_serialization_lib(self) -> Optional[str]: def compression(self) -> Optional[str]: if "write.compression" in self._parameters: # text or json return self._parameters["write.compression"] - elif "serde.param.write.compression" in self._parameters: # text or json + if "serde.param.write.compression" in self._parameters: # text or json return self._parameters["serde.param.write.compression"] - elif "parquet.compress" in self._parameters: # parquet + if "parquet.compress" in self._parameters: # parquet return self._parameters["parquet.compress"] - elif "orc.compress" in self._parameters: # orc + if "orc.compress" in self._parameters: # orc return self._parameters["orc.compress"] - else: - return None + return None @property def serde_properties(self) -> Dict[str, str]: diff --git a/pyathena/pandas/async_cursor.py b/pyathena/pandas/async_cursor.py index d6317e04..838833cb 100644 --- a/pyathena/pandas/async_cursor.py +++ b/pyathena/pandas/async_cursor.py @@ -64,8 +64,7 @@ def get_default_converter( ) -> Union[DefaultPandasTypeConverter, Any]: if unload: return DefaultPandasUnloadTypeConverter() - else: - return DefaultPandasTypeConverter() + return DefaultPandasTypeConverter() @property def arraysize(self) -> int: @@ -87,7 +86,7 @@ def _collect_result_set( kwargs: Optional[Dict[str, Any]] = None, ) -> AthenaPandasResultSet: if kwargs is None: - kwargs = dict() + kwargs = {} query_execution = cast(AthenaQueryExecution, self._poll(query_id)) return AthenaPandasResultSet( connection=self._connection, diff --git a/pyathena/pandas/converter.py b/pyathena/pandas/converter.py index b1e7f938..2ff0969d 100644 --- a/pyathena/pandas/converter.py +++ b/pyathena/pandas/converter.py @@ -62,7 +62,7 @@ def convert(self, type_: str, value: Optional[str]) -> Optional[Any]: class DefaultPandasUnloadTypeConverter(Converter): def __init__(self) -> None: super().__init__( - mappings=dict(), + mappings={}, default=_to_default, ) diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 0d39f8e3..65c724fd 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -81,8 +81,7 @@ def get_default_converter( ) -> Union[DefaultPandasTypeConverter, Any]: if unload: return DefaultPandasUnloadTypeConverter() - else: - return DefaultPandasTypeConverter() + return DefaultPandasTypeConverter() @property def arraysize(self) -> int: diff --git a/pyathena/pandas/result_set.py b/pyathena/pandas/result_set.py index 235f44e1..6c586e34 100644 --- a/pyathena/pandas/result_set.py +++ b/pyathena/pandas/result_set.py @@ -85,8 +85,7 @@ def get_chunk(self, size=None): if isinstance(self._reader, TextFileReader): return self._reader.get_chunk(size) - else: - return next(self._reader) + return next(self._reader) class AthenaPandasResultSet(AthenaResultSet): @@ -141,10 +140,7 @@ def __init__( self._fs = self.__s3_file_system() if self.state == AthenaQueryExecution.STATE_SUCCEEDED and self.output_location: df = self._as_pandas() - if self.is_unload: - trunc_date = _no_trunc_date - else: - trunc_date = self._trunc_date + trunc_date = _no_trunc_date if self.is_unload else self._trunc_date self._df_iter = DataFrameIterator(df, trunc_date) else: import pandas as pd @@ -169,8 +165,7 @@ def _get_engine(self) -> "str": "Trying to import the above resulted in these errors:" f"{error_msgs}" ) - else: - return self._engine + return self._engine def __s3_file_system(self): from pyathena.filesystem.s3 import S3FileSystem @@ -320,7 +315,7 @@ def _read_parquet(self, engine) -> "DataFrame": } elif engine == "fastparquet": unload_location = f"{self._unload_location}*" - kwargs = dict() + kwargs = {} else: raise ProgrammingError("Engine must be one of `pyarrow`, `fastparquet`.") kwargs.update(self._kwargs) @@ -379,7 +374,7 @@ def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]: engine = self._get_engine() df = self._read_parquet(engine) if df.empty: - self._metadata = tuple() + self._metadata = () else: self._metadata = self._read_parquet_schema(engine) else: @@ -389,8 +384,7 @@ def _as_pandas(self) -> Union["TextFileReader", "DataFrame"]: def as_pandas(self) -> Union[DataFrameIterator, "DataFrame"]: if self._chunksize is None: return next(self._df_iter) - else: - return self._df_iter + return self._df_iter def close(self) -> None: import pandas as pd diff --git a/pyathena/pandas/util.py b/pyathena/pandas/util.py index 3c4be339..bfc6d2ae 100644 --- a/pyathena/pandas/util.py +++ b/pyathena/pandas/util.py @@ -60,7 +60,7 @@ def reset_index(df: "DataFrame", index_label: Optional[str] = None) -> None: try: df.reset_index(inplace=True) except ValueError as e: - raise ValueError(f"Duplicate name in index/columns: {e}") + raise ValueError("Duplicate name in index/columns") from e def as_pandas(cursor: "Cursor", coerce_float: bool = False) -> "DataFrame": @@ -79,27 +79,25 @@ def to_sql_type_mappings(col: "Series") -> str: col_type = pd.api.types.infer_dtype(col, skipna=True) if col_type == "datetime64" or col_type == "datetime": return "TIMESTAMP" - elif col_type == "timedelta": + if col_type == "timedelta": return "INT" - elif col_type == "timedelta64": + if col_type == "timedelta64": return "BIGINT" - elif col_type == "floating": + if col_type == "floating": if col.dtype == "float32": return "FLOAT" - else: - return "DOUBLE" - elif col_type == "integer": + return "DOUBLE" + if col_type == "integer": if col.dtype == "int32": return "INT" - else: - return "BIGINT" - elif col_type == "boolean": + return "BIGINT" + if col_type == "boolean": return "BOOLEAN" - elif col_type == "date": + if col_type == "date": return "DATE" - elif col_type == "bytes": + if col_type == "bytes": return "BINARY" - elif col_type in ["complex", "time"]: + if col_type in ["complex", "time"]: raise ValueError(f"Data type `{col_type}` is not supported") return "STRING" @@ -189,18 +187,17 @@ def to_sql( if if_exists == "fail": if table: raise OperationalError(f"Table `{schema}.{name}` already exists.") - elif if_exists == "replace": - if table: - cursor.execute( - textwrap.dedent( - f""" - DROP TABLE `{schema}`.`{name}` - """ - ) + elif if_exists == "replace" and table: + cursor.execute( + textwrap.dedent( + f""" + DROP TABLE `{schema}`.`{name}` + """ ) - objects = bucket.objects.filter(Prefix=key_prefix) - if list(objects.limit(1)): - objects.delete() + ) + objects = bucket.objects.filter(Prefix=key_prefix) + if list(objects.limit(1)): + objects.delete() if index: reset_index(df, index_label) diff --git a/pyathena/result_set.py b/pyathena/result_set.py index 4c9891aa..5c292262 100644 --- a/pyathena/result_set.py +++ b/pyathena/result_set.py @@ -330,11 +330,10 @@ def fetchone( self._fetch() if not self._rows: return None - else: - if self._rownumber is None: - self._rownumber = 0 - self._rownumber += 1 - return self._rows.popleft() + if self._rownumber is None: + self._rownumber = 0 + self._rownumber += 1 + return self._rows.popleft() def fetchmany( self, size: Optional[int] = None diff --git a/pyathena/spark/common.py b/pyathena/spark/common.py index c85be14b..0e890385 100644 --- a/pyathena/spark/common.py +++ b/pyathena/spark/common.py @@ -107,14 +107,13 @@ def _wait_for_idle_session(self, session_id: str): session_status = self._get_session_status(session_id) if session_status.state in [AthenaSessionStatus.STATE_IDLE]: break - elif session_status in [ + if session_status in [ AthenaSessionStatus.STATE_TERMINATED, AthenaSessionStatus.STATE_DEGRADED, AthenaSessionStatus.STATE_FAILED, ]: raise OperationalError(session_status.state_change_reason) - else: - time.sleep(self._poll_interval) + time.sleep(self._poll_interval) def _exists_session(self, session_id: str) -> bool: request = {"SessionId": session_id} @@ -132,8 +131,7 @@ def _exists_session(self, session_id: str) -> bool: ): _logger.exception(f"Session: {session_id} not found.") return False - else: - raise OperationalError(*e.args) from e + raise OperationalError(*e.args) from e else: self._wait_for_idle_session(session_id) return True @@ -185,8 +183,7 @@ def __poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculation AthenaCalculationExecutionStatus.STATE_CANCELED, ]: return self._get_calculation_execution(query_id) - else: - time.sleep(self._poll_interval) + time.sleep(self._poll_interval) def _poll(self, query_id: str) -> Union[AthenaQueryExecution, AthenaCalculationExecution]: try: diff --git a/pyathena/sqlalchemy/arrow.py b/pyathena/sqlalchemy/arrow.py index 8c4f6a39..67528954 100644 --- a/pyathena/sqlalchemy/arrow.py +++ b/pyathena/sqlalchemy/arrow.py @@ -12,7 +12,7 @@ def create_connect_args(self, url): opts = super()._create_connect_args(url) opts.update({"cursor_class": ArrowCursor}) - cursor_kwargs = dict() + cursor_kwargs = {} if "unload" in opts: cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))}) if cursor_kwargs: diff --git a/pyathena/sqlalchemy/base.py b/pyathena/sqlalchemy/base.py index 1fe9ea33..d594756f 100644 --- a/pyathena/sqlalchemy/base.py +++ b/pyathena/sqlalchemy/base.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import annotations +import contextlib import re from typing import ( TYPE_CHECKING, @@ -350,7 +351,7 @@ "WINDOW", "WITH", } -RESERVED_WORDS: Set[str] = set(sorted(DDL_RESERVED_WORDS | SELECT_STATEMENT_RESERVED_WORDS)) +RESERVED_WORDS: Set[str] = set(DDL_RESERVED_WORDS | SELECT_STATEMENT_RESERVED_WORDS) ischema_names: Dict[str, Type[Any]] = { "boolean": types.BOOLEAN, @@ -441,17 +442,19 @@ def get_from_hint_text(self, table, text): def format_from_hint_text(self, sqltext, table, hint, iscrud): hint_upper = hint.upper() - if any( - [ - hint_upper.startswith("FOR TIMESTAMP AS OF"), - hint_upper.startswith("FOR SYSTEM_TIME AS OF"), - hint_upper.startswith("FOR VERSION AS OF"), - hint_upper.startswith("FOR SYSTEM_VERSION AS OF"), - ] + if ( + any( + [ + hint_upper.startswith("FOR TIMESTAMP AS OF"), + hint_upper.startswith("FOR SYSTEM_TIME AS OF"), + hint_upper.startswith("FOR VERSION AS OF"), + hint_upper.startswith("FOR SYSTEM_VERSION AS OF"), + ] + ) + and "AS" in sqltext ): - if "AS" in sqltext: - _, alias = sqltext.split(" AS ", 1) - return f"{table.original.fullname} {hint} AS {alias}" + _, alias = sqltext.split(" AS ", 1) + return f"{table.original.fullname} {hint} AS {alias}" return f"{sqltext} {hint}" @@ -475,10 +478,9 @@ def visit_NUMERIC(self, type_: Type[Any], **kw) -> str: # noqa: N802 def visit_DECIMAL(self, type_: Type[Any], **kw) -> str: # noqa: N802 if type_.precision is None: return "DECIMAL" - elif type_.scale is None: + if type_.scale is None: return f"DECIMAL({type_.precision})" - else: - return f"DECIMAL({type_.precision}, {type_.scale})" + return f"DECIMAL({type_.precision}, {type_.scale})" def visit_TINYINT(self, type_: Type[Any], **kw) -> str: # noqa: N802 return "TINYINT" @@ -703,11 +705,10 @@ def _get_table_location_specification( "`location` or `s3_staging_dir` parameter is required " "in the connection string" ) - else: - raise exc.CompileError( - "The location of the table should be specified " - "by the dialect keyword argument `awsathena_location`" - ) + raise exc.CompileError( + "The location of the table should be specified " + "by the dialect keyword argument `awsathena_location`" + ) return "\n".join(text) def _get_table_properties( @@ -898,11 +899,8 @@ def visit_create_table(self, create: "CreateTable", **kwargs) -> str: if ("table_type" in table_properties) and ("iceberg" in table_properties): is_iceberg = True - if is_iceberg: - # https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html - text = ["\nCREATE TABLE"] - else: - text = ["\nCREATE EXTERNAL TABLE"] + # https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg-creating-tables.html + text = ["\nCREATE TABLE"] if is_iceberg else ["\nCREATE EXTERNAL TABLE"] if create.if_not_exists: text.append("IF NOT EXISTS") @@ -1003,7 +1001,7 @@ class AthenaDialect(DefaultDialect): ischema_names: Dict[str, Type[Any]] = ischema_names - _connect_options: Dict[str, Any] = dict() # type: ignore + _connect_options: Dict[str, Any] = {} # type: ignore _pattern_column_type: Pattern[str] = re.compile(r"^([a-zA-Z]+)(?:$|[\(|<](.+)[\)|>]$)") def __init__(self, json_deserializer=None, json_serializer=None, **kwargs): @@ -1030,7 +1028,7 @@ def create_connect_args(self, url: "URL") -> Tuple[Tuple[str], MutableMapping[st # {aws_access_key_id}:{aws_secret_access_key}@athena.{region_name}.amazonaws.com:443/ # {schema_name}?s3_staging_dir={s3_staging_dir}&... self._connect_options = self._create_connect_args(url) - return cast(Tuple[str], tuple()), self._connect_options + return cast(Tuple[str], ()), self._connect_options def _create_connect_args(self, url: "URL") -> Dict[str, Any]: opts: Dict[str, Any] = { @@ -1046,11 +1044,9 @@ def _create_connect_args(self, url: "URL") -> Dict[str, Any]: opts.update(url.query) if "verify" in opts: verify = opts["verify"] - try: + # If a ValueError occurs, it is probably the file name of the CA certificate being used. + with contextlib.suppress(ValueError): verify = bool(strtobool(verify)) - except ValueError: - # Probably a file name of the CA cert bundle to use - pass opts.update({"verify": verify}) if "duration_seconds" in opts: opts.update({"duration_seconds": int(opts["duration_seconds"])}) @@ -1150,7 +1146,7 @@ def has_table( ): try: columns = self.get_columns(connection, table_name, schema) - return True if columns else False + return bool(columns) except exc.NoSuchTableError: return False @@ -1166,7 +1162,7 @@ def get_view_definition( except exc.OperationalError as e: raise exc.NoSuchTableError(f"{schema}.{view_name}") from e else: - return "\n".join([r for r in res]) + return "\n".join(res) @reflection.cache def get_columns( diff --git a/pyathena/sqlalchemy/pandas.py b/pyathena/sqlalchemy/pandas.py index 5a6dbd0f..a7bc88fa 100644 --- a/pyathena/sqlalchemy/pandas.py +++ b/pyathena/sqlalchemy/pandas.py @@ -12,7 +12,7 @@ def create_connect_args(self, url): opts = super()._create_connect_args(url) opts.update({"cursor_class": PandasCursor}) - cursor_kwargs = dict() + cursor_kwargs = {} if "unload" in opts: cursor_kwargs.update({"unload": bool(strtobool(opts.pop("unload")))}) if "engine" in opts: diff --git a/pyathena/util.py b/pyathena/util.py index acaddd63..2015a744 100644 --- a/pyathena/util.py +++ b/pyathena/util.py @@ -21,8 +21,7 @@ def parse_output_location(output_location: str) -> Tuple[str, str]: match = PATTERN_OUTPUT_LOCATION.search(output_location) if match: return match.group("bucket"), match.group("key") - else: - raise DataError("Unknown `output_location` format.") + raise DataError("Unknown `output_location` format.") def strtobool(val): @@ -34,10 +33,9 @@ def strtobool(val): val = val.lower() if val in ("y", "yes", "t", "true", "on", "1"): return 1 - elif val in ("n", "no", "f", "false", "off", "0"): + if val in ("n", "no", "f", "false", "off", "0"): return 0 - else: - raise ValueError(f"invalid truth value {val!r}") + raise ValueError(f"invalid truth value {val!r}") class RetryConfig: diff --git a/pyproject.toml b/pyproject.toml index 538feb38..fa963608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,8 @@ path = "pyathena/__init__.py" norecursedirs = [ "benchmarks", ".venv", - ".tox" + ".tox", + "docs.*", ] [tool.sqla_testing] @@ -100,6 +101,7 @@ line-length = 100 exclude = [ ".venv", ".tox", + "docs", ] target-version = "py39" @@ -111,9 +113,10 @@ select = [ "F", # pyflakes "I", # isort "N", # pep8-naming - # "SIM", # flake8-simplify - # "B", # flake8-bugbear - # "C4", # flake8-comprehensions + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "SIM", # flake8-simplify + "RET", # flake8-return # "UP", # pyupgrade ] @@ -135,6 +138,7 @@ exclude = [ "tests.*", ".venv.*", ".tox.*", + "docs.*", ] [tool.tox] diff --git a/tests/pyathena/arrow/test_async_cursor.py b/tests/pyathena/arrow/test_async_cursor.py index 5b89cee1..3d5ce82d 100644 --- a/tests/pyathena/arrow/test_async_cursor.py +++ b/tests/pyathena/arrow/test_async_cursor.py @@ -226,7 +226,7 @@ def test_as_arrow(self, async_arrow_cursor): table = future.result().as_arrow() assert table.shape[0] == 1 assert table.shape[1] == 1 - assert [row for row in zip(*table.to_pydict().values())] == [(1,)] + assert list(zip(*table.to_pydict().values())) == [(1,)] @pytest.mark.parametrize( "async_arrow_cursor", @@ -238,7 +238,7 @@ def test_many_as_arrow(self, async_arrow_cursor): table = future.result().as_arrow() assert table.shape[0] == 10000 assert table.shape[1] == 1 - assert [row for row in zip(*table.to_pydict().values())] == [(i,) for i in range(10000)] + assert list(zip(*table.to_pydict().values())) == [(i,) for i in range(10000)] def test_cancel(self, async_arrow_cursor): query_id, future = async_arrow_cursor.execute( @@ -258,9 +258,8 @@ def test_cancel(self, async_arrow_cursor): assert result_set.fetchall() == [] def test_open_close(self): - with contextlib.closing(connect()) as conn: - with conn.cursor(AsyncArrowCursor): - pass + with contextlib.closing(connect()) as conn, conn.cursor(AsyncArrowCursor): + pass def test_no_ops(self): conn = connect() diff --git a/tests/pyathena/arrow/test_cursor.py b/tests/pyathena/arrow/test_cursor.py index 84af881f..b780676f 100644 --- a/tests/pyathena/arrow/test_cursor.py +++ b/tests/pyathena/arrow/test_cursor.py @@ -259,7 +259,7 @@ def test_as_arrow(self, arrow_cursor): table = arrow_cursor.execute("SELECT * FROM one_row").as_arrow() assert table.shape[0] == 1 assert table.shape[1] == 1 - assert [row for row in zip(*table.to_pydict().values())] == [(1,)] + assert list(zip(*table.to_pydict().values())) == [(1,)] @pytest.mark.parametrize( "arrow_cursor", @@ -270,7 +270,7 @@ def test_many_as_arrow(self, arrow_cursor): table = arrow_cursor.execute("SELECT * FROM many_rows").as_arrow() assert table.shape[0] == 10000 assert table.shape[1] == 1 - assert [row for row in zip(*table.to_pydict().values())] == [(i,) for i in range(10000)] + assert list(zip(*table.to_pydict().values())) == [(i,) for i in range(10000)] def test_complex_as_arrow(self, arrow_cursor): table = arrow_cursor.execute( @@ -323,7 +323,7 @@ def test_complex_as_arrow(self, arrow_cursor): pa.field("col_decimal", pa.string()), ] ) - assert [row for row in zip(*table.to_pydict().values())] == [ + assert list(zip(*table.to_pydict().values())) == [ ( True, 127, @@ -406,7 +406,7 @@ def test_complex_unload_as_arrow(self, arrow_cursor): pa.field("col_decimal", pa.decimal128(10, 1)), ] ) - assert [row for row in zip(*table.to_pydict().values())] == [ + assert list(zip(*table.to_pydict().values())) == [ ( True, 127, @@ -450,9 +450,8 @@ def test_cancel_initial(self, arrow_cursor): pytest.raises(ProgrammingError, arrow_cursor.cancel) def test_open_close(self): - with contextlib.closing(connect()) as conn: - with conn.cursor(ArrowCursor): - pass + with contextlib.closing(connect()) as conn, conn.cursor(ArrowCursor): + pass def test_no_ops(self): conn = connect() @@ -522,7 +521,7 @@ def test_executemany(self, arrow_cursor): [{"a": a, "b": b} for a, b in rows], ) arrow_cursor.execute(f"SELECT * FROM {table_name}") - assert sorted(arrow_cursor.fetchall()) == [(a, b) for a, b in rows] + assert sorted(arrow_cursor.fetchall()) == list(rows) @pytest.mark.parametrize( "arrow_cursor", diff --git a/tests/pyathena/conftest.py b/tests/pyathena/conftest.py index 42358aa0..da2961ad 100644 --- a/tests/pyathena/conftest.py +++ b/tests/pyathena/conftest.py @@ -13,16 +13,14 @@ def pytest_sessionstart(session): _upload_rows() - with contextlib.closing(connect()) as conn: - with conn.cursor() as cursor: - _create_database(cursor) - _create_table(cursor) + with contextlib.closing(connect()) as conn, conn.cursor() as cursor: + _create_database(cursor) + _create_table(cursor) def pytest_sessionfinish(session): - with contextlib.closing(connect()) as conn: - with conn.cursor() as cursor: - _drop_database(cursor) + with contextlib.closing(connect()) as conn, conn.cursor() as cursor: + _drop_database(cursor) _delete_rows() @@ -106,12 +104,14 @@ def create_engine(**kwargs): def _cursor(cursor_class, request): if not hasattr(request, "param"): - setattr(request, "param", {}) - with contextlib.closing( - connect(schema_name=ENV.schema, cursor_class=cursor_class, **request.param) - ) as conn: - with conn.cursor() as cursor: - yield cursor + setattr(request, "param", {}) # noqa: B010 + with ( + contextlib.closing( + connect(schema_name=ENV.schema, cursor_class=cursor_class, **request.param) + ) as conn, + conn.cursor() as cursor, + ): + yield cursor @pytest.fixture @@ -175,7 +175,7 @@ def spark_cursor(request): from pyathena.spark.cursor import SparkCursor if not hasattr(request, "param"): - setattr(request, "param", {}) + setattr(request, "param", {}) # noqa: B010 request.param.update({"work_group": ENV.spark_work_group}) yield from _cursor(SparkCursor, request) @@ -185,7 +185,7 @@ def async_spark_cursor(request): from pyathena.spark.async_cursor import AsyncSparkCursor if not hasattr(request, "param"): - setattr(request, "param", {}) + setattr(request, "param", {}) # noqa: B010 request.param.update({"work_group": ENV.spark_work_group}) yield from _cursor(AsyncSparkCursor, request) @@ -193,7 +193,7 @@ def async_spark_cursor(request): @pytest.fixture def engine(request): if not hasattr(request, "param"): - setattr(request, "param", {}) + setattr(request, "param", {}) # noqa: B010 engine_ = create_engine(**request.param) try: with contextlib.closing(engine_.connect()) as conn: diff --git a/tests/pyathena/filesystem/test_s3.py b/tests/pyathena/filesystem/test_s3.py index 4d2dab5a..5f9704b0 100644 --- a/tests/pyathena/filesystem/test_s3.py +++ b/tests/pyathena/filesystem/test_s3.py @@ -128,7 +128,7 @@ def test_parse_path_invalid(self): @pytest.fixture(scope="class") def fs(self, request): if not hasattr(request, "param"): - setattr(request, "param", {}) + setattr(request, "param", {}) # noqa: B010 return S3FileSystem(connect(), **request.param) @pytest.mark.parametrize( @@ -757,7 +757,7 @@ def test_pandas_write_csv(self, line_count): with tempfile.NamedTemporaryFile("w+t") as tmp: tmp.write("col1") tmp.write("\n") - for i in range(0, line_count): + for _ in range(0, line_count): tmp.write("a") tmp.write("\n") tmp.flush() diff --git a/tests/pyathena/pandas/test_async_cursor.py b/tests/pyathena/pandas/test_async_cursor.py index 347edf24..5c226669 100644 --- a/tests/pyathena/pandas/test_async_cursor.py +++ b/tests/pyathena/pandas/test_async_cursor.py @@ -393,9 +393,8 @@ def test_cancel(self, async_pandas_cursor): assert result_set.fetchall() == [] def test_open_close(self): - with contextlib.closing(connect()) as conn: - with conn.cursor(AsyncPandasCursor): - pass + with contextlib.closing(connect()) as conn, conn.cursor(AsyncPandasCursor): + pass def test_no_ops(self): conn = connect() @@ -497,18 +496,16 @@ def test_integer_na_values(self, async_pandas_cursor, parquet_engine): df = future.result().as_pandas() if async_pandas_cursor._unload: rows = [ - tuple( - [ - True if math.isnan(row["a"]) else row["a"], - True if math.isnan(row["b"]) else row["b"], - ] + ( + True if math.isnan(row["a"]) else row["a"], + True if math.isnan(row["b"]) else row["b"], ) for _, row in df.iterrows() ] # If the UNLOAD option is enabled, it is converted to float for some reason. assert rows == [(1.0, 2.0), (1.0, True), (True, True)] else: - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] assert rows == [(1, 2), (1, pd.NA), (pd.NA, pd.NA)] @pytest.mark.parametrize( @@ -529,7 +526,7 @@ def test_float_na_values(self, async_pandas_cursor, parquet_engine): engine=parquet_engine, ) df = future.result().as_pandas() - rows = [tuple([row["col"]]) for _, row in df.iterrows()] + rows = [(row["col"],) for _, row in df.iterrows()] np.testing.assert_equal(rows, [(0.33,), (np.nan,)]) @pytest.mark.parametrize( @@ -552,17 +549,15 @@ def test_boolean_na_values(self, async_pandas_cursor, parquet_engine): df = future.result().as_pandas() if parquet_engine == "fastparquet": rows = [ - tuple( - [ - True if math.isnan(row["a"]) else row["a"], - True if math.isnan(row["b"]) else row["b"], - ] + ( + True if math.isnan(row["a"]) else row["a"], + True if math.isnan(row["b"]) else row["b"], ) for _, row in df.iterrows() ] assert rows == [(1.0, 0.0), (0.0, True), (True, True)] else: - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] assert rows == [(True, False), (False, None), (None, None)] @pytest.mark.parametrize( @@ -649,14 +644,7 @@ def test_null_decimal_value(self, async_pandas_cursor, parquet_engine): ) result_set = future.result() if parquet_engine == "fastparquet": - rows = [ - tuple( - [ - True if math.isnan(row[0]) else row[0], - ] - ) - for row in result_set.fetchall() - ] + rows = [(True if math.isnan(row[0]) else row[0],) for row in result_set.fetchall()] assert rows == [(True,)] else: assert result_set.fetchall() == [(None,)] diff --git a/tests/pyathena/pandas/test_cursor.py b/tests/pyathena/pandas/test_cursor.py index 4944f272..19f4d557 100644 --- a/tests/pyathena/pandas/test_cursor.py +++ b/tests/pyathena/pandas/test_cursor.py @@ -295,25 +295,23 @@ def test_complex_unload_pyarrow(self, pandas_cursor, parquet_engine): ("col_decimal", "decimal", None, None, 10, 1, "NULLABLE"), ] rows = [ - tuple( - [ - row[0], - row[1], - row[2], - row[3], - row[4], - row[5], - row[6], - row[7], - row[8], - row[9], - row[10], - row[11], - [a for a in row[12]], - row[13], - row[14], - row[15], - ] + ( + row[0], + row[1], + row[2], + row[3], + row[4], + row[5], + row[6], + row[7], + row[8], + row[9], + row[10], + row[11], + list(row[12]), + row[13], + row[14], + row[15], ) for row in pandas_cursor.fetchall() ] @@ -333,7 +331,7 @@ def test_complex_unload_pyarrow(self, pandas_cursor, parquet_engine): b"123", # ValueError: The truth value of an array with more than one element is ambiguous. # Use a.any() or a.all() - [a for a in np.array([1, 2], dtype=np.int32)], + list(np.array([1, 2], dtype=np.int32)), [(1, 2), (3, 4)], {"a": 1, "b": 2}, Decimal("0.1"), @@ -412,26 +410,24 @@ def test_complex_unload_fastparquet(self, pandas_cursor): ("col_struct.b", "integer", None, None, 10, 0, "NULLABLE"), ] rows = [ - tuple( - [ - row[0], - row[1], - row[2], - row[3], - row[4], - row[5], - row[6], - row[7], - row[8], - row[9], - row[10], - row[11], - [a for a in row[12]], - row[13], - row[14], - row[15], - row[16], - ] + ( + row[0], + row[1], + row[2], + row[3], + row[4], + row[5], + row[6], + row[7], + row[8], + row[9], + row[10], + row[11], + list(row[12]), + row[13], + row[14], + row[15], + row[16], ) for row in pandas_cursor.fetchall() ] @@ -451,7 +447,7 @@ def test_complex_unload_fastparquet(self, pandas_cursor): b"123", # ValueError: The truth value of an array with more than one element is ambiguous. # Use a.any() or a.all() - [a for a in np.array([1, 2], dtype=np.int32)], + list(np.array([1, 2], dtype=np.int32)), {1: 2, 3: 4}, # In the case of fastparquet, decimal types are handled as floats. 0.1, @@ -550,75 +546,69 @@ def test_complex_as_pandas(self, pandas_cursor, chunksize): df = pd.concat((d for d in df), ignore_index=True) assert df.shape[0] == 1 assert df.shape[1] == 19 - dtypes = tuple( - [ - df["col_boolean"].dtype.type, - df["col_tinyint"].dtype.type, - df["col_smallint"].dtype.type, - df["col_int"].dtype.type, - df["col_bigint"].dtype.type, - df["col_float"].dtype.type, - df["col_double"].dtype.type, - df["col_string"].dtype.type, - df["col_varchar"].dtype.type, - df["col_timestamp"].dtype.type, - df["col_time"].dtype.type, - df["col_date"].dtype.type, - df["col_binary"].dtype.type, - df["col_array"].dtype.type, - df["col_array_json"].dtype.type, - df["col_map"].dtype.type, - df["col_map_json"].dtype.type, - df["col_struct"].dtype.type, - df["col_decimal"].dtype.type, - ] + dtypes = ( + df["col_boolean"].dtype.type, + df["col_tinyint"].dtype.type, + df["col_smallint"].dtype.type, + df["col_int"].dtype.type, + df["col_bigint"].dtype.type, + df["col_float"].dtype.type, + df["col_double"].dtype.type, + df["col_string"].dtype.type, + df["col_varchar"].dtype.type, + df["col_timestamp"].dtype.type, + df["col_time"].dtype.type, + df["col_date"].dtype.type, + df["col_binary"].dtype.type, + df["col_array"].dtype.type, + df["col_array_json"].dtype.type, + df["col_map"].dtype.type, + df["col_map_json"].dtype.type, + df["col_struct"].dtype.type, + df["col_decimal"].dtype.type, ) - assert dtypes == tuple( - [ - np.bool_, - np.int64, - np.int64, - np.int64, - np.int64, - np.float64, - np.float64, - np.object_, - np.object_, - np.datetime64, - np.object_, - np.datetime64, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - ] + assert dtypes == ( + np.bool_, + np.int64, + np.int64, + np.int64, + np.int64, + np.float64, + np.float64, + np.object_, + np.object_, + np.datetime64, + np.object_, + np.datetime64, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, ) rows = [ - tuple( - [ - row["col_boolean"], - row["col_tinyint"], - row["col_smallint"], - row["col_int"], - row["col_bigint"], - row["col_float"], - row["col_double"], - row["col_string"], - row["col_varchar"], - row["col_timestamp"], - row["col_time"], - row["col_date"], - row["col_binary"], - row["col_array"], - row["col_array_json"], - row["col_map"], - row["col_map_json"], - row["col_struct"], - row["col_decimal"], - ] + ( + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_varchar"], + row["col_timestamp"], + row["col_time"], + row["col_date"], + row["col_binary"], + row["col_array"], + row["col_array_json"], + row["col_map"], + row["col_map_json"], + row["col_struct"], + row["col_decimal"], ) for _, row in df.iterrows() ] @@ -681,66 +671,60 @@ def test_complex_unload_as_pandas_pyarrow(self, pandas_cursor, parquet_engine): ).as_pandas() assert df.shape[0] == 1 assert df.shape[1] == 16 - dtypes = tuple( - [ - df["col_boolean"].dtype.type, - df["col_tinyint"].dtype.type, - df["col_smallint"].dtype.type, - df["col_int"].dtype.type, - df["col_bigint"].dtype.type, - df["col_float"].dtype.type, - df["col_double"].dtype.type, - df["col_string"].dtype.type, - df["col_varchar"].dtype.type, - df["col_timestamp"].dtype.type, - df["col_date"].dtype.type, - df["col_binary"].dtype.type, - df["col_array"].dtype.type, - df["col_map"].dtype.type, - df["col_struct"].dtype.type, - df["col_decimal"].dtype.type, - ] + dtypes = ( + df["col_boolean"].dtype.type, + df["col_tinyint"].dtype.type, + df["col_smallint"].dtype.type, + df["col_int"].dtype.type, + df["col_bigint"].dtype.type, + df["col_float"].dtype.type, + df["col_double"].dtype.type, + df["col_string"].dtype.type, + df["col_varchar"].dtype.type, + df["col_timestamp"].dtype.type, + df["col_date"].dtype.type, + df["col_binary"].dtype.type, + df["col_array"].dtype.type, + df["col_map"].dtype.type, + df["col_struct"].dtype.type, + df["col_decimal"].dtype.type, ) - assert dtypes == tuple( - [ - np.bool_, - np.int8, - np.int16, - np.int32, - np.int64, - np.float32, - np.float64, - np.object_, - np.object_, - np.datetime64, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - np.object_, - ] + assert dtypes == ( + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + np.object_, + np.object_, + np.datetime64, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, + np.object_, ) rows = [ - tuple( - [ - row["col_boolean"], - row["col_tinyint"], - row["col_smallint"], - row["col_int"], - row["col_bigint"], - row["col_float"], - row["col_double"], - row["col_string"], - row["col_varchar"], - row["col_timestamp"], - row["col_date"], - row["col_binary"], - [a for a in row["col_array"]], - row["col_map"], - row["col_struct"], - row["col_decimal"], - ] + ( + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_varchar"], + row["col_timestamp"], + row["col_date"], + row["col_binary"], + list(row["col_array"]), + row["col_map"], + row["col_struct"], + row["col_decimal"], ) for _, row in df.iterrows() ] @@ -760,7 +744,7 @@ def test_complex_unload_as_pandas_pyarrow(self, pandas_cursor, parquet_engine): b"123", # ValueError: The truth value of an array with more than one element is ambiguous. # Use a.any() or a.all() - [a for a in np.array([1, 2], dtype=np.int32)], + list(np.array([1, 2], dtype=np.int32)), [(1, 2), (3, 4)], {"a": 1, "b": 2}, Decimal("0.1"), @@ -803,71 +787,65 @@ def test_complex_unload_as_pandas_fastparquet(self, pandas_cursor): ).as_pandas() assert df.shape[0] == 1 assert df.shape[1] == 17 - dtypes = tuple( - [ - df["col_boolean"].dtype.type, - df["col_tinyint"].dtype.type, - df["col_smallint"].dtype.type, - df["col_int"].dtype.type, - df["col_bigint"].dtype.type, - df["col_float"].dtype.type, - df["col_double"].dtype.type, - df["col_string"].dtype.type, - df["col_varchar"].dtype.type, - df["col_timestamp"].dtype.type, - df["col_date"].dtype.type, - df["col_binary"].dtype.type, - df["col_array"].dtype.type, - df["col_map"].dtype.type, - df["col_decimal"].dtype.type, - # In the case of fastparquet, child elements of struct types are handled - # as fields separated by dots. - df["col_struct.a"].dtype.type, - df["col_struct.b"].dtype.type, - ] + dtypes = ( + df["col_boolean"].dtype.type, + df["col_tinyint"].dtype.type, + df["col_smallint"].dtype.type, + df["col_int"].dtype.type, + df["col_bigint"].dtype.type, + df["col_float"].dtype.type, + df["col_double"].dtype.type, + df["col_string"].dtype.type, + df["col_varchar"].dtype.type, + df["col_timestamp"].dtype.type, + df["col_date"].dtype.type, + df["col_binary"].dtype.type, + df["col_array"].dtype.type, + df["col_map"].dtype.type, + df["col_decimal"].dtype.type, + # In the case of fastparquet, child elements of struct types are handled + # as fields separated by dots. + df["col_struct.a"].dtype.type, + df["col_struct.b"].dtype.type, ) - assert dtypes == tuple( - [ - np.bool_, - np.int8, - np.int16, - np.int32, - np.int64, - np.float32, - np.float64, - np.object_, - np.object_, - np.datetime64, - np.datetime64, - np.object_, - np.object_, - np.object_, - np.float64, - np.int32, - np.int32, - ] + assert dtypes == ( + np.bool_, + np.int8, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + np.object_, + np.object_, + np.datetime64, + np.datetime64, + np.object_, + np.object_, + np.object_, + np.float64, + np.int32, + np.int32, ) rows = [ - tuple( - [ - row["col_boolean"], - row["col_tinyint"], - row["col_smallint"], - row["col_int"], - row["col_bigint"], - row["col_float"], - row["col_double"], - row["col_string"], - row["col_varchar"], - row["col_timestamp"], - row["col_date"], - row["col_binary"], - [a for a in row["col_array"]], - row["col_map"], - row["col_decimal"], - row["col_struct.a"], - row["col_struct.b"], - ] + ( + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_varchar"], + row["col_timestamp"], + row["col_date"], + row["col_binary"], + list(row["col_array"]), + row["col_map"], + row["col_decimal"], + row["col_struct.a"], + row["col_struct.b"], ) for _, row in df.iterrows() ] @@ -887,7 +865,7 @@ def test_complex_unload_as_pandas_fastparquet(self, pandas_cursor): b"123", # ValueError: The truth value of an array with more than one element is ambiguous. # Use a.any() or a.all() - [a for a in np.array([1, 2], dtype=np.int32)], + list(np.array([1, 2], dtype=np.int32)), {1: 2, 3: 4}, # In the case of fastparquet, decimal types are handled as floats. 0.1, @@ -919,9 +897,8 @@ def test_cancel_initial(self, pandas_cursor): pytest.raises(ProgrammingError, pandas_cursor.cancel) def test_open_close(self): - with contextlib.closing(connect()) as conn: - with conn.cursor(PandasCursor): - pass + with contextlib.closing(connect()) as conn, conn.cursor(PandasCursor): + pass def test_no_ops(self): conn = connect() @@ -1016,18 +993,16 @@ def test_integer_na_values(self, pandas_cursor, parquet_engine): ).as_pandas() if pandas_cursor._unload: rows = [ - tuple( - [ - True if math.isnan(row["a"]) else row["a"], - True if math.isnan(row["b"]) else row["b"], - ] + ( + True if math.isnan(row["a"]) else row["a"], + True if math.isnan(row["b"]) else row["b"], ) for _, row in df.iterrows() ] # If the UNLOAD option is enabled, it is converted to float for some reason. assert rows == [(1.0, 2.0), (1.0, True), (True, True)] else: - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] assert rows == [(1, 2), (1, pd.NA), (pd.NA, pd.NA)] @pytest.mark.parametrize( @@ -1047,7 +1022,7 @@ def test_float_na_values(self, pandas_cursor, parquet_engine): """, engine=parquet_engine, ).as_pandas() - rows = [tuple([row["col"]]) for _, row in df.iterrows()] + rows = [(row["col"],) for _, row in df.iterrows()] np.testing.assert_equal(rows, [(0.33,), (np.nan,)]) @pytest.mark.parametrize( @@ -1069,17 +1044,15 @@ def test_boolean_na_values(self, pandas_cursor, parquet_engine): ).as_pandas() if parquet_engine == "fastparquet": rows = [ - tuple( - [ - True if math.isnan(row["a"]) else row["a"], - True if math.isnan(row["b"]) else row["b"], - ] + ( + True if math.isnan(row["a"]) else row["a"], + True if math.isnan(row["b"]) else row["b"], ) for _, row in df.iterrows() ] assert rows == [(1.0, 0.0), (0.0, True), (True, True)] else: - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] assert rows == [(True, False), (False, None), (None, None)] @pytest.mark.parametrize( @@ -1102,7 +1075,7 @@ def test_executemany(self, pandas_cursor, parquet_engine): [{"a": a, "b": b} for a, b in rows], ) pandas_cursor.execute(f"SELECT * FROM {table_name}", engine=parquet_engine) - assert sorted(pandas_cursor.fetchall()) == [(a, b) for a, b in rows] + assert sorted(pandas_cursor.fetchall()) == list(rows) @pytest.mark.parametrize( "pandas_cursor, parquet_engine", @@ -1199,14 +1172,7 @@ def test_empty_and_null_string(self, pandas_cursor, parquet_engine): def test_null_decimal_value(self, pandas_cursor, parquet_engine): pandas_cursor.execute("SELECT CAST(null AS DECIMAL) AS col_decimal", engine=parquet_engine) if parquet_engine == "fastparquet": - rows = [ - tuple( - [ - True if math.isnan(row[0]) else row[0], - ] - ) - for row in pandas_cursor.fetchall() - ] + rows = [(True if math.isnan(row[0]) else row[0],) for row in pandas_cursor.fetchall()] assert rows == [(True,)] else: assert pandas_cursor.fetchall() == [(None,)] diff --git a/tests/pyathena/pandas/test_util.py b/tests/pyathena/pandas/test_util.py index ed4af03f..c71b4c84 100644 --- a/tests/pyathena/pandas/test_util.py +++ b/tests/pyathena/pandas/test_util.py @@ -79,27 +79,25 @@ def test_as_pandas(cursor): ) df = as_pandas(cursor) rows = [ - tuple( - [ - row["col_boolean"], - row["col_tinyint"], - row["col_smallint"], - row["col_int"], - row["col_bigint"], - row["col_float"], - row["col_double"], - row["col_string"], - row["col_timestamp"], - row["col_time"], - row["col_date"], - row["col_binary"], - row["col_array"], - row["col_array_json"], - row["col_map"], - row["col_map_json"], - row["col_struct"], - row["col_decimal"], - ] + ( + row["col_boolean"], + row["col_tinyint"], + row["col_smallint"], + row["col_int"], + row["col_bigint"], + row["col_float"], + row["col_double"], + row["col_string"], + row["col_timestamp"], + row["col_time"], + row["col_date"], + row["col_binary"], + row["col_array"], + row["col_array_json"], + row["col_map"], + row["col_map_json"], + row["col_struct"], + row["col_decimal"], ) for _, row in df.iterrows() ] @@ -135,7 +133,7 @@ def test_as_pandas_integer_na_values(cursor): """ ) df = as_pandas(cursor, coerce_float=True) - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] # TODO AssertionError: Lists differ: # [(1.0, 2.0), (1.0, nan), (nan, nan)] != [(1.0, 2.0), (1.0, nan), (nan, nan)] # assert rows == [ @@ -153,7 +151,7 @@ def test_as_pandas_boolean_na_values(cursor): """ ) df = as_pandas(cursor) - rows = [tuple([row["a"], row["b"]]) for _, row in df.iterrows()] + rows = [(row["a"], row["b"]) for _, row in df.iterrows()] assert rows == [(True, False), (False, None), (None, None)] @@ -467,7 +465,7 @@ def test_to_sql_with_index(cursor): def test_to_sql_with_partitions(cursor): df = pd.DataFrame( { - "col_int": np.int32([i for i in range(10)]), + "col_int": np.int32(range(10)), "col_bigint": np.int64([12345 for _ in range(10)]), "col_string": ["a" for _ in range(10)], } @@ -493,7 +491,7 @@ def test_to_sql_with_partitions(cursor): def test_to_sql_with_multiple_partitions(cursor): df = pd.DataFrame( { - "col_int": np.int32([i for i in range(10)]), + "col_int": np.int32(range(10)), "col_bigint": np.int64([12345 for _ in range(10)]), "col_string": ["a" for _ in range(5)] + ["b" for _ in range(5)], } diff --git a/tests/pyathena/test_async_cursor.py b/tests/pyathena/test_async_cursor.py index 59b12dba..9433ecf0 100644 --- a/tests/pyathena/test_async_cursor.py +++ b/tests/pyathena/test_async_cursor.py @@ -199,9 +199,8 @@ def test_cancel(self, async_cursor): assert result_set.fetchall() == [] def test_open_close(self): - with contextlib.closing(connect()) as conn: - with conn.cursor(AsyncCursor): - pass + with contextlib.closing(connect()) as conn, conn.cursor(AsyncCursor): + pass def test_no_ops(self): conn = connect() diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index a3de67a3..f101c631 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -230,10 +230,8 @@ def test_description_initial(self, cursor): assert cursor.description is None def test_description_failed(self, cursor): - try: + with contextlib.suppress(DatabaseError): cursor.execute("blah_blah") - except DatabaseError: - pass assert cursor.description is None def test_bad_query(self, cursor): @@ -387,9 +385,8 @@ def test_invalid_params(self, cursor): def test_open_close(self): with contextlib.closing(connect()): pass - with contextlib.closing(connect()) as conn: - with conn.cursor(): - pass + with contextlib.closing(connect()) as conn, conn.cursor(): + pass def test_unicode(self, cursor): unicode_str = "王兢" @@ -572,10 +569,12 @@ def test_cancel_initial(self, cursor): def test_multiple_connection(self): def execute_other_thread(): - with contextlib.closing(connect(schema_name=ENV.schema)) as conn: - with conn.cursor() as cursor: - cursor.execute("SELECT * FROM one_row") - return cursor.fetchall() + with ( + contextlib.closing(connect(schema_name=ENV.schema)) as conn, + conn.cursor() as cursor, + ): + cursor.execute("SELECT * FROM one_row") + return cursor.fetchall() with ThreadPoolExecutor(max_workers=2) as executor: fs = [executor.submit(execute_other_thread) for _ in range(2)] @@ -626,7 +625,7 @@ def test_executemany(self, cursor): # rowcount is not supported for executemany assert cursor.rowcount == -1 cursor.execute("SELECT * FROM execute_many") - assert sorted(cursor.fetchall()) == [(a, b) for a, b in rows] + assert sorted(cursor.fetchall()) == list(rows) def test_executemany_fetch(self, cursor): cursor.executemany("SELECT %(x)d FROM one_row", [{"x": i} for i in range(1, 2)]) diff --git a/tests/pyathena/test_model.py b/tests/pyathena/test_model.py index fb50f4cc..e6a20559 100644 --- a/tests/pyathena/test_model.py +++ b/tests/pyathena/test_model.py @@ -332,7 +332,7 @@ def test_init_json(self): "serialization.format": "1", "write.compression": "GZIP", } - for key in actual.table_properties.keys(): + for key in actual.table_properties: assert not key.startswith("serde.param.") assert actual.row_format == "SERDE 'org.openx.data.jsonserde.JsonSerDe'" assert ( @@ -353,7 +353,7 @@ def test_init_json_hcatalog(self): assert actual.serde_serialization_lib == "org.apache.hive.hcatalog.data.JsonSerDe" assert actual.compression == "SNAPPY" assert not actual.serde_properties - for key in actual.table_properties.keys(): + for key in actual.table_properties: assert not key.startswith("serde.param.") assert actual.row_format == "SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'" assert ( @@ -383,7 +383,7 @@ def test_init_parquet(self): assert actual.serde_properties == { "serialization.format": "1", } - for key in actual.table_properties.keys(): + for key in actual.table_properties: assert not key.startswith("serde.param.") assert ( actual.row_format @@ -408,7 +408,7 @@ def test_init_orc(self): assert actual.serde_serialization_lib == "org.apache.hadoop.hive.ql.io.orc.OrcSerde" assert actual.compression == "SNAPPY" assert not actual.serde_properties - for key in actual.table_properties.keys(): + for key in actual.table_properties: assert not key.startswith("serde.param.") assert actual.row_format == "SERDE 'org.apache.hadoop.hive.ql.io.orc.OrcSerde'" assert ( @@ -429,7 +429,7 @@ def test_init_avro(self): assert actual.serde_serialization_lib == "org.apache.hadoop.hive.serde2.avro.AvroSerDe" assert actual.compression is None assert not actual.serde_properties - for key in actual.table_properties.keys(): + for key in actual.table_properties: assert not key.startswith("serde.param.") assert actual.row_format == "SERDE 'org.apache.hadoop.hive.serde2.avro.AvroSerDe'" assert ( diff --git a/tests/sqlalchemy/conftest.py b/tests/sqlalchemy/conftest.py index 52aa410a..f5ee2ca7 100644 --- a/tests/sqlalchemy/conftest.py +++ b/tests/sqlalchemy/conftest.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import contextlib from urllib.parse import quote_plus from sqlalchemy.testing.plugin.pytestplugin import * # noqa @@ -32,11 +33,8 @@ def pytest_sessionstart(session): @create_db.for_db("awsathena") def _awsathena_create_db(cfg, eng, ident): - with eng.begin() as conn: - try: - _awsathena_drop_db(cfg, conn, ident) - except Exception: - pass + with eng.begin() as conn, contextlib.suppress(Exception): + _awsathena_drop_db(cfg, conn, ident) with eng.begin() as conn: conn.exec_driver_sql(f"CREATE DATABASE {ident}")