diff --git a/test/quantization/pass/test_propagate_quant_param.py b/test/quantization/pass/test_propagate_quant_param.py index 4fddfcc4..e0ad6537 100644 --- a/test/quantization/pass/test_propagate_quant_param.py +++ b/test/quantization/pass/test_propagate_quant_param.py @@ -260,3 +260,21 @@ def test_s16_different_scale(self): # The test will check cat's scale is 1.0, the larger one self.run_test() + + +class ExpandModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x.expand(5, 3) + + def get_example_inputs(self): + return (torch.randn(1, 3),), {} + + +class ExpandTest(SingleOpPropagateQParamForwardTest): + # TODO Support u8 + def test_s16(self): + self.setup(ExpandModule(), torch.ops.aten.expand.default, dtype="int16") + self.run_test() diff --git a/tico/quantization/passes/propagate_qparam_forward.py b/tico/quantization/passes/propagate_qparam_forward.py index eb3c8bee..887b4b56 100644 --- a/tico/quantization/passes/propagate_qparam_forward.py +++ b/tico/quantization/passes/propagate_qparam_forward.py @@ -27,6 +27,7 @@ from tico.utils.trace_decorators import trace_graph_diff_on_pass from tico.utils.validate_args_kwargs import ( CatArgs, + ExpandArgs, NegArgs, PermuteArgs, ReshapeArgs, @@ -130,7 +131,9 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node): assert max_scale_node is not None _propagate_qparam_if_possible(max_scale_node, node) - + elif node.target == torch.ops.aten.expand.default: + expand_args = ExpandArgs(*node.args, **node.kwargs) + _propagate_qparam_if_possible(expand_args.input, node) # TODO Support more ops. graph.eliminate_dead_code()