Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pyathena/aio/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def fetchall(self) -> Any:
def setinputsizes(self, sizes: Any) -> None:
self._cursor.setinputsizes(sizes)

async def _async_soft_close(self) -> None:
return

# PyAthena-specific methods used by AthenaDialect reflection
def list_databases(self, *args: Any, **kwargs: Any) -> Any:
return await_only(self._cursor.list_databases(*args, **kwargs))
Expand Down Expand Up @@ -122,11 +125,11 @@ class AsyncAdapt_pyathena_connection(AdaptedConnection): # noqa: N801 - follows

def __init__(self, dbapi: "AsyncAdapt_pyathena_dbapi", connection: AioConnection) -> None:
self.dbapi = dbapi
self._connection = connection
self._connection = connection # type: ignore[assignment]

@property
def driver_connection(self) -> AioConnection:
return self._connection # type: ignore[no-any-return]
return self._connection # type: ignore[return-value]

@property
def catalog_name(self) -> Optional[str]:
Expand All @@ -144,7 +147,7 @@ def close(self) -> None:
self._connection.close()

def commit(self) -> None:
self._connection.commit()
self._connection.commit() # type: ignore[unused-coroutine]

def rollback(self) -> None:
pass
Expand Down
134 changes: 65 additions & 69 deletions pyathena/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast

from sqlalchemy import exc, types, util
from sqlalchemy.sql.compiler import (
Expand Down Expand Up @@ -31,12 +31,12 @@
UniqueConstraint,
)
from sqlalchemy.sql.ddl import CreateTable
from sqlalchemy.sql.elements import FunctionElement
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.selectable import GenerativeSelect

from pyathena.sqlalchemy.base import AthenaDialect

_DialectArgDict = Dict[str, Any]
_DialectArgDict = Mapping[str, Any]
CreateColumn = Any


Expand All @@ -61,10 +61,10 @@ class AthenaTypeCompiler(GenericTypeCompiler):
https://docs.aws.amazon.com/athena/latest/ug/data-types.html
"""

def visit_FLOAT(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_REAL(type_, **kw)
def visit_FLOAT(self, type_: types.Float[Any], **kw: Any) -> str: # noqa: N802
return self.visit_REAL(type_, **kw) # type: ignore[arg-type]

def visit_REAL(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_REAL(self, type_: types.REAL[Any], **kw: Any) -> str: # noqa: N802
return "FLOAT"

def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802
Expand All @@ -73,78 +73,78 @@ def visit_DOUBLE(self, type_, **kw) -> str: # noqa: N802
def visit_DOUBLE_PRECISION(self, type_, **kw) -> str: # noqa: N802
return "DOUBLE"

def visit_NUMERIC(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_DECIMAL(type_, **kw)
def visit_NUMERIC(self, type_: types.Numeric[Any], **kw: Any) -> str: # noqa: N802
return self.visit_DECIMAL(type_, **kw) # type: ignore[arg-type]

def visit_DECIMAL(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_DECIMAL(self, type_: types.DECIMAL[Any], **kw: Any) -> str: # noqa: N802
if type_.precision is None:
return "DECIMAL"
if type_.scale is None:
return f"DECIMAL({type_.precision})"
return f"DECIMAL({type_.precision}, {type_.scale})"

def visit_TINYINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_TINYINT(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802
return "TINYINT"

def visit_INTEGER(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_INTEGER(self, type_: types.Integer, **kw: Any) -> str: # noqa: N802
return "INTEGER"

def visit_SMALLINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_SMALLINT(self, type_: types.SmallInteger, **kw: Any) -> str: # noqa: N802
return "SMALLINT"

def visit_BIGINT(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_BIGINT(self, type_: types.BigInteger, **kw: Any) -> str: # noqa: N802
return "BIGINT"

def visit_TIMESTAMP(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_TIMESTAMP(self, type_: types.TIMESTAMP, **kw: Any) -> str: # noqa: N802
return "TIMESTAMP"

def visit_DATETIME(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_TIMESTAMP(type_, **kw)
def visit_DATETIME(self, type_: types.DateTime, **kw: Any) -> str: # noqa: N802
return self.visit_TIMESTAMP(type_, **kw) # type: ignore[arg-type]

def visit_DATE(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_DATE(self, type_: types.Date, **kw: Any) -> str: # noqa: N802
return "DATE"

def visit_TIME(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_TIME(self, type_: types.Time, **kw: Any) -> str: # noqa: N802
raise exc.CompileError(f"Data type `{type_}` is not supported")

def visit_CLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw)
def visit_CLOB(self, type_: types.CLOB, **kw: Any) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]

def visit_NCLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw)
def visit_NCLOB(self, type_: types.Text, **kw: Any) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]

def visit_CHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_CHAR(self, type_: types.CHAR, **kw: Any) -> str: # noqa: N802
if type_.length:
return cast(str, self._render_string_type(type_, "CHAR"))
return self._render_string_type("CHAR", type_.length, type_.collation)
return "STRING"

def visit_NCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_CHAR(type_, **kw)
def visit_NCHAR(self, type_: types.NCHAR, **kw: Any) -> str: # noqa: N802
return self.visit_CHAR(type_, **kw) # type: ignore[arg-type]

def visit_VARCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_VARCHAR(self, type_: types.String, **kw: Any) -> str: # noqa: N802
if type_.length:
return cast(str, self._render_string_type(type_, "VARCHAR"))
return self._render_string_type("VARCHAR", type_.length, type_.collation)
return "STRING"

def visit_NVARCHAR(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_VARCHAR(type_, **kw)
def visit_NVARCHAR(self, type_: types.NVARCHAR, **kw: Any) -> str: # noqa: N802
return self.visit_VARCHAR(type_, **kw) # type: ignore[arg-type]

def visit_TEXT(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_TEXT(self, type_: types.Text, **kw: Any) -> str: # noqa: N802
return "STRING"

def visit_BLOB(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw)
def visit_BLOB(self, type_: types.LargeBinary, **kw: Any) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]

def visit_BINARY(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_BINARY(self, type_: types.BINARY, **kw: Any) -> str: # noqa: N802
return "BINARY"

def visit_VARBINARY(self, type_: Type[Any], **kw) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw)
def visit_VARBINARY(self, type_: types.VARBINARY, **kw: Any) -> str: # noqa: N802
return self.visit_BINARY(type_, **kw) # type: ignore[arg-type]

def visit_BOOLEAN(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_BOOLEAN(self, type_: types.Boolean, **kw: Any) -> str: # noqa: N802
return "BOOLEAN"

def visit_JSON(self, type_: Type[Any], **kw) -> str: # noqa: N802
def visit_JSON(self, type_: types.JSON, **kw: Any) -> str: # noqa: N802
return "JSON"

def visit_string(self, type_, **kw): # noqa: N802
Expand Down Expand Up @@ -219,10 +219,10 @@ class AthenaStatementCompiler(SQLCompiler):
https://docs.aws.amazon.com/athena/latest/ug/ddl-sql-reference.html
"""

def visit_char_length_func(self, fn: "FunctionElement[Any]", **kw):
def visit_char_length_func(self, fn: "Function[Any]", **kw: Any) -> str:
return f"length{self.function_argspec(fn, **kw)}"

def visit_filter_func(self, fn: "FunctionElement[Any]", **kw) -> str:
def visit_filter_func(self, fn: "Function[Any]", **kw: Any) -> str:
"""Compile Athena filter() function with lambda expressions.

Supports syntax: filter(array_expr, lambda_expr)
Expand Down Expand Up @@ -370,7 +370,7 @@ def _get_comment_specification(self, comment: str) -> str:
return f"COMMENT {self._escape_comment(comment)}"

def _get_bucket_count(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[str]:
if dialect_opts["bucket_count"]:
bucket_count = dialect_opts["bucket_count"]
Expand All @@ -381,7 +381,7 @@ def _get_bucket_count(
return cast(str, bucket_count) if bucket_count is not None else None

def _get_file_format(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[str]:
if dialect_opts["file_format"]:
file_format = dialect_opts["file_format"]
Expand All @@ -392,7 +392,7 @@ def _get_file_format(
return cast(Optional[str], file_format)

def _get_file_format_specification(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> str:
file_format = self._get_file_format(dialect_opts, connect_opts)
text = []
Expand All @@ -401,7 +401,7 @@ def _get_file_format_specification(
return "\n".join(text)

def _get_row_format(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[str]:
if dialect_opts["row_format"]:
row_format = dialect_opts["row_format"]
Expand All @@ -412,7 +412,7 @@ def _get_row_format(
return cast(Optional[str], row_format)

def _get_row_format_specification(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> str:
row_format = self._get_row_format(dialect_opts, connect_opts)
text = []
Expand All @@ -421,7 +421,7 @@ def _get_row_format_specification(
return "\n".join(text)

def _get_serde_properties(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[Union[str, Dict[str, Any]]]:
if dialect_opts["serdeproperties"]:
serde_properties = dialect_opts["serdeproperties"]
Expand All @@ -432,7 +432,7 @@ def _get_serde_properties(
return cast(Optional[str], serde_properties)

def _get_serde_properties_specification(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> str:
serde_properties = self._get_serde_properties(dialect_opts, connect_opts)
text = []
Expand All @@ -446,7 +446,7 @@ def _get_serde_properties_specification(
return "\n".join(text)

def _get_table_location(
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[str]:
if dialect_opts["location"]:
location = cast(str, dialect_opts["location"])
Expand All @@ -464,7 +464,7 @@ def _get_table_location(
return location

def _get_table_location_specification(
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, table: "Table", dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> str:
location = self._get_table_location(table, dialect_opts, connect_opts)
text = []
Expand All @@ -482,7 +482,7 @@ def _get_table_location_specification(
return "\n".join(text)

def _get_table_properties(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[Union[Dict[str, str], str]]:
if dialect_opts["tblproperties"]:
table_properties = cast(str, dialect_opts["tblproperties"])
Expand All @@ -493,7 +493,7 @@ def _get_table_properties(
return table_properties

def _get_compression(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> Optional[str]:
if dialect_opts["compression"]:
compression = cast(str, dialect_opts["compression"])
Expand All @@ -504,7 +504,7 @@ def _get_compression(
return compression

def _get_table_properties_specification(
self, dialect_opts: "_DialectArgDict", connect_opts: Dict[str, Any]
self, dialect_opts: "_DialectArgDict", connect_opts: Mapping[str, Any]
) -> str:
properties = self._get_table_properties(dialect_opts, connect_opts)
if properties:
Expand Down Expand Up @@ -554,34 +554,30 @@ def get_column_specification(self, column: "Column[Any]", **kwargs) -> str:
text.append(f"{self._get_comment_specification(column.comment)}")
return " ".join(text)

def visit_check_constraint(self, constraint: "CheckConstraint", **kw) -> Optional[str]:
return None
def visit_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str:
return ""

def visit_column_check_constraint(self, constraint: "CheckConstraint", **kw) -> Optional[str]:
return None
def visit_column_check_constraint(self, constraint: "CheckConstraint", **kw: Any) -> str:
return ""

def visit_foreign_key_constraint(
self, constraint: "ForeignKeyConstraint", **kw
) -> Optional[str]:
return None
def visit_foreign_key_constraint(self, constraint: "ForeignKeyConstraint", **kw: Any) -> str:
return ""

def visit_primary_key_constraint(
self, constraint: "PrimaryKeyConstraint", **kw
) -> Optional[str]:
return None
def visit_primary_key_constraint(self, constraint: "PrimaryKeyConstraint", **kw: Any) -> str:
return ""

def visit_unique_constraint(self, constraint: "UniqueConstraint", **kw) -> Optional[str]:
return None
def visit_unique_constraint(self, constraint: "UniqueConstraint", **kw: Any) -> str:
return ""

def _get_connect_option_partitions(self, connect_opts: Dict[str, Any]) -> List[str]:
def _get_connect_option_partitions(self, connect_opts: Mapping[str, Any]) -> List[str]:
if connect_opts:
partition = cast(str, connect_opts.get("partition"))
partitions = partition.split(",") if partition else []
else:
partitions = []
return partitions

def _get_connect_option_buckets(self, connect_opts: Dict[str, Any]) -> List[str]:
def _get_connect_option_buckets(self, connect_opts: Mapping[str, Any]) -> List[str]:
if connect_opts:
bucket = cast(str, connect_opts.get("cluster"))
buckets = bucket.split(",") if bucket else []
Expand Down Expand Up @@ -624,7 +620,7 @@ def _prepared_columns(
table: "Table",
is_iceberg: bool,
create_columns: List["CreateColumn"],
connect_opts: Dict[str, Any],
connect_opts: Mapping[str, Any],
) -> Tuple[List[str], List[str], List[str]]:
columns, partitions, buckets = [], [], []
conn_partitions = self._get_connect_option_partitions(connect_opts)
Expand Down
1 change: 0 additions & 1 deletion tests/sqlalchemy/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest

del BinaryTest # noqa
del BizarroCharacterFKResolutionTest # noqa
del ComponentReflectionTest # noqa
del ComponentReflectionTestExtra # noqa
del CompositeKeyReflectionTest # noqa
Expand Down
Loading