Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions xla/service/gpu/hlo_fusion_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
Expand Down Expand Up @@ -87,6 +88,85 @@ std::optional<TransposeDescription> FindConsistentTransposeHero(
return tiled_transpose_hero;
}

// Returns true if the instruction is a trivial data-moving operation.
bool IsTrivialDataMover(const HloInstruction* instr) {
switch (instr->opcode()) {
case HloOpcode::kSlice:
case HloOpcode::kReshape:
case HloOpcode::kBroadcast:
case HloOpcode::kBitcast:
case HloOpcode::kTranspose:
case HloOpcode::kCopy:
return true;
default:
return false;
}
}

// Collects the "root" inputs of an instruction by traversing through trivial
// data-moving operations (slice, bitcast, etc.). These are the real sources
// of data that the instruction depends on.
void CollectRootInputs(const HloInstruction* instr,
absl::flat_hash_set<const HloInstruction*>& roots) {
if (IsTrivialDataMover(instr)) {
for (const HloInstruction* operand : instr->operands()) {
CollectRootInputs(operand, roots);
}
} else {
roots.insert(instr);
}
}

// Returns true if the fusion's root is a concatenate.
bool FusionHasConcatenateRoot(const HloInstruction* fusion) {
if (fusion->opcode() != HloOpcode::kFusion) return false;
return fusion->fused_instructions_computation()
->root_instruction()
->opcode() == HloOpcode::kConcatenate;
}

// Returns true if the concat shares inputs with another concat (or fusion
// containing a concat) in the parent computation.
bool SharesInputsWithSiblingConcatenates(const HloInstruction& concat) {
const HloComputation* comp = concat.parent();

// If inside a fused computation, use the fusion instruction as our
// representative and check siblings in the fusion's parent computation.
const HloInstruction* our_instr = &concat;
if (comp->IsFusionComputation()) {
our_instr = comp->FusionInstruction();
comp = our_instr->parent();
}

// Collect root inputs for our instruction
absl::flat_hash_set<const HloInstruction*> our_roots;
for (const HloInstruction* operand : our_instr->operands()) {
CollectRootInputs(operand, our_roots);
}

// Check other concats and concat-fusions for shared inputs
for (const HloInstruction* instr : comp->instructions()) {
if (instr == our_instr) continue;

bool is_concat = instr->opcode() == HloOpcode::kConcatenate;
bool is_concat_fusion = FusionHasConcatenateRoot(instr);
if (!is_concat && !is_concat_fusion) continue;

absl::flat_hash_set<const HloInstruction*> other_roots;
for (const HloInstruction* operand : instr->operands()) {
CollectRootInputs(operand, other_roots);
}

for (const HloInstruction* root : other_roots) {
if (our_roots.contains(root)) {
return true; // Found shared input with sibling concat
}
}
}

return false;
}

bool UseConcatenateFusion(absl::Span<const HloInstructionAdaptor> roots,
absl::Span<const HloInstructionAdaptor> heroes) {
if (heroes.size() != 1) return false;
Expand All @@ -96,6 +176,13 @@ bool UseConcatenateFusion(absl::Span<const HloInstructionAdaptor> roots,
// Limit the number of operands because the concat emitter produces code for
// each operand, hurting occupancy.
if (heroes.front().instruction().operand_count() > 4) return false;
// Don't use concat fusion if there may be sibling concats that share same
// input. The loop emitter is preferred because it can be merged via
// multi-output fusion, avoiding duplicate computation and memory reads.
// TODO: Remove this check when concat emitter supports multiple outputs.
if (SharesInputsWithSiblingConcatenates(heroes.front().instruction())) {
return false;
}
// The loop emitter is faster when warp divergence and occupancy are both low.
// TODO(csigg): exclude this case.
return true;
Expand Down
99 changes: 99 additions & 0 deletions xla/service/gpu/hlo_fusion_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,5 +510,104 @@ TEST_F(HloFusionAnalysisTest, ConcatenateFusionFallbackToLoop) {
EXPECT_EQ(&analysis.fusion_hero(0).instruction(), multiply);
}

// Tests that when two concat fusions share computation through slices (like in
// concat_fusion_minimal.hlo), the emitter falls back to kLoop to allow
// multi-output fusion to merge them later.
TEST_F(HloFusionAnalysisTest, ConcatenateFusionWithSiblingSharingInputs) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule module

fusion1 {
a = f32[64,64] parameter(0)
b = f32[64,64] parameter(1)
s = f32[64,64] add(a, b)
top = f32[32,64] slice(s), slice={[0:32],[0:64]}
bot = f32[32,64] slice(s), slice={[32:64],[0:64]}
ROOT concatenate = f32[64,64] concatenate(top, bot), dimensions={0}
}

fusion2 {
a = f32[64,64] parameter(0)
b = f32[64,64] parameter(1)
d = f32[64,64] subtract(a, b)
left = f32[64,32] slice(d), slice={[0:64],[0:32]}
right = f32[64,32] slice(d), slice={[0:64],[32:64]}
ROOT concatenate = f32[64,64] concatenate(left, right), dimensions={1}
}

ENTRY entry_computation {
a = f32[64,64] parameter(0)
b = f32[64,64] parameter(1)
fusion.1 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion1
fusion.2 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion2
ROOT tuple = (f32[64,64], f32[64,64]) tuple(fusion.1, fusion.2)
})"));

auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

auto* entry = module->entry_computation();
auto* fusion1 = entry->GetInstructionWithName("fusion.1");
auto* fusion2 = entry->GetInstructionWithName("fusion.2");

// Both fusions share inputs a and b, so they should fall back to kLoop
// to enable multi-output fusion.
auto analysis1 = HloFusionAnalysis::Create(*fusion1, device_info);
EXPECT_EQ(analysis1.emitter_fusion_kind(),
HloFusionAnalysis::EmitterFusionKind::kLoop);

auto analysis2 = HloFusionAnalysis::Create(*fusion2, device_info);
EXPECT_EQ(analysis2.emitter_fusion_kind(),
HloFusionAnalysis::EmitterFusionKind::kLoop);
}

// Tests that concat fusions with non-overlapping inputs use kConcatenate.
TEST_F(HloFusionAnalysisTest, ConcatenateFusionsWithNonOverlappingInputs) {
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"(
HloModule module

fusion1 {
a = f32[64,64] parameter(0)
b = f32[64,64] parameter(1)
s = f32[64,64] add(a, b)
top = f32[32,64] slice(s), slice={[0:32],[0:64]}
bot = f32[32,64] slice(s), slice={[32:64],[0:64]}
ROOT concatenate = f32[64,64] concatenate(top, bot), dimensions={0}
}

fusion2 {
c = f32[64,64] parameter(0)
d = f32[64,64] parameter(1)
m = f32[64,64] multiply(c, d)
left = f32[64,32] slice(m), slice={[0:64],[0:32]}
right = f32[64,32] slice(m), slice={[0:64],[32:64]}
ROOT concatenate = f32[64,64] concatenate(left, right), dimensions={1}
}

ENTRY entry_computation {
a = f32[64,64] parameter(0)
b = f32[64,64] parameter(1)
c = f32[64,64] parameter(2)
d = f32[64,64] parameter(3)
fusion.1 = f32[64,64] fusion(a, b), kind=kInput, calls=fusion1
fusion.2 = f32[64,64] fusion(c, d), kind=kInput, calls=fusion2
ROOT tuple = (f32[64,64], f32[64,64]) tuple(fusion.1, fusion.2)
})"));

auto device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();

auto* entry = module->entry_computation();
auto* fusion1 = entry->GetInstructionWithName("fusion.1");
auto* fusion2 = entry->GetInstructionWithName("fusion.2");

// Different inputs (a,b vs c,d), so both should use kConcatenate.
auto analysis1 = HloFusionAnalysis::Create(*fusion1, device_info);
EXPECT_EQ(analysis1.emitter_fusion_kind(),
HloFusionAnalysis::EmitterFusionKind::kConcatenate);

auto analysis2 = HloFusionAnalysis::Create(*fusion2, device_info);
EXPECT_EQ(analysis2.emitter_fusion_kind(),
HloFusionAnalysis::EmitterFusionKind::kConcatenate);
}

} // namespace
} // namespace xla::gpu
Loading