From 587471d4a226ce4a8c0c390eb0580e58f62576e5 Mon Sep 17 00:00:00 2001 From: Aleksei Nurmukhametov Date: Wed, 14 Jan 2026 06:18:40 -0600 Subject: [PATCH] [ROCm][XLA:GPU] Prefer loop emitter for sibling concat fusions Concat fusions (kConcatenate) cannot be merged by multi_output_fusion. When sibling concats share inputs, this leads to duplicate memory reads and serialized kernel launches. Fall back to the loop emitter to enable their fusion to a single kernel. --- xla/service/gpu/hlo_fusion_analysis.cc | 87 ++++++++++++++++++ xla/service/gpu/hlo_fusion_analysis_test.cc | 99 +++++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/xla/service/gpu/hlo_fusion_analysis.cc b/xla/service/gpu/hlo_fusion_analysis.cc index 85368317fa137..3d42324c3048c 100644 --- a/xla/service/gpu/hlo_fusion_analysis.cc +++ b/xla/service/gpu/hlo_fusion_analysis.cc @@ -23,6 +23,7 @@ limitations under the License. #include #include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" @@ -87,6 +88,85 @@ std::optional FindConsistentTransposeHero( return tiled_transpose_hero; } +// Returns true if the instruction is a trivial data-moving operation. +bool IsTrivialDataMover(const HloInstruction* instr) { + switch (instr->opcode()) { + case HloOpcode::kSlice: + case HloOpcode::kReshape: + case HloOpcode::kBroadcast: + case HloOpcode::kBitcast: + case HloOpcode::kTranspose: + case HloOpcode::kCopy: + return true; + default: + return false; + } +} + +// Collects the "root" inputs of an instruction by traversing through trivial +// data-moving operations (slice, bitcast, etc.). These are the real sources +// of data that the instruction depends on. +void CollectRootInputs(const HloInstruction* instr, + absl::flat_hash_set& roots) { + if (IsTrivialDataMover(instr)) { + for (const HloInstruction* operand : instr->operands()) { + CollectRootInputs(operand, roots); + } + } else { + roots.insert(instr); + } +} + +// Returns true if the fusion's root is a concatenate. +bool FusionHasConcatenateRoot(const HloInstruction* fusion) { + if (fusion->opcode() != HloOpcode::kFusion) return false; + return fusion->fused_instructions_computation() + ->root_instruction() + ->opcode() == HloOpcode::kConcatenate; +} + +// Returns true if the concat shares inputs with another concat (or fusion +// containing a concat) in the parent computation. +bool SharesInputsWithSiblingConcatenates(const HloInstruction& concat) { + const HloComputation* comp = concat.parent(); + + // If inside a fused computation, use the fusion instruction as our + // representative and check siblings in the fusion's parent computation. + const HloInstruction* our_instr = &concat; + if (comp->IsFusionComputation()) { + our_instr = comp->FusionInstruction(); + comp = our_instr->parent(); + } + + // Collect root inputs for our instruction + absl::flat_hash_set our_roots; + for (const HloInstruction* operand : our_instr->operands()) { + CollectRootInputs(operand, our_roots); + } + + // Check other concats and concat-fusions for shared inputs + for (const HloInstruction* instr : comp->instructions()) { + if (instr == our_instr) continue; + + bool is_concat = instr->opcode() == HloOpcode::kConcatenate; + bool is_concat_fusion = FusionHasConcatenateRoot(instr); + if (!is_concat && !is_concat_fusion) continue; + + absl::flat_hash_set other_roots; + for (const HloInstruction* operand : instr->operands()) { + CollectRootInputs(operand, other_roots); + } + + for (const HloInstruction* root : other_roots) { + if (our_roots.contains(root)) { + return true; // Found shared input with sibling concat + } + } + } + + return false; +} + bool UseConcatenateFusion(absl::Span roots, absl::Span heroes) { if (heroes.size() != 1) return false; @@ -96,6 +176,13 @@ bool UseConcatenateFusion(absl::Span roots, // Limit the number of operands because the concat emitter produces code for // each operand, hurting occupancy. if (heroes.front().instruction().operand_count() > 4) return false; + // Don't use concat fusion if there may be sibling concats that share same + // input. The loop emitter is preferred because it can be merged via + // multi-output fusion, avoiding duplicate computation and memory reads. + // TODO: Remove this check when concat emitter supports multiple outputs. + if (SharesInputsWithSiblingConcatenates(heroes.front().instruction())) { + return false; + } // The loop emitter is faster when warp divergence and occupancy are both low. // TODO(csigg): exclude this case. return true; diff --git a/xla/service/gpu/hlo_fusion_analysis_test.cc b/xla/service/gpu/hlo_fusion_analysis_test.cc index 93914ae323263..8262fda5e464d 100644 --- a/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -510,5 +510,104 @@ TEST_F(HloFusionAnalysisTest, ConcatenateFusionFallbackToLoop) { EXPECT_EQ(&analysis.fusion_hero(0).instruction(), multiply); } +// Tests that when two concat fusions share computation through slices (like in +// concat_fusion_minimal.hlo), the emitter falls back to kLoop to allow +// multi-output fusion to merge them later. +TEST_F(HloFusionAnalysisTest, ConcatenateFusionWithSiblingSharingInputs) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion1 { + a = f32[64,64] parameter(0) + b = f32[64,64] parameter(1) + s = f32[64,64] add(a, b) + top = f32[32,64] slice(s), slice={[0:32],[0:64]} + bot = f32[32,64] slice(s), slice={[32:64],[0:64]} + ROOT concatenate = f32[64,64] concatenate(top, bot), dimensions={0} + } + + fusion2 { + a = f32[64,64] parameter(0) + b = f32[64,64] parameter(1) + d = f32[64,64] subtract(a, b) + left = f32[64,32] slice(d), slice={[0:64],[0:32]} + right = f32[64,32] slice(d), slice={[0:64],[32:64]} + ROOT concatenate = f32[64,64] concatenate(left, right), dimensions={1} + } + + ENTRY entry_computation { + a = f32[64,64] parameter(0) + b = f32[64,64] parameter(1) + fusion.1 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion1 + fusion.2 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion2 + ROOT tuple = (f32[64,64], f32[64,64]) tuple(fusion.1, fusion.2) + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* entry = module->entry_computation(); + auto* fusion1 = entry->GetInstructionWithName("fusion.1"); + auto* fusion2 = entry->GetInstructionWithName("fusion.2"); + + // Both fusions share inputs a and b, so they should fall back to kLoop + // to enable multi-output fusion. + auto analysis1 = HloFusionAnalysis::Create(*fusion1, device_info); + EXPECT_EQ(analysis1.emitter_fusion_kind(), + HloFusionAnalysis::EmitterFusionKind::kLoop); + + auto analysis2 = HloFusionAnalysis::Create(*fusion2, device_info); + EXPECT_EQ(analysis2.emitter_fusion_kind(), + HloFusionAnalysis::EmitterFusionKind::kLoop); +} + +// Tests that concat fusions with non-overlapping inputs use kConcatenate. +TEST_F(HloFusionAnalysisTest, ConcatenateFusionsWithNonOverlappingInputs) { + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( + HloModule module + + fusion1 { + a = f32[64,64] parameter(0) + b = f32[64,64] parameter(1) + s = f32[64,64] add(a, b) + top = f32[32,64] slice(s), slice={[0:32],[0:64]} + bot = f32[32,64] slice(s), slice={[32:64],[0:64]} + ROOT concatenate = f32[64,64] concatenate(top, bot), dimensions={0} + } + + fusion2 { + c = f32[64,64] parameter(0) + d = f32[64,64] parameter(1) + m = f32[64,64] multiply(c, d) + left = f32[64,32] slice(m), slice={[0:64],[0:32]} + right = f32[64,32] slice(m), slice={[0:64],[32:64]} + ROOT concatenate = f32[64,64] concatenate(left, right), dimensions={1} + } + + ENTRY entry_computation { + a = f32[64,64] parameter(0) + b = f32[64,64] parameter(1) + c = f32[64,64] parameter(2) + d = f32[64,64] parameter(3) + fusion.1 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion1 + fusion.2 = f32[64,64] fusion(c, d), kind=kInput, calls=fusion2 + ROOT tuple = (f32[64,64], f32[64,64]) tuple(fusion.1, fusion.2) + })")); + + auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo(); + + auto* entry = module->entry_computation(); + auto* fusion1 = entry->GetInstructionWithName("fusion.1"); + auto* fusion2 = entry->GetInstructionWithName("fusion.2"); + + // Different inputs (a,b vs c,d), so both should use kConcatenate. + auto analysis1 = HloFusionAnalysis::Create(*fusion1, device_info); + EXPECT_EQ(analysis1.emitter_fusion_kind(), + HloFusionAnalysis::EmitterFusionKind::kConcatenate); + + auto analysis2 = HloFusionAnalysis::Create(*fusion2, device_info); + EXPECT_EQ(analysis2.emitter_fusion_kind(), + HloFusionAnalysis::EmitterFusionKind::kConcatenate); +} + } // namespace } // namespace xla::gpu