-
Notifications
You must be signed in to change notification settings - Fork 74
Add Cutlass MxFp8 Block Scale Matrix Multiplication #5736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
!test |
Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
| ||||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Missing Error Handling
|
Greptile SummaryAdded CUTLASS-based MXFP8 (FP8_e4m3fn with block scaling) matrix multiplication support for SM100+ GPUs, mirroring the existing NVFP4 scaled GEMM implementation. Key Changes:
Issues Found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as Python User
participant Binding as Python Binding<br/>(cutlass.cpp)
participant API as C++ API<br/>(mxfp8_scaled_mm)
participant Validator as Input Validator<br/>(validateInputsMxFp8ScaledMm)
participant Kernel as CUTLASS Kernel<br/>(runGemm)
participant GPU as GPU Device
User->>Binding: nvf_cutlass.mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
Binding->>API: mxfp8_scaled_mm(tensors, out_dtype, skip_checks)
API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha)
Validator->>Validator: Check tensor dimensions
Validator->>Validator: Verify CUDA device & contiguity
Validator->>Validator: Validate data types
Validator->>Validator: Check alignment requirements
Validator->>Validator: Validate scale matrix shapes
Validator-->>API: Return (m, n, k)
API->>API: Set CUDA device guard
API->>API: Create empty output tensor
API->>Kernel: runGemm<T>(output, a, b, scales_a, scales_b, alpha, m, n, k, stream)
Kernel->>Kernel: Construct CUTLASS arguments (args_from_options)
Kernel->>Kernel: Allocate workspace memory
Kernel->>Kernel: Check can_implement
Kernel->>Kernel: Initialize GEMM kernel
Kernel->>GPU: Execute GEMM kernel
GPU-->>Kernel: Computation complete
Kernel-->>API: Return
API-->>Binding: Return output tensor
Binding-->>User: Return output tensor
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 5 comments
| // Validate data types | ||
| NVF_CHECK( | ||
| a.scalar_type() == at::ScalarType::Float8_e4m3fn, | ||
| "Expected Float4_e2m1fn_x2 for Operand A.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: error message says "Expected Float4_e2m1fn_x2" but should say "Expected Float8_e4m3fn" to match actual validation
| "Expected Float4_e2m1fn_x2 for Operand A.") | |
| "Expected Float8_e4m3fn for Operand A.") |
| "Expected Float4_e2m1fn_x2 for Operand A.") | ||
| NVF_CHECK( | ||
| b.scalar_type() == at::ScalarType::Float8_e4m3fn, | ||
| "Expected Float4_e2m1fn_x2 for Operand B.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: error message says "Expected Float4_e2m1fn_x2" but should say "Expected Float8_e4m3fn" to match actual validation
| "Expected Float4_e2m1fn_x2 for Operand B.") | |
| "Expected Float8_e4m3fn for Operand B.") |
| "Expected Float4_e2m1fn_x2 for Operand B.") | ||
| NVF_CHECK( | ||
| scales_a.scalar_type() == at::ScalarType::Float8_e8m0fnu, | ||
| "Expected FP8_E4M3 for Blockscale scale_a.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: error message says "Expected FP8_E4M3" but should say "Expected Float8_e8m0fnu" to match actual validation
| "Expected FP8_E4M3 for Blockscale scale_a.") | |
| "Expected Float8_e8m0fnu for Blockscale scale_a.") |
| "Expected FP8_E4M3 for Blockscale scale_a.") | ||
| NVF_CHECK( | ||
| scales_b.scalar_type() == at::ScalarType::Float8_e8m0fnu, | ||
| "Expected FP8_E4M3 for Blockscale scale_b.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: error message says "Expected FP8_E4M3" but should say "Expected Float8_e8m0fnu" to match actual validation
| "Expected FP8_E4M3 for Blockscale scale_b.") | |
| "Expected Float8_e8m0fnu for Blockscale scale_b.") |
| "shape", [(128, 128, 128), (128, 128, 256), (256, 128, 128), (128, 256, 256)] | ||
| ) | ||
| @torch.inference_mode() | ||
| def test_nvfp4_gemm( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: test function named test_nvfp4_gemm but tests mxfp8, should be test_mxfp8_gemm
| def test_nvfp4_gemm( | |
| def test_mxfp8_gemm( |
No description provided.