Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -6367,11 +6367,19 @@
"lineCount": 1
}
},
{
"code": "reportReturnType",
"range": {
"startColumn": 19,
"endColumn": 23,
"lineCount": 1
}
},
{
"code": "reportReturnType",
"range": {
"startColumn": 15,
"endColumn": 67,
"endColumn": 23,
"lineCount": 1
}
},
Expand All @@ -6387,16 +6395,16 @@
"code": "reportReturnType",
"range": {
"startColumn": 15,
"endColumn": 21,
"endColumn": 42,
"lineCount": 1
}
},
{
"code": "reportReturnType",
"range": {
"startColumn": 15,
"endColumn": 71,
"lineCount": 2
"endColumn": 42,
"lineCount": 1
}
},
{
Expand Down
80 changes: 48 additions & 32 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
else:
return self.handle_unsupported_expression(expr, *args, **kwargs)
else:
return self.map_foreign(expr, *args, **kwargs)

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

Check warning on line 208 in pymbolic/mapper/__init__.py

View workflow job for this annotation

GitHub Actions / Pytest on Py3.10

List found in expression graph. This is deprecated and will stop working in 2025. Use tuples instead.

rec = __call__
"""Identical to :meth:`__call__`, but intended for use in recursive dispatch
Expand Down Expand Up @@ -846,9 +846,8 @@
self.rec(child, *args, **kwargs) for child in expr.parameters
])
if (function is expr.function
and all(child is orig_child
for child, orig_child in zip(
expr.parameters, parameters, strict=True))):
and all(child is orig_child for child, orig_child in
zip(expr.parameters, parameters, strict=True))):
return expr

return type(expr)(function, parameters)
Expand All @@ -866,11 +865,12 @@
for key, val in expr.kw_parameters.items()})

if (function is expr.function
and all(child is orig_child for child, orig_child in
zip(parameters, expr.parameters, strict=True))
and all(child is orig_child for child, orig_child in
zip(parameters, expr.parameters, strict=True))
and all(kw_parameters[k] is v for k, v in
expr.kw_parameters.items())):
return expr

return type(expr)(function, parameters, kw_parameters)

@override
Expand All @@ -897,8 +897,8 @@
expr: p.Sum, /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
children = [self.rec_arith(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -908,8 +908,8 @@
expr: p.Product, /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
children = [self.rec_arith(child, *args, **kwargs) for child in expr.children]
if all(child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(tuple(children))
Expand All @@ -922,7 +922,7 @@
denominator = self.rec_arith(expr.denominator, *args, **kwargs)
if numerator is expr.numerator and denominator is expr.denominator:
return expr
return expr.__class__(numerator, denominator)
return type(expr)(numerator, denominator)

@override
def map_floor_div(self,
Expand All @@ -932,7 +932,7 @@
denominator = self.rec_arith(expr.denominator, *args, **kwargs)
if numerator is expr.numerator and denominator is expr.denominator:
return expr
return expr.__class__(numerator, denominator)
return type(expr)(numerator, denominator)

@override
def map_remainder(self,
Expand All @@ -942,7 +942,7 @@
denominator = self.rec_arith(expr.denominator, *args, **kwargs)
if numerator is expr.numerator and denominator is expr.denominator:
return expr
return expr.__class__(numerator, denominator)
return type(expr)(numerator, denominator)

@override
def map_power(self,
Expand All @@ -952,7 +952,7 @@
exponent = self.rec_arith(expr.exponent, *args, **kwargs)
if base is expr.base and exponent is expr.exponent:
return expr
return expr.__class__(base, exponent)
return type(expr)(base, exponent)

@override
def map_left_shift(self,
Expand Down Expand Up @@ -1062,17 +1062,19 @@
def map_list(self,
expr: list[Expression], /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
children = [self.rec(child, *args, **kwargs) for child in expr]
if all(a is b for a, b in zip(children, expr, strict=True)):
return expr

# True fact: lists aren't expressions
return [self.rec(child, *args, **kwargs) for child in expr]
return children

@override
def map_tuple(self,
expr: tuple[Expression, ...], /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
children = [self.rec(child, *args, **kwargs) for child in expr]
if all(child is orig_child
for child, orig_child in zip(children, expr, strict=True)):
if all(a is b for a, b in zip(children, expr, strict=True)):
return expr

return tuple(children)
Expand All @@ -1081,22 +1083,36 @@
def map_numpy_array(self,
expr: NDArray[np.generic], /, *args: P.args, **kwargs: P.kwargs
) -> Expression:
import numpy as np

import numpy
result = numpy.empty(expr.shape, dtype=object)
result = np.empty(expr.shape, dtype=object)
is_same = True
for i in ndindex(expr.shape):
result[i] = self.rec(expr[i], *args, **kwargs)
is_same = is_same and result[i] is expr[i]

return result
# True fact: ndarrays aren't expressions
return expr if is_same else result

@override
def map_multivector(self,
expr: MultiVector[ArithmeticExpression], /,
*args: P.args, **kwargs: P.kwargs
) -> Expression:
is_same = True

def rec(ch: ArithmeticExpression) -> ArithmeticExpression:
nonlocal is_same

result = self.rec_arith(ch, *args, **kwargs)
is_same = is_same and result is ch

return result

result = expr.map(rec)

# True fact: MultiVectors aren't expressions
return expr.map(lambda ch: cast("ArithmeticExpression",
self.rec(ch, *args, **kwargs)))
return expr if is_same else result

def map_common_subexpression(self,
expr: p.CommonSubexpression, /,
Expand All @@ -1117,8 +1133,9 @@
*args: P.args, **kwargs: P.kwargs) -> Expression:
child = self.rec(expr.child, *args, **kwargs)
values = tuple([self.rec(v, *args, **kwargs) for v in expr.values])
if child is expr.child and all(val is orig_val
for val, orig_val in zip(values, expr.values, strict=True)):
if (child is expr.child
and all(val is orig_val for val, orig_val in
zip(values, expr.values, strict=True))):
return expr

return type(expr)(child, expr.variables, values)
Expand All @@ -1141,8 +1158,8 @@
None if child is None else self.rec(child, *args, **kwargs)
for child in expr.children
]))
if all(child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand All @@ -1152,9 +1169,8 @@
condition = self.rec(expr.condition, *args, **kwargs)
then = self.rec(expr.then, *args, **kwargs)
else_ = self.rec(expr.else_, *args, **kwargs)
if condition is expr.condition \
and then is expr.then \
and else_ is expr.else_:

if condition is expr.condition and then is expr.then and else_ is expr.else_:
return expr

return type(expr)(condition, then, else_)
Expand All @@ -1165,8 +1181,8 @@
children = tuple([
self.rec(child, *args, **kwargs) for child in expr.children
])
if all(child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand All @@ -1177,8 +1193,8 @@
children = tuple([
self.rec(child, *args, **kwargs) for child in expr.children
])
if all(child is orig_child
for child, orig_child in zip(children, expr.children, strict=True)):
if all(child is orig_child for child, orig_child in
zip(children, expr.children, strict=True)):
return expr

return type(expr)(children)
Expand Down
Loading