This repository was archived by the owner on Aug 21, 2025. It is now read-only.
[PROTOTYPE] generated batching rules for custom dispatcher ops#578
Open
[PROTOTYPE] generated batching rules for custom dispatcher ops#578
Conversation
bdhirsh
commented
Mar 9, 2022
functorch/csrc/DynamicLayer.cpp
Outdated
| // doesn't play well with DynamicLayer (only one layer of vmap works right now). | ||
| // Why? In the generated batching rule, I effectively want to treat it as a "composite kernel", | ||
| // and have it run the to the python-defined forward function. But: | ||
| // (1) I want to go there through the dispatcher so other functionalities can run (e.g. AMP). |
Contributor
Author
There was a problem hiding this comment.
Actually, maybe I should just be treating this the same way that DynamicLayer already treats composite ops - just directly call into the composite function. That means that stuff like AMP will run on the base ops and not the composite ops, but maybe that's the right behavior (unless the user wants to write a custom "AMP rule" for their op)
Contributor
Author
There was a problem hiding this comment.
(ended up doing this, although there's another issue with disabling autograd that I left a comment about)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Note: this PR has a bunch of issues, but it shows one potential way of getting "automatically generated" batching rules for custom ops registered to the dispatcher through python.
Let's say you have a custom operator (
foo) in python that you've defined a derivative formula for (foo_vjp), and you want to vmap over it. If I run this:Then the chain of calls in the dispatcher will look something like this:
That will work, but it has a downside: it requires you (the user) to write a custom batching rule for your custom op. In theory, we should be able to get the batching rule for free.
One way to "get the batching rule for free" is by running
foo(), letting it decompose into whateveratenops it eventually calls, and running the batching rules on each of those aten ops. There's a problem with that though. If we decomposefoowhen we run the batching rule, then theres no way to "undo" the decomposition below. Any kernels that we redispatch to will see the "base" ops, instead of the original op:How do we get around that? We can't really "undo" the decomposition inside of the call stack... But we could just run "foo" twice: once for the forward pass where we do decompose into the base ops, and run the batching rule on each, and once for the backward pass where we dont decompose, taking care so that:
(1) When we run the forward, we skip autograd
(2) when we run the autograd kernel (to setup the autograd graph), we dont redispatch and run the backend again.
Known issues:
(1) I'm not sure how composable this is. I haven't thought too hard yet about what would happen if you the logic together with another functionality (e.g.
amporfunctionalization)(2) It interacts poorly with
DynamicLayer- I left a comment explaining why, but I'm not sure what the best solution is.(3) I'm hardcoding that the custom ops accept and return a single TensorList argument (but so does the existing code 😛)
(4) I got one very basic test to pass, but there are probably other problems I just haven't run into.