-
Notifications
You must be signed in to change notification settings - Fork 24
Description
When exporting only a part of your PyTorch model (like a single layer) with torch.export, it’s not always clear what the dummy input should be—especially for complex submodules with custom forward signatures.
Below are two practical, robust ways to automatically capture the dummy input for any submodule during a real inference.
1. Forward Hook with with_kwargs=True (PyTorch 2.0+)
Register a forward hook on your submodule with with_kwargs=True to automatically collect both positional and keyword arguments during a real forward pass.
Example
import torch
captured_inputs = {}
def make_hook(name):
def hook_fn(module, args, kwargs, output):
captured_inputs[name] = {
'args': args,
'kwargs': kwargs if kwargs is not None else {},
'output': output
}
return hook_fn
handles = []
for name, module in [('layer1', model.layer1), ('layer2', model.layer2)]:
handle = module.register_forward_hook(make_hook(name), with_kwargs=True)
handles.append(handle)
# Run a normal forward pass
dummy_input = torch.randn(1, 16)
_ = model(dummy_input)
# Inspect captured inputs
for name in captured_inputs:
print(f'[{name}]')
print(' Positional:', [x.shape if isinstance(x, torch.Tensor) else x for x in captured_inputs[name]['args']])
print(' Keyword:', {k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in captured_inputs[name]['kwargs'].items()})
# Remove hooks
for handle in handles:
handle.remove()2. Patch the Submodule’s Forward Method
import torch
import types
def record_inputs(self, *args, **kwargs):
# Save both args and kwargs for later use
record_inputs.args = args
record_inputs.kwargs = kwargs
# Optionally, call the real forward for correctness
return self._real_forward(*args, **kwargs)
# Suppose you want to export model.layer2
submodule = model.layer2
# Save original forward
submodule._real_forward = submodule.forward
# Properly bind the recorder method to submodule (preserves 'self')
submodule.forward = types.MethodType(record_inputs, submodule)
# Run a normal forward pass with realistic input to the full model
_ = model(some_real_input)
# Now retrieve the captured input for the submodule
args = record_inputs.args
kwargs = record_inputs.kwargs
print("Captured dummy input for submodule:")
print("args:", [a.shape if isinstance(a, torch.Tensor) else type(a) for a in args])
print("kwargs:", {k: (v.shape if isinstance(v, torch.Tensor) else type(v)) for k, v in kwargs.items()})
# Restore the original forward (important!)
submodule.forward = submodule._real_forward
del submodule._real_forward
# Now use (args, kwargs) as the dummy input for torch.export
# Example:
# torch.export(submodule, args, kwargs, ...)Why Use types.MethodType instead of just overriding the forward method?
Ensures the patched forward method receives the correct self argument. Avoids subtle bugs, especially for submodules where forward accesses self attributes or methods. More robust in complex inheritance or when submodule’s forward is nontrivial.
By temporarily replacing the submodule’s forward with a types.MethodType-wrapped recorder, you can safely and reliably capture all input arguments for dummy input construction—making partial module export with torch.export straightforward.
Which Should You Use?
For maximum compatibility (e.g., PyTorch <2.0, simple cases), use forward patching. For modern codebases (PyTorch 2.0+), multiple submodules, or less intrusive instrumentation, use forward hooks with with_kwargs=True.
Both methods allow you to programmatically extract the correct dummy input for partial model export, even for submodules with complex signatures.