Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Dec 21, 2025

Other changes:

  1. Add TensorView.dtype for convenience.
  2. Clean up broadcast_in_dim_fn.
  3. Add layernorm to triangle attention.

cc @DejunL

@wujingyue wujingyue requested a review from Priya2698 December 21, 2025 04:05
@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Dec 21, 2025

Review updated until commit 56a9c93

Description

  • Add dtype() method to TensorView class for accessing tensor data types

  • Refactor broadcast_in_dim_fn with improved parameter naming and validation

  • Implement AlphaFold3 triangle updates with layer normalization and gating

  • Complete triangle attention tests with proper masking and normalization

Changes walkthrough

Relevant files
Enhancement
ir.cpp
Add dtype method to TensorView                                                     

python/python_direct/ir.cpp

  • Added dtype() method to TensorView class
  • Returns PrimDataType with proper validation
  • Provides data type access for tensor views
  • +18/-0   
    ops.cpp
    Refactor broadcast_in_dim function                                             

    python/python_direct/ops.cpp

  • Refactored broadcast_in_dim_fn with cleaner parameter names
  • Improved validation logic and error handling
  • Added wrapDim utility for dimension handling
  • Added utils.h include for utility functions
  • +14/-33 
    Tests
    test_alphafold3.py
    Implement AlphaFold3 triangle mechanisms                                 

    tests/python/direct/test_alphafold3.py

  • Added layer_norm and gating helper functions
  • Implemented complete triangle updates test with both directions
  • Enhanced triangle attention test with layer normalization
  • Added proper masking and parameter handling for both test cases
  • +154/-7 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    API Change Validation

    The broadcast_in_dim_fn signature has been changed to put input first and use nonbroadcast_dims instead of broadcast_dims. This is a breaking API change that should be validated to ensure existing code using this function continues to work correctly.

    TensorView* broadcast_in_dim_fn(
        TensorView* input,
        ShapeType generic_output_shape,
        const std::vector<int64_t>& nonbroadcast_dims) {
      std::vector<Val*> output_shape = SequenceAsVector(generic_output_shape);
      NVF_CHECK_GE(output_shape.size(), nonbroadcast_dims.size());
    
      const auto input_ndim = std::ranges::distance(
          input->getLogicalDomain() | TensorDomain::kNoReductions);
      NVF_CHECK_GE(std::ssize(output_shape), input_ndim);
      NVF_CHECK_EQ(input_ndim, std::ssize(nonbroadcast_dims));
    
      std::vector<bool> is_broadcast_dim(output_shape.size(), true);
      for (auto nonbroadcast_dim : nonbroadcast_dims) {
        nonbroadcast_dim = wrapDim(nonbroadcast_dim, std::ssize(output_shape));
        is_broadcast_dim.at(nonbroadcast_dim) = false;
      }
    
      TensorView* output = broadcast(input, is_broadcast_dim);
      output = expand(output, output_shape);
      return output;
    }
    Test Coverage

    The triangle_updates test implements both OUTGOING and INCOMING directions but doesn't include validation against expected outputs. Consider adding comparison with reference PyTorch implementation to ensure correctness.

    def test_triangle_updates(direction):
        c_z = _DEFAULT_CONFIG.c_z
    
        with FusionDefinition() as fd:
            z_in = fd.define_tensor(
                shape=[-1, -1, -1, c_z],
                dtype=DataType.BFloat16,
                contiguity=True,
            )  # [b, i, j, c_z]
            w_norm_in = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            b_norm_in = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_p_in = fd.define_tensor(
                shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_g_in = fd.define_tensor(
                shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_norm_out = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            b_norm_out = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_p_out = fd.define_tensor(
                shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_g_out = fd.define_tensor(
                shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            mask = fd.define_tensor(
                shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
            )  # [b, i, j]
    
            batch_size = fd.ops.size(z_in, 0)
            n_tokens = fd.ops.size(z_in, 1)
    
            z_in = layer_norm(fd, z_in, w_norm_in, b_norm_in)
            z = gating(fd, z_in, w_p_in, z_in, w_g_in)
            mask = fd.ops.broadcast_in_dim(
                mask, shape=[batch_size, n_tokens, n_tokens, c_z], broadcast_dims=[0, 1, 2]
            )
            z = fd.ops.where(mask, z, 0.0)
            a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
            b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])
    
            match direction:
                case Direction.OUTGOING:
                    # z_out = einsum("bikc,bjkc->bijc", a, b)
                    a = fd.ops.permute(a, [0, 3, 1, 2])  # [b, c, i, k]
                    b = fd.ops.permute(b, [0, 3, 2, 1])  # [b, c, k, j]
                case Direction.INCOMING:
                    # z_out = einsum("bkic,bkjc->bijc", a, b)
                    a = fd.ops.permute(a, [0, 3, 2, 1])
                    b = fd.ops.permute(b, [0, 3, 1, 2])
            z = fd.ops.matmul(a, b)  # [b, c, i, j]
            z = fd.ops.permute(z, [0, 2, 3, 1])  # [b, i, j, c]
    
            z = layer_norm(fd, z, w_norm_out, b_norm_out)
            z = gating(fd, z, w_p_out, z_in, w_g_out)
            fd.add_output(z)
    
        batch_size = 3
        n_tokens = 5
        z_in = torch.testing.make_tensor(
            batch_size, n_tokens, n_tokens, c_z, dtype=torch.bfloat16, device="cuda"
        )
        w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        w_p_in = torch.testing.make_tensor(
            c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
        )
        w_g_in = torch.testing.make_tensor(
            c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
        )
        w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
        w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
        mask = torch.testing.make_tensor(
            batch_size, n_tokens, n_tokens, dtype=torch.bool, device="cuda"
        )
        (z_out,) = fd.execute(
            [
                z_in,
                w_norm_in,
                b_norm_in,
                w_p_in,
                w_g_in,
                w_norm_out,
                b_norm_out,
                w_p_out,
                w_g_out,
                mask,
            ]
        )
        assert z_out.shape == (batch_size, n_tokens, n_tokens, c_z)

    Test failures

    • (Medium, 32) nvFuser scaled_dot_product_attention tests failing with TypeError in sdpfa_fwd (thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder)

      Test Name GB200 H100 Source
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-False-0.001]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-False-None]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-True-0.001]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.0-True-None]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-False-0.001]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-False-None]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-True-0.001]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.bfloat16[0.2-True-None]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.float16[0.0-False-0.001]
      thunder.tests.test_nvfuser.test_sdpa_nvfuser_cuda_thunder.dtypes.float16[0.0-False-None]
      ... with 6 more test failures omitted. Check internal logs.

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 21, 2025

    Greptile Summary

    • Adds reference implementation for AlphaFold3 triangle updates with helper functions for layer normalization and gating
    • Adds TensorView.dtype convenience method to Python bindings for runtime type inspection
    • Refactors broadcast_in_dim_fn to improve parameter semantics, add negative index support, and simplify validation logic

    Important Files Changed

    Filename Overview
    tests/python/direct/test_alphafold3.py Implements complete AlphaFold3 triangle update algorithm with layer norm and gating helpers
    python/python_direct/ir.cpp Adds new dtype() method to TensorView bindings for runtime type inspection
    python/python_direct/ops.cpp Refactors broadcast_in_dim_fn with improved parameter semantics and negative index support

    Confidence score: 4/5

    • This PR is relatively safe but implements complex mathematical operations that warrant careful review
    • Score reflects the complexity of the triangle update algorithm implementation and tensor manipulation logic
    • Pay close attention to test_alphafold3.py for correctness of the mathematical formulations and tensor operations

    Sequence Diagram

    sequenceDiagram
        participant User
        participant FusionDefinition as fd
        participant Ops as fd.ops
        participant TensorView as TV
    
        User->>FusionDefinition: "define_tensor(shape, dtype, contiguity)"
        FusionDefinition-->>User: "TensorView"
    
        User->>Ops: "cast(x, dtype=DataType.Float)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "var_mean(x, dims=[-1], correction=0, keepdim=True)"
        Ops-->>User: "(var, mean)"
    
        User->>Ops: "sub(x, mean)"
        Ops-->>User: "TensorView"
    
        User->>FusionDefinition: "define_scalar(1e-5)"
        FusionDefinition-->>User: "Val"
    
        User->>Ops: "add(var, scalar)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "rsqrt(var)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "mul(y, rsqrt_result)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "shape(x)"
        Ops-->>User: "shape_vector"
    
        User->>Ops: "broadcast_in_dim(w, shape=shape, broadcast_dims=[-1])"
        Ops-->>User: "TensorView"
    
        User->>Ops: "mul(y, broadcasted_w)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "broadcast_in_dim(b, shape=shape, broadcast_dims=[-1])"
        Ops-->>User: "TensorView"
    
        User->>Ops: "add(y, broadcasted_b)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "cast(y, dtype=io_dtype)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "linear(z, w_p)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "linear(z_in, w_g)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "sigmoid(g)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "mul(p, g)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "size(z_in, dim)"
        Ops-->>User: "Val"
    
        User->>Ops: "slice(z, start_indices, end_indices)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "permute(tensor, dims)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "matmul(a, b)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "reshape(tensor, new_shape)"
        Ops-->>User: "TensorView"
    
        User->>Ops: "sdpfa_fwd(q_h, k_h, v_h, bias=b_h, mask=mask, is_causal=False)"
        Ops-->>User: "(output, log_sumexp, philox_seed, philox_offset)"
    
        User->>Ops: "where(mask, tensor, value)"
        Ops-->>User: "TensorView"
    
        User->>FusionDefinition: "add_output(z)"
        
        User->>FusionDefinition: "execute([inputs...])"
        FusionDefinition-->>User: "outputs"
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Copy link
    Collaborator

    @Priya2698 Priya2698 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I am assuming the layernorm is skipped for simplicity.

    Apart from that, the other differences are:

    1. No mask application -- looks like you are going to add it based on above discussion.
    2. No output gating -- is this skipped intentionally?

    @wujingyue wujingyue requested a review from Priya2698 December 24, 2025 05:19
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +249 to +251
    -------
    DataType
    The data type of this tensor.
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: Docstring incorrectly states return type as 'DataType' but should be 'PrimDataType' to match the actual return type.

    Suggested change
    -------
    DataType
    The data type of this tensor.
    -------
    PrimDataType
    The data type of this tensor.

    Comment on lines +129 to +130
    a = fd.ops.permute(a, [0, 3, 2, 1])
    b = fd.ops.permute(b, [0, 3, 1, 2])
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: Missing comment for INCOMING case permutations - should document the einsum equivalent like the OUTGOING case for clarity

    Suggested change
    a = fd.ops.permute(a, [0, 3, 2, 1])
    b = fd.ops.permute(b, [0, 3, 1, 2])
    # z_out = einsum("bkic,bkjc->bijc", a, b)
    a = fd.ops.permute(a, [0, 3, 2, 1])
    b = fd.ops.permute(b, [0, 3, 1, 2])

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants