From b0ffb2807804b0f2df122b513b0d05e71bb0f228 Mon Sep 17 00:00:00 2001 From: Christoph Berganski Date: Sat, 20 Dec 2025 21:03:08 +0100 Subject: [PATCH] [Rewriter] Extend list of supported commutative operations Signed-off-by: Christoph Berganski --- onnxscript/rewriter/_pattern_ir.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 9b81e33581..674d1fc593 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -899,12 +899,26 @@ def num_outputs(self) -> int: return len(self._outputs) def commute(self) -> Sequence[GraphPattern]: + # List all commutative elementwise (binary) operators for which we + # consider swapping the inputs + COMMUTATIVE_OPS = { + ("", "Add", ""), + ("", "Mul", ""), + ("", "And", ""), + ("", "Or", ""), + ("", "Xor", ""), + ("", "BitwiseAnd", ""), + ("", "BitwiseOr", ""), + ("", "BitwiseXor", ""), + ("", "Equal", ""), + ("", "Max", ""), + ("", "Mean", ""), + ("", "Min", ""), + ("", "Sum", ""), + } + def commute_node(node: NodePattern) -> Iterable[bool]: - if node.op_identifier() == ("", "Add", "") or node.op_identifier() == ( - "", - "Mul", - "", - ): + if node.op_identifier() in COMMUTATIVE_OPS: # Try with and without swapping inputs. return [False, True] # No swapping of inputs