-
Notifications
You must be signed in to change notification settings - Fork 74
RaggedIterDomain cloning #5707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: raggediterdomain-asnested
Are you sure you want to change the base?
RaggedIterDomain cloning #5707
Conversation
|
Review updated until commit 2dd9287 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Error handling |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Incomplete Implementation
cloneWithoutRFactor method for RaggedIterDomain has a TODO comment and throws "Not implemented" when map_with_original is true. This suggests the implementation is incomplete and may cause runtime errors if this functionality is needed. |
Greptile SummaryExtends Key Changes:
Issues Found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/Op
participant Builder as IterDomainBuilder
participant Factory as Build/Clone
participant RaggedID as RaggedIterDomain
participant Utils as ops::newOutputIterDomain
Note over User,Utils: Creating RaggedIterDomain from scratch
User->>Builder: IterDomainBuilder(start, extent)
User->>Builder: ragged_extents(extents_tv)
User->>Builder: build()
Builder->>Factory: Check if ragged_extents != nullptr
Factory->>RaggedID: Create RaggedIterDomain(builder_args)
RaggedID-->>User: Return RaggedIterDomain
Note over User,Utils: Cloning RaggedIterDomain
User->>RaggedID: cloneWithoutRFactor(map_with_original)
RaggedID->>Builder: Create with extents, iterType, parallelType
Builder->>RaggedID: IrBuilder::create<RaggedIterDomain>
alt map_with_original = true
RaggedID->>RaggedID: NVF_THROW("Not implemented")
end
RaggedID-->>User: Return cloned RaggedIterDomain
Note over User,Utils: Creating output from ragged inputs (unary/binary ops)
User->>Utils: newOutputIterDomain(input_ids)
Utils->>Utils: Check if any input isA<RaggedIterDomain>
alt has_ragged = true
Utils->>Utils: newOutputRaggedIterDomain(input_ids)
Utils->>Utils: Verify all_of inputs are RaggedIterDomain
Utils->>RaggedID: Create with first input's extents
RaggedID-->>User: Return output RaggedIterDomain
else has_ragged = false
Utils->>Factory: Create regular IterDomain
Factory-->>User: Return regular IterDomain
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 4 comments
| NVF_CHECK( | ||
| inp->domain()->hasRaggedIterDomain(), | ||
| "Padding a tensor with RaggedIterDomain not supported: ", | ||
| inp->toString()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic error: check is inverted - should be !inp->domain()->hasRaggedIterDomain() to reject ragged tensors
| NVF_CHECK( | |
| inp->domain()->hasRaggedIterDomain(), | |
| "Padding a tensor with RaggedIterDomain not supported: ", | |
| inp->toString()); | |
| NVF_CHECK( | |
| !inp->domain()->hasRaggedIterDomain(), | |
| "Padding a tensor with RaggedIterDomain not supported: ", | |
| inp->toString()); |
csrc/ir/internal_base_nodes.h
Outdated
| IterDomainBuilder& is_rfactor_domain(bool _is_rfactor_domain); | ||
| IterDomainBuilder& is_padded_dimension(bool _is_padded_dimension); | ||
| IterDomainBuilder& padded_to_size(std::optional<int64_t> _padded_to_size); | ||
| IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: missing implementation for ragged_extents method - declared in header but not implemented in csrc/ir/internal_base_nodes.cpp
| IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); | |
| // IterDomainBuilder& ragged_extents(TensorView* _ragged_extents); // TODO: implement or use direct member access |
| // Just using the first ragged ID as all input IDs are assumed to be | ||
| // equivalent | ||
| RaggedIterDomain* ref_input_id = input_ids.front()->as<RaggedIterDomain>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: assumes all input RaggedIterDomains are equivalent without validation - consider verifying extents tensors match
| if (map_with_original) { | ||
| // TODO: Implement mapping if needed | ||
| NVF_THROW("Not implemented"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: mapping implementation missing - should call fusion()->registerExactMapping(this, cloned) like base IterDomain::cloneWithoutRFactor does (line 334)
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
| bool has_ragged = | ||
| std::any_of(input_ids.begin(), input_ids.end(), [](IterDomain* id) { | ||
| return id->isA<RaggedIterDomain>(); | ||
| }); | ||
|
|
||
| if (has_ragged) { | ||
| NVF_ERROR( | ||
| !force_iter_type.has_value(), | ||
| "force_iter_type not supported for RaggedIterDomain"); | ||
| return newOutputRaggedIterDomain(input_ids); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: logic mismatch: checks if ANY input is ragged (std::any_of) but newOutputRaggedIterDomain requires ALL inputs to be ragged (std::ranges::all_of at line 321). This will fail when mixing ragged and non-ragged domains at the same position.
Should either:
- Check
std::ranges::all_ofhere instead ofstd::any_of - Or add validation in
newOutputRaggedIterDomainto filter/handle mixed cases
Extends
IterDomainBuilderandIterDomain::cloneWithoutRfactorforRaggedIterDomainso that utils likeops::newOutputTVcan correctly createRaggedIterDomainwhen an input ID is ragged.This is mainly for allowing ops like
set,add, etc to not generate invalid output tensors. We are not doing lowering, so this is just exercising Fusion IR constructions. Specifically, when an input tensor of a unary op has aRaggedIterDomain, its output should create aRaggedIterDomainat the corresponding position of its logical domain.The ops in csrc/ops/arith.h, csrc/ops/alias.h, csrc/ops/indexing.h should either generate valid tensors or immediately fail. The ops in the other files are not yet considered.