-
Notifications
You must be signed in to change notification settings - Fork 74
Reference implementation for triangle updates #5732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: wjy/end
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit 56a9c93 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
API Change Validation
|
Test failures
-
(Medium, 32)
nvFuserscaled_dot_product_attentiontests failing with TypeError in sdpfa_fwd (thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder)Test Name GB200 H100 Source thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-False-0.001] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-False-None] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-True-0.001] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-True-None] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-False-0.001] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-False-None] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-True-0.001] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-True-None] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.float16[0.0-False-0.001] ❌ ❌ thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.float16[0.0-False-None] ❌ ❌ ... with 6 more test failures omitted. Check internal logs.
Greptile Summary
Important Files Changed
Confidence score: 4/5
Sequence DiagramsequenceDiagram
participant User
participant FusionDefinition as fd
participant Ops as fd.ops
participant TensorView as TV
User->>FusionDefinition: "define_tensor(shape, dtype, contiguity)"
FusionDefinition-->>User: "TensorView"
User->>Ops: "cast(x, dtype=DataType.Float)"
Ops-->>User: "TensorView"
User->>Ops: "var_mean(x, dims=[-1], correction=0, keepdim=True)"
Ops-->>User: "(var, mean)"
User->>Ops: "sub(x, mean)"
Ops-->>User: "TensorView"
User->>FusionDefinition: "define_scalar(1e-5)"
FusionDefinition-->>User: "Val"
User->>Ops: "add(var, scalar)"
Ops-->>User: "TensorView"
User->>Ops: "rsqrt(var)"
Ops-->>User: "TensorView"
User->>Ops: "mul(y, rsqrt_result)"
Ops-->>User: "TensorView"
User->>Ops: "shape(x)"
Ops-->>User: "shape_vector"
User->>Ops: "broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])"
Ops-->>User: "TensorView"
User->>Ops: "mul(y, broadcasted_w)"
Ops-->>User: "TensorView"
User->>Ops: "broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])"
Ops-->>User: "TensorView"
User->>Ops: "add(y, broadcasted_b)"
Ops-->>User: "TensorView"
User->>Ops: "cast(y, dtype=io_dtype)"
Ops-->>User: "TensorView"
User->>Ops: "linear(z, w_p)"
Ops-->>User: "TensorView"
User->>Ops: "linear(z_in, w_g)"
Ops-->>User: "TensorView"
User->>Ops: "sigmoid(g)"
Ops-->>User: "TensorView"
User->>Ops: "mul(p, g)"
Ops-->>User: "TensorView"
User->>Ops: "size(z_in, dim)"
Ops-->>User: "Val"
User->>Ops: "slice(z, start_indices, end_indices)"
Ops-->>User: "TensorView"
User->>Ops: "permute(tensor, dims)"
Ops-->>User: "TensorView"
User->>Ops: "matmul(a, b)"
Ops-->>User: "TensorView"
User->>Ops: "reshape(tensor, new_shape)"
Ops-->>User: "TensorView"
User->>Ops: "sdpfa_fwd(q_h, k_h, v_h, bias=b_h, mask=mask, is_causal=False)"
Ops-->>User: "(output, log_sumexp, philox_seed, philox_offset)"
User->>Ops: "where(mask, tensor, value)"
Ops-->>User: "TensorView"
User->>FusionDefinition: "add_output(z)"
User->>FusionDefinition: "execute([inputs...])"
FusionDefinition-->>User: "outputs"
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am assuming the layernorm is skipped for simplicity.
Apart from that, the other differences are:
- No mask application -- looks like you are going to add it based on above discussion.
- No output gating -- is this skipped intentionally?
382fd62 to
e1f4f08
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| ------- | ||
| DataType | ||
| The data type of this tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.
| ------- | |
| DataType | |
| The data type of this tensor. | |
| ------- | |
| PrimDataType | |
| The data type of this tensor. |
| a = fd.ops.permute(a, [0, 3, 2, 1]) | ||
| b = fd.ops.permute(b, [0, 3, 1, 2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Missing comment for INCOMING case permutations - should document the einsum equivalent like the OUTGOING case for clarity
| a = fd.ops.permute(a, [0, 3, 2, 1]) | |
| b = fd.ops.permute(b, [0, 3, 1, 2]) | |
| # z_out = einsum("bkic,bkjc->bijc", a, b) | |
| a = fd.ops.permute(a, [0, 3, 2, 1]) | |
| b = fd.ops.permute(b, [0, 3, 1, 2]) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Other changes:
cc @DejunL