Skip to content

torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow #115

@MHGL

Description

@MHGL

🐛 Bug

I get an error

  • if-else in forward method
  • call torch.quantization.quantize_fx.prepare_fx

To Reproduce

Steps to reproduce the behavior:

  1. code example
import torch
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx

# init module
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        ...

    def forward(self, x):
        if x.size(1) != 3:
            return 
        return 

torch_model = MyModule().eval()

# fx
s_qconfig_dict = {'': get_default_qconfig("fbgemm")}
prepare_fx(torch_model, s_qconfig_dict)
  1. stack traces
Traceback (most recent call last):
  File "mini_code.py", line 22, in <module>
    prepare_fx(torch_model, s_qconfig_dict)
  File "/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize_fx.py", line 392, in prepare_fx
    return _prepare_fx(model, qconfig_dict, prepare_custom_config_dict)
  File "/opt/conda/lib/python3.8/site-packages/torch/quantization/quantize_fx.py", line 174, in _prepare_fx
    graph_module = GraphModule(model, tracer.trace(model))
  File "/opt/conda/lib/python3.8/site-packages/torch/fx/symbolic_trace.py", line 571, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "mini_code.py", line 14, in forward
    if x.size(1) != 3:
  File "/opt/conda/lib/python3.8/site-packages/torch/fx/proxy.py", line 199, in __bool__
    return self.tracer.to_bool(self)
  File "/opt/conda/lib/python3.8/site-packages/torch/fx/proxy.py", line 129, in to_bool
    raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
torch.fx.proxy.TraceError: symbolically traced variables cannot be used as inputs to control flow

Expected behavior

Environment

  • PyTorch Version: 1.9.0
  • OS (e.g., MacOS, Linux): Ubuntu20.04 LTS
  • How you install python (anaconda, virtualenv, system): miniconda
  • python version (e.g. 3.7): 3.8.5
  • any other relevant information:
    • gpu: GeForce GTX 1650
    • driver: Driver Version: 460.80
    • CUDA: CUDA Version: 11.2

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions