Skip to content

Conversation

@rdspring1
Copy link
Collaborator

No description provided.

@rdspring1
Copy link
Collaborator Author

!test

@github-actions
Copy link

Description

  • Add MXFP8 scaled matrix multiplication implementation using CUTLASS kernels for SM100+ GPUs

  • Implement comprehensive input validation for matrix dimensions, data types, and alignment requirements

  • Add Python bindings and test suite with quantization/dequantization utilities

  • Support FP16 and BF16 output types with proper workspace management and error handling

Changes walkthrough

Relevant files
Enhancement
mxfp8_scaled_mm.cu
MXFP8 CUTLASS kernel implementation                                           

cutlass/mxfp8_scaled_mm.cu

  • Implement MXFP8 scaled matrix multiplication using CUTLASS kernels for
    SM100+ architecture
  • Define kernel traits and configurations for FP16/BF16 output types
  • Add argument construction and GEMM execution functions with workspace
    management
  • Include fallback implementation for unsupported CUTLASS versions
  • +316/-0 
    nvf_cutlass.h
    MXFP8 function declarations and documentation                       

    cutlass/nvf_cutlass.h

  • Add function declarations for validateInputsMxFp8ScaledMm and
    mxfp8_scaled_mm
  • Include comprehensive documentation for input validation and matrix
    multiplication
  • Specify parameter requirements for MXFP8 format and scaling factors
  • +51/-0   
    nvf_cutlass.cpp
    MXFP8 input validation implementation                                       

    cutlass/nvf_cutlass.cpp

  • Implement validateInputsMxFp8ScaledMm function with comprehensive
    input checking
  • Validate matrix dimensions, CUDA properties, data types, and alignment
    requirements
  • Check scale matrix properties and padding requirements for optimal
    performance
  • +112/-0 
    cutlass.cpp
    Python bindings for MXFP8 operations                                         

    python/python_direct/cutlass.cpp

  • Add Python binding for mxfp8_scaled_mm function with proper tensor
    handling
  • Update nvfp4_scaled_mm docstring for accuracy
  • Expose MXFP8 functionality to Python interface with type safety
  • +21/-1   
    Tests
    test_cutlass_mxfp8_gemm.py
    MXFP8 GEMM test suite                                                                       

    tests/python/direct/test_cutlass_mxfp8_gemm.py

  • Add comprehensive test suite for MXFP8 GEMM with multiple data types
    and shapes
  • Implement quantization/dequantization utilities and reference
    computation
  • Include device capability checks and proper test parameterization
  • Validate output accuracy against PyTorch reference implementation
  • +122/-0 
    Configuration changes
    CMakeLists.txt
    Build configuration update                                                             

    CMakeLists.txt

  • Add mxfp8_scaled_mm.cu to NVFUSER_CUTLASS_SRCS list
  • Include new source file in build configuration for compilation
  • +1/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Missing Error Handling

    The function lacks comprehensive error handling for unsupported compute capabilities. While there's a fallback for unsupported CUTLASS versions, there's no explicit check for compute capability 10.0+ before attempting to use SM100+ kernels, which could lead to runtime failures on older GPUs.

    torch::Tensor mxfp8_scaled_mm(
        const torch::Tensor& a,
        const torch::Tensor& b,
        const torch::Tensor& scales_a,
        const torch::Tensor& scales_b,
        const torch::Tensor& alpha,
        const at::ScalarType out_dtype,
        bool skip_checks) {
      // Validate all inputs and get matrix dimensions
      auto [m, n, k] =
          validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha, skip_checks);
    
      at::cuda::CUDAGuard device_guard{(int8_t)a.get_device()};
      const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
    
      auto options =
          at::TensorOptions().dtype(out_dtype).device(at::kCUDA, a.get_device());
      torch::Tensor output = at::empty({a.sizes()[0], b.sizes()[0]}, options);
    
      if (out_dtype == at::ScalarType::Half) {
        runGemm<cutlass::half_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else if (out_dtype == at::ScalarType::BFloat16) {
        runGemm<cutlass::bfloat16_t>(
            output, a, b, scales_a, scales_b, alpha, m, n, k, stream);
      } else {
        NVF_THROW("Unsupported output data type of mxfp8 scaled_mm.");
      }
      return output;
    }
    Inconsistent Error Messages

    Error messages for data type validation are inconsistent - line 166 mentions "Float4_e2m1fn_x2" but the actual check is for "Float8_e4m3fn". This mismatch could confuse users debugging type-related issues.

    NVF_CHECK(
        a.scalar_type() == at::ScalarType::Float8_e4m3fn,
        "Expected Float4_e2m1fn_x2 for Operand A.")
    NVF_CHECK(
        b.scalar_type() == at::ScalarType::Float8_e4m3fn,
        "Expected Float4_e2m1fn_x2 for Operand B.")
    NVF_CHECK(
        scales_a.scalar_type() == at::ScalarType::Float8_e8m0fnu,
        "Expected FP8_E4M3 for Blockscale scale_a.")
    NVF_CHECK(
        scales_b.scalar_type() == at::ScalarType::Float8_e8m0fnu,
        "Expected FP8_E4M3 for Blockscale scale_b.")
    NVF_CHECK(
        alpha.scalar_type() == at::ScalarType::Float,
        "Expected FP32 for alpha scalar.")
    Test Coverage Gap

    The test function is named test_nvfp4_gemm but tests MXFP8 functionality. This naming inconsistency could lead to confusion when running specific test suites or when the actual NVFP4 tests are added later.

    def test_nvfp4_gemm(
        dtype: torch.dtype,
        shape: tuple[int, int, int],
    ) -> None:

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 23, 2025

    Greptile Summary

    Added CUTLASS-based MXFP8 (FP8_e4m3fn with block scaling) matrix multiplication support for SM100+ GPUs, mirroring the existing NVFP4 scaled GEMM implementation.

    Key Changes:

    • New mxfp8_scaled_mm.cu kernel implementing CUTLASS GEMM for MXFP8 format with FP16/BF16 output
    • Added validateInputsMxFp8ScaledMm function with comprehensive input validation
    • Python binding added to expose functionality via nvf_cutlass.mxfp8_scaled_mm()
    • Test suite covers multiple matrix shapes and output dtypes (FP16/BF16)

    Issues Found:

    • 4 error messages in validation function incorrectly copied from NVFP4 version (say "Float4_e2m1fn_x2" instead of "Float8_e4m3fn", say "FP8_E4M3" instead of "Float8_e8m0fnu")
    • Test function misnamed as test_nvfp4_gemm instead of test_mxfp8_gemm

    Confidence Score: 4/5

    • Safe to merge after fixing error messages and test name
    • Implementation follows established patterns from NVFP4, but has 5 copy-paste errors in strings that need correction to avoid confusing error messages at runtime
    • Pay attention to cutlass/nvf_cutlass.cpp (fix 4 error messages) and tests/python/direct/test_cutlass_mxfp8_gemm.py (fix test name)

    Important Files Changed

    Filename Overview
    cutlass/mxfp8_scaled_mm.cu New CUTLASS kernel implementation for MXFP8 GEMM with SM100+ support, well-structured
    cutlass/nvf_cutlass.cpp Added validation function with copy-paste errors in error messages (4 incorrect type names)
    tests/python/direct/test_cutlass_mxfp8_gemm.py New test file with test function misnamed as test_nvfp4_gemm instead of test_mxfp8_gemm

    Sequence Diagram

    sequenceDiagram
        participant User as Python User
        participant Binding as Python Binding<br/>(cutlass.cpp)
        participant API as C++ API<br/>(mxfp8_scaled_mm)
        participant Validator as Input Validator<br/>(validateInputsMxFp8ScaledMm)
        participant Kernel as CUTLASS Kernel<br/>(runGemm)
        participant GPU as GPU Device
    
        User->>Binding: nvf_cutlass.mxfp8_scaled_mm(a, b, scales_a, scales_b, alpha, dtype)
        Binding->>API: mxfp8_scaled_mm(tensors, out_dtype, skip_checks)
        API->>Validator: validateInputsMxFp8ScaledMm(a, b, scales_a, scales_b, alpha)
        Validator->>Validator: Check tensor dimensions
        Validator->>Validator: Verify CUDA device & contiguity
        Validator->>Validator: Validate data types
        Validator->>Validator: Check alignment requirements
        Validator->>Validator: Validate scale matrix shapes
        Validator-->>API: Return (m, n, k)
        API->>API: Set CUDA device guard
        API->>API: Create empty output tensor
        API->>Kernel: runGemm<T>(output, a, b, scales_a, scales_b, alpha, m, n, k, stream)
        Kernel->>Kernel: Construct CUTLASS arguments (args_from_options)
        Kernel->>Kernel: Allocate workspace memory
        Kernel->>Kernel: Check can_implement
        Kernel->>Kernel: Initialize GEMM kernel
        Kernel->>GPU: Execute GEMM kernel
        GPU-->>Kernel: Computation complete
        Kernel-->>API: Return
        API-->>Binding: Return output tensor
        Binding-->>User: Return output tensor
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 5 comments

    Edit Code Review Agent Settings | Greptile

    // Validate data types
    NVF_CHECK(
    a.scalar_type() == at::ScalarType::Float8_e4m3fn,
    "Expected Float4_e2m1fn_x2 for Operand A.")
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: error message says "Expected Float4_e2m1fn_x2" but should say "Expected Float8_e4m3fn" to match actual validation

    Suggested change
    "Expected Float4_e2m1fn_x2 for Operand A.")
    "Expected Float8_e4m3fn for Operand A.")

    "Expected Float4_e2m1fn_x2 for Operand A.")
    NVF_CHECK(
    b.scalar_type() == at::ScalarType::Float8_e4m3fn,
    "Expected Float4_e2m1fn_x2 for Operand B.")
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: error message says "Expected Float4_e2m1fn_x2" but should say "Expected Float8_e4m3fn" to match actual validation

    Suggested change
    "Expected Float4_e2m1fn_x2 for Operand B.")
    "Expected Float8_e4m3fn for Operand B.")

    "Expected Float4_e2m1fn_x2 for Operand B.")
    NVF_CHECK(
    scales_a.scalar_type() == at::ScalarType::Float8_e8m0fnu,
    "Expected FP8_E4M3 for Blockscale scale_a.")
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: error message says "Expected FP8_E4M3" but should say "Expected Float8_e8m0fnu" to match actual validation

    Suggested change
    "Expected FP8_E4M3 for Blockscale scale_a.")
    "Expected Float8_e8m0fnu for Blockscale scale_a.")

    "Expected FP8_E4M3 for Blockscale scale_a.")
    NVF_CHECK(
    scales_b.scalar_type() == at::ScalarType::Float8_e8m0fnu,
    "Expected FP8_E4M3 for Blockscale scale_b.")
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: error message says "Expected FP8_E4M3" but should say "Expected Float8_e8m0fnu" to match actual validation

    Suggested change
    "Expected FP8_E4M3 for Blockscale scale_b.")
    "Expected Float8_e8m0fnu for Blockscale scale_b.")

    "shape", [(128, 128, 128), (128, 128, 256), (256, 128, 128), (128, 256, 256)]
    )
    @torch.inference_mode()
    def test_nvfp4_gemm(
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: test function named test_nvfp4_gemm but tests mxfp8, should be test_mxfp8_gemm

    Suggested change
    def test_nvfp4_gemm(
    def test_mxfp8_gemm(

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants