diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index b5a86728..534ed9fb 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1851,14 +1851,6 @@ } ], "./arraycontext/context.py": [ - { - "code": "reportDeprecated", - "range": { - "startColumn": 69, - "endColumn": 74, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -1867,14 +1859,6 @@ "lineCount": 1 } }, - { - "code": "reportDeprecated", - "range": { - "startColumn": 27, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -17009,6 +16993,14 @@ "lineCount": 1 } }, + { + "code": "reportInvalidCast", + "range": { + "startColumn": 15, + "endColumn": 45, + "lineCount": 1 + } + }, { "code": "reportUnknownParameterType", "range": { @@ -19761,6 +19753,46 @@ "lineCount": 1 } }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 54, + "endColumn": 63, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 54, + "endColumn": 63, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 54, + "endColumn": 63, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 54, + "endColumn": 63, + "lineCount": 1 + } + }, + { + "code": "reportArgumentType", + "range": { + "startColumn": 54, + "endColumn": 63, + "lineCount": 1 + } + }, { "code": "reportIndexIssue", "range": { @@ -20265,6 +20297,14 @@ "lineCount": 1 } }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 21, + "endColumn": 28, + "lineCount": 1 + } + }, { "code": "reportUnknownLambdaType", "range": { @@ -20313,6 +20353,22 @@ "lineCount": 1 } }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 8, + "endColumn": 53, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 8, + "endColumn": 32, + "lineCount": 1 + } + }, { "code": "reportUnusedExpression", "range": { @@ -20321,6 +20377,14 @@ "lineCount": 1 } }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 8, + "endColumn": 27, + "lineCount": 1 + } + }, { "code": "reportUnusedExpression", "range": { @@ -20377,6 +20441,22 @@ "lineCount": 1 } }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 36, + "endColumn": 55, + "lineCount": 1 + } + }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 32, + "endColumn": 51, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -20601,6 +20681,14 @@ "lineCount": 1 } }, + { + "code": "reportOperatorIssue", + "range": { + "startColumn": 12, + "endColumn": 21, + "lineCount": 1 + } + }, { "code": "reportMissingParameterType", "range": { @@ -21063,22 +21151,6 @@ "lineCount": 1 } }, - { - "code": "reportGeneralTypeIssues", - "range": { - "startColumn": 10, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportGeneralTypeIssues", - "range": { - "startColumn": 14, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { @@ -21087,22 +21159,6 @@ "lineCount": 1 } }, - { - "code": "reportGeneralTypeIssues", - "range": { - "startColumn": 10, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportGeneralTypeIssues", - "range": { - "startColumn": 14, - "endColumn": 35, - "lineCount": 1 - } - }, { "code": "reportUnannotatedClassAttribute", "range": { diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 87313fd9..23495c8f 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -8,6 +8,13 @@ .. autofunction:: with_container_arithmetic .. autoclass:: BcastUntilActxArray + +References +---------- + +.. class:: TypeT + + A type variable with an upper bound of :class:`type`. """ @@ -62,6 +69,8 @@ T = TypeVar("T") +TypeT = TypeVar("TypeT", bound=type) + @enum.unique class _OpClass(enum.Enum): @@ -190,7 +199,7 @@ def with_container_arithmetic( bcast_numpy_array: bool = False, _bcast_actx_array_type: bool | None = None, bcast_container_types: tuple[type, ...] | None = None, - ) -> Callable[[type], type]: + ) -> Callable[[TypeT], TypeT]: """A class decorator that implements built-in operators for array containers by propagating the operations to the elements of the container. diff --git a/arraycontext/context.py b/arraycontext/context.py index d064392d..61e338ab 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -171,12 +171,20 @@ THE SOFTWARE. """ + from abc import ABC, abstractmethod from collections.abc import Callable, Mapping -from typing import TYPE_CHECKING, Any, Protocol, TypeAlias, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Protocol, + TypeAlias, + TypeVar, + overload, +) from warnings import warn -import numpy as np from typing_extensions import Self from pymbolic.typing import Integer, Scalar as _Scalar @@ -184,6 +192,9 @@ if TYPE_CHECKING: + import numpy as np + from numpy.typing import DTypeLike + import loopy from pytools.tag import ToTagSetConvertible @@ -243,6 +254,21 @@ def __rpow__(self, other: Self | ScalarLike) -> Array: ... def __truediv__(self, other: Self | ScalarLike) -> Array: ... def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... + def copy(self) -> Self: ... + + @property + def real(self) -> Array: ... + @property + def imag(self) -> Array: ... + def conj(self) -> Array: ... + + def astype(self, dtype: DTypeLike) -> Array: ... + + def reshape(self, + *shape: int, + order: Literal["C"] | Literal["F"] + ) -> Array: ... + # deprecated, use ScalarLike instead Scalar = _Scalar @@ -287,7 +313,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... ContainerOrScalarT = TypeVar("ContainerOrScalarT", bound="ArrayContainer | ScalarLike") -NumpyOrContainerOrScalar = Union[np.ndarray, "ArrayContainer", ScalarLike] +NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike" # }}} @@ -476,7 +502,7 @@ def tag(self, @abstractmethod def tag_axis(self, iaxis: int, tags: ToTagSetConvertible, - array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: + array: ArrayOrContainerT) -> ArrayOrContainerT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* in which axis number *iaxis* has the *tags* applied. *array* itself is not modified. When working with @@ -623,7 +649,7 @@ def permits_advanced_indexing(self) -> bool: def tag_axes( actx: ArrayContext, dim_to_tags: Mapping[int, ToTagSetConvertible], - ary: ArrayT) -> ArrayT: + ary: ArrayOrContainerT) -> ArrayOrContainerT: """ Return a copy of *ary* with the axes in *dim_to_tags* tagged with their corresponding tags. Equivalent to repeated application of