-
Notifications
You must be signed in to change notification settings - Fork 37
[WIP] Support TorchScript and graph rewrite #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3ac6c2c
62b6657
4c58e5a
01e7d49
7558a5c
a55e9de
f9d5a82
510f7f1
83154f1
69e5410
699fba7
19fcdb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| """Optmize einsum operation in the graph""" | ||
|
|
||
| from typing import List | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from fastseq.optimizer.jit.utils import graph_pattern, rewrite_graph | ||
|
|
||
| def einsum_pattern_0(t0: str, t1: List[Tensor]): | ||
| r = torch.einsum(t0, t1) | ||
| return r | ||
|
|
||
| def einsum_rewrite_pattern_0(eqn: str, operands: List[Tensor]): | ||
| # eqn = eqn.replace(' ', '') # TODO: fix the issue: ValueError: stoll | ||
| # for cases like "bmhtd,bnhsd->bmhts" | ||
| if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and | ||
| eqn[2] == eqn[8] and eqn[4] == eqn[10] and eqn[9] == eqn[17]): | ||
| t0 = operands[0] | ||
| t1 = operands[1] | ||
| b, m, h, t, d = t0.shape | ||
| s = t1.size(3) | ||
| n = t1.size(1) | ||
| t1 = t1.permute(0, 2, 3, 4, 1) # (b, h, s, d, n) | ||
| if n > 1: | ||
| t1 = t1.sum(dim=4, keepdim=True) # (b, h, s, d, 1) | ||
|
|
||
| t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, d) | ||
| t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, d, 1, s) | ||
| t0 = t0.reshape(b*h, m*t, d) | ||
| t1 = t1.view(b*h, d, s) | ||
| r = torch.bmm(t0, t1).view(b, h, m, t, s).permute(0, 2, 1, 3, 4) | ||
| return r | ||
|
|
||
| # for cases like "bmhts,bnhsd->bmhtd" | ||
| if (len(eqn) == 18 and eqn[0:4] == eqn[13:17] and eqn[0] == eqn[6] and | ||
| eqn[2] == eqn[8] and eqn[4] == eqn[9] and eqn[10] == eqn[17]): | ||
|
||
| t0 = operands[0] | ||
| t1 = operands[1] | ||
| b, m, h, t, s = t0.shape | ||
| n = t1.size(1) | ||
| d = t1.size(4) | ||
| t1 = t1.permute(0, 2, 4, 3, 1) # (b, h, d, s, n) | ||
| if n > 1: | ||
| t1 = t1.sum(dim=4, keepdim=True) # (b, h, d, s, 1) | ||
| # t1 = t1.squeeze(1) # (b, h, s, d) | ||
| t0 = t0.permute(0, 2, 1, 3, 4) # (b, h, m, t, s) | ||
| t1 = t1.permute(0, 1, 3, 4, 2) # (b, h, s, 1, d) | ||
| t0 = t0.reshape(b*h, m*t, s) | ||
| t1 = t1.view(b*h, s, d) | ||
| r = torch.bmm(t0, t1).view(b, h, m, t, d).permute(0, 2, 1, 3, 4) | ||
| return r | ||
|
||
|
|
||
| return torch.einsum(eqn, operands) | ||
|
|
||
| EINSUM_PATTERN_STR = graph_pattern(einsum_pattern_0)() | ||
| EINSUM_REWRITE_PATTERN_STR = graph_pattern(einsum_rewrite_pattern_0)() | ||
|
|
||
| def rewrite_einsum(input_graph: torch._C.Graph): | ||
| rewrite_graph(EINSUM_PATTERN_STR, EINSUM_REWRITE_PATTERN_STR, input_graph) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| """Load and apply the registered graph rewrite patterns""" | ||
|
|
||
| import torch | ||
|
|
||
| from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum | ||
|
|
||
| def optimize_graph(input_graph: torch._C.Graph): | ||
| rewrite_einsum(input_graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this more general ? Same pattern can be used for equations without batch dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer to leaving it as what it is. If we meet the cases in the future, it can be added easily with similar code block. To make it more general, it will be more like the implementation of einsum kernel.
From the micro benchmarking result, the runtime for large tensors will be very similar with/without the optimization.