Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/quantization/pass/test_propagate_quant_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,21 @@ def test_s16_different_scale(self):

# The test will check cat's scale is 1.0, the larger one
self.run_test()


class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.expand(5, 3)

def get_example_inputs(self):
return (torch.randn(1, 3),), {}


class ExpandTest(SingleOpPropagateQParamForwardTest):
# TODO Support u8
def test_s16(self):
self.setup(ExpandModule(), torch.ops.aten.expand.default, dtype="int16")
self.run_test()
5 changes: 4 additions & 1 deletion tico/quantization/passes/propagate_qparam_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tico.utils.trace_decorators import trace_graph_diff_on_pass
from tico.utils.validate_args_kwargs import (
CatArgs,
ExpandArgs,
NegArgs,
PermuteArgs,
ReshapeArgs,
Expand Down Expand Up @@ -130,7 +131,9 @@ def _propagate_qparam_if_possible(src: torch.fx.Node, dst: torch.fx.Node):

assert max_scale_node is not None
_propagate_qparam_if_possible(max_scale_node, node)

elif node.target == torch.ops.aten.expand.default:
expand_args = ExpandArgs(*node.args, **node.kwargs)
_propagate_qparam_if_possible(expand_args.input, node)
# TODO Support more ops.

graph.eliminate_dead_code()
Expand Down