Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 18, 2025

Extends IterDomainBuilder and IterDomain::cloneWithoutRfactor for RaggedIterDomain so that utils like ops::newOutputTV can correctly create RaggedIterDomain when an input ID is ragged.

This is mainly for allowing ops like set, add, etc to not generate invalid output tensors. We are not doing lowering, so this is just exercising Fusion IR constructions. Specifically, when an input tensor of a unary op has a RaggedIterDomain, its output should create a RaggedIterDomain at the corresponding position of its logical domain.

The ops in csrc/ops/arith.h, csrc/ops/alias.h, csrc/ops/indexing.h should either generate valid tensors or immediately fail. The ops in the other files are not yet considered.

@github-actions
Copy link

github-actions bot commented Dec 18, 2025

Review updated until commit 2dd9287

Description

  • Extend IterDomainBuilder to support RaggedIterDomain construction with ragged_extents parameter

  • Implement RaggedIterDomain::cloneWithoutRFactor override to preserve ragged structure

  • Add hasRaggedIterDomain() helper method to TensorDomain for detecting ragged tensors

  • Update newOutputIterDomain utility to create RaggedIterDomain when input contains ragged dimensions

  • Add validation checks across ops (reshape, flatten, pad, cat, slice, etc.) to reject ragged tensors

  • Prevent reduction operations on RaggedIterDomain dimensions

  • Add comprehensive test coverage for nested tensor operations and error cases

Changes walkthrough

Relevant files
Enhancement
internal_base_nodes.cpp
Extend IterDomainBuilder and RaggedIterDomain support       

csrc/ir/internal_base_nodes.cpp

  • Modified IterDomainBuilder constructor to extract ragged_extents from
    RaggedIterDomain
  • Added ragged_extents_ member and ragged_extents() setter method to
    IterDomainBuilder
  • Updated build() method to create RaggedIterDomain when ragged_extents
    are provided
  • Implemented RaggedIterDomain constructor accepting IterDomainBuilder
  • Added RaggedIterDomain::cloneWithoutRFactor() override to preserve
    ragged structure
  • Added TensorDomain::hasRaggedIterDomain() method to detect ragged
    tensors
  • +130/-22
    utils.cpp
    Add utility functions for RaggedIterDomain output creation

    csrc/ops/utils.cpp

  • Added newOutputRaggedIterDomain() function to create ragged output
    domains
  • Modified newOutputIterDomain() to detect RaggedIterDomain and delegate
    to newOutputRaggedIterDomain
  • Ensures ragged structure is preserved in tensor operations
  • +37/-0   
    internal_base_nodes.h
    Update interfaces for RaggedIterDomain support                     

    csrc/ir/internal_base_nodes.h

  • Updated IterDomainBuilder interface with ragged_extents() setter
    method
  • Declared RaggedIterDomain constructor accepting IterDomainBuilder
  • Made cloneWithoutRFactor() virtual in IterDomain base class
  • Added hasRaggedIterDomain() method declaration to TensorDomain
  • +22/-11 
    utils.h
    Add utility function declaration for RaggedIterDomain       

    csrc/ops/utils.h

  • Added declaration for newOutputRaggedIterDomain() utility function
  • Function creates output RaggedIterDomain from input ragged domains
  • +6/-0     
    Error handling
    alias.cpp
    Add ragged tensor validation to alias operations                 

    csrc/ops/alias.cpp

  • Added hasRaggedIterDomain() checks in reshape, flatten, pad, cat,
    slice, expand, repeat operations
  • These operations now explicitly reject tensors with RaggedIterDomain
    for safety
  • Updated asNested to check for existing ragged structure before
    creating nested tensors
  • +48/-1   
    arith.cpp
    Prevent reduction on RaggedIterDomain dimensions                 

    csrc/ops/arith.cpp

  • Added validation in newForReduction to prevent reduction of
    RaggedIterDomain dimensions
  • Reduction operations on ragged dimensions are explicitly rejected with
    error messages
  • +8/-0     
    indexing.cpp
    Add ragged tensor validation to indexing operations           

    csrc/ops/indexing.cpp

  • Added hasRaggedIterDomain() checks in select, indexSelect,
    indexPutAccumulate, gather, scatter operations
  • These indexing operations now reject tensors with RaggedIterDomain for
    safety
  • +52/-0   
    Tests
    test_ragged_iter_domain.cpp
    Add comprehensive test coverage for RaggedIterDomain operations

    tests/cpp/test_ragged_iter_domain.cpp

  • Added comprehensive test coverage for nested tensor operations
    (load/store, binary/ unary ops)
  • Added tests for broadcast, squeeze, unsqueeze, permute operations on
    nested tensors
  • Added reduction tests (valid on non-ragged dims, error on ragged dims)
  • Added error case tests for unsupported operations (reshape, flatten,
    slice, cat, pad)
  • Tests verify RaggedIterDomain structure preservation and proper error
    handling
  • +485/-1 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Incomplete Implementation

    The cloneWithoutRFactor method for RaggedIterDomain has a TODO comment and throws "Not implemented" when map_with_original is true. This suggests the implementation is incomplete and may cause runtime errors if this functionality is needed.

    IterDomain* RaggedIterDomain::cloneWithoutRFactor(bool map_with_original) {
      // Create a new RaggedIterDomain with the same extents and properties
      auto cloned = IrBuilder::create<RaggedIterDomain>(
          extents_, getIterType(), getParallelType());
    
      // Optionally map the clone with the original in the Exact graph
      if (map_with_original) {
        // TODO: Implement mapping if needed
        NVF_THROW("Not implemented");
      }
    
      return cloned;
    }
    Potential Runtime Error

    The newOutputRaggedIterDomain function assumes all input IDs are RaggedIterDomain and uses the first one as a reference. If called with mixed input types or empty vector, this could cause runtime errors or unexpected behavior.

    RaggedIterDomain* newOutputRaggedIterDomain(
        const std::vector<IterDomain*>& input_ids) {
      NVF_ERROR(
          std::ranges::all_of(
              input_ids,
              [](IterDomain* input_id) {
                return input_id->isA<RaggedIterDomain>();
              }),
          "All input iter domains must be RaggedIterDomain");
    
      NVF_ERROR(!input_ids.empty());
    
      // Just using the first ragged ID as all input IDs are assumed to be
      // equivalent
      RaggedIterDomain* ref_input_id = input_ids.front()->as<RaggedIterDomain>();
    
      return IrBuilder::create<RaggedIterDomain>(
          ref_input_id->extents(),
          ref_input_id->getIterType(),
          ref_input_id->getParallelType());
    }
    Test Coverage Gap

    While comprehensive tests were added for basic operations, there are no tests specifically for the cloneWithoutRFactor functionality that was implemented. This is a key feature of the PR and should have dedicated tests.

    TEST_F(RaggedIterDomainTest, LoadStoreWithNestedTensor) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensor from dimension 0
      auto nested = asNested(data, offsets, 0);
    
      // This should still be a nested tensor
      auto copy_of_nested = set(nested);
    
      fusion.addOutput(copy_of_nested);
    
      // Verify the output is a new TensorView
      EXPECT_TRUE(copy_of_nested != nullptr);
      EXPECT_NE(copy_of_nested, data);
      EXPECT_TRUE(copy_of_nested->isA<TensorView>());
    
      // Verify copy_of_nested tensor has 3 dimensions: [component, ragged,
      // original_dim1]
      EXPECT_EQ(copy_of_nested->nDims(), 3);
    
      // First axis should be a regular IterDomain (component)
      EXPECT_TRUE(copy_of_nested->axis(0)->isStrictlyA<IterDomain>());
      EXPECT_FALSE(copy_of_nested->axis(0)->isA<RaggedIterDomain>());
    
      // Second axis should be a RaggedIterDomain
      EXPECT_TRUE(copy_of_nested->axis(1)->isA<RaggedIterDomain>());
    
      // Third axis should be the original second dimension
      EXPECT_TRUE(copy_of_nested->axis(2)->isStrictlyA<IterDomain>());
    
      // The copy of the original copy_of_nested tensor does not inherit the
      // Partition op
      EXPECT_TRUE(copy_of_nested->axis(0)->definition() == nullptr);
    }
    
    // Test binary operations with nested tensors
    TEST_F(RaggedIterDomainTest, BinaryOpWithNestedTensors) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      // Create two 2D input tensors
      auto data1 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data1);
    
      auto data2 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data2);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensors from both inputs
      auto nested1 = asNested(data1, offsets, 0);
      auto nested2 = asNested(data2, offsets, 0);
    
      // Perform binary operation. The result should be a nested tensor
      auto result = add(nested1, nested2);
    
      fusion.addOutput(result);
    
      // Verify the result has 3 dimensions: [component, ragged, original_dim1]
      EXPECT_EQ(result->nDims(), 3);
    
      // First axis should be a regular IterDomain (component)
      EXPECT_TRUE(result->axis(0)->isStrictlyA<IterDomain>());
      EXPECT_FALSE(result->axis(0)->isA<RaggedIterDomain>());
    
      // Second axis should be a RaggedIterDomain
      EXPECT_TRUE(result->axis(1)->isA<RaggedIterDomain>());
    
      // Third axis should be the original second dimension
      EXPECT_TRUE(result->axis(2)->isStrictlyA<IterDomain>());
    }
    
    // Test binary operation with mixed inputs (one ragged, one not) - should error
    TEST_F(RaggedIterDomainTest, BinaryOpMixedInputsError) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data1 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data1);
    
      auto data2 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data2);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensor from first input only
      auto nested1 = asNested(data1, offsets, 0);
    
      // Try to add nested tensor with non-nested tensor
      // This should fail because one is ragged and one is not
      EXPECT_THROW(add(nested1, data2), nvfuser::nvfError);
    }
    
    // Test binary operation with different offsets
    TEST_F(RaggedIterDomainTest, BinaryOpDifferentRaggedStructures) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data1 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data1);
    
      auto data2 = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data2);
    
      auto offsets1 = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets1);
    
      auto offsets2 = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets2);
    
      // Create nested tensors with different offset tensors
      auto nested1 = asNested(data1, offsets1, 0);
      auto nested2 = asNested(data2, offsets2, 0);
    
      // This would be an error if, for example, the values of the offset
      // tensors are not equivalent, but, like binary ops with normal
      // tensors, we assume their shapes are indeed compatible
      auto result = add(nested1, nested2);
      fusion.addOutput(result);
    
      EXPECT_TRUE(result->axis(1)->isA<RaggedIterDomain>());
    }
    
    // Test unary operations with nested tensors
    TEST_F(RaggedIterDomainTest, UnaryOpWithNestedTensors) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensor
      auto nested = asNested(data, offsets, 0);
    
      // Perform unary operation: neg
      auto result = neg(nested);
    
      fusion.addOutput(result);
    
      // Verify the result preserves RaggedIterDomain structure
      EXPECT_EQ(result->nDims(), 3);
      EXPECT_TRUE(result->axis(0)->isStrictlyA<IterDomain>());
      EXPECT_TRUE(result->axis(1)->isA<RaggedIterDomain>());
      EXPECT_TRUE(result->axis(2)->isStrictlyA<IterDomain>());
    }
    
    // Test broadcast with nested tensors
    TEST_F(RaggedIterDomainTest, BroadcastWithNestedTensors) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensor: [component, ragged, dim1]
      auto nested = asNested(data, offsets, 0);
    
      auto result = broadcast(nested, {false, false, false, true});
    
      fusion.addOutput(result);
    
      // Result should be: [component, ragged, dim1, broadcast_dim]
      EXPECT_EQ(result->nDims(), 4);
      EXPECT_TRUE(result->axis(0)->isStrictlyA<IterDomain>());
      EXPECT_TRUE(result->axis(1)->isA<RaggedIterDomain>());
      EXPECT_TRUE(result->axis(2)->isStrictlyA<IterDomain>());
      EXPECT_TRUE(result->axis(3)->isStrictlyA<IterDomain>());
      EXPECT_TRUE(result->axis(3)->isBroadcast());
    }
    
    // Test squeeze on non-ragged dimension
    TEST_F(RaggedIterDomainTest, SqueezeNonRaggedDim) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      auto data = makeSymbolicTensor(2, DataType::Float);
      fusion.addInput(data);
    
      auto offsets = makeSymbolicTensor(1, DataType::Index);
      fusion.addInput(offsets);
    
      // Create nested tensor: [component, ragged, dim1]
      auto nested = asNested(data, offsets, 0);
    
      // First broadcast to add a dimension: [component, ragged, dim1, 1]
      auto broadcasted = broadcast(nested, {false, false, false, true});
    
      // Then squeeze the broadcast dimension (dimension index 3)
      auto result = squeeze(broadcasted, {3});
    
      fusion.addOutput(result);
    
      // Result should be: [component, ragged, dim1]
      EXPECT_EQ(result->nDims(), 3);
      EXPECT_TRUE(result->axis(0)->isStrictlyA<IterDomain>());
      EXPECT_TRUE(result->axis(1)->isA<RaggedIterDomain>());
      EXPECT_TRUE(result->axis(2)->isStrictlyA<IterDomain>());
    }
    
    // Test unsqueeze with nested tensors

    @naoyam naoyam marked this pull request as ready for review December 18, 2025 06:47
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 18, 2025

    Greptile Summary

    Extends IterDomainBuilder and cloning infrastructure to support RaggedIterDomain, enabling ops like set, add, etc. to correctly propagate ragged domains to output tensors.

    Key Changes:

    • Added ragged_extents field and builder method to IterDomainBuilder for creating RaggedIterDomain instances
    • Implemented RaggedIterDomain::cloneWithoutRFactor override (mapping logic incomplete - throws if map_with_original=true)
    • Extended newOutputIterDomain to detect ragged inputs and create ragged outputs via newOutputRaggedIterDomain
    • Added validation checks across alias, indexing, and arithmetic ops to reject unsupported operations on ragged tensors (reshape, flatten, pad, cat, slice, expand, repeat, select, gather, scatter, reduction)
    • Comprehensive test coverage for ragged domain cloning through unary/binary ops

    Issues Found:

    • Logic mismatch in newOutputIterDomain (line 352): uses std::any_of to check for ragged inputs but calls newOutputRaggedIterDomain which requires ALL inputs to be ragged (std::ranges::all_of). This will cause confusing errors when mixing ragged and non-ragged domains.

    Confidence Score: 3/5

    • This PR has a logic bug that will cause runtime errors in mixed ragged/non-ragged scenarios, but is otherwise well-structured
    • The implementation correctly extends the cloning infrastructure for RaggedIterDomain and adds comprehensive validation checks. However, there's a critical logic mismatch in newOutputIterDomain that checks if ANY input is ragged but then requires ALL inputs to be ragged, which will cause confusing errors. The cloneWithoutRFactor mapping implementation is incomplete (throws if map_with_original=true). Test coverage is good and validates expected behaviors including the mixed-input error case.
    • Pay close attention to csrc/ops/utils.cpp (logic mismatch at lines 352-361) and csrc/ir/internal_base_nodes.cpp (incomplete mapping implementation at line 991)

    Important Files Changed

    Filename Overview
    csrc/ir/internal_base_nodes.h Added ragged_extents field and builder method to IterDomainBuilder, and made cloneWithoutRFactor virtual in IterDomain
    csrc/ir/internal_base_nodes.cpp Added RaggedIterDomain constructor from builder args, implemented cloneWithoutRFactor override with incomplete mapping logic
    csrc/ops/utils.h Added newOutputRaggedIterDomain function declaration
    csrc/ops/utils.cpp Implemented ragged domain output creation, but doesn't validate that all input ragged domains have equivalent extents

    Sequence Diagram

    sequenceDiagram
        participant User as User/Op
        participant Builder as IterDomainBuilder
        participant Factory as Build/Clone
        participant RaggedID as RaggedIterDomain
        participant Utils as ops::newOutputIterDomain
    
        Note over User,Utils: Creating RaggedIterDomain from scratch
        User->>Builder: IterDomainBuilder(start, extent)
        User->>Builder: ragged_extents(extents_tv)
        User->>Builder: build()
        Builder->>Factory: Check if ragged_extents != nullptr
        Factory->>RaggedID: Create RaggedIterDomain(builder_args)
        RaggedID-->>User: Return RaggedIterDomain
    
        Note over User,Utils: Cloning RaggedIterDomain
        User->>RaggedID: cloneWithoutRFactor(map_with_original)
        RaggedID->>Builder: Create with extents, iterType, parallelType
        Builder->>RaggedID: IrBuilder::create<RaggedIterDomain>
        alt map_with_original = true
            RaggedID->>RaggedID: NVF_THROW("Not implemented")
        end
        RaggedID-->>User: Return cloned RaggedIterDomain
    
        Note over User,Utils: Creating output from ragged inputs (unary/binary ops)
        User->>Utils: newOutputIterDomain(input_ids)
        Utils->>Utils: Check if any input isA<RaggedIterDomain>
        alt has_ragged = true
            Utils->>Utils: newOutputRaggedIterDomain(input_ids)
            Utils->>Utils: Verify all_of inputs are RaggedIterDomain
            Utils->>RaggedID: Create with first input's extents
            RaggedID-->>User: Return output RaggedIterDomain
        else has_ragged = false
            Utils->>Factory: Create regular IterDomain
            Factory-->>User: Return regular IterDomain
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 4 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines 533 to 536
    NVF_CHECK(
    inp->domain()->hasRaggedIterDomain(),
    "Padding a tensor with RaggedIterDomain not supported: ",
    inp->toString());
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: logic error: check is inverted - should be !inp->domain()->hasRaggedIterDomain() to reject ragged tensors

    Suggested change
    NVF_CHECK(
    inp->domain()->hasRaggedIterDomain(),
    "Padding a tensor with RaggedIterDomain not supported: ",
    inp->toString());
    NVF_CHECK(
    !inp->domain()->hasRaggedIterDomain(),
    "Padding a tensor with RaggedIterDomain not supported: ",
    inp->toString());

    IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain);
    IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension);
    IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size);
    IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: missing implementation for ragged_extents method - declared in header but not implemented in csrc/ir/internal_base_nodes.cpp

    Suggested change
    IterDomainBuilder& ragged_extents(TensorView* _ragged_extents);
    // IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); // TODO: implement or use direct member access

    Comment on lines +330 to +332
    // Just using the first ragged ID as all input IDs are assumed to be
    // equivalent
    RaggedIterDomain* ref_input_id = input_ids.front()->as<RaggedIterDomain>();
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: assumes all input RaggedIterDomains are equivalent without validation - consider verifying extents tensors match

    Comment on lines +985 to +988
    if (map_with_original) {
    // TODO: Implement mapping if needed
    NVF_THROW("Not implemented");
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    style: mapping implementation missing - should call fusion()->registerExactMapping(this, cloned) like base IterDomain::cloneWithoutRFactor does (line 334)

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +352 to +361
    bool has_ragged =
    std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) {
    return id->isA<RaggedIterDomain>();
    });

    if (has_ragged) {
    NVF_ERROR(
    !force_iter_type.has_value(),
    "force_iter_type not supported for RaggedIterDomain");
    return newOutputRaggedIterDomain(input_ids);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: logic mismatch: checks if ANY input is ragged (std::any_of) but newOutputRaggedIterDomain requires ALL inputs to be ragged (std::ranges::all_of at line 321). This will fail when mixing ragged and non-ragged domains at the same position.

    Should either:

    1. Check std::ranges::all_of here instead of std::any_of
    2. Or add validation in newOutputRaggedIterDomain to filter/handle mixed cases

    @naoyam naoyam changed the title [WIP] RaggedIterDomain cloning RaggedIterDomain cloning Dec 19, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants