From c760c638f3b2acc829eca4420cc51d62dbd54c5a Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 7 Dec 2025 15:47:50 +0200 Subject: [PATCH 1/3] feat: return original expression if children unchanged --- pymbolic/mapper/__init__.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 68523e3..9971599 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -1062,9 +1062,12 @@ def map_comparison(self, def map_list(self, expr: list[Expression], /, *args: P.args, **kwargs: P.kwargs ) -> Expression: + children = [self.rec(child, *args, **kwargs) for child in expr] + if all(a is b for a, b in zip(children, expr, strict=True)): + return expr # True fact: lists aren't expressions - return [self.rec(child, *args, **kwargs) for child in expr] + return children @override def map_tuple(self, @@ -1081,22 +1084,35 @@ def map_tuple(self, def map_numpy_array(self, expr: NDArray[np.generic], /, *args: P.args, **kwargs: P.kwargs ) -> Expression: + import numpy as np - import numpy - result = numpy.empty(expr.shape, dtype=object) + result = np.empty(expr.shape, dtype=object) + is_same = True for i in ndindex(expr.shape): result[i] = self.rec(expr[i], *args, **kwargs) + is_same = is_same and result[i] is expr[i] - return result + return expr if is_same else result @override def map_multivector(self, expr: MultiVector[ArithmeticExpression], /, *args: P.args, **kwargs: P.kwargs ) -> Expression: + is_same = True + + def rec(ch: ArithmeticExpression) -> ArithmeticExpression: + nonlocal is_same + + result = self.rec_arith(ch, *args, **kwargs) + is_same = is_same and result is ch + + return result + + result = expr.map(rec) + # True fact: MultiVectors aren't expressions - return expr.map(lambda ch: cast("ArithmeticExpression", - self.rec(ch, *args, **kwargs))) + return expr if is_same else result def map_common_subexpression(self, expr: p.CommonSubexpression, /, From d1e2018c3c32c814ca22bc997d692e66cff0cc27 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 7 Dec 2025 15:48:05 +0200 Subject: [PATCH 2/3] chore: small formatting changes --- pymbolic/mapper/__init__.py | 52 ++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 9971599..eb8074d 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -846,9 +846,8 @@ def map_call(self, self.rec(child, *args, **kwargs) for child in expr.parameters ]) if (function is expr.function - and all(child is orig_child - for child, orig_child in zip( - expr.parameters, parameters, strict=True))): + and all(child is orig_child for child, orig_child in + zip(expr.parameters, parameters, strict=True))): return expr return type(expr)(function, parameters) @@ -866,11 +865,12 @@ def map_call_with_kwargs(self, for key, val in expr.kw_parameters.items()}) if (function is expr.function - and all(child is orig_child for child, orig_child in - zip(parameters, expr.parameters, strict=True)) + and all(child is orig_child for child, orig_child in + zip(parameters, expr.parameters, strict=True)) and all(kw_parameters[k] is v for k, v in expr.kw_parameters.items())): return expr + return type(expr)(function, parameters, kw_parameters) @override @@ -897,8 +897,8 @@ def map_sum(self, expr: p.Sum, /, *args: P.args, **kwargs: P.kwargs ) -> Expression: children = [self.rec_arith(child, *args, **kwargs) for child in expr.children] - if all(child is orig_child - for child, orig_child in zip(children, expr.children, strict=True)): + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -908,8 +908,8 @@ def map_product(self, expr: p.Product, /, *args: P.args, **kwargs: P.kwargs ) -> Expression: children = [self.rec_arith(child, *args, **kwargs) for child in expr.children] - if all(child is orig_child - for child, orig_child in zip(children, expr.children, strict=True)): + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): return expr return type(expr)(tuple(children)) @@ -922,7 +922,7 @@ def map_quotient(self, denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr - return expr.__class__(numerator, denominator) + return type(expr)(numerator, denominator) @override def map_floor_div(self, @@ -932,7 +932,7 @@ def map_floor_div(self, denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr - return expr.__class__(numerator, denominator) + return type(expr)(numerator, denominator) @override def map_remainder(self, @@ -942,7 +942,7 @@ def map_remainder(self, denominator = self.rec_arith(expr.denominator, *args, **kwargs) if numerator is expr.numerator and denominator is expr.denominator: return expr - return expr.__class__(numerator, denominator) + return type(expr)(numerator, denominator) @override def map_power(self, @@ -952,7 +952,7 @@ def map_power(self, exponent = self.rec_arith(expr.exponent, *args, **kwargs) if base is expr.base and exponent is expr.exponent: return expr - return expr.__class__(base, exponent) + return type(expr)(base, exponent) @override def map_left_shift(self, @@ -1074,8 +1074,7 @@ def map_tuple(self, expr: tuple[Expression, ...], /, *args: P.args, **kwargs: P.kwargs ) -> Expression: children = [self.rec(child, *args, **kwargs) for child in expr] - if all(child is orig_child - for child, orig_child in zip(children, expr, strict=True)): + if all(a is b for a, b in zip(children, expr, strict=True)): return expr return tuple(children) @@ -1092,6 +1091,7 @@ def map_numpy_array(self, result[i] = self.rec(expr[i], *args, **kwargs) is_same = is_same and result[i] is expr[i] + # True fact: ndarrays aren't expressions return expr if is_same else result @override @@ -1133,8 +1133,9 @@ def map_substitution(self, *args: P.args, **kwargs: P.kwargs) -> Expression: child = self.rec(expr.child, *args, **kwargs) values = tuple([self.rec(v, *args, **kwargs) for v in expr.values]) - if child is expr.child and all(val is orig_val - for val, orig_val in zip(values, expr.values, strict=True)): + if (child is expr.child + and all(val is orig_val for val, orig_val in + zip(values, expr.values, strict=True))): return expr return type(expr)(child, expr.variables, values) @@ -1157,8 +1158,8 @@ def map_slice(self, None if child is None else self.rec(child, *args, **kwargs) for child in expr.children ])) - if all(child is orig_child - for child, orig_child in zip(children, expr.children, strict=True)): + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): return expr return type(expr)(children) @@ -1168,9 +1169,8 @@ def map_if(self, expr: p.If, /, *args: P.args, **kwargs: P.kwargs) -> Expression condition = self.rec(expr.condition, *args, **kwargs) then = self.rec(expr.then, *args, **kwargs) else_ = self.rec(expr.else_, *args, **kwargs) - if condition is expr.condition \ - and then is expr.then \ - and else_ is expr.else_: + + if condition is expr.condition and then is expr.then and else_ is expr.else_: return expr return type(expr)(condition, then, else_) @@ -1181,8 +1181,8 @@ def map_min(self, children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) - if all(child is orig_child - for child, orig_child in zip(children, expr.children, strict=True)): + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): return expr return type(expr)(children) @@ -1193,8 +1193,8 @@ def map_max(self, children = tuple([ self.rec(child, *args, **kwargs) for child in expr.children ]) - if all(child is orig_child - for child, orig_child in zip(children, expr.children, strict=True)): + if all(child is orig_child for child, orig_child in + zip(children, expr.children, strict=True)): return expr return type(expr)(children) From b4cd5559fc8961d30804b40eb1cef5ccb064f6ea Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Sun, 7 Dec 2025 15:49:22 +0200 Subject: [PATCH 3/3] chore: update baseline --- .basedpyright/baseline.json | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 90f2dc9..02e6c87 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -6367,11 +6367,19 @@ "lineCount": 1 } }, + { + "code": "reportReturnType", + "range": { + "startColumn": 19, + "endColumn": 23, + "lineCount": 1 + } + }, { "code": "reportReturnType", "range": { "startColumn": 15, - "endColumn": 67, + "endColumn": 23, "lineCount": 1 } }, @@ -6387,7 +6395,7 @@ "code": "reportReturnType", "range": { "startColumn": 15, - "endColumn": 21, + "endColumn": 42, "lineCount": 1 } }, @@ -6395,8 +6403,8 @@ "code": "reportReturnType", "range": { "startColumn": 15, - "endColumn": 71, - "lineCount": 2 + "endColumn": 42, + "lineCount": 1 } }, {