diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 281f936fd772bb..b7d21e4c665b22 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -244,11 +244,11 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_gpu_nccl_collective_max_nchannels(0); opts.set_xla_gpu_nccl_p2p_max_nchannels(0); -#if GOOGLE_CUDA - opts.set_xla_gpu_mlir_emitter_level(4); -#else - opts.set_xla_gpu_mlir_emitter_level(0); -#endif +// #if GOOGLE_CUDA +// opts.set_xla_gpu_mlir_emitter_level(4); +// #else +// opts.set_xla_gpu_mlir_emitter_level(0); +// #endif opts.set_xla_gpu_multi_streamed_windowed_einsum(false); @@ -1798,12 +1798,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, "Specify the maximum number of channels(SMs) NCCL will use " "for p2p operations. Default is 0 which is to let " "NCCL decide.")); - flag_list->push_back( - tsl::Flag("xla_gpu_mlir_emitter_level", - int64_setter_for(&DebugOptions::set_xla_gpu_mlir_emitter_level), - debug_options->xla_gpu_mlir_emitter_level(), - "Enable new MLIR-based emitters. Level 0 means disabled, " - "higher levels enable more of the emitters.")); +// flag_list->push_back( +// tsl::Flag("xla_gpu_mlir_emitter_level", +// int64_setter_for(&DebugOptions::set_xla_gpu_mlir_emitter_level), +// debug_options->xla_gpu_mlir_emitter_level(), +// "Enable new MLIR-based emitters. Level 0 means disabled, " +// "higher levels enable more of the emitters.")); flag_list->push_back(tsl::Flag( "xla_gpu_multi_streamed_windowed_einsum", bool_setter_for( diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index a7acaf2c5146ab..52e14392648f54 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -191,8 +191,10 @@ xla_cc_test( srcs = ["gpu_copy_insertion_test.cc"], deps = [ ":buffer_sharing", + ":gpu_device_info_for_tests", "//xla:test", "//xla:test_helpers", + "//xla/hlo/analysis:hlo_dataflow_analysis", "//xla/hlo/ir:hlo", "//xla/service:copy_insertion", "//xla/tests:hlo_test_base", @@ -263,7 +265,7 @@ xla_cc_test( cc_library( name = "gpu_device_info_for_tests", - testonly = 1, + testonly = 0, srcs = ["gpu_device_info_for_tests.cc"], hdrs = ["gpu_device_info_for_tests.h"], compatible_with = get_compatible_with_portable(), @@ -696,25 +698,18 @@ cc_library( srcs = ["reduction_utils.cc"], hdrs = ["reduction_utils.h"], compatible_with = get_compatible_with_portable(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":ir_emission_utils", "//xla:shape_util", "//xla:util", "//xla/hlo/ir:hlo", - "//xla/service:hlo_module_config", - "//xla/stream_executor:semantic_version", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", - "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:span", "@local_tsl//tsl/platform:logging", - ] + if_cuda_is_configured([ - ":gpu_asm_opts_util", - "//xla/stream_executor/cuda:cuda_asm_compiler", - ]), + ], ) xla_cc_test( @@ -1345,6 +1340,7 @@ cc_library( "//xla/service/gpu/transforms:copy_fusion", "//xla/service/gpu/transforms:horizontal_loop_fusion", "//xla/service/gpu/transforms:sanitize_constant_names", + "//xla/stream_executor:device_description", ], ) @@ -2262,6 +2258,11 @@ xla_cc_test( ], ) +cc_library( + name = "stream_executor_util_kernel_stub", + srcs = ["stream_executor_util_kernel_stub.cc"], +) + gpu_kernel_library( name = "stream_executor_util_kernel", srcs = ["stream_executor_util_kernel.cu.cc"], @@ -2278,7 +2279,6 @@ cc_library( srcs = ["stream_executor_util.cc"], hdrs = ["stream_executor_util.h"], copts = tsl_copts(), - local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]), deps = [ ":cublas_cudnn", ":launch_dimensions", @@ -2311,9 +2311,10 @@ cc_library( "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/protobuf:dnn_proto_cc", - ] + if_gpu_is_configured([ - ":stream_executor_util_kernel", - ]), + ] + if_gpu_is_configured( + if_false = [":stream_executor_util_kernel_stub"], + if_true = [":stream_executor_util_kernel"], + ), ) xla_cc_test( @@ -2520,6 +2521,10 @@ xla_cc_test( ":gpu_fusible", "//xla/hlo/ir:hlo", "//xla/service:hlo_parser", + "//xla/service:hlo_runner", + "//xla/service:instruction_fusion", + "//xla/service:platform_util", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h index e70b252abb30a0..427381a1bccc86 100644 --- a/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h +++ b/third_party/xla/xla/service/gpu/autotuning/autotuner_util.h @@ -141,14 +141,8 @@ class AutotuneConfig { debug_options.xla_gpu_experimental_autotune_cache_mode()) {} std::string GetModelStr() const { - if (auto deviceless_config = std::get_if(&config_)) { - return AutotuneCacheKey::DeviceDescriptionToCacheKey( - deviceless_config->device_description); - } - - const auto& device_config = std::get(config_); return AutotuneCacheKey::DeviceDescriptionToCacheKey( - device_config.stream_exec->GetDeviceDescription()); + GetDeviceDescription()); } se::StreamExecutor* GetExecutor() const { @@ -175,11 +169,14 @@ class AutotuneConfig { } const se::GpuComputeCapability& GetGpuComputeCapability() const { - if (auto c = std::get_if(&config_)) { - return c->stream_exec->GetDeviceDescription().gpu_compute_capability(); + return GetDeviceDescription().gpu_compute_capability(); + } + + const se::DeviceDescription& GetDeviceDescription() const { + if (auto* device_config = std::get_if(&config_)) { + return device_config->stream_exec->GetDeviceDescription(); } - return std::get(config_) - .device_description.gpu_compute_capability(); + return std::get(config_).device_description; } bool IsDeviceless() const { diff --git a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc index 90437a5633f509..a9f45084fc2e39 100644 --- a/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc +++ b/third_party/xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc @@ -459,8 +459,7 @@ GpuConvAlgorithmPicker::AutotuneRuntimeArguments::FromInstruction( // Get canonical HLO. std::string canonical_hlo( - AutotuneCacheKey(config.GetExecutor()->GetDeviceDescription(), *instr) - .GetHlo()); + AutotuneCacheKey(config.GetDeviceDescription(), *instr).GetHlo()); TF_ASSIGN_OR_RETURN(GpuConvConfig gpu_conv_config, GetGpuConvConfig(instr)); diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc index 79524924584c97..e9b90dccb3af4d 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc @@ -380,7 +380,7 @@ absl::StatusOr> TritonGemmAutotuneExtractor( // If the priority fusion pass above skipped some instructions, turn them // into fusions. - FusionWrapper fusion_wrapper; + FusionWrapper fusion_wrapper(gpu_device_info); TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); } return new_module; @@ -528,7 +528,7 @@ absl::Status DumpAutotunedFusion(const AutotuneConfig& autotune_config, TritonGemmConfig::FromProto(result.triton())); } const se::DeviceDescription& device_desc = - autotune_config.GetExecutor()->GetDeviceDescription(); + autotune_config.GetDeviceDescription(); TF_ASSIGN_OR_RETURN( std::unique_ptr module, util.ExtractModule([&](const DebugOptions& debug_opts) { @@ -693,12 +693,12 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { // a sufficient number of thread block programs to occupy all available cores. // Around 5 full waves completely avoid the need for split-K. // n_tiles = split_k * (M * N) / (block_m * block_n) - const int kCoreCount = - !config_.IsDeviceless() - ? config_.GetExecutor()->GetDeviceDescription().core_count() - : 100; // some sensible default + const int kCoreCount = config_.GetDeviceDescription().core_count(); + CHECK_GE(kCoreCount, 1); const int64_t kSufficientNumberOfTiles = kMaxWavesForSplitK * kCoreCount; const int64_t result_size = ShapeUtil::ElementsIn(dot.shape()); + const int64_t threads_per_warp = + config_.GetDeviceDescription().threads_per_warp(); // Triton configurations are adjusted and deduplicated. absl::flat_hash_set added; @@ -735,7 +735,7 @@ GemmFusionAutotunerImpl::GenerateTritonConfigs(const HloDotInstruction& dot) { 2 * std::max(kMinTileSize, kLdmatrixGranularity / minBitWidth)); int meta_elements = config.block_m * config.block_k / 16; config.num_warps = - std::min(config.num_warps, meta_elements / WarpSize()); + std::min(config.num_warps, meta_elements / threads_per_warp); } if (added.insert(config).second) { @@ -783,11 +783,11 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, -> absl::StatusOr { std::unique_ptr executable; if (std::holds_alternative(config)) { - TF_ASSIGN_OR_RETURN( - executable, compile_util.Compile([&](const DebugOptions& opts) { + TF_ASSIGN_OR_RETURN(executable, + compile_util.Compile([&](const DebugOptions& opts) { return TritonGemmAutotuneExtractor( std::get(config), - config_.GetExecutor()->GetDeviceDescription(), fusion, opts, + config_.GetDeviceDescription(), fusion, opts, allow_filtering_kernels_spilling_registers); })); } else if (std::holds_alternative(config)) { @@ -802,7 +802,7 @@ GemmFusionAutotunerImpl::CompileAll(AutotunerCompileUtil& compile_util, TF_ASSIGN_OR_RETURN( executable, compile_util.Compile([&](const DebugOptions& opts) { return CublasGemmAutotuneExtractor( - config_, config_.GetExecutor()->GetDeviceDescription(), + config_, config_.GetDeviceDescription(), toolkit_version_, fusion, opts); })); } else { @@ -1005,6 +1005,8 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { bool tune_ctas = debug_options_.xla_gpu_enable_triton_hopper() && cc.IsAtLeastHopper(); + const int64_t threads_per_warp = + config_.GetDeviceDescription().threads_per_warp(); for (int num_stages : kNumStages) { // Volta doesn't support num_stages > 2. if (!cc.IsAtLeastAmpere() && num_stages > 2) { @@ -1017,7 +1019,7 @@ GemmFusionAutotunerImpl::GetExhaustiveTritonConfigs() const { const int tile_rhs = tile_k * tile_n; for (int num_warps : kNumWarps) { // Each thread should read at least one input element. - if (num_warps * WarpSize() > std::min(tile_lhs, tile_rhs)) { + if (num_warps * threads_per_warp > std::min(tile_lhs, tile_rhs)) { break; } for (int split_k : kSplitK) { diff --git a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc index d9bec3a09906a8..3f92bcf1b96d50 100644 --- a/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc +++ b/third_party/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc @@ -256,6 +256,8 @@ absl::StatusOr> GetPossibleMatmulAutotuneConfigs( auto ccc = deviceless_proto.mutable_cuda_compute_capability(); ccc->set_major(compute_capability.major); ccc->set_minor(compute_capability.minor); + deviceless_proto.set_core_count(100); + deviceless_proto.set_threads_per_warp(32); DevicelessConfig test_config{se::DeviceDescription{deviceless_proto}}; AutotuneConfig autotune_config{test_config, debug_options}; GemmFusionAutotunerImpl autotuner(autotune_config, toolkit_version, @@ -941,7 +943,9 @@ ENTRY wais { compute_capability, GetToolkitVersion(), debug_options)); for (const auto& config : configs) { int metadata_size = config.block_m * config.block_k / 16; - EXPECT_LE(config.num_warps * WarpSize(), metadata_size); + EXPECT_LE(config.num_warps * + WarpSize(backend().default_stream_executor()->GetDeviceDescription()), + metadata_size); EXPECT_GT(config.block_k, 16); // kMinTileSize } } diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc index 0ffb8e3fe63de9..9307c8ddbf7bc6 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.cc +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -42,7 +42,8 @@ namespace gpu { std::optional FusionCanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index) { + const ShapeIndex& user_index, + const se::DeviceDescription& device_description) { const HloFusionInstruction* fusion = DynCast(user); if (fusion == nullptr) { return std::nullopt; @@ -77,8 +78,6 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, // Allow multiple output users, if they end in reductions. // This only works for the reduction emitter, as it calculates the reduction // first, i.e. before processing other outputs (that may overwrite the input). - stream_executor::GpuDeviceInfoProto device_info; - stream_executor::DeviceDescription device_description(device_info); auto analysis = HloFusionAnalysis::Create(*user, device_description); bool is_reduction_emitter = analysis.GetEmitterFusionKind() == HloFusionAnalysis::EmitterFusionKind::kReduction; @@ -221,7 +220,8 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, std::optional CanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index) { + const ShapeIndex& user_index, + const se::DeviceDescription& device_description) { switch (user->opcode()) { case HloOpcode::kAllReduce: case HloOpcode::kCollectiveBroadcast: @@ -243,7 +243,7 @@ std::optional CanShareBufferHint(const HloInstruction* user, } return false; case HloOpcode::kFusion: - return FusionCanShareBufferHint(user, operand, user_index); + return FusionCanShareBufferHint(user, operand, user_index, device_description); default: return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.h b/third_party/xla/xla/service/gpu/buffer_sharing.h index 7fdf4af78c11c7..4beb8db5b08a1f 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.h +++ b/third_party/xla/xla/service/gpu/buffer_sharing.h @@ -20,16 +20,19 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { std::optional FusionCanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index); + const ShapeIndex& user_index, + const se::DeviceDescription& device_description); std::optional CanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index); + const ShapeIndex& user_index, + const se::DeviceDescription& device_description); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusion_pipeline.cc b/third_party/xla/xla/service/gpu/fusion_pipeline.cc index e27865c06c63d1..6100774e3766aa 100644 --- a/third_party/xla/xla/service/gpu/fusion_pipeline.cc +++ b/third_party/xla/xla/service/gpu/fusion_pipeline.cc @@ -89,7 +89,7 @@ HloPassPipeline FusionPipeline( HloPassPipeline HorizontalFusionPipeline( const se::DeviceDescription& gpu_device_info) { HloPassFix horizontal_fusion("horizontal fusion"); - horizontal_fusion.AddPass(); + horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(gpu_device_info); horizontal_fusion.AddPass(/*is_layout_sensitive=*/true, /*only_fusion_computations=*/true); diff --git a/third_party/xla/xla/service/gpu/fusions/BUILD b/third_party/xla/xla/service/gpu/fusions/BUILD index 290c451dfffb8b..d82b30a55ef4aa 100644 --- a/third_party/xla/xla/service/gpu/fusions/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/BUILD @@ -222,17 +222,17 @@ cc_library( "//xla/service/gpu:hlo_fusion_analysis", "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu/fusions/legacy:concatenate", - "//xla/service/gpu/fusions/legacy:in_place_dynamic_update_slice", - "//xla/service/gpu/fusions/legacy:input_slices", - "//xla/service/gpu/fusions/legacy:loop", - "//xla/service/gpu/fusions/legacy:reduction", - "//xla/service/gpu/fusions/legacy:scatter", - "//xla/service/gpu/fusions/legacy:transpose", - "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", + # "//xla/service/gpu/fusions/legacy:concatenate", + # "//xla/service/gpu/fusions/legacy:in_place_dynamic_update_slice", + # "//xla/service/gpu/fusions/legacy:input_slices", + # "//xla/service/gpu/fusions/legacy:loop", + # "//xla/service/gpu/fusions/legacy:reduction", + # "//xla/service/gpu/fusions/legacy:scatter", + # "//xla/service/gpu/fusions/legacy:transpose", + # "//xla/service/gpu/fusions/mlir:elemental_hlo_to_mlir", "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", + # "@com_google_absl//absl/log", + # "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", ], diff --git a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc index 432d600701d1ab..6babb17a57f6ad 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusion_emitter.cc @@ -91,11 +91,15 @@ absl::Status AnnotateKernelLaunchDimensions( const se::DeviceDescription& device_info, const LaunchDimensions& launch_dims, const std::string& kernel_name, llvm::Module* llvm_module) { - TF_RET_CHECK(device_info.block_dim_limit().x == 0 || - launch_dims.block_counts().x < device_info.block_dim_limit().x) + TF_RET_CHECK( + (device_info.block_dim_limit().x == 0 || + launch_dims.block_counts().x < device_info.block_dim_limit().x) && + (device_info.block_dim_limit().y == 0 || + launch_dims.block_counts().y < device_info.block_dim_limit().y)) << "Kernel '" << kernel_name << "' launch needs more blocks (" - << launch_dims.block_counts().x << ") than allowed by hardware (" - << device_info.block_dim_limit().x << ")."; + << launch_dims.block_counts().x << ", " << launch_dims.block_counts().y + << ") than allowed by hardware (" << device_info.block_dim_limit().x + << ", " << device_info.block_dim_limit().y << ")."; // Add __launch_bounds__ to metadata. This limits registers per thread to // avoid out-of-resources launching errors. diff --git a/third_party/xla/xla/service/gpu/fusions/fusions.cc b/third_party/xla/xla/service/gpu/fusions/fusions.cc index 200f06f8461db5..2414c7e3619cac 100644 --- a/third_party/xla/xla/service/gpu/fusions/fusions.cc +++ b/third_party/xla/xla/service/gpu/fusions/fusions.cc @@ -34,13 +34,13 @@ limitations under the License. #include "xla/service/gpu/fusions/fusion_emitter.h" #include "xla/service/gpu/fusions/in_place_dynamic_update_slice_mlir.h" #include "xla/service/gpu/fusions/input_slices_mlir.h" -#include "xla/service/gpu/fusions/legacy/concatenate.h" -#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" -#include "xla/service/gpu/fusions/legacy/input_slices.h" -#include "xla/service/gpu/fusions/legacy/loop.h" -#include "xla/service/gpu/fusions/legacy/reduction.h" -#include "xla/service/gpu/fusions/legacy/scatter.h" -#include "xla/service/gpu/fusions/legacy/transpose.h" +// #include "xla/service/gpu/fusions/legacy/concatenate.h" +// #include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" +// #include "xla/service/gpu/fusions/legacy/input_slices.h" +// #include "xla/service/gpu/fusions/legacy/loop.h" +// #include "xla/service/gpu/fusions/legacy/reduction.h" +// #include "xla/service/gpu/fusions/legacy/scatter.h" +// #include "xla/service/gpu/fusions/legacy/transpose.h" #include "xla/service/gpu/fusions/loop_mlir.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" #include "xla/service/gpu/fusions/reduction_mlir.h" @@ -108,14 +108,14 @@ std::unique_ptr GetFusionEmitter( const auto& analysis = fusion_info.analysis(); const FusionBackendConfig& backend_config = analysis.fusion_backend_config(); - const auto& opts = analysis.fusion_root(0) - .instruction() - .GetModule() - ->config() - .debug_options(); - auto check_mlir_emitters = [&](int64_t required_level) { - return opts.xla_gpu_mlir_emitter_level() >= required_level; - }; + // const auto& opts = analysis.fusion_root(0) + // .instruction() + // .GetModule() + // ->config() + // .debug_options(); + // auto check_mlir_emitters = [&](int64_t required_level) { + // return opts.xla_gpu_mlir_emitter_level() >= required_level; + // }; switch (analysis.GetEmitterFusionKind()) { case HloFusionAnalysis::EmitterFusionKind::kCustomFusion: { @@ -126,52 +126,52 @@ std::unique_ptr GetFusionEmitter( return std::make_unique(); } case HloFusionAnalysis::EmitterFusionKind::kInputSlices: - if (check_mlir_emitters(/*required_level=*/2)) { + // if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique(analysis); - } - return std::make_unique(analysis); + // } + // return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kLoop: { if (IsDynamicUpdateSliceFusion(analysis) && fusion_info.CanEmitDynamicUpdateSliceInPlace()) { - if (check_mlir_emitters(/*required_level=*/2)) { + // if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique( analysis); - } - return std::make_unique(analysis); + // } + // return std::make_unique(analysis); } if (auto copy_fusion = fusion_info.GetCopyFusion()) { return *std::move(copy_fusion); } - if (check_mlir_emitters(/*required_level=*/1)) { + // if (check_mlir_emitters(/*required_level=*/1)) { return std::make_unique(analysis); - } - return std::make_unique(analysis); + // } + // return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kReduction: - if (check_mlir_emitters(/*required_level=*/4)) { + //if (check_mlir_emitters(/*required_level=*/4)) { return CreateMlirReductionFusion(analysis); - } - return std::make_unique(analysis); + //} + // return std::make_unique(analysis); case HloFusionAnalysis::EmitterFusionKind::kScatter: { - if (check_mlir_emitters(/*required_level=*/2)) { + //if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique(analysis); - } - return std::make_unique(analysis); + //} + //return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kTranspose: { - if (check_mlir_emitters(/*required_level=*/3)) { + // if (check_mlir_emitters(/*required_level=*/3)) { return std::make_unique(analysis); - } - return std::make_unique(analysis.device_info(), - analysis); + // } + // return std::make_unique(analysis.device_info(), + // analysis); } case HloFusionAnalysis::EmitterFusionKind::kConcatenate: { - if (check_mlir_emitters(/*required_level=*/2)) { + //if (check_mlir_emitters(/*required_level=*/2)) { return std::make_unique(analysis); - } - return std::make_unique(analysis); + //} + // return std::make_unique(analysis); } case HloFusionAnalysis::EmitterFusionKind::kTriton: return std::make_unique(analysis); diff --git a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir index 572202bf148ce2..3b1174df5ec31d 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir +++ b/third_party/xla/xla/service/gpu/fusions/ir/tests/ops.mlir @@ -231,38 +231,6 @@ func.func @reduce_middle_dim(%in: tensor<16x8x4xf32>, %init: f32) // CHECK-SAME: combiner=@add // CHECK-SAME: : tensor<16x8x4xf32> to tensor<16x4xf32> -// ----- - -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> -func.func @reindex(%in0: tensor<1024xf32>) -> tensor<16x64xf32> { - %0 = xla_gpu.reindex %in0 at #map : tensor<1024xf32> -> tensor<16x64xf32> - func.return %0 : tensor<16x64xf32> -} - -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) -// CHECK-LABEL: func.func @reindex( -// CHECK-SAME: %[[IN1:.*]]: tensor<1024xf32> -// CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] : -// CHECK-SAME: tensor<1024xf32> -> tensor<16x64xf32> - -// ----- - -#map = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1), domain: d0 in [0, 15], d1 in [0, 63], is_simplified: false> -func.func @reindex_pad(%in0: tensor<1022xf32>) -> tensor<16x64xf32> { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reindex %in0 at #map default %c0 - : tensor<1022xf32> -> tensor<16x64xf32> - func.return %0 : tensor<16x64xf32> -} - -// CHECK: #[[$MAP:.*]] = #xla_gpu.indexing_map<(d0, d1) -> (d0 * 64 + d1) -// CHECK-LABEL: func.func @reindex_pad( -// CHECK-SAME: %[[IN1:.*]]: tensor<1022xf32> -// CHECK: %[[C0:.*]] = arith.constant 0.00 -// CHECK: xla_gpu.reindex %[[IN1]] at #[[$MAP]] default %[[C0]] : -// CHECK-SAME: tensor<1022xf32> -> tensor<16x64xf32> - - // ----- func.func @do_nothing(%a: f32, %b: i32, %c: f32, %d: i32) -> (f32, i32) { diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td index 44e8dd4353a5b6..af865f897979c8 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_attrs.td @@ -68,6 +68,7 @@ def XLAGPU_IndexingMapAttr : XLAGPU_Attr<"IndexingMap"> { }]; } +/* def XLAGPU_LaunchGridAttr : XLAGPU_Attr<"LaunchGrid"> { let summary = "An attribute representing a launch grid."; let mnemonic = "launch_grid"; @@ -81,6 +82,7 @@ def XLAGPU_LaunchGridAttr : XLAGPU_Attr<"LaunchGrid"> { `>` }]; } +*/ //===----------------------------------------------------------------------===// // Tensor layout attribute diff --git a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc index 967df31ba84397..9907b269792719 100644 --- a/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc +++ b/third_party/xla/xla/service/gpu/fusions/ir/xla_gpu_ops.cc @@ -1057,17 +1057,17 @@ LogicalResult InsertOp::verify() { return success(); } -//===----------------------------------------------------------------------===// -// ReindexOp -//===----------------------------------------------------------------------===// - -void ReindexOp::build(mlir::OpBuilder& builder, mlir::OperationState& result, - mlir::Type type, mlir::Value operand, mlir::Value padding, - const xla::gpu::IndexingMap& indexing_map) { - IndexingMapAttr indexing_map_attr = - IndexingMapAttr::get(builder.getContext(), indexing_map); - build(builder, result, type, operand, padding, indexing_map_attr); -} +// //===----------------------------------------------------------------------===// +// // ReindexOp +// //===----------------------------------------------------------------------===// + +// void ReindexOp::build(mlir::OpBuilder& builder, mlir::OperationState& result, +// mlir::Type type, mlir::Value operand, mlir::Value padding, +// const xla::gpu::IndexingMap& indexing_map) { +// IndexingMapAttr indexing_map_attr = +// IndexingMapAttr::get(builder.getContext(), indexing_map); +// build(builder, result, type, operand, padding, indexing_map_attr); +// } //===----------------------------------------------------------------------===// // ReduceOp diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD b/third_party/xla/xla/service/gpu/fusions/legacy/BUILD deleted file mode 100644 index 98d8ade7c5e5c3..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/BUILD +++ /dev/null @@ -1,406 +0,0 @@ -load("//xla:xla.bzl", "xla_cc_test") - -package( - # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = ["//xla/service/gpu/fusions:__pkg__"], - licenses = ["notice"], -) - -cc_library( - name = "in_place_dynamic_update_slice", - srcs = ["in_place_dynamic_update_slice.cc"], - hdrs = ["in_place_dynamic_update_slice.h"], - deps = [ - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:dynamic_update_slice_util", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - ], -) - -xla_cc_test( - name = "in_place_dynamic_update_slice_test", - srcs = ["in_place_dynamic_update_slice_test.cc"], - deps = [ - ":in_place_dynamic_update_slice", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "loop", - srcs = ["loop.cc"], - hdrs = ["loop.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_fusible", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/numeric:bits", - "@com_google_absl//absl/status", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:macros", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "loop_test", - srcs = ["loop_test.cc"], - deps = [ - "//xla:status_macros", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "scatter", - srcs = ["scatter.cc"], - hdrs = ["scatter.h"], - deps = [ - ":loop", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_fusible", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "scatter_test", - srcs = ["scatter_test.cc"], - deps = [ - ":scatter", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "tiling_util", - srcs = ["tiling_util.cc"], - hdrs = ["tiling_util.h"], - visibility = ["//xla/service/gpu:__subpackages__"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:target_util", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "//xla/service/llvm_ir:llvm_util", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "reduction", - srcs = ["reduction.cc"], - hdrs = ["reduction.h"], - deps = [ - ":tiling_util", - "//xla:shape_util", - "//xla:status_macros", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:hlo_traversal", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:kernel_arguments", - "//xla/service/gpu:kernel_reuse_cache", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu:reduction_utils", - "//xla/service/gpu:target_util", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/fusions:reduction_base", - "//xla/service/gpu/fusions:thunk_util", - "//xla/service/gpu/runtime:kernel_thunk", - "//xla/service/gpu/runtime:thunk", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "//xla/service/llvm_ir:llvm_util", - "//xla/service/llvm_ir:loop_emitter", - "//xla/stream_executor:device_description", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/container:node_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:Support", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "reduction_test", - srcs = ["reduction_test.cc"], - deps = [ - ":reduction", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "concatenate", - srcs = ["concatenate.cc"], - hdrs = ["concatenate.h"], - deps = [ - "//xla:shape_util", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:loop_emitter", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/status", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:errors", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "concatenate_test", - srcs = ["concatenate_test.cc"], - deps = [ - ":concatenate", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - ], -) - -cc_library( - name = "transpose", - srcs = ["transpose.cc"], - hdrs = ["transpose.h"], - deps = [ - ":tiling_util", - "//xla:permutation_util", - "//xla:shape_util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:target_util", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:llvm_util", - "//xla/service/llvm_ir:loop_emitter", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:inlined_vector", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "transpose_test", - srcs = ["transpose_test.cc"], - deps = [ - ":transpose", - "//xla:status_macros", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_absl//absl/status:statusor", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -cc_library( - name = "input_slices", - srcs = ["input_slices.cc"], - hdrs = ["input_slices.h"], - deps = [ - "//xla:shape_util", - "//xla:util", - "//xla/hlo/ir:hlo", - "//xla/service:elemental_ir_emitter", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu:ir_emission_utils", - "//xla/service/gpu:ir_emitter", - "//xla/service/gpu:ir_emitter_context", - "//xla/service/gpu:launch_dimensions", - "//xla/service/gpu:parallel_loop_emitter", - "//xla/service/gpu/fusions:fusion_emitter", - "//xla/service/gpu/model:indexing_analysis", - "//xla/service/llvm_ir:fused_ir_emitter", - "//xla/service/llvm_ir:ir_array", - "//xla/service/llvm_ir:kernel_support_library", - "//xla/service/llvm_ir:llvm_loop", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:span", - "@llvm-project//llvm:ir_headers", - "@llvm-project//mlir:IR", - "@local_tsl//tsl/platform:statusor", - ], -) - -xla_cc_test( - name = "input_slices_test", - srcs = ["input_slices_test.cc"], - deps = [ - ":input_slices", - "//xla/service/gpu:gpu_device_info_for_tests", - "//xla/service/gpu:hlo_fusion_analysis", - "//xla/service/gpu/fusions", - "//xla/service/gpu/model:affine_map_printer", - "//xla/service/gpu/model:indexing_test_utils", - "//xla/stream_executor:device_description", - "//xla/tests:hlo_test_base", - "//xla/tests:xla_internal_test_main", - "@com_google_googletest//:gtest", - "@llvm-project//mlir:IR", - ], -) diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/README.md b/third_party/xla/xla/service/gpu/fusions/legacy/README.md deleted file mode 100644 index 0fa6bb98f73147..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Deprecated emitters - -The emitters in this directory are deprecated. Please do not add any new -features. If you believe you need to add a feature, please reach out and -describe your use case. - -These emitters have more modern MLIR-based equivalents in the directory above -this one. \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc deleted file mode 100644 index 8bb0e04cc7f337..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/concatenate.h" - -#include -#include -#include - -#include "absl/algorithm/container.h" -#include "absl/status/status.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/loop_emitter.h" -#include "xla/shape.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis) { - const HloInstruction& concat = analysis.fusion_hero(0).instruction(); - int64_t dim = concat.concatenate_dimension(); - auto less = [&](const HloInstruction* lhs, const HloInstruction* rhs) { - return lhs->shape().dimensions(dim) < rhs->shape().dimensions(dim); - }; - HloInstruction* operand = *absl::c_max_element(concat.operands(), less); - return operand->shape(); -} - -ConcatenateFusion::ConcatenateFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis) {} - -std::optional ConcatenateFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - return std::nullopt; -} - -std::optional ConcatenateFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - return GetDefaultThreadIdIndexingMap(launch_dimensions(), /*unroll_factor=*/1, - GetLargestConcatOperandShape(analysis_), - ctx); -} - -absl::Status ConcatenateFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, std::vector inputs, - std::vector outputs, llvm::IRBuilder<>* builder) const { - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fusion.fused_parameters().size(); i++) { - fused_emitter.BindGenerator( - *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - - llvm::Type* index_type = - GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - - const HloInstruction& concat = analysis_.fusion_hero(0).instruction(); - int64_t concat_dim = concat.concatenate_dimension(); - int64_t operand_offset = 0; - - // Emit the slices that correspond to the operands of the concat hero. - for (const HloInstruction* operand : concat.operands()) { - llvm_ir::BodyEmitter body_emitter = - [&](const llvm_ir::IrArray::Index& operand_index) -> absl::Status { - // Bind concat to generate the current operand. - TF_ASSIGN_OR_RETURN(auto operand_generator, - fused_emitter.GetGenerator(*operand)); - fused_emitter.BindGenerator(concat, [&](llvm_ir::IrArray::Index) { - return operand_generator(operand_index); - }); - - // Create the index of the slice corresponding to the current operand. - llvm_ir::IrArray::Index result_index = operand_index.AddOffsetToDim( - llvm::ConstantInt::get(index_type, operand_offset), concat_dim, - builder); - operand_offset += operand->shape().dimensions(concat_dim); - - // Generate and write out the slice for each root. - for (const auto& [output, root] : - llvm::zip_equal(outputs, analysis_.fusion_roots())) { - llvm_ir::IrArray::Index root_index = result_index.SourceIndexOfBitcast( - concat.shape(), root.shape(), builder); - TF_ASSIGN_OR_RETURN(auto generator, - fused_emitter.GetGenerator(root.instruction())); - TF_ASSIGN_OR_RETURN(llvm::Value * value, generator(root_index)); - output.EmitWriteArrayElement(root_index, value, builder); - } - return absl::OkStatus(); - }; - - ParallelLoopEmitter emitter(body_emitter, operand->shape(), launch_dims, - builder); - TF_RETURN_IF_ERROR(emitter.EmitLoop(fusion.name(), index_type)); - } - - return absl::OkStatus(); -} - -LaunchDimensions ConcatenateFusion::launch_dimensions() const { - return CalculateLaunchDimensions(GetLargestConcatOperandShape(analysis_), - analysis_.device_info()); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h deleted file mode 100644 index be0465b421e916..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ - -#include -#include - -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/shape.h" - -namespace xla { -namespace gpu { - -const Shape& GetLargestConcatOperandShape(const HloFusionAnalysis& analysis); - -// Emits a kernel for the given hlo instruction where each thread produces -// one element of each concat operand. -class ConcatenateFusion : public KernelFusionEmitterBase { - public: - explicit ConcatenateFusion(const HloFusionAnalysis& analysis); - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_CONCATENATE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc deleted file mode 100644 index 9a9bdc2dd488b2..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/concatenate_test.cc +++ /dev/null @@ -1,121 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/concatenate.h" - -#include - -#include -#include -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla { -namespace gpu { -namespace { - -class ConcatenateTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - - protected: - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } - AffineMapPrinter printer_; - mlir::MLIRContext mlir_context_; -}; - -TEST_F(ConcatenateTest, ThreadIndexing) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fused_computation { - param0 = f32[200] parameter(0) - param1 = f32[400] parameter(1) - param2 = f32[300] parameter(2) - ROOT concat = f32[900] concatenate(param0, param1, param2), dimensions={0} - } - ENTRY main { - param0 = f32[200] parameter(0) - param1 = f32[400] parameter(1) - param2 = f32[300] parameter(2) - ROOT fusion = f32[900] fusion(param0, param1, param2), - calls=fused_computation, kind=kLoop - } - )") - .value(); - - stream_executor::DeviceDescription device_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - auto fusion = dynamic_cast(emitter.get()); - ASSERT_NE(fusion, nullptr); - - constexpr auto kIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> - (bl_x * 128 + th_x), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 3], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 399], - is_simplified: true - )"; - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kIndexing)); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc deleted file mode 100644 index 38a3e5b68d12f6..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" - -#include -#include -#include - -#include "absl/status/status.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/dynamic_update_slice_util.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" - -namespace xla { -namespace gpu { -namespace { - -constexpr int kDUSUpdateIndex = 1; - -} // namespace - -LaunchDimensions InPlaceDynamicUpdateSliceFusion::launch_dimensions() const { - const auto& update_shape = dus_ops_.front().GetOperand(1).shape(); - return CalculateLaunchDimensions(update_shape, analysis_.device_info()); -} - -std::optional -InPlaceDynamicUpdateSliceFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* mlir_context) const { - if (hero_operand_index != kDUSUpdateIndex) { - return std::nullopt; - } - auto launch_dims = launch_dimensions(); - // It is guaranteed that all DUS ops have the same output shape at this point. - const auto& update_shape = - dus_ops_.front().GetOperand(kDUSUpdateIndex).shape(); - return GetDefaultThreadIdIndexingMap(launch_dims, /*unroll_factor=*/1, - update_shape, mlir_context); -} - -absl::Status InPlaceDynamicUpdateSliceFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, std::vector inputs, - std::vector outputs, llvm::IRBuilder<>* builder) const { - // In case a dynamic slice update's output is bitcasted, we need to ensure we - // write to the output array using the shape and layout of the dynamic slice - // update. This cast is known to be safe to do iff, in the case the output of - // the dynamic slice update is bitcasted, that bitcast is either the fusion's - // output, or has a single user and is part of the fusion's tuple output. - // This condition should be enforced explicitly in the - // 'CanEmitFusedDynamicUpdateSliceInPlaceForGpu' matcher. - for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - output = output.CastToShape(op.shape(), builder); - } - - auto* fused_computation = fusion.fused_instructions_computation(); - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - FusedIrEmitter fused_emitter(elemental_emitter); - for (auto [index, input] : llvm::enumerate(inputs)) { - auto fused_operand = fused_computation->parameter_instruction(index); - fused_emitter.BindGenerator( - *fused_operand, [input = input, builder, - fused_operand](const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - std::vector> - dus_and_output_array; - dus_and_output_array.reserve(dus_ops_.size()); - - for (auto [op, output] : llvm::zip(dus_ops_, outputs)) { - dus_and_output_array.push_back(std::make_pair(&op.instruction(), output)); - } - - return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( - fused_computation, dus_and_output_array, &fused_emitter, launch_dims, - builder); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h deleted file mode 100644 index db12c3cbbf4643..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" - -namespace xla { -namespace gpu { - -// Fusion node where the root is either: -// 1. a dynamic-update-slice op -// 2. a bitcast of a dynamic-update-slice op -// 3. a tuple op returning the result of several dynamic-update-slice ops -// 4. a tuple op returning the result of several bitcast -// dynamic-update-slice ops -// -// Additionally, all the dynamic-update-slice ops have exactly one user. The -// fusion parameter that they update can have users (in addition to the -// dynamic-update-slice op) that read in either -// a. a dynamic-slice corresponding exactly to the slice of the parameter that -// is updated by the dynamic-update-slice op -// b. a dynamic-slice reading in a single element anywhere in the parameter. -// This is only allowed if the dynamic-update-slice op updates a single -// element -// -// In both cases, the additional users must not flow into any other output -// than the dynamic-slice-update corresponding to that particular slice of the -// parameter. -// -// The assumption is that each op's input (i.e. array to update) shares the -// same slice as its output. In this case, we have a special algorithm that -// modifies the output in place without touching the un-updated elements. The -// update slice is assumed to be the exact same for all the -// dynamic-update-slice ops. -class InPlaceDynamicUpdateSliceFusion : public KernelFusionEmitterBase { - public: - explicit InPlaceDynamicUpdateSliceFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis), - dus_ops_( - GetOutputDefiningDynamicUpdateSlices(analysis.fusion_roots())) {} - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { - // The mapping cannot be statically computed in general, since the offsets - // are unknown. - return std::nullopt; - } - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* mlir_context) const override; - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - const HloFusionAnalysis& analysis_; - std::vector dus_ops_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_IN_PLACE_DYNAMIC_UPDATE_SLICE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc deleted file mode 100644 index 53be6363567cdd..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice_test.cc +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/in_place_dynamic_update_slice.h" - -#include - -#include -#include -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -class InPlaceDynamicUpdateSliceFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - - protected: - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } - AffineMapPrinter printer_; - mlir::MLIRContext mlir_context_; - stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); -}; - -TEST_F(InPlaceDynamicUpdateSliceFusionTest, ThreadIndexing) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - fused_computation { - in = f32[20,30] parameter(0) - updates = f32[5,6] parameter(1) - i0 = s32[] parameter(2) - i1 = s32[] parameter(3) - ROOT updated = f32[20,30] dynamic-update-slice(in, updates, i0, i1) - } - ENTRY entry { - in = f32[20,30] parameter(0) - updates = f32[5,6] parameter(1) - i0 = s32[] constant(2) - i1 = s32[] constant(3) - ROOT fusion = f32[20,30] fusion(in, updates, i0, i1), kind=kLoop, calls=fused_computation - } - )")); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = HloFusionAnalysis::Create(*root, device_info_); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - auto fusion = dynamic_cast(emitter.get()); - ASSERT_NE(fusion, nullptr); - - auto thread_id_update_indexing = fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/1, &mlir_context_); - EXPECT_THAT(thread_id_update_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - th_x floordiv 6, th_x mod 6), - domain: - th_x in [0, 29], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 0], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true - )")); - auto thread_id_dst_indexing = fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_dst_indexing, ::testing::Eq(std::nullopt)); -} - -TEST_F(InPlaceDynamicUpdateSliceFusionTest, ProduceConsumerFusion) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule m - - fused_computation.1 { - param_0 = bf16[1,2,5,1,2] parameter(0) - bitcast = bf16[1,5,1,2,2] bitcast(param_0) - param_1 = bf16[1,1,1,2,2] parameter(1) - param_2 = s32[] parameter(2) - param_3 = s32[] parameter(3) - ROOT dynamic-update-slice = bf16[1,5,1,2,2] dynamic-update-slice(bitcast, param_1, param_2, param_3, param_2, param_2, param_2) - } - - ENTRY entry_computation { - param_0.2 = bf16[1,2,5,1,2] parameter(3) - param_1.2 = bf16[1,1,1,2,2] parameter(0) - param_2.2 = s32[] parameter(1) - param_3.2 = s32[] parameter(2) - fusion = bf16[1,5,1,2,2] fusion(param_0.2, param_1.2, param_2.2, param_3.2), kind=kLoop, calls=fused_computation.1 - ROOT bitcast.1 = bf16[1,2,5,1,2] bitcast(fusion) - } - )")); - - auto* root = module->entry_computation()->root_instruction(); - - auto analysis_fused = - HloFusionAnalysis::Create(*root->operand(0), *root, device_info_); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - - auto fusion = dynamic_cast(emitter.get()); - - ASSERT_NE(fusion, nullptr); - EXPECT_EQ(fusion->launch_dimensions().launch_bound(), 4 /* update size */); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc deleted file mode 100644 index d336f9226256a2..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.cc +++ /dev/null @@ -1,220 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/input_slices.h" - -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/elemental_ir_emitter.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/kernel_support_library.h" -#include "xla/service/llvm_ir/llvm_loop.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -// Emits code for slices based on the below structure. An if statement with -// a guarding condition is generated for each ROOT slice. -// -// Pseudo code: -// -// Compute values of slice input operands -// -// Compute guarding_cond0 -// if (guarding_cond0) { -// Write to output of slice0 -// } -// -// Compute guarding_cond1 -// if (guarding_cond1) { -// Write to output of slice1 -// } -// -absl::Status EmitElementForInputFusibleSlices( - ElementalIrEmitter& elemental_emitter, - const HloComputation* fused_computation, - const std::vector& inputs, - const std::vector& outputs, - const llvm_ir::IrArray::Index& index, llvm::IRBuilder<>* builder) { - VLOG(10) << "Emitting slice input fusion for " - << fused_computation->ToString(); - - HloInstruction* slice_or_tuple = fused_computation->root_instruction(); - auto slice_instructions = [&]() -> absl::Span { - if (slice_or_tuple->opcode() == HloOpcode::kSlice) { - return absl::Span(&slice_or_tuple, 1); - } - CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple); - return slice_or_tuple->operands(); - }(); - - // Emit input operand values of slices. - std::vector input_ir_values; - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [&inputs, i, builder](llvm_ir::IrArray::Index index) { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - for (const HloInstruction* slice : slice_instructions) { - auto input_generator = *fused_emitter.GetGenerator(*slice->operand(0)); - input_ir_values.push_back(input_generator(index).value()); - } - - // Emit for slice_instructions. - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - for (int64_t i = 0; i < slice_instructions.size(); ++i) { - HloInstruction* slice = slice_instructions[i]; - - // guarding_cond := index >= start && index < limit, for each dim. - std::vector index_within_ranges; - for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) { - CHECK_EQ(slice->slice_strides(dim), 1); - auto larger_or_equal_than_start = builder->CreateICmpSGE( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - llvm::Value* smaller_than_limit = builder->CreateICmpSLT( - index.multidim()[dim], - index.GetConstantWithIndexType(slice->slice_limits(dim))); - llvm::Value* within_range = - builder->CreateAnd(larger_or_equal_than_start, smaller_than_limit); - index_within_ranges.push_back(within_range); - } - llvm::Value* guarding_cond = builder->CreateAnd(index_within_ranges); - - auto emit_slice_elem_func = [&] { - const std::vector& src_multidim = index.multidim(); - std::vector dst_multidim(src_multidim.size()); - for (size_t dim = 0; dim < src_multidim.size(); ++dim) { - dst_multidim[dim] = builder->CreateSub( - src_multidim[dim], - index.GetConstantWithIndexType(slice->slice_starts(dim))); - } - const llvm_ir::IrArray& src_ir_array = outputs[i]; - llvm_ir::IrArray::Index slice_dst_index(dst_multidim, slice->shape(), - index.GetType()); - src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i], - builder); - }; - - ksl.If(absl::StrCat("slice", i), guarding_cond, emit_slice_elem_func); - } - return absl::OkStatus(); -} - -// Gets the input shape of the ROOT slices, which will be used as the kernel -// launch dims. The slice input fusion requires the input shapes of the ROOT -// slices to be the same although the (slice) output shapes can be different. -// -// Returns the input shape of the ROOT slices if all the input shapes of ROOT -// slices are the same and the slices are non-strided. Otherwise, returns -// FailedPrecondition. -absl::StatusOr GetConsistentInputShapeForRootSlices( - const HloComputation* fused_computation) { - const HloInstruction& root = *fused_computation->root_instruction(); - if (root.opcode() == HloOpcode::kSlice) { - return root.operands()[0]->shape(); - } - - CHECK_EQ(root.opcode(), HloOpcode::kTuple); - const Shape& first_slice_operand_shape = - root.operands()[0]->operands()[0]->shape(); - for (size_t i = 1; i < root.operands().size(); ++i) { - const HloInstruction* slice = root.operands()[i]; - const Shape& operand_shape = slice->operands()[0]->shape(); - if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape, - operand_shape)) { - return FailedPrecondition( - "Fused slices do not have the same input shape, fused computation = " - "%s.", - root.parent()->name()); - } - } - - return first_slice_operand_shape; -} - -} // namespace - -LaunchDimensions InputSlicesFusion::launch_dimensions() const { - const auto& root = analysis_.fusion_root(0).instruction(); - const auto& shape = root.operand(0)->shape(); - return CalculateLaunchDimensions(shape, analysis_.device_info(), - {unroll_factor_}); -} - -std::optional InputSlicesFusion::ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const { - // The mapping here is trivial and the same for all outputs - slice offsets - // are applied in the indexing from slice outputs to slice inputs. - auto launch_dims = launch_dimensions(); - // The implementation requires the shapes and layouts to be the same, but we - // still use the requested output's shape for clarity. - const auto& shape = analysis_.fusion_root(output_id).shape(); - return GetDefaultThreadIdIndexingMap(launch_dims, unroll_factor_, shape, ctx); -} - -absl::Status InputSlicesFusion::EmitKernel( - IrEmitterContext& ir_emitter_context, const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, std::vector inputs, - std::vector outputs, llvm::IRBuilder<>* builder) const { - TF_ASSIGN_OR_RETURN(Shape element_shape, - GetConsistentInputShapeForRootSlices( - fusion.fused_instructions_computation())); - LaunchDimensionsConfig launch_config; - launch_config.unroll_factor = unroll_factor_; - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - return ParallelLoopEmitter( - [&](const llvm_ir::IrArray::Index index) -> absl::Status { - return EmitElementForInputFusibleSlices( - elemental_emitter, fusion.fused_instructions_computation(), - inputs, outputs, index, builder); - }, - element_shape, launch_dims, builder, launch_config) - .EmitLoop( - fusion.name(), - GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder)); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h deleted file mode 100644 index e6532241123aee..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/util.h" - -namespace xla { -namespace gpu { - -// Generates code for input-fusible slices. -// -// Prerequisite: ROOT is either a slice or a tuple of slices. The input shapes -// of all ROOT slices need to be the same while their output shapes can be -// different. On the other hand, the input ranges of slices can be -// overlapping. Further generalization/specialization when the needs are seen -// in the future. -class InputSlicesFusion : public KernelFusionEmitterBase { - public: - explicit InputSlicesFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis), - unroll_factor_(CeilOfRatio( - 8, analysis.input_output_info().smallest_output_dtype_bits)) {} - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t output_id, mlir::MLIRContext* ctx) const override; - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { - // TODO(b/319081342): Implement this. - return std::nullopt; - } - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; - const int unroll_factor_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_INPUT_SLICES_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc deleted file mode 100644 index 0c604502bd51d1..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/input_slices_test.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/input_slices.h" - -#include - -#include -#include -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla { -namespace gpu { -namespace { - -class InputSlicesTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - - protected: - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } - AffineMapPrinter printer_; - mlir::MLIRContext mlir_context_; -}; - -TEST_F(InputSlicesTest, ThreadIndexing) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fused_computation { - %input = f32[2,3,5,7]{2,1,0,3} parameter(0) - slice0 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[1:3],[0:3],[2:7]} - slice1 = f32[1,2,3,5]{2,1,0,3} slice(input), slice={[0:1],[0:2],[0:3],[2:7]} - ROOT tuple = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) tuple(slice0, slice1) - } - - ENTRY entry { - %input = f32[2,3,5,7]{2,1,0,3} parameter(0) - ROOT %fusion = (f32[1,2,3,5]{2,1,0,3}, f32[1,2,3,5]{2,1,0,3}) fusion(%input), kind=kLoop, calls=fused_computation - })") - .value(); - - stream_executor::DeviceDescription device_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - auto fusion = dynamic_cast(emitter.get()); - ASSERT_NE(fusion, nullptr); - - auto thread_id_to_output_indexing = - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (0, - ((bl_x * 128 + th_x) floordiv 3) mod 2, - (bl_x * 128 + th_x) mod 3, - (bl_x * 128 + th_x) floordiv 6), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 1], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 29], - is_simplified: true - )")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc deleted file mode 100644 index e6ce5f113c713b..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop.cc +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/loop.h" - -#include -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/log/log.h" -#include "absl/numeric/bits.h" -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "tsl/platform/macros.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -const Shape& GetElementShape(const HloFusionAnalysis& analysis) { - const Shape* shape = &analysis.fusion_root(0).shape(); - while (shape->IsTuple()) { - shape = &shape->tuple_shapes(0); - } - return *shape; -} - -} // namespace - -LoopFusion::LoopFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) {} - -std::optional LoopFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - auto launch_dims = launch_dimensions(); - return GetDefaultThreadIdIndexingMap(launch_dims, config_.unroll_factor, - GetElementShape(analysis_), ctx); -} - -std::optional LoopFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - std::optional thread_id_to_output_indexing = - ComputeThreadIdToOutputIndexing(root_index, ctx); - if (!thread_id_to_output_indexing.has_value()) { - return std::nullopt; - } - const HloInstruction* fusion_root = - &analysis_.fusion_root(root_index).instruction(); - auto output_to_input_indexing = - ComputeOutputToInputIndexing(fusion_root, /*output_id=*/0, ctx); - IndexingMapSet output_to_input_indexing_set = - output_to_input_indexing.indexing_maps[hero_operand_index]; - // Since we are computing the indexing for a non-fusion op, there is only one - // indexing map per operand. - CHECK_EQ(output_to_input_indexing_set.size(), 1); - IndexingMap thread_id_to_input_indexing_map = ComposeIndexingMaps( - *thread_id_to_output_indexing, *output_to_input_indexing_set.begin()); - thread_id_to_input_indexing_map.Simplify(); - return thread_id_to_input_indexing_map; -} - -absl::Status LoopFusion::EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const { - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - FusedIrEmitter fused_emitter(elemental_emitter); - for (int i = 0; i < fusion.fused_parameters().size(); i++) { - fused_emitter.BindGenerator( - *fusion.fused_parameter(i), [&, i](llvm_ir::IrArray::Index index) { - return inputs[i].EmitReadArrayElement(index, builder); - }); - } - TF_ASSIGN_OR_RETURN( - auto element_generator, - fused_emitter.GetGenerator(*fusion.fused_expression_root())); - - llvm::Type* index_type = - GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - - return ParallelLoopEmitter(element_generator, outputs, launch_dims, builder, - config_) - .EmitLoop(fusion.name(), index_type); -} - -LaunchDimensions LoopFusion::launch_dimensions() const { - return CalculateLaunchDimensions(GetElementShape(analysis_), - analysis_.device_info(), config_); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop.h b/third_party/xla/xla/service/gpu/fusions/legacy/loop.h deleted file mode 100644 index 30e5007bec658f..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ - -#include -#include -#include - -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" - -namespace xla { -namespace gpu { - -// Generic loop fusion. -class LoopFusion : public KernelFusionEmitterBase { - public: - explicit LoopFusion(const HloFusionAnalysis& analysis); - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; - LaunchDimensionsConfig config_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_LOOP_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc deleted file mode 100644 index 82a9de34c7cc49..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/loop_test.cc +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include -#include -#include "absl/status/statusor.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -class LoopTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id"}); - } - - protected: - stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - AffineMapPrinter printer_; - mlir::MLIRContext mlir_context_; -}; - -absl::StatusOr> GetFusion( - const HloFusionAnalysis& analysis) { - auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); - auto fusion = dynamic_cast(emitter.get()); - TF_RET_CHECK(fusion != nullptr); - - emitter.release(); - return std::unique_ptr{fusion}; -} - -TEST_F(LoopTest, ThreadIndexingUnrolled) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[100,200,300] parameter(0) - ROOT neg = f32[100,200,300] negate(%input) - } - - ENTRY entry { - %input = f32[100,200,300] parameter(0) - ROOT %fusion = f32[100,200,300] fusion(%input), kind=kLoop, calls=neg - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); - auto thread_id_to_output_indexing = - loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); - - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + th_x) floordiv 15000, - ((bl_x * 128 + th_x) floordiv 75) mod 200, - ((bl_x * 128 + th_x) mod 75) * 4 + unroll_id - ), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 11718], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 3], - bl_x * 128 + th_x in [0, 1499999], - is_simplified: true -)")); -} - -TEST_F(LoopTest, ThreadIndexingNotUnrolled) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - neg { - %input = f32[20] parameter(0) - ROOT neg = f32[20] negate(%input) - } - - ENTRY entry { - %input = f32[20] parameter(0) - ROOT %fusion = f32[20] fusion(%input), kind=kLoop, calls=neg - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); - auto thread_id_to_output_indexing = - loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), - domain: - th_x in [0, 19], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 0], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true - )")); - auto thread_id_to_input_indexing = - loop_fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> (th_x), - domain: - th_x in [0, 19], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 0], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - is_simplified: true - )")); -} - -TEST_F(LoopTest, Broadcast) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - bcast { - %input = f32[20] parameter(0) - ROOT bcast = f32[10, 20, 30] broadcast(%input), dimensions={1} - } - - ENTRY entry { - %input = f32[20] parameter(0) - ROOT %fusion = f32[10, 20, 30] fusion(%input), kind=kLoop, calls=bcast - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto loop_fusion, GetFusion(analysis)); - auto thread_id_to_output_indexing = - loop_fusion->ComputeThreadIdToOutputIndexing(/*root_index=*/0, - &mlir_context_); - EXPECT_THAT(thread_id_to_output_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + th_x) floordiv 600, - ((bl_x * 128 + th_x) floordiv 30) mod 20, - (bl_x * 128 + th_x) mod 30), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 46], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true - )")); - auto thread_id_to_input_indexing = - loop_fusion->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/0, &mlir_context_); - EXPECT_THAT(thread_id_to_input_indexing->ToString(printer_), - MatchIndexingString(R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> - (((bl_x * 128 + th_x) floordiv 30) mod 20), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 46], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 5999], - is_simplified: true - )")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc deleted file mode 100644 index e009ea18e0b48c..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.cc +++ /dev/null @@ -1,1330 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/reduction.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/flat_hash_set.h" -#include "absl/container/inlined_vector.h" -#include "absl/container/node_hash_map.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/string_view.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Twine.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/AtomicOrdering.h" -#include "llvm/Support/Casting.h" -#include "mlir/Support/LLVM.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/layout_util.h" -#include "xla/service/buffer_assignment.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/legacy/tiling_util.h" -#include "xla/service/gpu/fusions/reduction_base.h" -#include "xla/service/gpu/fusions/thunk_util.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/hlo_traversal.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/ir_emitter_nested.h" -#include "xla/service/gpu/kernel_arguments.h" -#include "xla/service/gpu/kernel_reuse_cache.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/gpu/reduction_utils.h" -#include "xla/service/gpu/runtime/kernel_thunk.h" -#include "xla/service/gpu/runtime/thunk.h" -#include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/kernel_support_library.h" -#include "xla/service/llvm_ir/llvm_loop.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/llvm_ir/loop_emitter.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_description.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/errors.h" -#include "tsl/platform/logging.h" -#include "tsl/platform/status.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using TypedPointer = std::pair; - -// Fusion root -> array of indexes, one per reduction output. -using ReductionOutputMap = - ConstHloInstructionMap>; - -using ExtraOutputGensMap = ConstHloInstructionMap; - -int GetNumOutputs(const Shape& shape) { - if (shape.IsTuple()) { - return shape.tuple_shapes_size(); - } - return 1; -} - -const Shape& OutputShape(const Shape& output_shape, int output_index) { - CHECK(output_index == 0 || output_shape.IsTuple()); - return output_shape.IsTuple() ? output_shape.tuple_shapes(output_index) - : output_shape; -} - -llvm::Type* GetIndexType(const HloFusionInstruction& fusion, - const Tiling& tiling, llvm::IRBuilder<>* builder) { - return GetIndexTypeForKernel( - &fusion, tiling.GetNumThreadsPerBlock() * tiling.GetNumBlocks(), builder); -} - -llvm::Value* CastSharedToGlobal(llvm::IRBuilder<>* builder, llvm::Value* input, - llvm::Type* element_type, llvm::Twine name) { - return builder->CreateAddrSpaceCast( - input, - llvm::PointerType::get(element_type, - /*AddressSpace=*/0), - name); -} - -class ReductionEmitter { - public: - ReductionEmitter(const HloFusionAnalysis& analysis, - const ReductionInfo& reduction_codegen_info, - IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - llvm::IRBuilder<>* builder) - : builder_(builder), - elemental_emitter_(ir_emitter_context, builder_), - analysis_(analysis), - reduction_codegen_info_(reduction_codegen_info), - ir_emitter_context_(ir_emitter_context), - fusion_(fusion), - index_ty_(GetIndexType(fusion, reduction_codegen_info.GetTiling(), - elemental_emitter_.builder())) { - for (auto hero : analysis.fusion_heroes()) { - if (hero.opcode() == HloOpcode::kReduce) { - for (int i = 0; i < hero.instruction().operand_count() / 2; ++i) { - CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hero.instruction().operand(i)->shape().layout())) - << "reduction-layout-normalizer must run before code generation"; - } - } - } - } - - absl::StatusOr EmitInitializers(); - absl::Status EmitKernel(const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs); - - private: - friend class ReductionGroupEmitter; - - absl::StatusOr> BuildKernelThunkForFusion( - const LaunchDimensions& launch_dimensions, - absl::string_view discriminator, - std::function, - std::vector)> - kernel_builder_fn); - - absl::StatusOr> BuildFusedInitializerThunk( - const HloInstruction* fusion_root, BufferAllocation::Slice dest_slice, - int output_index); - - absl::Status EmitIRForReduction( - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, - const Shape& input_shape); - - void MaybeEmitFenceForAMDGPU(); - void EmitSyncThreads(); - - int ReducedDimensionSize() const { - return reduction_codegen_info_.GetTiling().GetShape()[2]; - } - - llvm::IRBuilder<>* builder_; - GpuElementalIrEmitter elemental_emitter_; - const HloFusionAnalysis& analysis_; - const ReductionInfo& reduction_codegen_info_; - IrEmitterContext& ir_emitter_context_; - const HloFusionInstruction& fusion_; - llvm::Type* index_ty_; -}; - -class ReductionEmitter; - -class ReductionGroupEmitter { - public: - struct ReductionCalculationState { - std::optional shared_cache; - llvm::Value* initial_value; - llvm::AllocaInst* partial_result_address; - llvm::AllocaInst* input_address; - llvm_ir::ElementGenerator input_gen; - }; - - ReductionGroupEmitter( - ReductionEmitter& reduction_emitter, - absl::Span reduce_instr_index_group, - const ReductionOutputMap& result_ir_arrays, - FusedIrEmitter& fused_emitter); - - const ReductionCalculationState& GetCalculationStateFor( - const HloInstruction* instruction, int operand_idx) const { - const ReductionOpState& op_state = state_.at(instruction); - CHECK_LT(operand_idx, op_state.size()); - return op_state[operand_idx]; - } - - void SetCalculationStateFor( - const ReductionCalculationState& calculation_state, - const HloInstruction* instruction, int operand_idx) { - ReductionOpState& op_state = state_[instruction]; - CHECK_EQ(operand_idx, op_state.size()); - op_state.push_back(calculation_state); - } - - void EmitReductionOutputForRowReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots) const; - - void EmitReductionOutputForColumnReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots) const; - - void EmitFullWarpShuffleDownLoopForReduce( - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp) const; - - void WriteReductionOutput(const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots, - absl::Span values) const; - - llvm_ir::IrArray::Index GetOutputIndexForReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int output_idx) const; - - void GenerateElementForReducer(const HloReduceInstruction* reduction, - const llvm_ir::IrArray::Index& index) const; - - absl::Status EmitExtraOutputsForReduce( - const Shape& reduction_operand_shape, - const llvm_ir::IrArray::Index& index, - const ExtraOutputGensMap& extra_output_gens); - - private: - ReductionEmitter& reduction_emitter_; - const ReductionOutputMap& result_ir_arrays_; - - // One state per reduction operand. - using ReductionOpState = absl::InlinedVector; - - // HloInstruction -> operand_idx -> cache - absl::flat_hash_map state_; -}; - -// Creates accumulator alloca's, populates them with initial values, generates -// __shared__ caches and returns the populated object. -ReductionGroupEmitter::ReductionGroupEmitter( - ReductionEmitter& reduction_emitter, - absl::Span reduce_instr_index_group, - const ReductionOutputMap& result_ir_arrays, FusedIrEmitter& fused_emitter) - : reduction_emitter_(reduction_emitter), - result_ir_arrays_(result_ir_arrays) { - const ReductionInfo& reduction_info = - reduction_emitter_.reduction_codegen_info_; - VLOG(10) << "Emit prologue for reduction: " - << reduction_emitter_.fusion_.ToString(); - - auto* builder = reduction_emitter_.builder_; - for (const HloReduceInstruction* reduce_hlo : reduce_instr_index_group) { - for (int op_result_idx = 0; - op_result_idx < GetNumOutputs(reduce_hlo->shape()); op_result_idx++) { - Shape result_shape = OutputShape(reduce_hlo->shape(), op_result_idx); - - llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType( - result_shape.element_type(), builder->GetInsertBlock()->getModule()); - llvm::AllocaInst* reduction_input_address = - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "reduction_input_address", builder); - - llvm::AllocaInst* result_address = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "partial_reduction_result", builder); - - const HloInstruction* init_value = - reduce_hlo->init_values()[op_result_idx]; - - // Initialize the partial result with the initial value of the reduction. - llvm::Value* init_ir_value = (*fused_emitter.GetGenerator( - *init_value))(llvm_ir::IrArray::Index(builder->getInt32Ty())) - .value(); - - builder->CreateStore(init_ir_value, result_address); - const Tiling& tiling = reduction_info.GetTiling(); - auto shared_cache = [&]() -> std::optional { - auto* module = reduction_emitter.ir_emitter_context_.llvm_module(); - if (reduction_info.IsRowReduction()) { - // Multi-row reductions do not use shared memory. - if (RowReductionGetRowsPerWarp( - reduction_emitter_.ReducedDimensionSize()) > 1) { - return std::nullopt; - } - // Allocate one shared memory element per warp. - auto block_size = tiling.GetThreadsPerBlock(); - CHECK_EQ(block_size[ReductionDimensions::kRowMinorReducedDimension] % - WarpSize(), - 0); - return llvm_ir::AllocateSharedMemoryTile( - module, element_type, - {block_size[ReductionDimensions::kRowKeptDimension], - block_size[ReductionDimensions::kRowMinorReducedDimension] / - WarpSize()}, - "shared_cache"); - } - const auto& num_threads = tiling.GetThreadsPerBlock(); - int n = num_threads[ReductionDimensions::kColReducedDimension]; - CHECK_EQ(n, num_threads[ReductionDimensions::kColMinorKeptDimension]); - // The "+1" is used to avoid bank conflicts. - return llvm_ir::AllocateSharedMemoryTile(module, element_type, - {n, n + 1}, "shared_cache"); - }(); - - llvm_ir::ElementGenerator input_gen = - *fused_emitter.GetGenerator(*reduce_hlo->inputs()[op_result_idx]); - SetCalculationStateFor({shared_cache, init_ir_value, result_address, - reduction_input_address, input_gen}, - reduce_hlo, op_result_idx); - } - } -} - -void ReductionEmitter::MaybeEmitFenceForAMDGPU() { - auto* module = builder_->GetInsertBlock()->getModule(); - if (IsAMDGPU(module) && - ir_emitter_context_.rocm_compute_capability().fence_before_barrier()) { - builder_->CreateFence( - llvm::AtomicOrdering::SequentiallyConsistent, - builder_->getContext().getOrInsertSyncScopeID("workgroup")); - } -} - -void ReductionEmitter::EmitSyncThreads() { - MaybeEmitFenceForAMDGPU(); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder_); -} - -// Builds a thunk that calls a new or reused kernel for a fusion operation. -// -// The caller must specify the same launch dimensions for fusions which have -// the same computation. -// -// If a given fusion is implemented using multiple kernels, then for each -// kernel we should provide a discriminator, such as "init" and "impl". -// -// The builder_fn is only invoked if the kernel couldn't be reused. -// -// This is the typical usage pattern of this method: -// -// ``` -// auto builder_fn = [](std::vector inputs, -// std::vector outputs) { ... }; -// TF_ASSIGN_OR_RETURN( -// auto thunk, -// BuildKernelThunkForFusion(..., launch_dimensions, builder_fn)); -// AddThunkToThunkSequence(std::move(thunk)) -// ``` -absl::StatusOr> -ReductionEmitter::BuildKernelThunkForFusion( - const LaunchDimensions& launch_dimensions, absl::string_view discriminator, - std::function, - std::vector)> - kernel_builder_fn) { - const HloComputation* fused_computation = - fusion_.fused_instructions_computation(); - std::string suggested_kernel_name = std::string(fusion_.name()); - - TF_ASSIGN_OR_RETURN(auto kernel_arguments, - KernelArguments::Create( - ir_emitter_context_.buffer_assignment(), &fusion_)); - - auto [status_or_entry, cached] = - ir_emitter_context_.kernel_cache().GetWithStatus( - fused_computation, kernel_arguments.args(), discriminator, - [&]() -> absl::StatusOr { - llvm::Function* kernel; - std::vector input_arrays; - std::vector output_arrays; - TF_ASSIGN_OR_RETURN( - std::tie(kernel, input_arrays, output_arrays), - BuildKernelPrototype(ir_emitter_context_, suggested_kernel_name, - kernel_arguments.args(), - fusion_.operand_count(), launch_dimensions, - builder_)); - TF_RETURN_IF_ERROR(kernel_builder_fn(input_arrays, output_arrays)); - // Shared memory is allocated statically. - return {{kernel->getName().str(), launch_dimensions, - /*cluster_dim=*/std::nullopt, - /*shmem_bytes=*/0}}; - }); - TF_ASSIGN_OR_RETURN(const KernelReuseCache::Entry* entry, status_or_entry); - if (cached) { - VLOG(3) << "Reuse: " << suggested_kernel_name << " -> " - << entry->kernel_name; - } - - return std::make_unique( - &fusion_, entry->kernel_name, kernel_arguments.args(), launch_dimensions, - entry->cluster_dim, entry->shmem_bytes); -} - -absl::Status ReductionGroupEmitter::EmitExtraOutputsForReduce( - const Shape& reduction_operand_shape, const llvm_ir::IrArray::Index& index, - const ExtraOutputGensMap& extra_output_gens) { - if (extra_output_gens.empty()) { - return absl::OkStatus(); - } - - auto* builder = reduction_emitter_.builder_; - // Compute all extra output values before writing them. This avoids - // overwriting aliased input/output buffers before all reads occurred. - std::vector> - extra_output_ir_values; - extra_output_ir_values.reserve(extra_output_gens.size()); - - auto get_index = [&](const HloInstruction* instr) { - const Shape& s = instr->shape(); - return ShapeUtil::EqualIgnoringElementType(reduction_operand_shape, s) - ? index - : index.SourceIndexOfBitcast(reduction_operand_shape, s, - builder); - }; - - for (const auto& [instr, generator] : extra_output_gens) { - TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value, - generator(get_index(instr))); - extra_output_ir_values.emplace_back(instr, extra_output_ir_value); - } - - for (const auto& [instr, generator] : extra_output_ir_values) { - absl::Span result_ir = result_ir_arrays_.at(instr); - CHECK_EQ(result_ir.size(), 1); - result_ir[0].EmitWriteArrayElement(get_index(instr), generator, builder); - } - return absl::OkStatus(); -} - -absl::StatusOr> -ReductionEmitter::BuildFusedInitializerThunk(const HloInstruction* fusion_root, - BufferAllocation::Slice dest_slice, - int output_index) { - const HloReduceInstruction* reduce = - DynCast(fusion_root); - TF_RET_CHECK(reduce); - - const HloInstruction* init_value = reduce->init_values()[0]; - TF_ASSIGN_OR_RETURN( - std::optional> constant_init_thunk, - BuildConstantInitializerThunk(ir_emitter_context_, fusion_root, - init_value, dest_slice)); - if (constant_init_thunk) { - return *std::move(constant_init_thunk); - } - - const Shape& dest_shape = fusion_root->shape(); - - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - dest_shape, ir_emitter_context_.gpu_device_info()); - const HloComputation* fused_computation = - fusion_.fused_instructions_computation(); - - auto builder_fn = [&](std::vector inputs, - std::vector outputs) -> absl::Status { - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - fused_emitter.BindGenerator( - *fused_computation->parameter_instruction(i), - [builder = builder_, - input = inputs[i]](llvm_ir::IrArray::Index index) { - return input.EmitReadArrayElement(index, builder); - }); - } - HloInstruction* instr = fused_computation->root_instruction(); - if (instr->opcode() == HloOpcode::kTuple) { - instr = instr->mutable_operand(output_index); - } else { - CHECK_EQ(0, output_index); - } - TF_RET_CHECK(instr->shape().IsArray()); - TF_ASSIGN_OR_RETURN(auto generator, - fused_emitter.GetGenerator(*instr->operand(1))); - TF_RETURN_IF_ERROR(ParallelLoopEmitter(generator, {outputs[output_index]}, - launch_dimensions, builder_) - .EmitLoop(fusion_.name())); - return absl::OkStatus(); - }; - - return BuildKernelThunkForFusion(launch_dimensions, - /*discriminator=*/ - absl::StrCat("init_", output_index), - builder_fn); -} - -// Emits shuffle-down reduction for the `partial_result_address` using the -// reduction computation `reducer`, writes output into -// `partial_result_address`. -// -// Multiple partial_result_address inputs happen when doing variadic -// reduction: each one should get the output value. -void ReductionGroupEmitter::EmitFullWarpShuffleDownLoopForReduce( - const HloComputation* reducer, - absl::Span partial_result_addresses, - int threads_per_block, int num_results_per_warp) const { - // This only works when the block size is a multiple of 32 threads. - // We check this here as a mistake in the number of threads per - // block is very hard to detect. - CHECK_EQ(threads_per_block % 32, 0); - CHECK_EQ(WarpSize() % num_results_per_warp, 0); - - auto* builder = reduction_emitter_.builder_; - for (int distance = 16 / num_results_per_warp; distance >= 1; distance /= 2) { - absl::InlinedVector reduction_params; - - for (auto acc : partial_result_addresses) { - reduction_params.push_back(acc.first); - } - - for (auto [partial_result_address, element_type] : - partial_result_addresses) { - int bit_width = llvm_ir::GetSizeInBits(element_type); - llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "result_from_other_lane", builder); - - reduction_params.push_back(result_from_other_lane); - - // Bitcast cannot be applied to aggregate types (even packed ones), so - // we bitcast addresses of load/store to intN* of the same bit-width. - llvm::Type* shuffled_value_type = element_type->isStructTy() - ? builder->getIntNTy(bit_width) - : element_type; - - llvm::Value* partial_result = - builder->CreateLoad(shuffled_value_type, partial_result_address, - "partial_reduction_result"); - builder->CreateStore( - EmitFullWarpShuffleDown( - partial_result, builder->getInt32(distance), builder, - reduction_emitter_.ir_emitter_context_.gpu_device_info()), - result_from_other_lane); - } - - absl::StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs( - builder, reduction_emitter_.ir_emitter_context_, *reducer, - reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(/*Val=*/returned_scalars->at(i), - /*Ptr=*/partial_result_addresses[i].first); - } - } -} - -llvm_ir::IrArray::Index ReductionGroupEmitter::GetOutputIndexForReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, const HloInstruction* root, - int output_idx) const { - auto* builder = reduction_emitter_.builder_; - auto* index_ty = reduction_emitter_.index_ty_; - - // 1d or 2d output index (for row/column reduction). - auto projected_index = [&]() -> llvm_ir::IrArray::Index { - const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const auto& offset = tiling_kernel_info.tile_origin; - const auto& shape = reduction_info.GetTiling().GetXlaShape(); - const auto& thread_ids = tiling_kernel_info.thread_id_info.thread_ids; - if (reduction_info.IsRowReduction()) { - constexpr int kDim = ReductionDimensions::kRowKeptDimension; - return {{builder->CreateAdd(offset[kDim], thread_ids[kDim])}, - {shape.dimensions(kDim)}, - index_ty}; - } - auto* major_idx = offset[ReductionDimensions::kColMajorKeptDimension]; - auto* minor_idx = builder->CreateAdd( - offset[ReductionDimensions::kColMinorKeptDimension], - thread_ids[ReductionDimensions::kColReducedDimension]); - return {{major_idx, minor_idx}, - ShapeUtil::DeleteDimension( - ReductionDimensions::kColReducedDimension, shape), - index_ty}; - }(); - - auto physical_shape = ShapeUtil::DeleteDimensions( - reduction->dimensions(), reduction->operand(output_idx)->shape()); - auto physical_index = - projected_index.SourceIndexOfBitcast(physical_shape, builder); - return llvm_ir::IrArray::Index(physical_index.multidim(), - OutputShape(reduction->shape(), output_idx), - index_ty) - .SourceIndexOfBitcast(OutputShape(root->shape(), output_idx), builder); -} - -void ReductionGroupEmitter::WriteReductionOutput( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots, - const absl::Span values) const { - auto* builder = reduction_emitter_.builder_; - const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const HloComputation* reducer = reduction->to_apply(); - for (const auto& [oidx, typed_ptr] : llvm::enumerate(values)) { - auto [output_ptr, type] = typed_ptr; - for (auto root : roots) { - llvm_ir::IrArray::Index output_index = - GetOutputIndexForReduction(tiling_kernel_info, reduction, root, oidx); - - llvm::Value* output_address = - result_ir_arrays_.at(root)[oidx].EmitArrayElementAddress( - output_index, builder, "output_element_address"); - if (reduction_info.IsRaceFree()) { - FusedIrEmitter fused_emitter(reduction_emitter_.elemental_emitter_); - llvm::Value* loaded = builder->CreateLoad(type, output_ptr, "output"); - fused_emitter.BindGenerator( - *reduction, - [&](const llvm_ir::IrArray::Index& index) { return loaded; }); - llvm_ir::ElementGenerator gen = *fused_emitter.GetGenerator(*root); - llvm::Value* generated = *gen(output_index); - builder->CreateStore(generated, output_address); - } else { - CHECK_EQ(values.size(), 1); - CHECK_EQ(roots.size(), 1); - CHECK_EQ(reduction, root) - << "output fusion is not allowed for racing reductions"; - TF_CHECK_OK(EmitAtomicOperationForNestedComputation( - builder, reduction_emitter_.ir_emitter_context_, *reducer, - output_address, output_ptr, type)); - } - } - } -} - -void ReductionGroupEmitter::EmitReductionOutputForRowReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots) const { - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - const auto& thread_ids = thread_id_info.thread_ids; - auto* thread_id_x = - thread_ids[ReductionDimensions::kRowMinorReducedDimension]; - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); - }; - - auto* builder = reduction_emitter_.builder_; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - - int num_outputs = reducer->num_parameters() / 2; - absl::InlinedVector current_outputs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const auto& state = GetCalculationStateFor(reduction, output_idx); - current_outputs.push_back( - {state.partial_result_address, - state.partial_result_address->getAllocatedType()}); - } - - const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const Tiling& tiling = reduction_info.GetTiling(); - int num_rows_per_warp = - RowReductionGetRowsPerWarp(reduction_emitter_.ReducedDimensionSize()); - EmitFullWarpShuffleDownLoopForReduce(reducer, absl::MakeSpan(current_outputs), - tiling.GetNumThreadsPerBlock(), - num_rows_per_warp); - - KernelSupportLibrary ksl(builder); - llvm::Value* warp_id = builder->CreateUDiv(thread_id_x, constant(WarpSize())); - - auto emit_write_output = [&](llvm::Value* write_condition, - const absl::Span values) { - ksl.If("reduction_write_output", write_condition, [&] { - WriteReductionOutput(tiling_kernel_info, reduction, roots, values); - }); - }; - - // The major kept dimension and vector dimension are not tiled, so they're - // always in bounds. - llvm::Value* is_in_bounds_y = builder->CreateICmpULT( - thread_ids[ReductionDimensions::kRowKeptDimension], - tiling_kernel_info - .output_tile_bounds[ReductionDimensions::kRowKeptDimension]); - - ksl.If("thread_in_bounds", is_in_bounds_y, [&] { - if (num_rows_per_warp > 1) { - llvm::Value* is_writing_thread = is_zero(builder->CreateAnd( - thread_id_x, - constant(reduction_emitter_.ReducedDimensionSize() - 1))); - emit_write_output(is_writing_thread, current_outputs); - return; - } - - ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] { - for (int oidx = 0; oidx < num_outputs; oidx++) { - auto& state = GetCalculationStateFor(reduction, oidx); - state.shared_cache->Store( - builder->CreateLoad(current_outputs[oidx].second, - current_outputs[oidx].first), - {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension], - warp_id}, - builder); - } - }); - - // TODO(cheshire): Don't we want to sync it once for everything in the - // output? Not once per each? - reduction_emitter_.EmitSyncThreads(); - ksl.If("inter_warp_reduce", is_zero(warp_id), [&] { - absl::InlinedVector selected_values; - for (int oidx = 0; oidx < num_outputs; oidx++) { - auto& state = GetCalculationStateFor(reduction, oidx); - llvm::Value* block_accum_addr = state.shared_cache->Address( - {thread_id_info.thread_ids[ReductionDimensions::kRowKeptDimension], - thread_id_info.lane_id}, - builder); - - llvm::Type* element_type = - state.partial_result_address->getAllocatedType(); - - // Ensure initial value address is in generic, not scratch. - llvm::Value* initial_value_addr = - CastSharedToGlobal(builder, - llvm_ir::EmitAllocaAtFunctionEntry( - element_type, "initial_value_addr", builder), - element_type, /*name=*/""); - builder->CreateStore(state.initial_value, initial_value_addr); - - llvm::Value* warp_exists = builder->CreateICmpULT( - thread_id_x, - constant(tiling.GetThreadsPerBlock() - [ReductionDimensions::kRowMinorReducedDimension] / - WarpSize())); - - llvm::Value* selected_value = builder->CreateSelect( - warp_exists, block_accum_addr, initial_value_addr); - - selected_values.push_back({selected_value, element_type}); - } - - // If only one warp produces the output element, we don't need to emit - // an inter warp reduce. In our tiling, DimX is the minor reduced - // dimension. The major reduced dimension is always emitted as a loop. - // TODO(b/241414088) If only warp is present, then inter-warp - // communication using shared memory and synchronization using barrier is - // also unnecessary and should be removed. - if (tiling.GetThreadsPerBlock() - [ReductionDimensions::kRowMinorReducedDimension] > WarpSize()) { - EmitFullWarpShuffleDownLoopForReduce( - reducer, absl::MakeSpan(selected_values), - tiling.GetNumThreadsPerBlock(), /*num_results_per_warp=*/1); - } - - emit_write_output(is_zero(thread_id_x), selected_values); - }); - }); -} - -// Same arguments as EmitReductionOutputForRowReduction. -void ReductionGroupEmitter::EmitReductionOutputForColumnReduction( - const TilingKernelInfo& tiling_kernel_info, - const HloReduceInstruction* reduction, - const std::vector& roots) const { - auto* builder = reduction_emitter_.builder_; - KernelSupportLibrary ksl(builder); - const HloComputation* reducer = reduction->to_apply(); - const auto& thread_id_info = tiling_kernel_info.thread_id_info; - const auto& thread_ids = thread_id_info.thread_ids; - - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(reduction_emitter_.index_ty_, c); - }; - auto is_zero = [&](llvm::Value* value) { - return builder->CreateICmpEQ(value, constant(0)); - }; - const auto& reduction_info = reduction_emitter_.reduction_codegen_info_; - const Tiling& tiling = reduction_info.GetTiling(); - int num_outputs = reducer->num_parameters() / 2; - - auto* kept_index = thread_ids[ReductionDimensions::kColMinorKeptDimension]; - auto* reduced_index = thread_ids[ReductionDimensions::kColReducedDimension]; - - // Store the transpose in shared memory. - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const auto& state = GetCalculationStateFor(reduction, output_idx); - auto* current_output_value = - builder->CreateLoad(state.partial_result_address->getAllocatedType(), - state.partial_result_address); - state.shared_cache->Store(current_output_value, {kept_index, reduced_index}, - builder); - } - - reduction_emitter_.EmitSyncThreads(); - - // Get transposed element from shared memory. - absl::InlinedVector shmem_transposed_addrs; - for (int output_idx = 0; output_idx < num_outputs; output_idx++) { - const auto& state = GetCalculationStateFor(reduction, output_idx); - auto* shmem_transposed_addr = - state.shared_cache->Address({reduced_index, kept_index}, builder); - shmem_transposed_addrs.push_back( - {shmem_transposed_addr, state.shared_cache->GetElementType()}); - } - - EmitFullWarpShuffleDownLoopForReduce(reducer, - absl::MakeSpan(shmem_transposed_addrs), - tiling.GetNumThreadsPerBlock(), - /*num_results_per_warp=*/1); - - // Some warps in the block are completely outside of the bound of the - // tensor, so they should not write any output at all. - llvm::Value* has_output = builder->CreateAnd( - builder->CreateICmpULT( - reduced_index, - tiling_kernel_info - .output_tile_bounds[ReductionDimensions::kColMinorKeptDimension]), - builder->CreateICmpULT( - kept_index, - tiling_kernel_info - .output_tile_bounds[ReductionDimensions::kColReducedDimension])); - - ksl.If("reduction_write_output", - builder->CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] { - WriteReductionOutput(tiling_kernel_info, reduction, roots, - shmem_transposed_addrs); - }); -} - -// Generate a single element of the tile (update the accumulator state) for a -// given reducer. -void ReductionGroupEmitter::GenerateElementForReducer( - const HloReduceInstruction* reduction, - const llvm_ir::IrArray::Index& index) const { - HloComputation* reducer = reduction->to_apply(); - auto* builder = reduction_emitter_.builder_; - CHECK_EQ(reducer->num_parameters() % 2, 0); - - absl::InlinedVector reduction_accumulators; - absl::InlinedVector reduction_input_value; - for (int red_idx = 0; red_idx < reducer->num_parameters() / 2; red_idx++) { - const auto& state = GetCalculationStateFor(reduction, red_idx); - - llvm::AllocaInst* input_address = state.input_address; - auto input_index = - index.SourceIndexOfBitcast(reduction->operand(0)->shape(), builder); - llvm::Value* const input_ir_value = *state.input_gen(input_index); - builder->CreateStore(input_ir_value, input_address); - reduction_accumulators.push_back(state.partial_result_address); - reduction_input_value.push_back(input_address); - } - - absl::InlinedVector reduction_params; - for (llvm::Value* acc : reduction_accumulators) { - reduction_params.push_back(acc); - } - for (llvm::Value* value : reduction_input_value) { - reduction_params.push_back(value); - } - - // Emit a call to the variadic reducer. Since it may be returning a - // tuple, we can't return it directly as a value. Instead, before - // the call, we create N (N = # arguments in the tuple) allocas, one - // for each returned argument, then when we make the call we pass N - // pointers as last parameters, the called computation writes into - // those pointers, and we have returned values on the stack (as well - // as pointers to them). - absl::StatusOr> returned_scalars = - CallNestedComputationWithScalarAddrs( - builder, reduction_emitter_.ir_emitter_context_, *reducer, - reduction_params); - TF_CHECK_OK(returned_scalars.status()); - - for (int i = 0; i < returned_scalars->size(); i++) { - builder->CreateStore(returned_scalars->at(i), reduction_accumulators[i]); - } -} - -// Emits code for reductions in the output_instructions. -absl::Status ReductionEmitter::EmitIRForReduction( - absl::Span instr_index_group, - FusedIrEmitter& fused_emitter, const ReductionOutputMap& result_ir_arrays, - const Shape& input_shape) { - ExtraOutputGensMap extra_output_gens; - absl::flat_hash_map> - heroes_to_roots; - // Keep a list of deduplicated heroes separate from heroes_to_roots to make - // the CodeGen deterministic. - std::vector heroes; - - for (const HloInstruction* hlo : instr_index_group) { - auto& hero = FindNonTrivialHero(*hlo); - if (IsRealReductionHero(*hlo, hero)) { - auto reduction = Cast(&hero); - if (heroes_to_roots.find(reduction) == heroes_to_roots.end()) { - heroes.push_back(reduction); - } - heroes_to_roots[reduction].push_back(hlo); - } else { - extra_output_gens[hlo] = *fused_emitter.GetGenerator(*hlo); - } - } - - CHECK(!heroes.empty()) << " expect at least one reduce instructions."; - const Tiling& tiling = reduction_codegen_info_.GetTiling(); - CHECK_EQ(tiling.GetNumThreadsPerBlock() % WarpSize(), 0); - ReductionGroupEmitter group_emitter(*this, heroes, result_ir_arrays, - fused_emitter); - - TF_ASSIGN_OR_RETURN( - TilingKernelInfo tiling_kernel_info, - EmitTilingKernel( - builder_, tiling, index_ty_, - [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& tile_index, - absl::Span tile_dimensions) { - auto emit_element = - [&](absl::Span index_in_tile) { - auto index = tile_index.AddOffset(index_in_tile, builder_); - - // Emit code to generate the input and perform the reduction - // computation for each reduction instruction. - for (const HloReduceInstruction* reduce : heroes) { - group_emitter.GenerateElementForReducer(reduce, index); - } - - // Emit code to generate the output for the non-reduction - // instructions in the fusion, if any. - TF_CHECK_OK(group_emitter.EmitExtraOutputsForReduce( - ShapeUtil::MakeShape( - F32, reduction_codegen_info_.GetTiling().GetShape()), - index, extra_output_gens)); - }; - EmitTile(builder_, reduction_codegen_info_.GetTiling(), - thread_id_info, tile_dimensions, emit_element); - })); - - KernelSupportLibrary ksl(builder_); - for (auto reduce : heroes) { - if (reduction_codegen_info_.IsRowReduction()) { - group_emitter.EmitReductionOutputForRowReduction( - tiling_kernel_info, reduce, heroes_to_roots[reduce]); - } else { - group_emitter.EmitReductionOutputForColumnReduction( - tiling_kernel_info, reduce, heroes_to_roots[reduce]); - } - } - - return absl::OkStatus(); -} - -absl::StatusOr ReductionEmitter::EmitInitializers() { - FusionEmissionResult result; - if (reduction_codegen_info_.IsRaceFree()) { - return result; - } - // We need to get the dest slice by traversing the slice assigned to - // fusion, because instructions inside fusion don't have buffer assignment. - // - // The order of fusion roots is determined by its position in the result - // tuple. For example, in the following fused computation - // - // %fused_computation { - // %a = ... - // &b = ... - // ROOT %root = tuple(%a, %b) - // } - // - // The fusion root with index = 0 is %a, and the fusion root %b has index 1. - // Therefore we can get the ordered slices by calling ForEachSubshape on the - // result shape. - std::vector slices; - TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - fusion_.shape(), [&](const Shape& subshape, ShapeIndex index) { - if (!ShapeUtil::IsLeafIndex(fusion_.shape(), index)) { - return absl::OkStatus(); - } - - TF_ASSIGN_OR_RETURN( - BufferAllocation::Slice slice, - ir_emitter_context_.buffer_assignment().GetUniqueSlice(&fusion_, - index)); - slices.push_back(slice); - return absl::OkStatus(); - })); - - absl::Span fusion_roots = - analysis_.fusion_roots(); - for (int i = 0; i < fusion_roots.size(); ++i) { - const HloInstruction* fusion_root = &fusion_roots[i].instruction(); - - if (IsReductionFromOrToContiguousDimensions(*fusion_root)) { - TF_ASSIGN_OR_RETURN( - result.thunks.emplace_back(), - BuildFusedInitializerThunk(fusion_root, slices[i], i)); - } - } - return result; -} - -absl::Status ReductionEmitter::EmitKernel( - const LaunchDimensions& launch_dims, std::vector inputs, - std::vector outputs) { - const HloComputation* fused_computation = - fusion_.fused_instructions_computation(); - FusedIrEmitter fused_emitter(elemental_emitter_); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - HloInstruction* fused_operand = fused_computation->parameter_instruction(i); - fused_emitter.BindGenerator( - *fused_operand, [builder = builder_, input = inputs[i], - fused_operand](const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - // Get outputs. - ReductionOutputMap result_ir_arrays; - - int ir_arrays_idx = 0; - for (const HloInstructionAdaptor& root : analysis_.fusion_roots()) { - int get_num_results = GetNumOutputs(root.shape()); - result_ir_arrays[&root.instruction()] = - absl::MakeSpan(outputs).subspan(ir_arrays_idx, get_num_results); - ir_arrays_idx += get_num_results; - } - - KernelSupportLibrary ksl(builder_, llvm_ir::UnrollMode::kDefaultUnroll); - - // Use raw block_id_y to select the i-th parallel reduction to run. Using - // block_id_y instead of block_id_x simplifies the index calculation - // for reduction code generation as the block_id_y is orthogonal to - // the indices used within the reductions. - const auto& instr_index_groups = - reduction_codegen_info_.GetGroups().grouped_roots; - Shape reduce_operand_shape = reduction_codegen_info_.GetReduceOperandShape(); - - llvm::Value* block_id_y = gpu::EmitCallToTargetIntrinsic( - gpu::TargetIntrinsicID::kBlockIdy, {}, {}, builder_); - llvm_ir::AddRangeMetadata(0, instr_index_groups.size(), - llvm::cast(block_id_y), - builder_->GetInsertBlock()->getModule()); - block_id_y = builder_->CreateZExtOrTrunc(block_id_y, builder_->getInt32Ty()); - block_id_y->setName("block.id.y"); - for (int i = 0; i < instr_index_groups.size(); ++i) { - TF_RETURN_IF_ERROR(ksl.IfWithStatus( - absl::StrCat("reduce-group-", i), - builder_->CreateICmpEQ(block_id_y, builder_->getInt32(i)), [&] { - return EmitIRForReduction(instr_index_groups[i], fused_emitter, - result_ir_arrays, reduce_operand_shape); - })); - } - - return absl::OkStatus(); -} - -} // namespace - -absl::StatusOr ReductionFusion::EmitInitializers( - IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion) const { - llvm::IRBuilder<> builder(ir_emitter_context.llvm_module()->getContext()); - return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context, - fusion, &builder) - .EmitInitializers(); -} - -absl::Status ReductionFusion::EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const { - return ReductionEmitter(analysis_, reduction_info_, ir_emitter_context, - fusion, builder) - .EmitKernel(launch_dims, inputs, outputs); -} - -int ReductionInfo::GetRowsPerWarp() const { - if (!is_row_reduction_) return 1; - return RowReductionGetRowsPerWarp( - tiling_.GetShape()[ReductionDimensions::kRowMinorReducedDimension]); -} - -LaunchDimensions ReductionInfo::launch_dimensions() const { - size_t blocks_y = groups_.grouped_roots.size(); - return {se::BlockDim(/*x=*/tiling_.GetNumBlocks(), - /*y=*/static_cast(blocks_y), /*z=*/1), - se::ThreadDim(/*x=*/tiling_.GetNumThreadsPerBlock(), - /*y=*/1, /*z=*/1)}; -} - -ReductionInfo ReductionInfo::Create(const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - Shape input_shape = hero_reduction->operand(0)->shape(); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - auto shape = reduction_dimensions.dimensions; - VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction - << " " << shape[0] << " " << shape[1] << " " << shape[2]; - Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); - - int64_t num_threads_y = - reduction_dimensions.is_row_reduction ? 1 : WarpSize(); - int64_t rows_per_warp = - reduction_dimensions.is_row_reduction - ? RowReductionGetRowsPerWarp( - shape[ReductionDimensions::kRowMinorReducedDimension]) - : 1; - int64_t num_threads_x = [&] { - if (reduction_dimensions.is_row_reduction) { - if (rows_per_warp > 1) { - return shape[ReductionDimensions::kRowMinorReducedDimension]; - } - int64_t max_block_size = - MinThreadsXRowReduction(hero_reduction->GetModule()->config()); - return std::min( - max_block_size, - RoundUpTo( - CeilOfRatio(shape[ReductionDimensions::kRowMinorReducedDimension], - reduction_tiling - [ReductionDimensions::kRowMinorReducedDimension]), - WarpSize())); - } - return WarpSize(); - }(); - - // If we're limited by the size of the x dimension, add additional parallelism - // in the y dimension. The code generator doesn't currently support - // parallelizing the z dimension (major reduced dimensions). The general - // recommendation is to use between 128 and 512 threads, so we just go for - // 256. See https://forums.developer.nvidia.com/t/55529 - constexpr int64_t kThreadsPerBlockTarget = 256; - if (reduction_dimensions.is_row_reduction && - num_threads_x * 2 <= kThreadsPerBlockTarget) { - int64_t kept_size = - reduction_dimensions.dimensions[ReductionDimensions::kRowKeptDimension]; - // Increase the size of the y dimension as long as there's remaining - // parallelism. - if (kept_size * num_threads_x <= kThreadsPerBlockTarget) { - num_threads_y = kept_size; - // num_threads_x is a power of two, but it may be less than 32. If dim_y - // is also small, we may have to increase the bound so the total number of - // threads is a multiple of 32. - while ((num_threads_x * num_threads_y) % 32) ++num_threads_y; - } else { - num_threads_y = kThreadsPerBlockTarget / num_threads_x; - } - } - - int vector_size = GetVectorSize(analysis, reduction_dimensions, num_threads_x, - reduction_tiling); - - absl::InlinedVector num_threads{1, num_threads_y, num_threads_x}; - absl::InlinedVector tiled_shape{shape[0], shape[1], - shape[2] / vector_size}; - absl::InlinedVector tile_per_thread{ - reduction_tiling[0], reduction_tiling[1], - std::max(reduction_tiling[2] / vector_size, 1)}; - if (rows_per_warp > 1) { - // If we produce more than one element per thread, that means the reduced - // dimension is small and it can't be tiled - we already have more threads - // in a warp than the size of the reduced dimension. The code generator - // doesn't currently support tiling the kept dimension, because it just - // uses the thread ID as the coordinate. - tile_per_thread[2] = 1; - } - if (vector_size != 1) { - num_threads.push_back(1); // The vector dimension is a loop. - tiled_shape.push_back(vector_size); - tile_per_thread.push_back(vector_size); - } - - Tiling tiling(tiled_shape, tile_per_thread, num_threads, - /*loops_to_unroll=*/{false, false, true, false}); - bool reduction_is_race_free = ReductionIsRaceFree( - hero_reduction->GetModule()->config(), reduction_dimensions); - return ReductionInfo(analysis, tiling, reduction_dimensions.is_row_reduction, - reduction_is_race_free, - GroupDisjointReductions(analysis, /*for_mlir=*/false), - hero_reduction); -} - -std::optional ReductionInfo::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - if (!groups_.is_reduction_root[root_index]) { - auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), - analysis_.fusion_root(root_index).shape(), ctx)); - AddGroupIdConstraint(map, root_index, groups_); - return map; - } - const auto& hero = analysis_.fusion_hero(root_index).instruction(); - - auto block_offsets = GetBlockOffsetsForTiling(tiling_, ctx); - auto thread_ids = DelinearizeInBoundsIndex(mlir::getAffineDimExpr(0, ctx), - tiling_.GetThreadsPerBlock()); - - auto physical_shape = - ShapeUtil::DeleteDimensions(hero.dimensions(), hero.operand(0)->shape()); - std::vector dimension_ranges{ - {{0, tiling_.GetNumThreadsPerBlock() - 1}}, - {}, - {}, - {{0, tiling_.GetNumBlocks() - 1}}, - {{0, static_cast(groups_.grouped_roots.size() - 1)}}, - {}, - }; - - constexpr int kRowKept = ReductionDimensions::kRowKeptDimension; - constexpr int kRowMinorReduced = - ReductionDimensions::kRowMinorReducedDimension; - - constexpr int kColMajorKept = ReductionDimensions::kColMajorKeptDimension; - constexpr int kColMinorKept = ReductionDimensions::kColMinorKeptDimension; - constexpr int kColReduced = ReductionDimensions::kColReducedDimension; - - auto map = [&]() { - if (is_row_reduction_) { - IndexingMap linear_index( - mlir::AffineMap::get( - 6, 0, block_offsets.getResult(kRowKept) + thread_ids[kRowKept], - ctx), - dimension_ranges, /*range_vars=*/{}, /*rt_vars=*/{}); - int rows_per_warp = GetRowsPerWarp(); - if (rows_per_warp > 1) { - linear_index.AddConstraint( - thread_ids[kRowMinorReduced] % (WarpSize() / rows_per_warp), - {0, 0}); - } else { - linear_index.AddConstraint(thread_ids[kRowMinorReduced], {0, 0}); - } - return ComposeIndexingMaps( - linear_index, GetBitcastMap(ShapeUtil::MakeShape( - PRED, {tiling_.GetShape()[kRowKept]}), - physical_shape, ctx)); - } - - mlir::SmallVector projected_dims{ - block_offsets.getResult(kColMajorKept), - block_offsets.getResult(kColMinorKept) + thread_ids[kColReduced]}; - std::vector range_vars; - if (thread_ids.size() == 4) { - int vector_size = tiling_.GetThreadTileSize().back(); - range_vars.push_back({0, vector_size - 1}); - projected_dims.push_back(mlir::getAffineSymbolExpr(0, ctx)); - } - IndexingMap projected_index( - mlir::AffineMap::get(6, range_vars.size(), projected_dims, ctx), - dimension_ranges, range_vars, /*rt_vars=*/{}); - - projected_index.AddConstraint( - mlir::getAffineDimExpr( - KernelFusionInterface::kIndexingMapThreadIdxDims[0], ctx) % - WarpSize(), - {0, 0}); - if (!is_row_reduction_) { - projected_index.AddConstraint( - projected_index.GetAffineMap().getResult(1), - {0, tiling_.GetShape()[ReductionDimensions::kColMinorKeptDimension] - - 1}); - } - - return ComposeIndexingMaps( - projected_index, - GetBitcastMap(ShapeUtil::DeleteDimension( - ReductionDimensions::kColReducedDimension, - tiling_.GetXlaShape()), - physical_shape, ctx)); - }(); - - AddGroupIdConstraint(map, root_index, groups_); - map.Simplify(); - return map; -} - -std::optional ReductionInfo::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - const auto& hero = analysis_.fusion_hero(root_index).instruction(); - if (groups_.is_reduction_root[root_index] && - hero_operand_index >= hero.operand_count() / 2) { - // We don't have indexing for the init values. - return std::nullopt; - } - if (!groups_.is_reduction_root[root_index]) { - return ComposeIndexingMaps( - *ComputeThreadIdToOutputIndexing(root_index, ctx), - *ComputeOutputToInputIndexing( - &analysis_.fusion_root(root_index).instruction(), 0, ctx) - .indexing_maps[hero_operand_index] - .begin()); - } - - auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), - hero.operand(hero_operand_index)->shape(), ctx)); - AddGroupIdConstraint(map, root_index, groups_); - map.Simplify(); - return map; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h b/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h deleted file mode 100644 index 131b4ec38c7693..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction.h +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ - -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "xla/hlo/ir/hlo_instruction.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/legacy/tiling_util.h" -#include "xla/service/gpu/fusions/reduction_base.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/shape.h" - -namespace xla { -namespace gpu { - -class ReductionInfo { - public: - static ReductionInfo Create(const HloFusionAnalysis& analysis); - - const Tiling& GetTiling() const { return tiling_; } - const ReductionGroups& GetGroups() const { return groups_; } - Shape GetReduceOperandShape() const { - return first_reduce_->operand(0)->shape(); - } - - bool IsRowReduction() const { return is_row_reduction_; } - bool IsRaceFree() const { return is_race_free_; } - int GetRowsPerWarp() const; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const; - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const; - - LaunchDimensions launch_dimensions() const; - - private: - ReductionInfo(const HloFusionAnalysis& analysis, Tiling tiling, - bool is_row_reduction, bool is_race_free, - ReductionGroups groups, const HloInstruction* first_reduce) - : analysis_(analysis), - tiling_(tiling), - is_row_reduction_(is_row_reduction), - is_race_free_(is_race_free), - groups_(std::move(groups)), - first_reduce_(first_reduce) {} - - const HloFusionAnalysis& analysis_; - Tiling tiling_; - bool is_row_reduction_; - bool is_race_free_; - ReductionGroups groups_; - const HloInstruction* first_reduce_; -}; - -// Generates code for reduction to contiguous dimensions. -// -// Row reduction uses the following algorithm described in CUDA-like -// pseudocode: -// -// ``` -// __global__ void reduce(int num_rows, float *in, float out) { -// __shared__ float[32] cache; -// int offset = blockDim.x * blockIdx.x + threadIdx.x; -// if (offset >= num_rows) return; -// int tile_bound = std::min(offset + kTileSizeX, num_rows); -// float accum = 0; -// for (int i=offset; i ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { - return reduction_info_.ComputeThreadIdToOutputIndexing(root_index, ctx); - } - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override { - return reduction_info_.ComputeThreadIdToInputIndexing( - root_index, hero_operand_index, ctx); - } - - LaunchDimensions launch_dimensions() const override { - return reduction_info_.launch_dimensions(); - } - - const ReductionInfo& reduction_info() const { return reduction_info_; } - - protected: - absl::StatusOr EmitInitializers( - IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion) const override; - - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; - ReductionInfo reduction_info_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_REDUCTION_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc deleted file mode 100644 index 144159ce442424..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/reduction_test.cc +++ /dev/null @@ -1,178 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/fusions/legacy/reduction.h" - -#include - -#include -#include -#include "absl/status/status.h" -#include "absl/status/statusor.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" - -namespace xla { -namespace gpu { -namespace { - -using ::testing::ElementsAre; -using ::testing::SizeIs; - -class ReductionTest : public HloTestBase { - protected: - stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - mlir::MLIRContext mlir_context_; -}; - -TEST_F(ReductionTest, ThreadIndexingRowReduction) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - - fusion { - %input = f32[100,64,512] parameter(0) - %c0 = f32[] constant(0) - ROOT reduce = f32[100,64] reduce(%input, %c0), dimensions={2}, to_apply=add - } - - ENTRY entry { - %input = f32[100,64,512] parameter(0) - ROOT %fusion = f32[100,64] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT( - fusion.ComputeThreadIdToInputIndexing(0, 0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3] -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32, - (d0 mod 32) * 2 + s2 * 64 + s3 - ), - domain: - d0 in [0, 255], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], - s0 in [0, 0], - s1 in [0, 0], - s2 in [0, 7], - s3 in [0, 1], - is_simplified: true - )")); - EXPECT_THAT( - fusion.ComputeThreadIdToOutputIndexing(0, &mlir_context_)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5) -> ( - d3 floordiv 8, - (d3 mod 8) * 8 + d0 floordiv 32 - ), - domain: - d0 in [0, 224], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 799], - d4 in [0, 0], - d5 in [0, 0], - d0 mod 32 in [0, 0], - is_simplified: true - )")); -} - -TEST_F(ReductionTest, TwoGroups) { - auto module = ParseAndReturnVerifiedModule(R"( - add { - p0 = f32[] parameter(0) - p1 = f32[] parameter(1) - ROOT add = f32[] add(p0, p1) - } - fusion { - %p0 = f32[2] parameter(0) - %p1 = f32[2] parameter(1) - %c0 = f32[] constant(-inf) - %r0 = f32[] reduce(%p0, %c0), dimensions={0}, to_apply=add - %c1 = f32[] constant(inf) - %r1 = f32[] reduce(%p1, %c1), dimensions={0}, to_apply=add - ROOT %tuple = (f32[], f32[]) tuple(%r0, %r1) - } - ENTRY entry { - %p0 = f32[2] parameter(0) - %p1 = f32[2] parameter(1) - ROOT %fusion = (f32[], f32[]) fusion(%p0, %p1), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, - ElementsAre(ElementsAre(&analysis.fusion_root(0).instruction()), - ElementsAre(&analysis.fusion_root(1).instruction()))); -} - -TEST_F(ReductionTest, OneGroup) { - auto module = ParseAndReturnVerifiedModule(R"( - %add { - %p0 = c128[] parameter(0) - %p1 = c128[] parameter(1) - ROOT %add.35 = c128[] add(c128[] %p0, c128[] %p1) - } - %fusion { - %p0 = c128[1,2] parameter(0) - %c0 = c128[] constant((0, 0)) - %reduce = c128[] reduce(%p0, %c0), dimensions={0,1}, to_apply=%add - %real = f64[] real(c128[] %reduce) - %imag = f64[] imag(c128[] %reduce) - %negate = f64[] negate(f64[] %imag) - ROOT %tuple.29 = (f64[], f64[]) tuple(f64[] %real, f64[] %negate) - } - ENTRY entry { - %p0 = c128[1,2] parameter(0) - ROOT %fusion = (f64[], f64[]) fusion(%p0), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - ReductionFusion fusion(analysis); - - EXPECT_THAT(fusion.reduction_info().GetGroups().grouped_roots, SizeIs(2)); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc deleted file mode 100644 index 07987886a73120..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/scatter.h" - -#include -#include -#include -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/types/span.h" -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_casting_utils.h" -#include "xla/hlo/ir/hlo_computation.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/legacy/loop.h" -#include "xla/service/gpu/gpu_fusible.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/ir_emitter_nested.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/parallel_loop_emitter.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { - -ScatterFusion::ScatterFusion(const HloFusionAnalysis& analysis) - : analysis_(analysis), config_(ComputeLoopFusionConfig(analysis)) { - CHECK_EQ(analysis.fusion_root_count(), 1); - CHECK_EQ(analysis.fusion_root(0).opcode(), HloOpcode::kScatter); -} - -LaunchDimensions ScatterFusion::launch_dimensions() const { - const auto& updates_shape = - analysis_.fusion_root(0).instruction().operands().back()->shape(); - return CalculateLaunchDimensions(updates_shape, analysis_.device_info()); -} - -absl::Status ScatterFusion::EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const { - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - // Spin up a new fused emitter for the scatter kernel and emit it. - FusedIrEmitter scatter_fused_emitter(elemental_emitter); - auto* fused_computation = fusion.fused_instructions_computation(); - for (int i = 0; i < fused_computation->num_parameters(); i++) { - auto fused_operand = fused_computation->parameter_instruction(i); - scatter_fused_emitter.BindGenerator( - *fused_operand, [builder, &input = inputs[i], - fused_operand](llvm_ir::IrArray::Index index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - auto* root = fused_computation->root_instruction(); - const xla::ScatterDimensionNumbers& scatter_dims = - Cast(root)->scatter_dimension_numbers(); - - std::string name = llvm_ir::IrName(root); - const Shape& operand_shape = root->operand(0)->shape(); - const Shape& scatter_indices_shape = root->operand(1)->shape(); - const Shape& updates_shape = root->operand(2)->shape(); - const HloComputation& update_computation = *root->called_computations()[0]; - - TF_ASSIGN_OR_RETURN(auto scatter_indices_gen, - scatter_fused_emitter.GetGenerator(*root->operand(1))); - TF_ASSIGN_OR_RETURN(auto updates_gen, - scatter_fused_emitter.GetGenerator(*root->operand(2))); - - auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& index) -> absl::Status { - std::vector raw_window_multidim; - std::vector input_scatter_multidim; - std::vector raw_window_bounds; - - auto get_i64_array = [](absl::Span container) { - return llvm::ArrayRef{container.data(), - static_cast(container.size())}; - }; - - llvm::ArrayRef update_window_dims = - get_i64_array(scatter_dims.update_window_dims()); - // Partition the index into window indices and scatter indices. - for (int64_t i = 0, e = index.size(); i != e; ++i) { - // For window indices also remember the window size, this comes in handy - // later. - if (llvm::is_contained(update_window_dims, i)) { - raw_window_multidim.push_back(index[i]); - raw_window_bounds.push_back(updates_shape.dimensions(i)); - } else { - input_scatter_multidim.push_back(index[i]); - } - } - DCHECK_EQ(raw_window_multidim.size(), - scatter_dims.update_window_dims_size()); - - // Apply inserted_window_dims to the window dimensions. - int64_t raw_window_multidim_idx = 0; - llvm::SmallVector input_window_multidim; - llvm::SmallVector input_window_bounds; - const int64_t rank = operand_shape.rank(); - input_window_bounds.reserve(rank); - input_window_multidim.reserve(rank); - - llvm::ArrayRef inserted_window_dims = - get_i64_array(scatter_dims.inserted_window_dims()); - for (int64_t i = 0; i != rank; ++i) { - if (llvm::is_contained(inserted_window_dims, i)) { - input_window_bounds.push_back(1); // Trivial dimension. - input_window_multidim.push_back(index.GetConstantWithIndexType(0)); - } else { - input_window_bounds.push_back( - raw_window_bounds[raw_window_multidim_idx]); - input_window_multidim.push_back( - raw_window_multidim[raw_window_multidim_idx]); - ++raw_window_multidim_idx; - } - } - DCHECK_EQ(input_window_multidim.size(), operand_shape.rank()); - - // Insert a 1 dimension at the end if index_vector_dim requests one. - Shape scatter_indices_shape_fixed = scatter_indices_shape; - if (scatter_dims.index_vector_dim() == scatter_indices_shape.rank()) { - scatter_indices_shape_fixed.add_dimensions(1); - scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major( - scatter_dims.index_vector_dim()); - } - - // Now load the indices corresponding to the current window from - // scatter_indices. - std::vector raw_scatter_index_multidim = - input_scatter_multidim; - raw_scatter_index_multidim.insert( - raw_scatter_index_multidim.begin() + scatter_dims.index_vector_dim(), - nullptr); - - llvm::ArrayRef scatter_dims_to_operand_dims = - get_i64_array(scatter_dims.scatter_dims_to_operand_dims()); - llvm::Value* is_in_bounds = builder->getTrue(); - for (int64_t i = 0, e = scatter_dims_to_operand_dims.size(); i != e; ++i) { - // Our index is stored along index_vector_dim, insert that into the lookup - // index into scatter_indices. - raw_scatter_index_multidim[scatter_dims.index_vector_dim()] = - index.GetConstantWithIndexType(i); - llvm_ir::IrArray::Index raw_scatter_index_index( - raw_scatter_index_multidim, scatter_indices_shape_fixed, - index.GetType()); - - int64_t operand_dim = scatter_dims_to_operand_dims[i]; - if (operand_dim > rank) { - return absl::OutOfRangeError( - "The provided scatter_dims_to_operand_dims was out of range."); - } - TF_ASSIGN_OR_RETURN( - llvm::Value* const loaded_scatter_index, - scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape( - scatter_indices_shape_fixed, scatter_indices_shape, builder))); - // And add the index to our window index. This yields the output index. - llvm::Value* casted_scatter_index = builder->CreateIntCast( - loaded_scatter_index, index.GetType(), - /*isSigned=*/ShapeUtil::ElementIsSigned(scatter_indices_shape)); - llvm::Value* dim_offset = builder->CreateAdd( - input_window_multidim[operand_dim], casted_scatter_index); - input_window_multidim[operand_dim] = dim_offset; - - // Also do the bounds check now. - int64_t max_index = operand_shape.dimensions(operand_dim) - - input_window_bounds[operand_dim] + 1; - // is_in_bounds = index >= 0 && index < dim_size-window_size+1 - // --> index u< dim_size-window_size+1 - is_in_bounds = builder->CreateAnd( - is_in_bounds, - builder->CreateICmpULT(casted_scatter_index, - index.GetConstantWithIndexType(max_index))); - } - - llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( - is_in_bounds, "scatter.in_bounds", builder, /*emit_else=*/false); - llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, - builder); - // All done, now just read from the calculated input from the window, and do - // an atomic store to the calculated location in the output. - llvm_ir::IrArray::Index input_window_index( - input_window_multidim, outputs.back().GetShape(), index.GetType()); - llvm::Value* output_address = - outputs.back().EmitArrayElementAddress(input_window_index, builder); - llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(updates_shape.element_type(), - ir_emitter_context.llvm_module()), - "input_address", builder); - TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index)); - builder->CreateStore(input_ir_value, input_address); - - if (root->unique_indices()) { - return CallNestedComputation( - builder, ir_emitter_context, update_computation, - {output_address, input_address}, output_address); - } - return EmitAtomicOperationForNestedComputation( - builder, ir_emitter_context, update_computation, output_address, - input_address, outputs.back().GetElementLlvmType()); - }; - - // Launch a kernel that reads every element in the updates tensor. We could - // also do one kernel per window instead if bounds checks turn out to be a - // bottleneck. - auto index_type = - GetIndexTypeForKernel(root, launch_dims.launch_bound(), builder); - return ParallelLoopEmitter(loop_body_emitter, updates_shape, launch_dims, - builder) - .EmitLoop(name, index_type); -} - -std::optional ScatterFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - const auto* scatter = - DynCast(&analysis_.fusion_hero(0).instruction()); - int64_t scatter_operand_count = scatter->scatter_operand_count(); - // Scatter operands a packed in the following way: - // Operand IDs [0, scatter_operand_count - 1] for `scatter operands`. - // Operand ID scatter_operand_count for `scatter indices`. - // Operand IDs [scatter_operand_count + 1, 2 * scatter_operand_count] for - // `scatter updates`. - - // For scatter operands we do not know the thread ID indexing. - if (hero_operand_index < scatter_operand_count) { - return std::nullopt; - } - // Compute thread id mapping based on the first update operand. - Shape scatter_update_shape = scatter->scatter_updates().front()->shape(); - IndexingMap scatter_update_map = GetDefaultThreadIdIndexingMap( - launch_dimensions(), config_.unroll_factor, scatter_update_shape, ctx); - - // For scatter indices we project indexing for scatter updates and take the - // first result of the affine map only, because they coincide. - if (hero_operand_index == scatter_operand_count) { - Shape scatter_indices_shape = scatter->scatter_indices()->shape(); - CHECK_EQ(scatter_indices_shape.rank(), 2) << scatter->ToString(); - // Create a map from scatter update to scatter indices. - IndexingMap updates_to_indices_map{ - mlir::AffineMap::get( - /*dimCount=*/scatter_update_shape.rank(), /*symbolCount=*/1, - {mlir::getAffineDimExpr(0, ctx), mlir::getAffineSymbolExpr(0, ctx)}, - ctx), - DimVarsFromTensorSizes(scatter_update_shape.dimensions()), - RangeVarsFromTensorSizes({scatter_indices_shape.dimensions(1)}), - /*rt_vars=*/{}}; - auto scatter_indices_map = scatter_update_map * updates_to_indices_map; - scatter_indices_map.Simplify(); - return scatter_indices_map; - } - return scatter_update_map; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h b/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h deleted file mode 100644 index 862d0b3543b4ad..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ - -#include -#include - -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" - -namespace xla { -namespace gpu { - -// A scatter, implemented as a loop over the updates. All scatters are in-place. -class ScatterFusion : public KernelFusionEmitterBase { - public: - explicit ScatterFusion(const HloFusionAnalysis& analysis); - - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override { - // The kernel iterates over updates, whose correspondence to output - // elements cannot be computed statically. - return std::nullopt; - } - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; - LaunchDimensionsConfig config_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_SCATTER_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc deleted file mode 100644 index 8c6674d4a2b546..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/scatter_test.cc +++ /dev/null @@ -1,226 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/scatter.h" - -#include - -#include -#include -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/affine_map_printer.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -class ScatterFusionTest : public HloTestBase { - public: - void SetUp() override { - HloTestBase::SetUp(); - printer_ = - AffineMapPrinter({"th_x", "th_y", "th_z", "bl_x", "bl_y", "bl_z"}, - {"chunk_id", "unroll_id", "index_id"}); - } - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } - - protected: - AffineMapPrinter printer_; - mlir::MLIRContext mlir_context_; -}; - -TEST_F(ScatterFusionTest, ScatterFusion) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - add (lhs: f32[], rhs: f32[]) -> f32[] { - lhs = f32[] parameter(0) - rhs = f32[] parameter(1) - ROOT sum = f32[] add(lhs, rhs) - } - - fused_computation { - %input = f32[2,9] parameter(0) - %indices = s32[3] parameter(1) - %updates = f32[3,9] parameter(2) - ROOT %scatter = f32[2,9] scatter(%input, %indices, %updates), - to_apply=add, - update_window_dims={1}, - inserted_window_dims={0}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 - } - - ENTRY entry { - %input = f32[2,9] parameter(0) - %indices = s32[3] parameter(1) - %updates = f32[3,9] parameter(2) - ROOT %fusion = f32[2,9] fusion(%input, %indices, %updates), kind=kLoop, calls=fused_computation - })") - .value(); - - stream_executor::DeviceDescription device_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - auto scatter_fusion = dynamic_cast(emitter.get()); - ASSERT_NE(scatter_fusion, nullptr); - EXPECT_EQ(scatter_fusion->launch_dimensions().launch_bound(), - 3 * 9 /* updates size */); -} - -TEST_F(ScatterFusionTest, ThreadIdIndexing) { - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(R"( - HloModule module - - computation { - %p0 = f32[] parameter(0) - %p1 = f32[] parameter(1) - %p2 = f32[] parameter(2) - %p3 = f32[] parameter(3) - ROOT %tuple = (f32[], f32[]) tuple(f32[] %p2, f32[] %p3) - } - scatter { - %operand0 = f32[300,200] parameter(0) - %operand1 = f32[300,200] parameter(1) - %indices = s32[42,1] parameter(2) - %update.1 = f32[42,10,20] parameter(3) - %update.2 = f32[42,10,20]parameter(4) - - ROOT %scatter = (f32[300,200], f32[300,200]) scatter( - f32[300,200] %operand0, - f32[300,200] %operand1, - s32[42,1] %indices, - f32[42,10,20] %update.1, - f32[42,10,20] %update.2 - ), - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1, - to_apply=computation - } - ENTRY entry { - %operand0 = f32[300,200] parameter(0) - %operand1 = f32[300,200] parameter(1) - %indices = s32[42,1] parameter(2) - %update.1 = f32[42,10,20] parameter(3) - %update.2 = f32[42,10,20]parameter(4) - ROOT %fusion = (f32[300,200], f32[300,200]) fusion( - %operand0, %operand1, %indices, %update.1, %update.2), - kind=kLoop, calls=scatter - } - )")); - stream_executor::DeviceDescription device_info = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis_fused = HloFusionAnalysis::Create(*root, device_info); - - auto emitter = - GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused}); - auto fusion = dynamic_cast(emitter.get()); - ASSERT_NE(fusion, nullptr); - - constexpr auto kUpdatesIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id] -> ( - (bl_x * 128 + th_x) floordiv 200, - ((bl_x * 128 + th_x) floordiv 20) mod 10, - (bl_x * 128 + th_x) mod 20 - ), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 65], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true - )"; - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kUpdatesIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kUpdatesIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/3, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kUpdatesIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/4, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kUpdatesIndexing)); - - constexpr auto kIndicesIndexing = R"( - (th_x, th_y, th_z, bl_x, bl_y, bl_z)[chunk_id, unroll_id, index_id] -> - ((bl_x * 128 + th_x) floordiv 200, 0), - domain: - th_x in [0, 127], - th_y in [0, 0], - th_z in [0, 0], - bl_x in [0, 65], - bl_y in [0, 0], - bl_z in [0, 0], - chunk_id in [0, 0], - unroll_id in [0, 0], - index_id in [0, 0], - bl_x * 128 + th_x in [0, 8399], - is_simplified: true - )"; - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/0, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kIndicesIndexing)); - EXPECT_THAT( - fusion - ->ComputeThreadIdToInputIndexing( - /*root_index=*/1, /*hero_operand_index=*/2, &mlir_context_) - ->ToString(printer_), - MatchIndexingString(kIndicesIndexing)); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc deleted file mode 100644 index a1a7acb58388a7..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.cc +++ /dev/null @@ -1,351 +0,0 @@ -/*Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "xla/service/gpu/fusions/legacy/tiling_util.h" - -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/Casting.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/kernel_support_library.h" -#include "xla/service/llvm_ir/llvm_loop.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -using mlir::AffineExpr; -using mlir::AffineMap; -using mlir::MLIRContext; - -void EmitTileRec(const TilingThreadIdInfo& thread_id_info, const Tiling& tiling, - int dim, absl::InlinedVector tile_idx, - absl::Span tile_dimensions, - llvm::IRBuilder<>* b, const TileElementGenerator& emit_elem) { - llvm::Type* index_ty = thread_id_info.thread_id->getType(); - auto constant = [&](int64_t val) { - return llvm::ConstantInt::get(index_ty, val); - }; - - auto recurse = [&] { - if (dim == tile_idx.size() - 1) { - emit_elem(tile_idx); - } else { - EmitTileRec(thread_id_info, tiling, dim + 1, tile_idx, tile_dimensions, b, - emit_elem); - } - }; - - bool unroll = tiling.GetLoopsToUnroll()[dim]; - KernelSupportLibrary ksl(b, unroll ? llvm_ir::UnrollMode::kFullyUnroll - : llvm_ir::UnrollMode::kDefaultUnroll); - - if (tiling.GetBlockTileSize()[dim] == 1) { - tile_idx[dim] = constant(0); - recurse(); - } else if (unroll) { - // TODO(jreiffers): Check if this unrolling does anything useful. - int64_t stride = tiling.GetThreadsPerBlock()[dim]; - int64_t dim_size = tiling.GetThreadTileSize()[dim]; - - auto make_loop = [&](bool emit_bounds_checks) { - auto body = [&, emit_bounds_checks](llvm::Value* i) { - tile_idx[dim] = b->CreateAdd(i, thread_id_info.thread_ids[dim]); - if (emit_bounds_checks) { - auto* in_bounds = - b->CreateICmpULT(tile_idx[dim], tile_dimensions[dim]); - ksl.If("x_in_tile", in_bounds, recurse); - } else { - recurse(); - } - }; - return [&, body] { - ksl.For(absl::StrCat("loop", dim), constant(0), - constant(dim_size * stride), constant(stride), body); - }; - }; - if (stride > 1 && dim_size > 1) { - // Most tiles will be full, so we emit a single bounds check for those. - auto* is_full_tile = b->CreateICmpEQ( - constant(tiling.GetBlockTileSize()[dim]), tile_dimensions[dim]); - ksl.If("is_full_tile", is_full_tile, make_loop(false), make_loop(true)); - } else { - make_loop(true)(); - } - } else { - // All dimensions are strided (thread 0 processes elements 0, num_threads, - // num_threads+2, ...; thread 1 processes elements 1, num_threads + 1 and so - // on). - ksl.For(absl::StrCat("loop", dim), /*start=*/thread_id_info.thread_ids[dim], - /*end=*/tile_dimensions[dim], - /*step=*/tiling.GetThreadsPerBlock()[dim], [&](llvm::Value* i) { - tile_idx[dim] = i; - recurse(); - }); - } -} - -} // namespace - -void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, - const TilingThreadIdInfo& thread_id_info, - absl::Span tile_dimensions, - const TileElementGenerator& emit_elem_function) { - absl::InlinedVector tile_idx(tiling.GetShape().size()); - EmitTileRec(thread_id_info, tiling, 0, tile_idx, tile_dimensions, builder, - emit_elem_function); -} - -namespace { - -// Emits current block id. -llvm::Value* EmitBlockId(llvm::IRBuilder<>* builder, int32_t num_blocks, - llvm::Type* index_ty) { - llvm::Value* block_id = - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, builder); - if (num_blocks != 0) { - llvm_ir::AddRangeMetadata(0, num_blocks, - llvm::cast(block_id), - builder->GetInsertBlock()->getModule()); - } - auto ret = builder->CreateIntCast(block_id, index_ty, /*isSigned=*/true); - ret->setName("block.id.x"); - return ret; -} - -// Emits current thread id with the given type. -// -// Sets the return value range to [0, threads_per_block). -llvm::Value* EmitThreadId(llvm::IRBuilder<>* builder, int64_t threads_per_block, - llvm::Type* index_ty) { - // Calculate (y, x) coordinates respectively in the 2D view of thread block, - // defined by (num_thread_y, num_thread_x) from thread_id. - llvm::CallInst* thread_id = - EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, builder); - llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id, - builder->GetInsertBlock()->getModule()); - auto ret = builder->CreateIntCast(thread_id, index_ty, /*isSigned=*/true); - ret->setName("thread.id.x"); - return ret; -} - -// Emits the LLVM values for thread_id, block_id, coordinates of the current -// tile and strides of the loops to iterate over the current tile. -absl::StatusOr EmitThreadIdInfo(llvm::IRBuilder<>* builder, - const Tiling& tiling, - llvm::Type* index_ty) { - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - int64_t num_blocks = tiling.GetNumBlocks(); - if (num_blocks > (int64_t)std::numeric_limits::max()) { - return FailedPrecondition( - "Number of physical blocks (%d) does not fit in an i32 in tiling " - "scheme: %s", - num_blocks, tiling.ToString()); - } - - TilingThreadIdInfo info; - info.thread_id = - EmitThreadId(builder, tiling.GetNumThreadsPerBlock(), index_ty); - info.block_id = EmitBlockId(builder, num_blocks, index_ty); - - for (auto [dim, stride] : llvm::enumerate(tiling.GetThreadStrides())) { - int64_t size = tiling.GetThreadsPerBlock()[dim]; - if (size == 1) { - info.thread_ids.emplace_back(constant(0)); - } else { - auto& dim_id = info.thread_ids.emplace_back(info.thread_id); - if (stride > 1) { - dim_id = builder->CreateUDiv(dim_id, constant(stride)); - } - if (dim) { - dim_id = builder->CreateURem(dim_id, constant(size)); - } - dim_id->setName(absl::StrCat("thread.id.", dim)); - } - } - - info.lane_id = - builder->CreateURem(info.thread_id, constant(WarpSize()), "lane_id"); - return info; -} - -AffineMap GetTilingAffineMap(llvm::ArrayRef exprs, - int64_t num_symbols) { - return AffineMap::get( - /*dimCount=*/6, /*symbolCount=*/num_symbols, exprs, - exprs[0].getContext()); -} - -} // namespace - -absl::StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, - const TileGenerator& tile_element_generator) { - absl::Span dims_in_elems = tiling.GetShape(); - const auto& block_counts = tiling.GetBlockCounts(); - auto constant = [&](uint64_t c) -> llvm::Constant* { - return llvm::ConstantInt::get(index_ty, c); - }; - - TF_ASSIGN_OR_RETURN(TilingThreadIdInfo thread_id_info, - EmitThreadIdInfo(builder, tiling, index_ty)); - - KernelSupportLibrary ksl(builder, llvm_ir::UnrollMode::kDefaultUnroll); - - const llvm_ir::IrArray::Index block_coords( - thread_id_info.block_id, - ShapeUtil::MakeShape(PRED /*arbitrary*/, block_counts), builder); - - absl::InlinedVector tile_dimensions; - for (int i = 0; i < block_counts.size(); ++i) { - int64_t block_tile_size = tiling.GetBlockTileSize()[i]; - if (dims_in_elems[i] % block_tile_size == 0) { - // The block tile size evenly divides the tiled shape -> no need to emit - // the bounds check. - tile_dimensions.push_back(constant(block_tile_size)); - } else { - // Only the last tile in each dimension may not have full size. - llvm::Value* is_last = - builder->CreateICmpEQ(block_coords[i], constant(block_counts[i] - 1)); - int64_t partial_row = - dims_in_elems[i] - (block_counts[i] - 1) * block_tile_size; - tile_dimensions.push_back(builder->CreateSelect( - is_last, constant(partial_row), constant(block_tile_size), - absl::StrCat("tile_bound.", i))); - } - } - - llvm_ir::IrArray::Index tile_offset = [&] { - std::vector elem_multi_index = block_coords.multidim(); - llvm::Type* index_ty = block_coords.GetType(); - for (int i = 0; i < block_counts.size(); ++i) { - elem_multi_index[i] = builder->CreateMul( - block_coords[i], - llvm::ConstantInt::get(index_ty, tiling.GetBlockTileSize()[i]), - absl::StrCat("tile_origin.", i)); - } - return llvm_ir::IrArray::Index(elem_multi_index, tiling.GetShape(), - index_ty); - }(); - - tile_element_generator(thread_id_info, tile_offset, tile_dimensions); - return {{tile_dimensions, tile_offset, thread_id_info}}; -} - -AffineMap GetBlockOffsetsForTiling( - absl::Span num_blocks, - absl::Span tile_sizes_per_block, int64_t rank, - MLIRContext* mlir_context) { - auto offsets = - DelinearizeInBoundsIndex(getAffineDimExpr(3, mlir_context), num_blocks); - for (auto&& [offset, tile_size] : llvm::zip(offsets, tile_sizes_per_block)) { - offset = offset * tile_size; - } - return GetTilingAffineMap(offsets, rank); -} - -AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetBlockOffsetsForTiling(tiling.GetBlockCounts(), - tiling.GetBlockTileSize(), - tiling.GetShape().size(), mlir_context); -} - -AffineMap GetThreadOffsetsForTiling( - absl::Span num_threads, - absl::Span tile_sizes_per_thread, int64_t rank, - MLIRContext* mlir_context) { - auto offsets = - DelinearizeInBoundsIndex(getAffineDimExpr(0, mlir_context), num_threads); - for (int dim = 0; dim < rank; ++dim) { - if (tile_sizes_per_thread[dim] > 1) { - offsets[dim] = offsets[dim] + - getAffineSymbolExpr(dim, mlir_context) * num_threads[dim]; - } - } - return GetTilingAffineMap(offsets, rank); -} - -AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetThreadOffsetsForTiling(tiling.GetThreadsPerBlock(), - tiling.GetThreadTileSize(), - tiling.GetShape().size(), mlir_context); -} - -IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - MLIRContext* mlir_context) { - return GetIndexingMapForTiling( - GetBlockOffsetsForTiling(tiling, mlir_context), - GetThreadOffsetsForTiling(tiling, mlir_context), - tiling.GetNumThreadsPerBlock(), tiling.GetNumBlocks(), - tiling.GetThreadTileSize(), tiling.GetShape()); -} - -IndexingMap GetIndexingMapForTiling(AffineMap block_offsets, - AffineMap thread_offsets, - int64_t threads_per_block, - int64_t num_blocks, - absl::Span thread_tile_sizes, - absl::Span tiled_shape) { - auto* mlir_context = block_offsets.getContext(); - llvm::SmallVector offsets; - offsets.reserve(block_offsets.getNumResults()); - for (auto [block, thread] : - llvm::zip(block_offsets.getResults(), thread_offsets.getResults())) { - offsets.push_back(block + thread); - } - std::vector dimension_ranges{ - {{0, threads_per_block - 1}}, {}, {}, {{0, num_blocks - 1}}, {}, {}, - }; - auto affine_map = mlir::AffineMap::get(block_offsets.getNumDims(), - block_offsets.getNumSymbols(), offsets, - mlir_context); - IndexingMap map{affine_map, dimension_ranges, - RangeVarsFromTensorSizes(thread_tile_sizes), /*rt_vars=*/{}}; - for (int i = 0; i < tiled_shape.size(); ++i) { - map.AddConstraint(affine_map.getResult(i), {0, tiled_shape[i] - 1}); - } - return map; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h b/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h deleted file mode 100644 index de367e36addb61..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/tiling_util.h +++ /dev/null @@ -1,215 +0,0 @@ -/*Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/strings/str_format.h" -#include "absl/strings/str_join.h" -#include "absl/types/span.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/shape.h" -#include "xla/shape_util.h" -#include "xla/util.h" -#include "xla/xla_data.pb.h" - -namespace xla { -namespace gpu { - -// Describes tiling used by the kernel. -// -// Used by reduction and transpose emitters. -class Tiling { - public: - Tiling(absl::Span shape, absl::Span tile_sizes, - absl::Span num_threads, - // By default, don't unroll anything. - absl::InlinedVector loops_to_unroll = {}) - : shape_{shape.begin(), shape.end()}, - tile_sizes_per_thread_{tile_sizes.begin(), tile_sizes.end()}, - tile_sizes_per_block_(shape.size()), - num_threads_{num_threads.begin(), num_threads.end()}, - num_blocks_(shape.size()), - loops_to_unroll_(loops_to_unroll) { - for (int64_t i = 0; i < shape.size(); ++i) { - tile_sizes_per_block_[i] = tile_sizes[i] * num_threads[i]; - CHECK_NE(tile_sizes_per_block_[i], 0); - num_blocks_[i] = CeilOfRatio(shape[i], tile_sizes_per_block_[i]); - CHECK_NE(num_blocks_[i], 0); - } - if (loops_to_unroll_.empty()) loops_to_unroll_.resize(shape.size()); - } - Tiling() = default; - - std::string ToString() const { - return absl::StrJoin( - {absl::StrFormat("shape = {%s}", absl::StrJoin(shape_, ", ")), - absl::StrFormat("tile_sizes = {%s}", - absl::StrJoin(tile_sizes_per_thread_, ", ")), - absl::StrFormat("num_threads = {%s}", - absl::StrJoin(num_threads_, ", "))}, - ", "); - } - - // Number of elements in each dimension. - const absl::InlinedVector& GetShape() const { return shape_; } - xla::Shape GetXlaShape(PrimitiveType element_type = F32) const { - return ShapeUtil::MakeShape(element_type, shape_); - } - - const absl::InlinedVector& GetBlockCounts() const { - return num_blocks_; - } - - // Tile size for each thread. - // - // Equals to the number of iterations in the loop each tile will make. - const absl::InlinedVector& GetThreadTileSize() const { - return tile_sizes_per_thread_; - } - - // Tile size for an entire thread block. - const absl::InlinedVector& GetBlockTileSize() const { - return tile_sizes_per_block_; - } - - const absl::InlinedVector& GetThreadsPerBlock() const { - return num_threads_; - } - - // Returns the strides of the thread index dimensions wrt. the linear thread - // id. - absl::InlinedVector GetThreadStrides() const { - return *ShapeUtil::ByteStrides(ShapeUtil::MakeShape(U8, num_threads_)); - } - - int64_t GetNumThreadsPerBlock() const { return Product(num_threads_); } - - int64_t GetNumBlocks() const { return Product(num_blocks_); } - - const absl::InlinedVector& GetLoopsToUnroll() const { - return loops_to_unroll_; - } - - private: - // The number of elements in each dimension. - absl::InlinedVector shape_; - - // The number of elements for each dimension of a tile. - absl::InlinedVector tile_sizes_per_thread_; - absl::InlinedVector tile_sizes_per_block_; - - absl::InlinedVector num_threads_; - absl::InlinedVector num_blocks_; - - absl::InlinedVector loops_to_unroll_; -}; - -struct TilingThreadIdInfo { - llvm::Value* thread_id; - - absl::InlinedVector thread_ids; - - // Lane id: `thread_id % WarpSize` - llvm::Value* lane_id; - - // Block id. - llvm::Value* block_id; -}; - -struct TilingKernelInfo { - // Tiling bounds. - absl::InlinedVector output_tile_bounds; - - // Starting tile, as calculated from block id only. - llvm_ir::IrArray::Index tile_origin; - - // Thread meta-info. - TilingThreadIdInfo thread_id_info; -}; - -// A function to generate the code to emit the entire tile. -// -// index: Absolute coordinate of the start of the tile in input. -// tile_dimensions: Size of the tile -using TileGenerator = - std::function tile_dimensions)>; - -// A function object to generate code to process one element in a tile. -// -// index_in_tile: the current coordinates within the tile. To get the global -// coordinates, use `tile_start_index.AddOffset(index_in_tile, ...)`. -using TileElementGenerator = - std::function index_in_tile)>; - -// Emits code to iterate through a tile with given tile dimensions and generate -// elements using the callback. -void EmitTile(llvm::IRBuilder<>* builder, const Tiling& tiling, - const TilingThreadIdInfo& thread_id_info, - absl::Span tile_dimensions, - const TileElementGenerator& emit_elem_function); - -// Emits a kernel for the hlo instruction using the given kernel mapping -// scheme. -absl::StatusOr EmitTilingKernel( - llvm::IRBuilder<>* builder, const Tiling& tiling, llvm::Type* index_ty, - const TileGenerator& tile_element_generator); - -// Creates an indexing map from thread and block IDs to elements of the tiled -// shape. Uses the same convention as KernelFusionInterface: dimensions 0 to 2 -// are thread indices (currently only 0 is used), dimensions 3 to 5 are block -// indices (currently only 3 is used). -mlir::AffineMap GetBlockOffsetsForTiling( - absl::Span num_blocks, - absl::Span tile_sizes_per_block, int64_t rank, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetBlockOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetThreadOffsetsForTiling( - absl::Span num_threads, - absl::Span tile_sizes_per_thread, int64_t rank, - mlir::MLIRContext* mlir_context); -mlir::AffineMap GetThreadOffsetsForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); - -// Convenience functions for the two functions above -// (`GetBlockOffsestsForTiling` + `GetThreadOffsetsForTiling`). Also sets up -// the ranges of dimensions and symbols. -IndexingMap GetIndexingMapForTiling(const Tiling& tiling, - mlir::MLIRContext* mlir_context); -IndexingMap GetIndexingMapForTiling(mlir::AffineMap block_offsets, - mlir::AffineMap thread_offsets, - int64_t threads_per_block, - int64_t num_blocks, - absl::Span thread_tile_sizes, - absl::Span tiled_shape); - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TILING_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc deleted file mode 100644 index f91a0a4b6b120f..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.cc +++ /dev/null @@ -1,365 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/transpose.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/container/inlined_vector.h" -#include "absl/log/check.h" -#include "absl/status/status.h" -#include "absl/strings/str_cat.h" -#include "absl/types/span.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/AtomicOrdering.h" -#include "mlir/IR/AffineMap.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/hlo/ir/hlo_opcode.h" -#include "xla/permutation_util.h" -#include "xla/service/gpu/elemental_ir_emitter.h" -#include "xla/service/gpu/fusions/legacy/tiling_util.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/gpu/target_util.h" -#include "xla/service/llvm_ir/fused_ir_emitter.h" -#include "xla/service/llvm_ir/ir_array.h" -#include "xla/service/llvm_ir/llvm_util.h" -#include "xla/service/llvm_ir/loop_emitter.h" -#include "xla/shape_util.h" -#include "xla/xla_data.pb.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -Tiling ComputeTransposeTiling(const se::DeviceDescription& gpu_device_info, - const TransposeDescription& tiled_transpose) { - constexpr int kNumRows = 4; - static_assert(WarpSize() % kNumRows == 0); - - // 3D view over the output shape. - absl::InlinedVector transposed_dims = tiled_transpose.dimensions; - absl::InlinedVector permutation = tiled_transpose.permutation; - - // Note: the supported permutations are their own inverses. Therefore we - // always use the permutation, even when we want the inverse. - CHECK((permutation == absl::InlinedVector{0, 2, 1}) || - (permutation == absl::InlinedVector{2, 1, 0})); - - absl::InlinedVector input_dims{transposed_dims[permutation[0]], - transposed_dims[permutation[1]], - transposed_dims[permutation[2]]}; - - // We tile along the minor dimensions pre- and post-transpose. - absl::InlinedVector tile_sizes{1, 1, 1}; - tile_sizes[permutation[2]] = WarpSize() / kNumRows; - absl::InlinedVector num_threads{1, 1, WarpSize()}; - num_threads[permutation[2]] = kNumRows; - - auto capability = gpu_device_info.gpu_compute_capability(); - std::visit( - [&](const auto& capability) { - if constexpr (std::is_same_v, - stream_executor::RocmComputeCapability>) { - // kNumRows = 8 works well on MI300 with wavefront size 64. - if (capability.gfx9_mi300()) { - tile_sizes[permutation[2]] = gpu_device_info.threads_per_warp() / 8; - num_threads[permutation[2]] = 8; - } - } - }, - capability); - - return Tiling(input_dims, tile_sizes, num_threads); -} - -void MaybeEmitFenceForAMDGPU(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - auto* module = builder->GetInsertBlock()->getModule(); - if (IsAMDGPU(module) && - ir_emitter_context.rocm_compute_capability().fence_before_barrier()) { - builder->CreateFence( - llvm::AtomicOrdering::SequentiallyConsistent, - builder->getContext().getOrInsertSyncScopeID("workgroup")); - } -} - -void EmitSyncThreads(llvm::IRBuilder<>* builder, - IrEmitterContext& ir_emitter_context) { - MaybeEmitFenceForAMDGPU(builder, ir_emitter_context); - EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, builder); -} - -llvm_ir::IrArray::Index PermuteIndex(const llvm_ir::IrArray::Index& index, - absl::Span permutation) { - return llvm_ir::IrArray::Index{Permute(index.multidim(), permutation), - Permute(index.dims(), permutation), - index.GetType()}; -} - -} // namespace - -TransposeFusion::TransposeFusion(const se::DeviceDescription& gpu_device_info, - const HloFusionAnalysis& analysis) - : analysis_(analysis), - tiling_( - ComputeTransposeTiling(gpu_device_info, analysis.tiled_transpose())) { - for (auto [root, hero] : - llvm::zip(analysis_.fusion_roots(), analysis_.fusion_heroes())) { - if (auto transpose = - GetDescriptionForTiledTransposeEmitter(hero.instruction())) { - permutation_ = transpose->permutation; - break; - } - } -} - -absl::Status TransposeFusion::EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const { - const auto& hlo_roots = analysis_.fusion_roots(); - GpuElementalIrEmitter elemental_emitter(ir_emitter_context, builder); - FusedIrEmitter fused_emitter(elemental_emitter); - for (auto [i, input] : llvm::enumerate(inputs)) { - HloInstruction* fused_operand = fusion.fused_parameter(i); - fused_emitter.BindGenerator( - *fused_operand, [input = input, builder, - fused_operand](const llvm_ir::IrArray::Index& index) { - return input.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - absl::flat_hash_map>> - transposes_to_roots; - // Keep a list of deduplicated transpose heroes separate from - // transposes_to_roots to make the CodeGen deterministic. - std::vector transposes; - transposes.reserve(hlo_roots.size()); - std::vector> extra_outputs; - - for (const auto& [output_idx, root] : llvm::enumerate(hlo_roots)) { - const auto& hero = analysis_.fusion_hero(output_idx).instruction(); - auto transpose_descr = GetDescriptionForTiledTransposeEmitter(hero); - if (transpose_descr.has_value()) { - auto iterator_inserted = transposes_to_roots.insert(std::make_pair( - &hero, std::vector>{ - {output_idx, &root.instruction()}})); - if (iterator_inserted.second) { - transposes.push_back(*transpose_descr); - } else { - iterator_inserted.first->second.push_back( - {output_idx, &root.instruction()}); - } - } else { - extra_outputs.push_back({output_idx, &root.instruction()}); - } - } - - absl::flat_hash_map tiles; - absl::InlinedVector permutation; - for (const auto& [tile_idx, tr] : llvm::enumerate(transposes)) { - permutation = tr.permutation; - auto tile_size = tiling_.GetBlockTileSize(); - ++tile_size.back(); // Prevent bank conflicts. - auto* module = ir_emitter_context.llvm_module(); - tiles[tr.instr] = llvm_ir::AllocateSharedMemoryTile( - module, - llvm_ir::PrimitiveTypeToIrType(tr.instr->shape().element_type(), - module), - tile_size, absl::StrCat("tr_tile_", tile_idx)); - } - - auto tile_generator = [&](const TilingThreadIdInfo& thread_id_info, - const llvm_ir::IrArray::Index& tile_start_index, - absl::Span tile_dimensions) { - // Copy input parameter values to shared memory buffers: - // tile[thread_id_y, thread_id_x] = input[index] - EmitTile(builder, tiling_, thread_id_info, tile_dimensions, - [&](absl::Span index_in_tile) { - auto index = tile_start_index.AddOffset(index_in_tile, builder); - for (const auto& tr : transposes) { - auto input_gen = - *fused_emitter.GetGenerator(*tr.instr->operand(0)); - auto input_index = index.SourceIndexOfBitcast( - tr.instr->operand(0)->shape(), builder); - llvm::Value* value = *input_gen(input_index); - tiles[tr.instr].Store(value, index_in_tile, builder); - } - - // Compute all extra output values before writing them. This - // avoids overwriting aliased input/output values before all - // reads occurred. - std::vector> - scheduled_writes; - for (const auto& [output_idx, root] : extra_outputs) { - auto extra_output_index = - index.SourceIndexOfBitcast(root->shape(), builder); - auto output_gen = *fused_emitter.GetGenerator(*root); - llvm::Value* output_value = *output_gen(extra_output_index); - scheduled_writes.emplace_back( - outputs[output_idx], extra_output_index, output_value); - } - - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, builder); - } - }); - - EmitSyncThreads(builder, ir_emitter_context); - - auto output_tile_index = PermuteIndex(tile_start_index, permutation); - auto transposed_tile_dimensions = Permute(tile_dimensions, permutation); - - EmitTile( - builder, tiling_, thread_id_info, transposed_tile_dimensions, - /*emit_elem_function=*/ - [&](absl::Span index_in_tile) { - auto index = output_tile_index.AddOffset(index_in_tile, builder); - for (const auto& tr : transposes) { - llvm::Value* loaded = tiles[tr.instr].Load( - Permute(index_in_tile, permutation), builder); - - FusedIrEmitter fused_emitter(elemental_emitter); - fused_emitter.BindGenerator( - *tr.instr, - [&](const llvm_ir::IrArray::Index&) { return loaded; }); - for (int64_t i = 0; - i < fusion.fused_instructions_computation()->num_parameters(); - ++i) { - llvm_ir::IrArray ir_array = inputs[i]; - HloInstruction* fused_operand = fusion.fused_parameter(i); - fused_emitter.BindGenerator( - *fused_operand, [=](const llvm_ir::IrArray::Index& index) { - return ir_array.EmitReadArrayElement(index, builder, - fused_operand->name()); - }); - } - - // Apply code generation for the code after the real hero. - // Compute all output values before writing them. This avoids - // overwriting aliased input/output values before all reads - // occurred. - std::vector> - scheduled_writes; - for (const auto& [output_idx, root] : - transposes_to_roots[tr.instr]) { - TF_ASSIGN_OR_RETURN(llvm_ir::ElementGenerator gen, - fused_emitter.GetGenerator(*root)); - - // Both for emission and writing it should be - // index-as-transformed by the computation. - auto untiled_index = - index.SourceIndexOfBitcast(root->shape(), builder); - TF_ASSIGN_OR_RETURN(llvm::Value * generated, gen(untiled_index)); - scheduled_writes.emplace_back(outputs[output_idx], untiled_index, - generated); - } - for (const auto& [output, idx, value] : scheduled_writes) { - output.EmitWriteArrayElement(idx, value, builder); - } - } - return absl::OkStatus(); - }); - }; - - llvm::Type* index_type = - GetIndexTypeForKernel(&fusion, launch_dims.launch_bound(), builder); - return EmitTilingKernel(builder, tiling_, index_type, tile_generator) - .status(); -} - -LaunchDimensions TransposeFusion::launch_dimensions() const { - return LaunchDimensions(tiling_.GetNumBlocks(), - tiling_.GetNumThreadsPerBlock()); -} - -std::optional TransposeFusion::ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const { - const auto& hero = analysis_.fusion_hero(root_index); - if (hero.opcode() != HloOpcode::kTranspose) { - // The shape of non-transpose roots are bitcast compatible with the input - // shape of transpose heroes. - auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), - analysis_.fusion_root(root_index).shape(), ctx)); - map.Simplify(); - return map; - } - - // The block offsets are permuted, but the thread offsets remain the same. - auto block_offset = GetBlockOffsetsForTiling(tiling_, ctx) - .getSubMap(std::vector{permutation_.begin(), - permutation_.end()}); - auto thread_offset = GetThreadOffsetsForTiling(tiling_, ctx); - auto permuted_tiled_shape = - ShapeUtil::MakeShape(U8, Permute(tiling_.GetShape(), permutation_)); - - auto map = ComposeIndexingMaps( - GetIndexingMapForTiling( - block_offset, thread_offset, tiling_.GetNumThreadsPerBlock(), - tiling_.GetNumBlocks(), tiling_.GetThreadTileSize(), - permuted_tiled_shape.dimensions()), - GetBitcastMap(permuted_tiled_shape, hero.shape(), ctx)); - map.Simplify(); - return map; -} - -std::optional TransposeFusion::ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const { - const auto& hero = analysis_.fusion_hero(root_index).instruction(); - if (hero.opcode() != HloOpcode::kTranspose) { - auto map = ComposeIndexingMaps( - *ComputeThreadIdToOutputIndexing(root_index, ctx), - *ComputeOutputToInputIndexing( - &analysis_.fusion_root(root_index).instruction(), 0, ctx) - .indexing_maps[hero_operand_index] - .begin()); - map.Simplify(); - return map; - } - - auto map = ComposeIndexingMaps( - GetIndexingMapForTiling(tiling_, ctx), - GetBitcastMap(tiling_.GetXlaShape(), hero.operand(0)->shape(), ctx)); - map.Simplify(); - return map; -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h b/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h deleted file mode 100644 index 3366130c05546b..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright 2023 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ -#define XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ - -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/status/status.h" -#include "llvm/IR/IRBuilder.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/hlo/ir/hlo_instructions.h" -#include "xla/service/gpu/fusions/fusion_emitter.h" -#include "xla/service/gpu/fusions/legacy/tiling_util.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/ir_emitter_context.h" -#include "xla/service/gpu/launch_dimensions.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/service/llvm_ir/ir_array.h" - -namespace xla { -namespace gpu { - -// Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose -// algorithm to improve the memory access patterns for the input parameters -// with a shape that is a 0-2-1 transpose of the output tensor shape. The -// caller is responsible for making sure that it is safe to apply the shared -// memory transpose on the input parameters. -// -// For the purpose of tiling, the output tensors have a logical shape of three -// components 0-2-1 while the relevant input parameters have a logical shape -// of three components 0-1-2 in the order major to minor. The x- and y- -// dimensions of the tensors are tiled in square tiles with an edge length -// `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads -// transposes one tile: each thread copies kTileSize/kNumRows elements from -// the input to a shared memory tile, then the otherwise "regular HLO kernel" -// reads from the shared memory instead of the original input. -// -// This is similar to the following CUDA algorithm in TensorFlow: -// https://goo.gl/MStRV6. -// -// `kTileSize` should usually be same as warp size. We currently choose 32 for -// `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`. -// -// TODO(b/33320379): Here each block transposes 1 tile. It may be more -// efficient to launch fewer blocks so each transposes many tiles. -class TransposeFusion : public KernelFusionEmitterBase { - public: - explicit TransposeFusion(const se::DeviceDescription& gpu_device_info, - const HloFusionAnalysis& analysis); - LaunchDimensions launch_dimensions() const override; - - std::optional ComputeThreadIdToOutputIndexing( - int64_t root_index, mlir::MLIRContext* ctx) const override; - - std::optional ComputeThreadIdToInputIndexing( - int64_t root_index, int64_t hero_operand_index, - mlir::MLIRContext* ctx) const override; - - protected: - absl::Status EmitKernel(IrEmitterContext& ir_emitter_context, - const HloFusionInstruction& fusion, - const LaunchDimensions& launch_dims, - std::vector inputs, - std::vector outputs, - llvm::IRBuilder<>* builder) const override; - - private: - const HloFusionAnalysis& analysis_; - Tiling tiling_; - absl::InlinedVector permutation_; -}; - -} // namespace gpu -} // namespace xla - -#endif // XLA_SERVICE_GPU_FUSIONS_LEGACY_TRANSPOSE_H_ diff --git a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc b/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc deleted file mode 100644 index bba3d721368e5b..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/legacy/transpose_test.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "xla/service/gpu/fusions/legacy/transpose.h" - -#include -#include - -#include -#include -#include "absl/status/statusor.h" -#include "mlir/IR/MLIRContext.h" -#include "xla/service/gpu/fusions/fusions.h" -#include "xla/service/gpu/gpu_device_info_for_tests.h" -#include "xla/service/gpu/hlo_fusion_analysis.h" -#include "xla/service/gpu/model/indexing_test_utils.h" -#include "xla/status_macros.h" -#include "xla/stream_executor/device_description.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/statusor.h" - -namespace xla { -namespace gpu { -namespace { - -class TransposeTest : public HloTestBase { - protected: - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } - stream_executor::DeviceDescription device_info_ = - TestGpuDeviceInfo::RTXA6000DeviceInfo(); -}; - -absl::StatusOr> GetTransposeFusion( - const HloFusionAnalysis& analysis) { - auto emitter = GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis}); - auto fusion = dynamic_cast(emitter.get()); - TF_RET_CHECK(fusion != nullptr); - - emitter.release(); - return std::unique_ptr{fusion}; -} - -TEST_F(TransposeTest, ThreadIndexing021) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fusion { - %input = f32[100,32,64] parameter(0) - ROOT transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1} - } - - ENTRY entry { - %input = f32[100,32,64] parameter(0) - ROOT %fusion = f32[100,64,32] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); - mlir::MLIRContext mlir_context; - - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - (d3 mod 2) * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); -} - -TEST_F(TransposeTest, ThreadIndexing201_SimplifiedTo021) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fusion { - %input = f32[1,6400,32] parameter(0) - ROOT transpose = f32[1,32,6400] transpose(%input), dimensions={0,2,1} - } - - ENTRY entry { - %input = f32[1,6400,32] parameter(0) - ROOT %fusion = f32[1,32,6400] fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); - mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - 0, - d3 * 32 + s1 * 4 + d0 floordiv 32, - d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - 0, - d0 floordiv 32 + s1 * 4, - d3 * 32 + d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); -} - -TEST_F(TransposeTest, ThreadIndexingPartialBlock) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule m - - fused_computation { - %p0 = f64[24,2,24] parameter(0) - ROOT %t = f64[24,2,24] transpose(%p0), dimensions={2,1,0} - } - - ENTRY main { - %p0 = f64[24,2,24] parameter(0) - ROOT %fusion = f64[24,2,24] fusion(%p0), kind=kInput, - calls=%fused_computation - } - )") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); - mlir::MLIRContext mlir_context; - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], - s0 in [0, 5], - s1 in [0, 0], - s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true - )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d0 floordiv 32 + s0 * 4, - d3, - d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 1], - d4 in [0, 0], - d5 in [0, 0], - s0 in [0, 5], - s1 in [0, 0], - s2 in [0, 0], - d0 mod 32 in [0, 23], - is_simplified: true - )")); -} - -TEST_F(TransposeTest, SameInputIndexingForRealHeroAndSideOutput) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fusion { - %input = f32[100,32,64] parameter(0) - %transpose = f32[100,64,32] transpose(%input), dimensions={0,2,1} - %bitcast = f32[100,2048] bitcast(%input) - ROOT %tuple = (f32[100,64,32], f32[100,2048]) tuple(%transpose, %bitcast) - } - - ENTRY entry { - %input = f32[100,32,64] parameter(0) - ROOT %fusion = (f32[100,64,32], f32[100,2048]) fusion(%input), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); - mlir::MLIRContext mlir_context; - - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(0, 0, &mlir_context)->ToString(), - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString()); -} - -TEST_F(TransposeTest, ThreadIndexingSideOutput) { - auto module = ParseAndReturnVerifiedModule(R"( - HloModule module - - fusion { - %input0 = f32[100,32,64] parameter(0) - %input1 = f32[100,32] parameter(1) - %transpose = f32[100,64,32] transpose(%input0), dimensions={0,2,1} - %broadcast = f32[100,32,64] broadcast(%input1), dimensions={0,1} - ROOT %tuple = (f32[100,64,32], f32[100,32,64]) tuple(%transpose, %broadcast) - } - - ENTRY entry { - %input0 = f32[100,32,64] parameter(0) - %input1 = f32[100,32] parameter(1) - ROOT %fusion = (f32[100,64,32], f32[100,32,64]) fusion(%input0, %input1), kind=kInput, calls=fusion - })") - .value(); - - auto* root = module->entry_computation()->root_instruction(); - auto analysis = HloFusionAnalysis::Create(*root, device_info_); - - TF_ASSERT_OK_AND_ASSIGN(auto fusion, GetTransposeFusion(analysis)); - mlir::MLIRContext mlir_context; - // Check if side output `%broadcast` get the correct input indexing, which - // should corresponds to `%input1` with shape [100,32]. - EXPECT_THAT( - fusion->ComputeThreadIdToInputIndexing(1, 0, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); - EXPECT_THAT( - fusion->ComputeThreadIdToOutputIndexing(1, &mlir_context)->ToString(), - MatchIndexingString(R"( - (d0, d1, d2, d3, d4, d5)[s0, s1, s2] -> ( - d3 floordiv 2, - d0 floordiv 32 + s1 * 4, - (d3 mod 2) * 32 + d0 mod 32 - ), - domain: - d0 in [0, 127], - d1 in [0, 0], - d2 in [0, 0], - d3 in [0, 199], - d4 in [0, 0], - d5 in [0, 0], - - s0 in [0, 0], - s1 in [0, 7], - s2 in [0, 0], - is_simplified: true - )")); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc index 448d3050c30bd4..4b17a1873132e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.cc @@ -305,7 +305,7 @@ MlirFusionEmitterBase::CreateLLVMModule( mlir::PassManager pm(&mlir_context); AddXlaGpuOpsOptimizationPasses(pm); - AddLoopTransformationPasses(pm); + AddLoopTransformationPasses(pm, device); AddLoweringPasses(pm, device); auto pipeline_status = RunPassPipeline(module.get(), pm, trace.get()); if (trace) { @@ -539,15 +539,17 @@ void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm) { pm.addPass(mlir::createCSEPass()); } -void AddLoopTransformationPasses(mlir::OpPassManager& pm) { - pm.addNestedPass(CreateLowerXlaGpuToScfPass()); +void AddLoopTransformationPasses(mlir::OpPassManager& pm, + const se::DeviceDescription& device) { + pm.addNestedPass( + CreateLowerXlaGpuToScfPass(device.threads_per_warp())); + pm.addNestedPass(CreatePeelLoopsPass()); pm.addPass(mlir::createInlinerPass({}, [&](mlir::OpPassManager& pm) { // CSE after inlining because inlining can introduce duplicates. pm.addPass(mlir::createCSEPass()); })); pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(mlir::createCSEPass()); - pm.addNestedPass(CreatePeelLoopsPass()); pm.addNestedPass(CreateLowerXlaGpuLoopsToScfPass()); pm.addPass(mlir::mhlo::createConvertToSignlessPass()); pm.addPass(CreatePropagateSliceIndicesPass()); diff --git a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h index 68ce87f4374aab..8c2cbc090edce6 100644 --- a/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h @@ -118,7 +118,8 @@ void AddXlaGpuOpsOptimizationPasses(mlir::OpPassManager& pm); // Adds passes that transform XLA_GPU and SCF loops, e.g. peel, pipeline, // vectorize. -void AddLoopTransformationPasses(mlir::OpPassManager& pm); +void AddLoopTransformationPasses(mlir::OpPassManager& pm, + const se::DeviceDescription& device); // Adds passes that lower transformed loops to LLVM. void AddLoweringPasses(mlir::OpPassManager& pm, diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc index b7f62e3c7d1d54..f0a5e7b9c8fc8c 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_base.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_base.cc @@ -56,11 +56,11 @@ namespace xla { namespace gpu { int RowReductionGetRowsPerWarp(int reduced_dimension_size) { - if (WarpSize() % reduced_dimension_size != 0 || - reduced_dimension_size >= WarpSize()) { + if (64 % reduced_dimension_size != 0 || + reduced_dimension_size >= 64) { return 1; } - return WarpSize() / reduced_dimension_size; + return 64 / reduced_dimension_size; } int GetVectorSize(const HloFusionAnalysis& analysis, @@ -168,8 +168,8 @@ ReductionGroups GroupDisjointReductions(const HloFusionAnalysis& analysis, auto [it, inserted] = disjoint_sets.try_emplace(root, root); CHECK(inserted) << "Duplicate root " << root.ToString(); // Crash OK reachable_outputs[root].insert(root); - result.is_reduction_root.push_back( - IsRealReductionHero(root.instruction(), hero.instruction())); + result.is_reduction_root.push_back(IsRealReductionHero( + root.instruction(), hero.instruction(), analysis.device_info())); if (result.is_reduction_root.back()) { roots_with_reduction.insert(root); } else if (first_non_reduction_root != nullptr) { diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc index b2db8fa5cd1730..e1da6e3f96a40c 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.cc @@ -117,10 +117,18 @@ struct MlirReductionFusion::EmitterState { PerThreadOutputs EmitPerThreadElements(int group_id, const HloValueMap& inits, const SmallVector& outputs); + mlir::ValueRange ReduceViaSharedMemory(int group_id, + const PerThreadOutputs& per_thread, + const HloValueMap& inits, + std::optional padding, + int max_dist); + mlir::ValueRange ReduceViaSharedMemory( int group_id, const PerThreadOutputs& per_thread, - const HloValueMap& inits, std::optional padding = std::nullopt, - int max_dist = WarpSize() / 2); + const HloValueMap& inits, std::optional padding = std::nullopt) { + return ReduceViaSharedMemory(group_id, per_thread, inits, padding, + owner.WarpSize() / 2); + } mlir::func::FuncOp GetReducer(const HloInstruction* hero) const { return call_target(hero->called_computations()[0]->root_instruction()); @@ -133,8 +141,12 @@ struct MlirReductionFusion::EmitterState { const HloValueMap& values, std::optional padding = std::nullopt); HloValueMap ShuffleReduce(absl::Span reductions, - const HloValueMap& per_thread_values, - int max_dist = WarpSize() / 2); + const HloValueMap& per_thread_values, int max_dist); + + HloValueMap ShuffleReduce(absl::Span reductions, + const HloValueMap& per_thread_values) { + return ShuffleReduce(reductions, per_thread_values, owner.WarpSize() / 2); + } SmallVector FusionParams() { return ValueRange(entry_function.getArguments().take_front( @@ -253,7 +265,7 @@ SmallVector MlirReductionFusion::EmitterState::WriteToSharedMemory( } if (padding) { shape.back() += *padding; - } else if ((shape.back() % WarpSize()) == 0) { + } else if ((shape.back() % owner.WarpSize()) == 0) { // Avoid bank conflicts. ++shape.back(); } @@ -325,7 +337,7 @@ mlir::ValueRange MlirReductionFusion::EmitterState::ReduceViaSharedMemory( // The constraints may have reduced the upper bound of the dimension. If // that's the case, we reset it to a multiple of the warp size. auto& bound = loop_indexing.GetMutableDimensionBound(0); - bound.upper = RoundUpTo(bound.upper + 1, WarpSize()) - 1; + bound.upper = RoundUpTo(bound.upper + 1, owner.WarpSize()) - 1; auto tiles = WriteToSharedMemory(reductions, per_thread.reduction_scalars, padding); @@ -364,8 +376,7 @@ MlirReductionFusion::MlirReductionFusion(const HloFusionAnalysis& analysis) GetReductionKindAndContiguousComponents(*hero_reduction); VLOG(10) << reduction_dimensions_; - CHECK(ReductionIsRaceFree(hero_reduction->GetModule()->config(), - reduction_dimensions_)) + CHECK(ReductionIsRaceFree(reduction_dimensions_, analysis.device_info())) << "Non-race-free reductions should have been decomposed. Did " "tree_reduction_rewriter run?"; @@ -770,38 +781,17 @@ llvm::SmallVector MlirSmallColumnReductionFusion::EmitReduction( shared_rows_ / 2); } -std::unique_ptr CreateMlirReductionFusion( - const HloFusionAnalysis& analysis) { - auto* hero_reduction = analysis.FindHeroReduction(); - CHECK_NE(hero_reduction, nullptr); - ReductionDimensions reduction_dimensions = - GetReductionKindAndContiguousComponents(*hero_reduction); - if (reduction_dimensions.is_row_reduction) { - if (RowReductionGetRowsPerWarp( - reduction_dimensions.dimensions[kRowMinorReduced]) > 1) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); - } - - if (WarpSize() % reduction_dimensions.dimensions[kColMinorKept] == 0) { - return std::make_unique(analysis); - } - return std::make_unique(analysis); -} MlirRowReductionFusion::MlirRowReductionFusion( const HloFusionAnalysis& analysis) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - CHECK_EQ(RowReductionGetRowsPerWarp(shape[kRowMinorReduced]), 1); constexpr int64_t kMinorReducedElementsPerThread = 16; int64_t num_threads_kept = 1; int64_t num_threads_reduced = [&] { - int64_t max_block_size = - MinThreadsXRowReduction(first_reduce_->GetModule()->config()); + int64_t max_block_size = MinThreadsXRowReduction(); return std::min(max_block_size, RoundUpTo(CeilOfRatio(shape[kRowMinorReduced], kMinorReducedElementsPerThread), @@ -931,34 +921,29 @@ llvm::SmallVector MlirRowReductionFusion::EmitReduction( } MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( - const HloFusionAnalysis& analysis) + const HloFusionAnalysis& analysis, int vector_size) : MlirReductionFusion(analysis) { CHECK(reduction_dimensions_.is_row_reduction); Vector3 shape = reduction_dimensions_.dimensions; - int64_t rows_per_warp = RowReductionGetRowsPerWarp(shape[kRowMinorReduced]); - input_shape_ = {shape[0], shape[1], shape[2]}; - CHECK_GT(rows_per_warp, 1); - - auto compute_block_size = [&](int vector_size) { - int64_t num_threads_reduced = shape[kRowMinorReduced] / vector_size; - - constexpr int64_t kThreadsPerBlockTarget = 256; - int64_t kept_size = reduction_dimensions_.dimensions[kRowKept]; - int64_t num_threads_kept = 1; - if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { - num_threads_kept = kept_size; - } else { - num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; - } - num_threads_ = {num_threads_kept, num_threads_reduced}; - tile_sizes_per_thread_ = {shape[0], vector_size}; - num_blocks_ = {CeilOfRatio(input_shape_[kRowKept], num_threads_kept)}; - }; - - // Compute the launch grid without vectorization. We use the results to - // compute the vectorized launch grid. - compute_block_size(1); + input_shape_ = {shape[0], shape[1], shape[2]}; + num_threads_ = GetNumThreads(reduction_dimensions_, vector_size); + num_blocks_ = {GetNumBlocks(reduction_dimensions_, num_threads_)}; + tile_sizes_per_thread_ = {shape[0], vector_size}; +} + +std::unique_ptr MlirMultiRowReductionFusion::TryCreate( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + auto reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + auto shape = reduction_dimensions.dimensions; + // This emitter only supports reductions where the reduced dimension is a + // power of 2. + if (shape[kRowMinorReduced] & (shape[kRowMinorReduced] - 1)) { + return nullptr; + } // Normally, we only consider input types for vectorization. However, in // multi-row reductions, the input:output ratio is much higher, so we consider // both inputs and outputs. @@ -966,23 +951,74 @@ MlirMultiRowReductionFusion::MlirMultiRowReductionFusion( std::min(analysis.input_output_info().smallest_input_dtype_bits, analysis.input_output_info().smallest_output_dtype_bits); - // This vector size is always valid: we know that the reduced dimension is a - // power of 2, since otherwise RowReductionGetRowsPerWarp would have - // returned 1. + int largest_input_or_output_bits = + std::max(analysis.input_output_info().smallest_input_dtype_bits, + analysis.input_output_info().smallest_output_dtype_bits); // Our codegen can't currently deal with vectorization across rows, so we // limit the vector size to the size of the row. Note that this emitter // essentially reverts to the loop emitter in this case, except for side // outputs. - int vector_size = std::min(static_cast(input_shape_[kRowMinorReduced]), - 32 / smallest_input_or_output_bits); - - // We target 8 warps per block, which means there could be up to 8 blocks per - // SM, but we have no good way of knowing. In practice, enabling vectorization - // for decently sized reductions at least does not hurt. - if (num_blocks_.front() > analysis.device_info().core_count() && - vector_size > 1) { - compute_block_size(vector_size); + int vector_size = std::min(static_cast(shape[kRowMinorReduced]), + 64 / smallest_input_or_output_bits); + + // Very large vector sizes for f32 can be detrimental, so we limit the vector + // size to 16 bytes if we have some >= 32 bit inputs or outputs. This is still + // a bit on the high side, but remember that we also have very small inputs + // or outputs. + if (largest_input_or_output_bits >= 32) { + vector_size = std::min(128 / largest_input_or_output_bits, vector_size); + } + + // The reduced dimension must fit into a single warp. + const int64_t warp_size = analysis.device_info().threads_per_warp(); + if (shape[kRowMinorReduced] > warp_size * vector_size) { + return nullptr; + } + + // At the very least, we want to have work for every SM. + // TODO(jreiffers): This limit is probably too low: if we have as many blocks + // as SMs, we'll only run about 8 warps per SM, so occupancy will be very low. + // Further measurements are needed to refine this heuristic. + int64_t min_desired_blocks = analysis.device_info().core_count(); + while (vector_size > 1 && + GetNumBlocks(reduction_dimensions, + GetNumThreads(reduction_dimensions, vector_size)) < + min_desired_blocks) { + vector_size /= 2; + } + // Check again that the reduced dimension fits after potentially reducing the + // vector size. + if (shape[kRowMinorReduced] > warp_size * vector_size) { + return nullptr; } + + return std::make_unique(analysis, vector_size); +} + +absl::InlinedVector MlirMultiRowReductionFusion::GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size) { + int64_t num_threads_reduced = + reduction_dimensions.dimensions[kRowMinorReduced] / vector_size; + + constexpr int64_t kThreadsPerBlockTarget = 256; + int64_t kept_size = reduction_dimensions.dimensions[kRowKept]; + int64_t num_threads_kept = 1; + if (kept_size * num_threads_reduced <= kThreadsPerBlockTarget) { + num_threads_kept = kept_size; + } else { + num_threads_kept = kThreadsPerBlockTarget / num_threads_reduced; + } + return {num_threads_kept, num_threads_reduced}; +} + +int64_t MlirMultiRowReductionFusion::GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads) { + CHECK_EQ(num_threads.size(), 2) + << "Expected num_threads to contain the number of threads in the {kept, " + "reduced} dimensions."; + return CeilOfRatio(reduction_dimensions.dimensions[kRowKept], + num_threads.front()); } IndexingMap MlirMultiRowReductionFusion::ComputeReductionInputIndexing( @@ -1039,5 +1075,26 @@ llvm::SmallVector MlirMultiRowReductionFusion::EmitReduction( group_id, /*symbol_values=*/{}); } +std::unique_ptr CreateMlirReductionFusion( + const HloFusionAnalysis& analysis) { + auto* hero_reduction = analysis.FindHeroReduction(); + CHECK_NE(hero_reduction, nullptr); + ReductionDimensions reduction_dimensions = + GetReductionKindAndContiguousComponents(*hero_reduction); + if (reduction_dimensions.is_row_reduction) { + auto multi_row_emitter = MlirMultiRowReductionFusion::TryCreate(analysis); + if (multi_row_emitter != nullptr) { + return multi_row_emitter; + } + return std::make_unique(analysis); + } + + const int64_t warp_size = analysis.device_info().threads_per_warp(); + if (warp_size % reduction_dimensions.dimensions[kColMinorKept] == 0) { + return std::make_unique(analysis); + } + return std::make_unique(analysis); +} + } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h index 838729254070ac..a418515215f7b2 100644 --- a/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/reduction_mlir.h @@ -16,6 +16,7 @@ limitations under the License. #define XLA_SERVICE_GPU_FUSIONS_REDUCTION_MLIR_H_ #include +#include #include #include #include @@ -123,6 +124,10 @@ class MlirReductionFusion : public MlirFusionEmitterBase { return IndexingMap::GetUndefined(); } + int64_t WarpSize() const { + return ::xla::gpu::WarpSize(analysis_.device_info()); + } + // The reduction heroes for each reduction group. std::vector> reduction_heroes_; // The roots that have reduction heroes for each reduction group. @@ -168,9 +173,21 @@ class MlirRowReductionFusion : public MlirReductionFusion { class MlirMultiRowReductionFusion : public MlirReductionFusion { public: - explicit MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis); + MlirMultiRowReductionFusion(const HloFusionAnalysis& analysis, int vector_size); + // Attempts to create a multi-row reduction emitter for the given analysis. + // Returns nullptr if the fusion is not supported. + static std::unique_ptr TryCreate( + const HloFusionAnalysis& analysis); protected: + // Returns the number of {kept, reduced} threads for the given reduction and + // vector size. + static absl::InlinedVector GetNumThreads( + const ReductionDimensions& reduction_dimensions, int vector_size); +static int64_t GetNumBlocks( + const ReductionDimensions& reduction_dimensions, + const absl::InlinedVector& num_threads); + int GetRowsPerWarp() const; llvm::SmallVector EmitReduction( int group_id, EmitterState& state) const override; diff --git a/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo new file mode 100644 index 00000000000000..b225b378acc732 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/loop/broadcast_constant_block_dim_limit.hlo @@ -0,0 +1,15 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt --xla-gpu-test-optimize \ +// RUN: --inline="default-pipeline='cse'" | FileCheck %s + +bcast { + one = bf16[] constant(1) + ROOT broadcast = bf16[24,2048,2048,3,4096]{4,3,2,1,0} broadcast(one), dimensions={} +} + +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<24x2048x2048x3x4096xbf16> +// CHECK: gpu.block_id x {xla.range = [0 : index, 1207959551 : index]} +// CHECK: gpu.block_id y {xla.range = [0 : index, 1 : index]} +// CHECK: xla_gpu.loop ({{.*}})[{{.*}}] -> (%[[RA:.*]], %[[RB:.*]], %[[RC:.*]], %[[RD:.*]], %[[RE:.*]]) in +// CHECK-SAME: iter_args(%[[ITER:.*]] = %[[ARG0]]) +// CHECK: %[[CST:.*]] = arith.constant 1.000 +// CHECK: %[[INSERTED:.*]] = tensor.insert %[[CST]] into %[[ITER]][%[[RA]], %[[RB]], %[[RC]], %[[RD]], %[[RE]]] diff --git a/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo new file mode 100644 index 00000000000000..241da1bab33948 --- /dev/null +++ b/third_party/xla/xla/service/gpu/fusions/tests/reduce_multirow/f16_v4.hlo @@ -0,0 +1,22 @@ +// RUN: fusion_to_mlir %s | mlir_fusions_opt -xla-gpu-test-optimize \ +// RUN: -xla-gpu-test-transform-loops | FileCheck %s + +// The reference implementation reduces in f64, so we need a larger tolerance. +// RUN: test_correctness %s --bijection_inputs=reduce:0 \ +// RUN: --bijection_outputs=reduce --abs_error_bound=0.005 --rel_error_bound=0.005 + +add { + lhs = f16[] parameter(0) + rhs = f16[] parameter(1) + ROOT add = f16[] add(lhs, rhs) +} + +fusion { + param_0 = f16[2048,64] parameter(0) + c = f16[] constant(0) + ROOT reduce = f16[2048] reduce(param_0, c), dimensions={1}, to_apply=add +} + +// If unvectorized, this would be a regular row reduction. However, since we can +// vectorize to size four, we can emit this as a multi-row reduction. +// CHECK: vector.transfer_read {{.*}} vector<4xf16> \ No newline at end of file diff --git a/third_party/xla/xla/service/gpu/fusions/tools/BUILD b/third_party/xla/xla/service/gpu/fusions/tools/BUILD index 66079f5dcb02b3..72a52fa8c7d020 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/tools/BUILD @@ -11,6 +11,7 @@ xla_cc_binary( visibility = ["//xla/service/gpu/fusions:__subpackages__"], deps = [ "//xla/mlir_hlo", + "//xla/service/gpu:gpu_device_info_for_tests", "//xla/service/gpu/fusions/ir:xla_gpu", "//xla/service/gpu/fusions/mlir:mlir_fusion_emitter", "//xla/service/gpu/fusions/transforms:passes", diff --git a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc index 0db20fb3a3bbe0..43a1f708286456 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/mlir_fusions_opt.cc @@ -38,6 +38,7 @@ limitations under the License. #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/mlir_fusion_emitter.h" #include "xla/service/gpu/fusions/transforms/passes.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" int main(int argc, char** argv) { mlir::DialectRegistry registry; @@ -76,7 +77,8 @@ int main(int argc, char** argv) { llvm::function_ref errorHandler) { if (!options.empty()) return mlir::failure(); - xla::gpu::AddLoopTransformationPasses(pm); + xla::gpu::AddLoopTransformationPasses( + pm, xla::gpu::TestGpuDeviceInfo::RTXA6000DeviceInfo()); return mlir::success(); }, [](llvm::function_ref) {}); diff --git a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc index 11b82ddd517072..cb9f40644c684f 100644 --- a/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc +++ b/third_party/xla/xla/service/gpu/fusions/tools/test_lib.cc @@ -50,9 +50,9 @@ namespace gpu { absl::StatusOr> LoadTestModule( absl::string_view filename) { auto module = *xla::LoadModuleFromFile(std::string(filename)); - module->mutable_config() - .mutable_debug_options() - .set_xla_gpu_mlir_emitter_level(4); + // module->mutable_config() + // .mutable_debug_options() + // .set_xla_gpu_mlir_emitter_level(4); int num_fusions = absl::c_count_if( module->entry_computation()->instructions(), diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD index e06494acb6262e..ee8a424c5e93e3 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/fusions/transforms/BUILD @@ -45,7 +45,6 @@ cc_library( "optimize_loops.cc", "peel_loops.cc", "propagate_slice_indices.cc", - "rewrite_reductions.cc", "simplify_affine.cc", "simplify_arith.cc", "unswitch_loops.cc", diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc index be1686164d656f..729b26b8a47212 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc +++ b/third_party/xla/xla/service/gpu/fusions/transforms/lower_xla_gpu_to_scf.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" #include "xla/service/gpu/fusions/mlir/elemental_hlo_to_mlir.h" +#include "xla/service/gpu/fusions/transforms/passes.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/service/gpu/model/indexing_map.h" #include "xla/util.h" @@ -66,7 +67,9 @@ using mlir::ValueRange; using mlir::scf::IfOp; struct RewritePredicatedInsert : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewritePredicatedInsert(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( PredicatedInsertOp op, mlir::PatternRewriter& rewriter) const override { @@ -86,7 +89,9 @@ struct RewritePredicatedInsert : mlir::OpRewritePattern { }; struct RewritePredicatedExtract : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewritePredicatedExtract(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( PredicatedExtractOp op, mlir::PatternRewriter& rewriter) const override { @@ -106,15 +111,19 @@ struct RewritePredicatedExtract : mlir::OpRewritePattern { }; struct RewriteShuffleReduce : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + const int64_t warp_size; + + RewriteShuffleReduce(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context), warp_size(options.warp_size) {} mlir::LogicalResult matchAndRewrite( ShuffleReduceOp op, mlir::PatternRewriter& rewriter) const override { int max_distance = mlir::cast(op->getAttr("max_distance")).getInt(); // TODO(jreiffers): Do this in a verifier. - if (max_distance & (max_distance - 1) || max_distance >= WarpSize()) { - return op->emitOpError("max_distance must be a power of 2 < WarpSize()"); + if (max_distance & (max_distance - 1) || max_distance >= warp_size) { + return op->emitOpError("max_distance must be a power of 2 < warp_size"); } ImplicitLocOpBuilder b(op.getLoc(), rewriter); @@ -123,7 +132,7 @@ struct RewriteShuffleReduce : mlir::OpRewritePattern { namespace ml = mlir::LLVM; auto shuffle_32 = [&](Value v) { return b - .create(v, distance, WarpSize(), + .create(v, distance, warp_size, mlir::gpu::ShuffleMode::DOWN) .getShuffleResult(); }; @@ -259,7 +268,9 @@ mlir::VectorType getThreadLevelVectorType(IndexedVectorType indexed_vector) { } struct RewriteMaterialize : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewriteMaterialize(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( MaterializeOp op, mlir::PatternRewriter& rewriter) const override { @@ -316,7 +327,9 @@ struct RewriteMaterialize : mlir::OpRewritePattern { }; struct RewriteInsert : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + RewriteInsert(mlir::MLIRContext* context, + const LowerXlaGpuToScfPassOptions& options) + : OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( InsertOp op, mlir::PatternRewriter& rewriter) const override { @@ -368,16 +381,22 @@ struct RewriteInsert : mlir::OpRewritePattern { class LowerXlaGpuToScfPass : public impl::LowerXlaGpuToScfPassBase { public: + explicit LowerXlaGpuToScfPass(const LowerXlaGpuToScfPassOptions& options) + : options_(options) {} + void runOnOperation() override { auto* ctx = &getContext(); mlir::RewritePatternSet patterns(ctx); patterns.add(ctx); + RewriteShuffleReduce, RewriteMaterialize, RewriteInsert>( + ctx, options_); if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { signalPassFailure(); } } + private: + const LowerXlaGpuToScfPassOptions options_; }; class LowerXlaGpuLoopsToScfPass @@ -396,8 +415,11 @@ class LowerXlaGpuLoopsToScfPass } // namespace -std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass() { - return std::make_unique(); +std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuToScfPass( + const int64_t warp_size) { + LowerXlaGpuToScfPassOptions options; + options.warp_size = warp_size; + return std::make_unique(options); } std::unique_ptr<::mlir::Pass> CreateLowerXlaGpuLoopsToScfPass() { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h index 470a333f70ccca..08db1729cf315c 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.h +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ #define XLA_SERVICE_GPU_FUSIONS_TRANSFORMS_PASSES_H_ +#include #include #include #include @@ -47,13 +48,12 @@ std::unique_ptr CreateFlattenTensorsPass(); std::unique_ptr CreateLowerTensorsPass( bool is_amd_gpu = false, const std::string& gpu_arch = "6.0"); std::unique_ptr CreateLowerToLLVMPass(bool use_rocdl); -std::unique_ptr CreateLowerXlaGpuToScfPass(); +std::unique_ptr CreateLowerXlaGpuToScfPass(int64_t warp_size = 32); std::unique_ptr CreateLowerXlaGpuLoopsToScfPass(); std::unique_ptr CreateMergePointersToSameSlicePass(); std::unique_ptr CreateOptimizeLoopsPass(); std::unique_ptr CreatePeelLoopsPass(); std::unique_ptr CreatePropagateSliceIndicesPass(); -std::unique_ptr CreateRewriteReductionsPass(); std::unique_ptr CreateSimplifyAffinePass(); std::unique_ptr CreateSimplifyArithPass(); std::unique_ptr CreateUnswitchLoopsPass(); diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td index 52a0dacbc3db8f..184af77d9e94aa 100644 --- a/third_party/xla/xla/service/gpu/fusions/transforms/passes.td +++ b/third_party/xla/xla/service/gpu/fusions/transforms/passes.td @@ -177,7 +177,10 @@ def LowerXlaGpuToScfPass : "mlir::gpu::GPUDialect", "mlir::LLVM::LLVMDialect", "mlir::scf::SCFDialect", "mlir::tensor::TensorDialect", "xla::gpu::XlaGpuDialect", "mlir::vector::VectorDialect", ]; - + + let options = [ + Option<"warp_size", "warp_size", "int64_t", /*default=*/"32", "Warp size.">, + ]; let constructor = "CreateLowerXlaGpuToScfPass()"; } @@ -234,6 +237,7 @@ def LowerToLLVMPass : ]; } +/* def RewriteReductionsPass : Pass< "xla-gpu-rewrite-reductions", "mlir::func::FuncOp"> { let summary = "Rewrites reductions to pieces that can efficiently be emitted."; @@ -255,6 +259,7 @@ def RewriteReductionsPass : Pass< let constructor = "CreateRewriteReductionsPass()"; } +*/ def VectorizeLoadsAndStoresPass : Pass<"xla-gpu-vectorize-loads-stores", "mlir::func::FuncOp"> { diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc b/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc deleted file mode 100644 index 50969b8bd6bbd8..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/transforms/rewrite_reductions.cc +++ /dev/null @@ -1,346 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include -#include -#include -#include -#include - -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "xla/service/gpu/fusions/ir/xla_gpu_ops.h" -#include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/gpu/model/indexing_analysis.h" -#include "xla/service/gpu/model/indexing_map.h" -#include "xla/util.h" - -namespace xla { -namespace gpu { - -#define GEN_PASS_DEF_REWRITEREDUCTIONSPASS -#include "xla/service/gpu/fusions/transforms/passes.h.inc" - -namespace { - -class RewriteReductionsPass - : public impl::RewriteReductionsPassBase { - public: - void runOnOperation() override; -}; - -mlir::ShapedType GetInputType(ReduceOp op) { - return mlir::cast(op.getOperand(0).getType()); -} - -mlir::ShapedType GetOutputType(ReduceOp op) { - return mlir::cast(op.getResult(0).getType()); -} - -int GetNumThreads(mlir::Operation* op) { - auto grid = - op->getParentOfType()->getAttrOfType( - "xla_gpu.launch_grid"); - assert(grid); - return Product(grid.getThreadCounts()); -} - -struct DimensionGroup { - int64_t size; - int64_t stride; - int first_dimension; - int num_dimensions; -}; - -DimensionGroup GetMinorMostReduction(ReduceOp op) { - llvm::ArrayRef dims = op.getDimensions(); - - auto input_ty = GetInputType(op); - DimensionGroup result{1, 1, static_cast(input_ty.getRank()), 0}; - llvm::SmallBitVector reduced_dims(input_ty.getRank()); - for (int64_t dim : dims) { - reduced_dims.set(dim); - } - - // Look for the first group of consecutive reduced dimensions and compute the - // stride and size of the group. - bool in_reduction = false; - for (int dim = input_ty.getRank() - 1; - dim >= 0 && (!in_reduction || reduced_dims[dim]); --dim) { - assert(input_ty.getDimSize(dim) > 1 && - "degenerate dimensions are not allowed"); - --result.first_dimension; - if (reduced_dims[dim]) { - in_reduction = true; - result.size *= input_ty.getDimSize(dim); - ++result.num_dimensions; - } else { - result.stride *= input_ty.getDimSize(dim); - } - } - - return result; -} - -llvm::SmallVector ReindexTensors( - mlir::OpBuilder& b, mlir::ValueRange tensors, mlir::ValueRange defaults, - llvm::ArrayRef new_shape, const IndexingMap& map) { - llvm::SmallVector reindexed; - reindexed.reserve(tensors.size()); - for (auto [tensor, def] : llvm::zip(tensors, defaults)) { - auto new_ty = - mlir::cast(tensor.getType()).clone(new_shape); - reindexed.push_back( - b.create(tensor.getLoc(), new_ty, tensor, def, map)); - } - return reindexed; -} - -// Rewrites large row reductions to three reductions: -// 1. to one element per thread. -// 2. to one element per warp. -// 3. to one element per block. -// This also pads the input if the number of threads does not divide the row -// size. -struct RewriteRowReduction : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ReduceOp op, mlir::PatternRewriter& rewriter) const override { - auto* ctx = op.getContext(); - - auto minor_reduction = GetMinorMostReduction(op); - if (minor_reduction.stride > 1) { - return rewriter.notifyMatchFailure(op, "not a row reduction"); - } - - if (minor_reduction.size <= WarpSize()) { - return rewriter.notifyMatchFailure(op, "small minor dimension"); - } - - int64_t num_threads = GetNumThreads(op); - assert(num_threads % WarpSize() == 0); - - llvm::ArrayRef input_shape = GetInputType(op).getShape(); - auto projected_input_shape = llvm::to_vector( - input_shape.take_front(minor_reduction.first_dimension)); - projected_input_shape.push_back(minor_reduction.size); - - // Collapse the minor dimensions into one. - // [..., 123, 456] -> [..., 123 * 456] - auto projection_map = - GetBitcastMap(projected_input_shape, input_shape, ctx); - - // Pad the new minor dimension to a multiple of the number of threads. For - // example, for 128 threads, 123 * 456 = 56088 is padded to 56192. - auto padded_projected_input_shape = projected_input_shape; - int64_t padded_size = RoundUpTo(minor_reduction.size, num_threads); - padded_projected_input_shape.back() = padded_size; - - // Reshape the padded minor dimension so that we can reduce it per thread - // and then per warp. - // [..., 56192] -> [..., 439, 4, 32] - auto per_thread_reduction_input_shape = llvm::to_vector( - input_shape.take_front(minor_reduction.first_dimension)); - per_thread_reduction_input_shape.push_back(padded_size / num_threads); - per_thread_reduction_input_shape.push_back(num_threads / WarpSize()); - per_thread_reduction_input_shape.push_back(WarpSize()); - - int per_thread_input_rank = per_thread_reduction_input_shape.size(); - - auto reindex_map = GetBitcastMap(per_thread_reduction_input_shape, - padded_projected_input_shape, ctx) * - projection_map; - reindex_map.AddConstraint( - mlir::getAffineDimExpr(per_thread_input_rank - 1, ctx) + - mlir::getAffineDimExpr(per_thread_input_rank - 2, ctx) * - num_threads, - {0, minor_reduction.size - 1}); - - auto new_inputs = - ReindexTensors(rewriter, op.getInputs(), op.getInits(), - per_thread_reduction_input_shape, reindex_map); - - // Reduce the non-minor dimensions and the third to last dimension. - auto dims_for_first_reduction = llvm::to_vector( - op.getDimensions().drop_back(minor_reduction.num_dimensions)); - dims_for_first_reduction.push_back(per_thread_input_rank - 3); - auto first_reduction = - rewriter.create(op.getLoc(), new_inputs, op.getInits(), - dims_for_first_reduction, op.getCombiner()); - - // Reduce the last and the second-to-last dimensions. First to produce one - // output element per warp, then to produce one output element per block. - int rank = GetOutputType(first_reduction).getRank(); - auto second_reduction = rewriter.create( - op.getLoc(), first_reduction.getResults(), op.getInits(), - llvm::ArrayRef{rank - 1}, op.getCombiner()); - rewriter.replaceOpWithNewOp( - op, second_reduction.getResults(), op.getInits(), - llvm::ArrayRef{rank - 2}, op.getCombiner()); - - return mlir::success(); - } -}; - -// Rewrites column reductions to a reduce-transpose-reduce. -struct RewriteColumnReduction : mlir::OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - ReduceOp op, mlir::PatternRewriter& rewriter) const override { - auto* ctx = op.getContext(); - - auto minor_reduction = GetMinorMostReduction(op); - - if (minor_reduction.stride == 1) { - return rewriter.notifyMatchFailure(op, "not a column reduction"); - } - - int64_t num_threads = GetNumThreads(op); - - // If the stride is larger than the number of threads, we can efficiently - // emit this reduction as a simple loop, assuming there's no excessive - // padding. - // TODO(jreiffers): Is there anything we can do if the number of threads - // doesn't divide the stride? - if (minor_reduction.stride >= num_threads) { - return rewriter.notifyMatchFailure(op, "efficient loop reduction"); - } - - // A column reduction reduces [a, b] to [b]. We do this in four steps: - // 1. reshape [a, b] to [a ceildiv c, c, b] - // 2. reduce [a ceildiv c, c, b] to [c, b] via a loop - // 3. transpose [c, b] to [b, c] - // 4. emit a row reduction on [b, c]. - // - // We are constrained in our choice for `c`: - // - // - we need one element of shared memory (or a register) for each element - // of the intermediate results, so a larger c needs more shared memory. - // - we can have at most WarpSize intermediate results per final result, - // so c can be at most 32. - // - c must be a power of two so we can use a warp shuffle. - // - c * b should be less than the number of threads (but as close to it - // as possible, so we don't have excessive padding). - // - // All of this assumes no vectorization. - // TODO(jreiffers): Handle vectorization here. - - // Emitters always choose `c = 32` if `b` is not a small power of two. - // Also, reductions are tiled so `b = 32`. The number of threads is always - // 1024. This satisfies all the constraints above. - // Reduce the size of the reduction dimension. The maximum size we can - // handle is the warp size. - - assert(num_threads > minor_reduction.stride); - int64_t c = std::min(WarpSize(), num_threads / minor_reduction.stride); - - llvm::ArrayRef input_shape = GetInputType(op).getShape(); - auto projected_input_shape = llvm::to_vector( - input_shape.take_front(minor_reduction.first_dimension)); - projected_input_shape.push_back(minor_reduction.size); - projected_input_shape.push_back(minor_reduction.stride); - auto projection_map = - GetBitcastMap(projected_input_shape, input_shape, ctx); - int64_t projected_rank = projected_input_shape.size(); - - // Pad the new minor dimension to a multiple of c. - auto padded_projected_input_shape = projected_input_shape; - int64_t padded_size = RoundUpTo(minor_reduction.size, c); - padded_projected_input_shape[projected_rank - 2] = padded_size; - - // Reshape the input to [..., a ceildiv c, c, b] - auto reshaped_input_shape = llvm::to_vector( - input_shape.take_front(minor_reduction.first_dimension)); - reshaped_input_shape.push_back(padded_size / c); - reshaped_input_shape.push_back(c); - reshaped_input_shape.push_back(minor_reduction.stride); - int64_t reshaped_rank = reshaped_input_shape.size(); - - auto reindex_map = - GetBitcastMap(reshaped_input_shape, padded_projected_input_shape, ctx) * - projection_map; - reindex_map.AddConstraint( - mlir::getAffineDimExpr(reshaped_rank - 2, ctx) + - mlir::getAffineDimExpr(reshaped_rank - 3, ctx) * c, - {0, minor_reduction.size - 1}); - - auto new_inputs = ReindexTensors(rewriter, op.getInputs(), op.getInits(), - reshaped_input_shape, reindex_map); - - // Reduce the non-minor dimensions and the third to last dimension. - // [..., a ceildiv c, c, b] -> [..., c, b] - auto dims_for_first_reduction = llvm::to_vector( - op.getDimensions().drop_back(minor_reduction.num_dimensions)); - dims_for_first_reduction.push_back(reshaped_rank - 3); - auto first_reduction = - rewriter.create(op.getLoc(), new_inputs, op.getInits(), - dims_for_first_reduction, op.getCombiner()); - - // Transpose [..., c, b] to [..., b, c] - auto shape = GetOutputType(first_reduction).getShape(); - int64_t first_reduction_rank = shape.size(); - llvm::SmallVector permutation(first_reduction_rank); - std::iota(permutation.begin(), permutation.end(), 0); - std::swap(permutation[first_reduction_rank - 1], - permutation[first_reduction_rank - 2]); - - auto transposed_shape = llvm::to_vector(shape); - std::swap(transposed_shape[first_reduction_rank - 1], - transposed_shape[first_reduction_rank - 2]); - IndexingMap transpose_map( - mlir::AffineMap::getPermutationMap(permutation, ctx), - DimVarsFromTensorSizes(transposed_shape), {}, {}); - - auto transposed = - ReindexTensors(rewriter, first_reduction.getResults(), op.getInits(), - transposed_shape, transpose_map); - - rewriter.replaceOpWithNewOp( - op, transposed, op.getInits(), - llvm::ArrayRef{first_reduction_rank - 1}, op.getCombiner()); - return mlir::success(); - } -}; - -void RewriteReductionsPass::runOnOperation() { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { - signalPassFailure(); - } -} - -} // namespace - -std::unique_ptr> -CreateRewriteReductionsPass() { - return std::make_unique(); -} - -} // namespace gpu -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir b/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir deleted file mode 100644 index 94c6cddd4a8a40..00000000000000 --- a/third_party/xla/xla/service/gpu/fusions/transforms/tests/rewrite_reductions.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: mlir_fusions_opt %s -split-input-file -xla-gpu-rewrite-reductions | \ -// RUN: FileCheck %s - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @row_reduction(%arg0: tensor<128x1027xf32>) - -> tensor<128xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add - : tensor<128x1027xf32> to tensor<128xf32> - return %0 : tensor<128xf32> -} - -// CHECK: #[[$PAD_AND_RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 128 + d2 * 32 + d3), -// CHECK-SAME: domain: d0 in [0, 127], d1 in [0, 8], d2 in [0, 3], d3 in [0, 31], d1 * 128 + d2 * 32 + d3 in [0, 1026] -// CHECK-LABEL: @row_reduction -// CHECK-SAME: %[[IN:.*]]: tensor<128x1027xf32> -// CHECK: %[[C0:.*]] = arith.constant 0.00 -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$PAD_AND_RESHAPE]] default %[[C0]] -// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] combiner=@add -// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[R1]]) inits(%[[C0]]) dimensions=[2] combiner=@add -// CHECK: %[[R3:.*]] = xla_gpu.reduce(%[[R2]]) inits(%[[C0]]) dimensions=[1] combiner=@add -// CHECK: return %[[R3]] : tensor<128xf32> - -// ----- - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @row_reduction_with_major_reduced_dim(%arg0: tensor<2x42x128x32x8xf32>) - -> tensor<2x128xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1, 3, 4] combiner=@add - : tensor<2x42x128x32x8xf32> to tensor<2x128xf32> - return %0 : tensor<2x128xf32> -} - -// CHECK-LABEL: @row_reduction_with_major_reduced_dim -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex -// CHECK-SAME: : tensor<2x42x128x32x8xf32> -> tensor<2x42x128x2x4x32xf32> -// CHECK: xla_gpu.reduce(%[[REINDEXED]]) -// CHECK-SAME: dimensions=[1, 3] -// CHECK-SAME: : tensor<2x42x128x2x4x32xf32> - -// ----- - -func.func @add(%a: f32, %b: f32) -> f32 { - %0 = arith.addf %a, %b : f32 - return %0 : f32 -} - -func.func @column(%arg0: tensor<2x32x32xf32>) - -> tensor<2x32xf32> attributes { - xla_gpu.launch_grid = #xla_gpu.launch_grid< - block_counts = [42, 1, 1], - thread_counts = [128, 1, 1] - > - } { - %c0 = arith.constant 0.0 : f32 - %0 = xla_gpu.reduce (%arg0) inits(%c0) dimensions=[1] combiner=@add - : tensor<2x32x32xf32> to tensor<2x32xf32> - return %0 : tensor<2x32xf32> -} - -// CHECK: #[[$RESHAPE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3) -// CHECK-SAME: d1 * 4 + d2 in [0, 31] -// CHECK: #[[$TRANSPOSE:.*]] = #xla_gpu.indexing_map<(d0, d1, d2) -> (d0, d2, d1) -// CHECK-LABEL: @column -// CHECK-SAME: %[[IN:.*]]: tensor<2x32x32xf32> -// CHECK: %[[C0:.*]] = arith.constant 0.00 -// CHECK: %[[REINDEXED:.*]] = xla_gpu.reindex %[[IN]] at #[[$RESHAPE]] default %[[C0]] -// CHECK-SAME: -> tensor<2x8x4x32xf32> -// CHECK: %[[R1:.*]] = xla_gpu.reduce(%[[REINDEXED]]) inits(%[[C0]]) dimensions=[1] -// CHECK-SAME: to tensor<2x4x32xf32> -// CHECK: %[[TRANSPOSED:.*]] = xla_gpu.reindex %[[R1]] at #[[$TRANSPOSE]] -// CHECK-SAME: -> tensor<2x32x4xf32> -// CHECK: %[[R2:.*]] = xla_gpu.reduce(%[[TRANSPOSED]]) inits(%[[C0]]) dimensions=[2] -// CHECK: return %[[R2]] : tensor<2x32xf32> diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc index fd18cef310a8fb..7f54fa2f341316 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.cc @@ -74,7 +74,6 @@ using mlir::func::ReturnOp; using mlir_converter::ApplyIndexing; constexpr int kNumRows = 4; -constexpr int kBaseBlockSize = WarpSize(); constexpr int kNumThreadsPerBlock = 128; constexpr int kMaxVectorizedBytes = 4; @@ -85,7 +84,8 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) transpose_(analysis.tiled_transpose()), permutation_(transpose_.permutation), input_shape_( - Permute(transpose_.dimensions, InversePermutation(permutation_))) { + Permute(transpose_.dimensions, InversePermutation(permutation_))), + base_block_size_(WarpSize(analysis_.device_info())) { ConstHloInstructionSet transposes_to_tile; int index = 0; int64_t shmem_usage = 0; @@ -103,7 +103,7 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) size *= input_shape_.back(); } max_element_bytes = std::max(max_element_bytes, size); - shmem_usage += kBaseBlockSize * (kBaseBlockSize + 1) * size; + shmem_usage += base_block_size_ * (base_block_size_ + 1) * size; shmem_transpose_root_indices_.push_back(index); } else { side_output_roots_.push_back(&root.instruction()); @@ -117,12 +117,12 @@ MlirTransposeFusion::MlirTransposeFusion(const HloFusionAnalysis& analysis) vector_size_ = vector_size; block_sizes_.assign(input_shape_.size(), 1); if (MostMinorDimensionUnchanged()) { - block_size_ = kBaseBlockSize; + block_size_ = base_block_size_; block_sizes_.back() = vector_size_; block_sizes_[block_sizes_.size() - 2] = block_size_; block_sizes_[permutation_[block_sizes_.size() - 2]] = block_size_; } else { - block_size_ = kBaseBlockSize * vector_size_; + block_size_ = base_block_size_ * vector_size_; block_sizes_.back() = block_size_; block_sizes_[permutation_.back()] = block_size_; } diff --git a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h index 9602242fe4745a..6fcb666ab2abbe 100644 --- a/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h +++ b/third_party/xla/xla/service/gpu/fusions/transpose_mlir.h @@ -109,7 +109,8 @@ class MlirTransposeFusion : public MlirFusionEmitterBase { std::vector block_counts_; int vector_size_; int block_size_; - + int64_t base_block_size_; + std::vector shmem_transposes_; std::vector shmem_transpose_roots_; std::vector shmem_transpose_root_indices_; diff --git a/third_party/xla/xla/service/gpu/fusions/triton.cc b/third_party/xla/xla/service/gpu/fusions/triton.cc index 7d235c132989c4..0b3bbd0eb421f6 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton.cc @@ -174,7 +174,8 @@ absl::StatusOr TritonFusion::Emit( TF_ASSIGN_OR_RETURN( launch_dimensions, - GetMatMulLaunchDimensions(analysis, analysis_.fusion(), config)); + GetMatMulLaunchDimensions(analysis, analysis_.fusion(), config, + analysis_.device_info())); } llvm::Function* impl_fn = @@ -233,7 +234,8 @@ std::optional TritonFusion::launch_config() const { LaunchConfig launch_config; launch_config.launch_dimensions = LaunchDimensions{ static_cast(num_blocks), - static_cast(block_level_parameters.num_warps * WarpSize())}; + static_cast(block_level_parameters.num_warps * + WarpSize(analysis_.device_info()))}; launch_config.block_level_parameters = std::move(block_level_parameters); return launch_config; } diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc index e9974d3ce1584f..f72483a80fe01c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.cc @@ -1290,7 +1290,8 @@ struct MatMulDims { struct MatMulLaunchConfig { explicit MatMulLaunchConfig(const TritonGemmConfig& config, const HloDotInstruction& dot, - const MatMulDims& dims); + const MatMulDims& dims, + const se::DeviceDescription& device_info); int64_t grid_m; int64_t grid_n; @@ -1387,7 +1388,8 @@ struct MatMulLaunchConfig { MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, const HloDotInstruction& dot, - const MatMulDims& dims) + const MatMulDims& dims, + const se::DeviceDescription& device_info) : grid_m((dims.m + config.block_m - 1) / config.block_m), grid_n((dims.n + config.block_n - 1) / config.block_n) { int64_t batch_size = dims.lhs_noncontracting_split.value_or( @@ -1409,13 +1411,13 @@ MatMulLaunchConfig::MatMulLaunchConfig(const TritonGemmConfig& config, noncontracting_program_id_dim = mt::ProgramIDDim::Y; launch_dims = LaunchDimensions( se::BlockDim(batch_size, grid_m * grid_n, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + se::ThreadDim(config.num_warps * WarpSize(device_info), 1, 1)); } else { batch_program_id_dim = mt::ProgramIDDim::Y; noncontracting_program_id_dim = mt::ProgramIDDim::X; launch_dims = LaunchDimensions( se::BlockDim(grid_m * grid_n, batch_size, config.split_k), - se::ThreadDim(config.num_warps * WarpSize(), 1, 1)); + se::ThreadDim(config.num_warps * WarpSize(device_info), 1, 1)); } } @@ -1962,7 +1964,7 @@ class MatMulEmitterHelper { absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config) { + const TritonGemmConfig& config, const se::DeviceDescription& device_info) { auto dot = HloBfsFindIf(fusion.GetRoots(), fusion, [](auto node) { return node.opcode() == HloOpcode::kDot; }); @@ -1971,7 +1973,7 @@ absl::StatusOr GetMatMulLaunchDimensions( *static_cast(&dot->instruction()); TF_ASSIGN_OR_RETURN(MatMulDims dims, MatMulDims::Create(config, dot_instr, analysis)); - MatMulLaunchConfig launch_config(config, dot_instr, dims); + MatMulLaunchConfig launch_config(config, dot_instr, dims, device_info); return launch_config.launch_dims; } @@ -2516,7 +2518,7 @@ absl::Status EmitMatMul(mlir::OpBuilder builder, TF_ASSIGN_OR_RETURN(const MatMulDims dims, MatMulDims::Create(config, *dot_instr, analysis)); - const MatMulLaunchConfig launch_config(config, *dot_instr, dims); + const MatMulLaunchConfig launch_config(config, *dot_instr, dims, device_info); VLOG(6) << analysis.ToString(); MatMulEmitterHelper emitter(libdevice_path, device_info, dot_instr, b, diff --git a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h index 9c7cd49cd3d862..08909175cb23a9 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h +++ b/third_party/xla/xla/service/gpu/fusions/triton/triton_fusion_emitter.h @@ -77,7 +77,7 @@ absl::Status EmitGeneric(mlir::OpBuilder b, absl::string_view libdevice_path, // Compute the launch dimensions for the given Triton MatMul. absl::StatusOr GetMatMulLaunchDimensions( const TritonFusionAnalysis& analysis, const HloFusionAdaptor& fusion, - const TritonGemmConfig& config); + const TritonGemmConfig& config, const se::DeviceDescription& device_info); // Use tiling and execution parameters from 'config'. output_tile_sizes is // ignored. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 0db21cdf7173d2..281fcb1d7003c0 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -976,7 +976,8 @@ absl::Status RunCollectiveOptimizationPasses( absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, se::GpuComputeCapability gpu_version, - se::dnn::VersionInfo dnn_version) { + se::dnn::VersionInfo dnn_version, + const se::DeviceDescription& device_description) { // Run layout assignment in a separate pipeline from // "post-layout-assignment" because we want everything after layout // assignment to have a layout-sensitive invariant-checker, but @@ -990,7 +991,7 @@ absl::Status RunLayoutAssignmentPasses(HloModule* hlo_module, ChannelLayoutConstraints layout_constraints; pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), gpu_version, dnn_version, - &layout_constraints); + device_description, &layout_constraints); // Run SubByteNormalization because GpuLayoutAssignment may modify a // Layout's element_size_in_bits field. pipeline.AddPass( @@ -1151,7 +1152,8 @@ absl::Status RunPostFusionCollectiveOptimizationPasses(HloModule* hlo_module) { absl::Status RunPostFusionSimplificationPasses( HloModule* hlo_module, const AlgebraicSimplifierOptions& layout_insensitive_algsimp_opts, - se::GpuComputeCapability gpu_version) { + se::GpuComputeCapability gpu_version, + const Compiler::TargetConfig& gpu_target_config) { HloPassPipeline pipeline("post-fusion-simplification-pipeline optimization"); AlgebraicSimplifierOptions options = layout_insensitive_algsimp_opts; options.set_is_layout_sensitive(true); @@ -1166,7 +1168,8 @@ absl::Status RunPostFusionSimplificationPasses( if (hlo_module->config() .debug_options() .xla_gpu_multi_streamed_windowed_einsum()) { - pipeline.AddPass(); + pipeline.AddPass( + gpu_target_config.device_description); pipeline.AddPass(); } @@ -1314,7 +1317,8 @@ absl::Status GpuCompiler::OptimizeHloModule( gpu_target_config.device_description.runtime_version())); TF_RETURN_IF_ERROR( - RunLayoutAssignmentPasses(hlo_module, gpu_version, dnn_version)); + RunLayoutAssignmentPasses(hlo_module, gpu_version, dnn_version, + gpu_target_config.device_description)); TF_RETURN_IF_ERROR(RunLayoutNormalizationPasses(hlo_module, gpu_version)); @@ -1338,7 +1342,8 @@ absl::Status GpuCompiler::OptimizeHloModule( })); TF_RETURN_IF_ERROR(RunPostFusionCollectiveOptimizationPasses(hlo_module)); TF_RETURN_IF_ERROR(RunPostFusionSimplificationPasses( - hlo_module, layout_insensitive_algsimp_opts, gpu_version)); + hlo_module, layout_insensitive_algsimp_opts, gpu_version, + gpu_target_config)); TF_RETURN_IF_ERROR(RunPostFusionVerificationPasses( hlo_module, stream_exec, options, gpu_target_config)); @@ -1361,8 +1366,11 @@ AlgebraicSimplifierOptions GpuCompiler::GetAlgebraicSimplifierOptions( // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -absl::Status GpuCompiler::PrepareHloModuleForIrEmitting(HloModule* hlo_module) { - return PrepareHloModuleForIrEmittingPipeline(*hlo_module, GetCanShareBuffer()) +absl::Status GpuCompiler::PrepareHloModuleForIrEmitting( + HloModule* hlo_module, const se::DeviceDescription& device_description) { + return PrepareHloModuleForIrEmittingPipeline( + *hlo_module, GetCanShareBuffer(device_description), + device_description) .Run(hlo_module) .status(); } @@ -1451,7 +1459,8 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); pipeline.AddPass([&](const HloInstruction* r) { - return IsReductionFromOrToContiguousDimensions(*r); + return IsReductionFromOrToContiguousDimensions( + *r, gpu_target_config.device_description); }); // Greedy pattern matching for custom kernel fusions. We run it before @@ -1535,10 +1544,13 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); // Do not split small reduction dimensions unless priority fusion is // enabled, which handles such cases well. - bool ignore_small_reduce_dims = - !debug_options.xla_gpu_enable_priority_fusion(); - pipeline.AddPass>(ignore_small_reduce_dims); - pipeline.AddPass>(gpu_version); + // bool ignore_small_reduce_dims = + // !debug_options.xla_gpu_enable_priority_fusion(); + pipeline.AddPass>( + gpu_target_config.device_description, + /*ignore_small_reduce_dims=*/false); + pipeline.AddPass>( + gpu_target_config.device_description); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } @@ -1695,7 +1707,8 @@ absl::StatusOr> GpuCompiler::RunHloPasses( is_deviceless ? nullptr : stream_exec, options, gpu_target_config)); - TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting( + module.get(), gpu_target_config.device_description)); uint64_t end_usecs = tsl::Env::Default()->NowMicros(); @@ -2160,7 +2173,7 @@ GpuCompiler::CompileToBackendResult( const se::DeviceDescription& gpu_device_info) { tsl::profiler::TraceMe traceme("GpuCompiler::CompileToBackendResult"); - TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor)); + TF_RETURN_IF_ERROR(RunPreSchedulingPasses(module, executor, gpu_device_info)); TF_ASSIGN_OR_RETURN( ScheduleMetadata schedule_metadata, ScheduleGpuModule(module, pointer_size_, gpu_device_info)); @@ -2190,7 +2203,8 @@ GpuCompiler::CompileToBackendResult( CompileModuleResults compile_module_results, CompileModuleToLlvmIr(module, llvm_context, target_triple_, data_layout_, platform->Name(), platform->id(), gpu_device_info, - GetCanShareBuffer(), BufferSizeBytesFunction(), + GetCanShareBuffer(gpu_device_info), + BufferSizeBytesFunction(), /*split_constants_module=*/use_cache)); if (user_pre_optimization_hook_) { @@ -2454,9 +2468,10 @@ absl::StatusOr> GpuCompiler::Export( } absl::Status GpuCompiler::RunPreSchedulingPasses( - HloModule* module, se::StreamExecutor* stream_exec) { + HloModule* module, se::StreamExecutor* stream_exec, + const se::DeviceDescription& gpu_device_info) { HloPassPipeline pipeline("pre-scheduling-passes"); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); return pipeline.Run(module).status(); } @@ -2525,7 +2540,8 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( HloModule* module, int64_t scheduler_mem_limit, const se::DeviceDescription& gpu_device_info) const { TF_RETURN_IF_ERROR( - RunPostSchedulingCopyInsertion(module, GetCanShareBuffer())); + RunPostSchedulingCopyInsertion( + module, GetCanShareBuffer(gpu_device_info))); HloPassPipeline main_pipeline("post-scheduling-passes"); // Pipeline for async -> sync conversion on for non-overlapped async ops. @@ -2559,7 +2575,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( main_pipeline.AddPass("remat-pipeline"); pipeline.AddPass(remat_opts, sizes); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); pipeline.AddPass(); } @@ -2569,7 +2585,7 @@ absl::Status GpuCompiler::RunPostSchedulingPipelines( { HloPassPipeline& pipeline = main_pipeline.AddPass("fusion-wrapper"); - pipeline.AddPass(); + pipeline.AddPass(gpu_device_info); } // Pipeline with passes which wrap a scheduled module into command buffers. diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index b18b48abfcc4d9..168f3e30c2e63e 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -114,8 +114,13 @@ class GpuCompiler : public LLVMCompiler { const Compiler::CompileOptions& options, const DebugOptions& debug_opts, se::StreamExecutor* executor); - virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { - return &FusionCanShareBufferHint; + virtual HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer( + const se::DeviceDescription& device_description) const { + return [&](const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index) { + return FusionCanShareBufferHint(user, operand, user_index, + device_description); + }; } virtual absl::StatusOr CanUseLinkModules( @@ -212,7 +217,8 @@ class GpuCompiler : public LLVMCompiler { const DebugOptions& debug_options); absl::Status RunPreSchedulingPasses(HloModule* module, - se::StreamExecutor* stream_exec); + se::StreamExecutor* stream_exec, + const se::DeviceDescription& gpu_device_info); absl::Status RunCollectiveScheduleLinearizerPasses( HloModule* hlo_module, se::StreamExecutor* stream_exec); @@ -237,7 +243,8 @@ class GpuCompiler : public LLVMCompiler { se::GpuComputeCapability gpu_version, bool relocatable, const HloModule* debug_module, const CompileOptions& options) = 0; - absl::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module); + absl::Status PrepareHloModuleForIrEmitting( + HloModule* hlo_module, const se::DeviceDescription& device_description); virtual absl::StatusOr> LinkModules( se::GpuComputeCapability gpu_compute_capability, diff --git a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc index 874fa087af7ba2..2eb2c9a3c93fa7 100644 --- a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc @@ -19,12 +19,14 @@ limitations under the License. #include "absl/log/check.h" #include "absl/log/log.h" +#include "xla/hlo/analysis/hlo_dataflow_analysis.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/copy_insertion.h" #include "xla/service/gpu/buffer_sharing.h" +#include "xla/service/gpu/gpu_device_info_for_tests.h" #include "xla/test.h" #include "xla/test_helpers.h" #include "xla/tests/hlo_test_base.h" @@ -64,7 +66,38 @@ void ExpectOptionalFalse(std::optional value) { EXPECT_FALSE(*value); } -using GpuCopyInsertionTest = HloTestBase; +class CanShareBufferWrapper { + public: + CanShareBufferWrapper() + : can_share_buffer_([&](const HloInstruction* fusion, + const HloInstruction* operand, + const ShapeIndex& user_index) { + return FusionCanShareBufferHint(fusion, operand, user_index, + device_description_); + }) {} + + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const { + return can_share_buffer_; + } + + private: + const se::DeviceDescription device_description_{ + xla::gpu::TestGpuDeviceInfo::CudaOrRocmDeviceInfo()}; + const HloDataflowAnalysis::CanShareBuffer can_share_buffer_; + }; + + class GpuCopyInsertionTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + CopyInsertion CreateCopyInsertion() const { + return CopyInsertion(can_share_buffer_wrapper_.GetCanShareBuffer(), + /*use_region_based_live_range_analysis=*/0); + } + + private: + const CanShareBufferWrapper can_share_buffer_wrapper_; + }; // This is some kind of end-to-end test for FusionCanShareBufferHint. TEST_F(GpuCopyInsertionTest, DUSBitcastNoCopy) { @@ -116,8 +149,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); - CopyInsertion copy_insertion(FusionCanShareBufferHint, - /*use_region_based_live_range_analysis=*/0); + CopyInsertion copy_insertion = CreateCopyInsertion(); ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status()); VLOG(2) << module->ToString(); // Copy insertion adds two copies inside the entry computation. @@ -127,7 +159,21 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 2); } -using FusionCanShareBufferHintTest = HloTestBase; +class FusionCanShareBufferHintTest : public HloTestBase { + public: + FusionCanShareBufferHintTest() + : can_share_buffer_(can_share_buffer_wrapper_.GetCanShareBuffer()) {} + + std::optional FusionCanShareBufferHint(const HloInstruction* fusion, + const HloInstruction* operand, + const ShapeIndex& user_index) { + return can_share_buffer_(fusion, operand, user_index); + } + + private: + const CanShareBufferWrapper can_share_buffer_wrapper_; + const HloDataflowAnalysis::CanShareBuffer can_share_buffer_; +}; TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedSameShape) { const char* const kModuleString = R"( @@ -990,8 +1036,8 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); - CopyInsertion copy_insertion(FusionCanShareBufferHint, - /*use_region_based_live_range_analysis=*/0); + CopyInsertion copy_insertion = CreateCopyInsertion(); + ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status()); VLOG(2) << module->ToString(); EXPECT_EQ(CountCopies(*module), 0); diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.cc b/third_party/xla/xla/service/gpu/gpu_fusible.cc index 94e67e43c1adb6..712b5069eccb6e 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible.cc @@ -101,13 +101,14 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) { return false; } -bool IsExpensiveToRepeat(const HloInstruction& instr) { +bool IsExpensiveToRepeat(const HloInstruction& instr, + const se::DeviceDescription& device_info) { CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused."; // Reductions which use many input elements to calculate one output element // are both memory and computationally heavy. constexpr int kMaxInputsPerOutput = 10; if (instr.opcode() == HloOpcode::kReduce && - !IsReductionFromOrToContiguousDimensions(instr)) { + !IsReductionFromOrToContiguousDimensions(instr, device_info)) { int64_t reduction_ratio = ShapeUtil::ElementsIn(instr.operand(0)->shape()) / ShapeUtil::ElementsIn(instr.shape()); if (reduction_ratio > kMaxInputsPerOutput) return true; @@ -192,24 +193,27 @@ bool TransposesMinorDimension(const HloInstruction* instr) { } } -bool IsReduceInputFusion(const HloInstruction& instr) { +bool IsReduceInputFusion(const HloInstruction& instr, + const se::DeviceDescription& device_info) { return instr.opcode() == HloOpcode::kFusion && absl::c_any_of(GetFusionRoots(*instr.called_computations()[0]), - [](const HloInstruction* root) { - return IsRealReductionHero(*root, - FindNonTrivialHero(*root)); + [&](const HloInstruction* root) { + return IsRealReductionHero( + *root, FindNonTrivialHero(*root), device_info); }); } -bool IsInputFusibleReduction(const HloInstruction& instr) { - return IsReduceInputFusion(instr) || - IsReductionFromOrToContiguousDimensions(instr); +bool IsInputFusibleReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info) { + return IsReduceInputFusion(instr, device_info) || + IsReductionFromOrToContiguousDimensions(instr, device_info); } -bool IsNestableVariadicReduction(const HloInstruction& instr) { +bool IsNestableVariadicReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info) { return instr.shape().IsTuple() && ((instr.opcode() == HloOpcode::kReduce && - !IsReductionFromOrToContiguousDimensions(instr)) || + !IsReductionFromOrToContiguousDimensions(instr, device_info)) || (instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kLoop && instr.fused_expression_root()->opcode() == HloOpcode::kReduce)); @@ -226,14 +230,14 @@ bool IsInputFusibleTranspose(const HloInstruction& instr) { } const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr) { + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() != HloOpcode::kFusion) { return &instr; } auto fused_expression_root = instr.fused_expression_root(); if (!instr.IsMultiOutputFusion()) { const auto& hero = FindNonTrivialHero(*fused_expression_root); - if (IsRealReductionHero(*fused_expression_root, hero) || + if (IsRealReductionHero(*fused_expression_root, hero, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return &hero; } @@ -245,7 +249,7 @@ const HloInstruction* GetRealHeroForMultiOutputFusion( // we find any, we can immediately return it. for (auto* inst : fused_expression_root->mutable_operands()) { const auto& hero = FindNonTrivialHero(*inst); - if (IsRealReductionHero(*inst, hero) || + if (IsRealReductionHero(*inst, hero, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return &hero; } @@ -253,14 +257,15 @@ const HloInstruction* GetRealHeroForMultiOutputFusion( return fused_expression_root->operands()[0]; } -FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, - const HloInstruction* hero2) { +FusionDecision FusionHeroesAreCompatible( + const HloInstruction* hero1, const HloInstruction* hero2, + const se::DeviceDescription& device_info) { auto hero1_is_unnested_reduce = - IsReductionFromOrToContiguousDimensions(*hero1); + IsReductionFromOrToContiguousDimensions(*hero1, device_info); auto tiled_transpose_hero1 = GetDescriptionForTiledTransposeEmitter(*hero1); bool hero1_is_unnested_transpose = tiled_transpose_hero1.has_value(); bool hero2_is_unnested_reduce = - IsReductionFromOrToContiguousDimensions(*hero2); + IsReductionFromOrToContiguousDimensions(*hero2, device_info); auto tiled_transpose_hero2 = GetDescriptionForTiledTransposeEmitter(*hero2); bool hero2_is_unnested_transpose = tiled_transpose_hero2.has_value(); @@ -318,7 +323,8 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, } FusionDecision ShapesCompatibleForMultiOutputFusion( - const HloInstruction& instr1, const HloInstruction& instr2) { + const HloInstruction& instr1, const HloInstruction& instr2, + const se::DeviceDescription& device_info) { // Multi-output fusion kernels share a common parallel loop. The loop // dimensions are determined by instruction shapes. auto get_loop_shape = [&](const HloInstruction* element_instr) { @@ -328,7 +334,7 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( const auto& hero = element_instr->parent()->IsFusionComputation() ? FindNonTrivialHero(*element_instr) : *element_instr; - if (IsReductionFromOrToContiguousDimensions(*element_instr) || + if (IsReductionFromOrToContiguousDimensions(*element_instr, device_info) || GetDescriptionForTiledTransposeEmitter(hero).has_value()) { return hero.operand(0)->shape(); } @@ -339,10 +345,13 @@ FusionDecision ShapesCompatibleForMultiOutputFusion( // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - const HloInstruction* hero1 = GetRealHeroForMultiOutputFusion(instr1); - const HloInstruction* hero2 = GetRealHeroForMultiOutputFusion(instr2); + const HloInstruction* hero1 = + GetRealHeroForMultiOutputFusion(instr1, device_info); + const HloInstruction* hero2 = + GetRealHeroForMultiOutputFusion(instr2, device_info); - if (auto compatible = FusionHeroesAreCompatible(hero1, hero2); !compatible) { + if (auto compatible = FusionHeroesAreCompatible(hero1, hero2, device_info); + !compatible) { return compatible; } @@ -371,10 +380,12 @@ bool IsInputFusibleScatter(const HloInstruction& instr) { return false; } -bool IsInputFusible(const HloInstruction& instr) { +bool IsInputFusible(const HloInstruction& instr, + const se::DeviceDescription& device_info) { // Input fusion only handles non-elemental reduction and scatter operations. return instr.IsFusible() && - (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr) || + (IsInputFusibleReduction(instr, device_info) || + IsInputFusibleScatter(instr) || IsInputFusibleTranspose(instr)); } @@ -414,7 +425,8 @@ bool IsUniversallyLoopFusible(const HloInstruction& instr) { } // Returns true if `instr` can be fused as a consumer into a kLoop fusion. -bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { +bool IsLoopFusibleAsConsumer(const HloInstruction& instr, + const se::DeviceDescription& device_info) { // Instr should be fusible. if (!instr.IsFusible()) return false; @@ -429,7 +441,8 @@ bool IsLoopFusibleAsConsumer(const HloInstruction& instr) { // We may have input fusions which effectively have turned into loop // fusions. Those should still be considered as loop fusible consumers, // but they are not universally loop fusible. - if (!IsInputFusible(instr) && instr.opcode() == HloOpcode::kFusion && + if (!IsInputFusible(instr, device_info) && + instr.opcode() == HloOpcode::kFusion && instr.fusion_kind() == HloInstruction::FusionKind::kInput) { return true; } @@ -496,13 +509,14 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, } FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, - const HloInstruction& consumer) { + const HloInstruction& consumer, + const se::DeviceDescription& device_info) { if (!IsLoopFusibleAsProducer(producer) && !IsInputFusibleTranspose(producer)) { return FusionDecision::Forbid("the producer is not loop-fusible"); } - if (IsInputFusibleReduction(producer)) { + if (IsInputFusibleReduction(producer, device_info)) { if (!producer.GetModule() ->config() .debug_options() @@ -515,8 +529,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, ? FindNonTrivialHero(*producer.fused_expression_root()) : producer; if (!ReductionIsRaceFree( - reduce_hero.GetModule()->config(), - GetReductionKindAndContiguousComponents(reduce_hero))) { + GetReductionKindAndContiguousComponents(reduce_hero), + device_info)) { return FusionDecision::Forbid( "Reduction output fusion only works for race free reductions"); } @@ -537,7 +551,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, return can_fuse; } - if (!IsInputFusible(consumer) && !IsLoopFusibleAsConsumer(consumer)) { + if (!IsInputFusible(consumer, device_info) && + !IsLoopFusibleAsConsumer(consumer, device_info)) { return FusionDecision::Forbid( "the consumer is not input-fusible and not loop-fusible"); } @@ -566,7 +581,8 @@ FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, return InstructionFusion::ShouldFuseInPlaceOp(&producer, &consumer); } -FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { +FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer, const se::DeviceDescription& device_info) { // Skip multiple output fusion. It's not yet supported. if (producer.IsMultiOutputFusion()) { return FusionDecision::Forbid("Producer is a multi-output fusion"); @@ -613,16 +629,17 @@ FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer) { // Returns an estimate of the shared memory usage for a given instruction in // bytes. -static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) { +static int64_t SharedMemoryUsageNoCache( + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() == HloOpcode::kFusion) { int64_t sum = 0; for (const HloInstruction* hlo : instr.fused_instructions_computation()->instructions()) { - sum += SharedMemoryUsageNoCache(*hlo); + sum += SharedMemoryUsageNoCache(*hlo, device_info); } return sum; } else if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + IsReductionFromOrToContiguousDimensions(instr, device_info)) { ReductionDimensions reduction_info = GetReductionKindAndContiguousComponents(instr); int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType( @@ -665,16 +682,17 @@ int64_t FusionInfoCache::GetSharedMemoryUsage(const HloInstruction& instr) { // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // SharedMemoryUsageNoCache and use the cache *within* the fusion. - int64_t shared_memory_usage = SharedMemoryUsageNoCache(instr); + int64_t shared_memory_usage = SharedMemoryUsageNoCache(instr, device_info_); absl::MutexLock lock(&mutex_); shared_memory_usage_.emplace(&instr, shared_memory_usage); return shared_memory_usage; } -int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { +int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache, + const se::DeviceDescription& device_info) { if (!cache) { - return SharedMemoryUsageNoCache(instr); + return SharedMemoryUsageNoCache(instr, device_info); } return cache->GetSharedMemoryUsage(instr); } @@ -684,16 +702,17 @@ int64_t SharedMemoryUsage(const HloInstruction& instr, FusionInfoCache* cache) { constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8; // Returns the number of unnested reductions in the instruction output. -static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) { +static int64_t NumUnnestedReductionsNoCache( + const HloInstruction& instr, const se::DeviceDescription& device_info) { if (instr.opcode() == HloOpcode::kReduce && - IsReductionFromOrToContiguousDimensions(instr)) { + IsReductionFromOrToContiguousDimensions(instr, device_info)) { return 1; } if (instr.opcode() == HloOpcode::kFusion) { int64_t sum = 0; for (const HloInstruction* hlo : instr.fused_instructions_computation()->instructions()) { - sum += NumUnnestedReductionsNoCache(*hlo); + sum += NumUnnestedReductionsNoCache(*hlo, device_info); } return sum; } @@ -713,7 +732,8 @@ int64_t FusionInfoCache::GetNumUnnestedReductions(const HloInstruction& instr) { // instructions, not instructions inside fusion nodes. Therefore we can only // cache top-level instructions; it would not be valid to pass the cache to // NumUnnestedReductionsNoCache and use the cache *within* the fusion. - int64_t num_unnested_reductions = NumUnnestedReductionsNoCache(instr); + int64_t num_unnested_reductions = + NumUnnestedReductionsNoCache(instr, device_info_); absl::MutexLock lock(&mutex_); num_unnested_reductions_.emplace(&instr, num_unnested_reductions); @@ -721,9 +741,10 @@ int64_t FusionInfoCache::GetNumUnnestedReductions(const HloInstruction& instr) { } static int64_t NumUnnestedReductions(const HloInstruction& instr, - FusionInfoCache* cache) { + FusionInfoCache* cache, + const se::DeviceDescription& device_info) { if (!cache) { - return NumUnnestedReductionsNoCache(instr); + return NumUnnestedReductionsNoCache(instr, device_info); } return cache->GetNumUnnestedReductions(instr); @@ -757,15 +778,16 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, const se::DeviceDescription& device_info, bool is_consumer_producer_fusion, FusionInfoCache* cache /*=nullptr*/) { - if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) > + if (SharedMemoryUsage(instr1, cache, device_info) + + SharedMemoryUsage(instr2, cache, device_info) > device_info.shared_memory_per_block()) { return FusionDecision::Forbid( "shared memory usage would be over the budget of ") << device_info.shared_memory_per_block() << "B"; } - if (NumUnnestedReductions(instr1, cache) + - NumUnnestedReductions(instr2, cache) > + if (NumUnnestedReductions(instr1, cache, device_info) + + NumUnnestedReductions(instr2, cache, device_info) > kMaxUnnestedReductionOutputsPerFusion) { return FusionDecision::Forbid("over ") << kMaxUnnestedReductionOutputsPerFusion @@ -837,16 +859,17 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, return FusionDecision::Allow(); } -bool CreatesHeavyComputation(const HloInstruction& producer, - const HloInstruction& consumer) { +bool CreatesHeavyComputation( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info) { // If producer's computation is not expensive to repeat even in the consumer // requests the same element multiple times there is nothing to do. auto producer_is_heavy = [&](const HloInstruction& instr) { if (producer.opcode() != HloOpcode::kFusion) { - return IsExpensiveToRepeat(producer); + return IsExpensiveToRepeat(producer, device_info); } for (const auto& instr : producer.fused_instructions()) { - if (IsExpensiveToRepeat(*instr)) { + if (IsExpensiveToRepeat(*instr, device_info)) { return true; } } @@ -901,21 +924,25 @@ bool CreatesHeavyComputation(const HloInstruction& producer, return false; } -bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) { +bool IsFusibleAsMultiOutputFusionRoot( + const HloInstruction& instr, const se::DeviceDescription& device_info) { // We can fuse reduces and loop fusions. Elementwise instructions can be fused // with any other instruction. // Note that scatter cannot be the root of a multi-output fusion because // its emitter doesn't support it. return instr.IsFusible() && - (IsInputFusibleReduction(instr) || IsInputFusibleTranspose(instr) || + (IsInputFusibleReduction(instr, device_info) || + IsInputFusibleTranspose(instr) || instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here. instr.IsElementwise()); } -HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, - const HloInstruction& consumer) { - return (IsInputFusible(consumer) || IsInputFusible(producer)) +HloInstruction::FusionKind ChooseFusionKind( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info) { + return (IsInputFusible(consumer, device_info) || + IsInputFusible(producer, device_info)) ? HloInstruction::FusionKind::kInput : HloInstruction::FusionKind::kLoop; } @@ -932,7 +959,7 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, }); } -size_t GetInstrCountOfFusible(const HloInstruction& instr) { +int64_t GetInstrCountOfFusible(const HloInstruction& instr) { return instr.opcode() == HloOpcode::kFusion ? instr.fused_instruction_count() : 1; } diff --git a/third_party/xla/xla/service/gpu/gpu_fusible.h b/third_party/xla/xla/service/gpu/gpu_fusible.h index 0dadbfa36f5476..7a579a8f93479c 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible.h +++ b/third_party/xla/xla/service/gpu/gpu_fusible.h @@ -55,6 +55,8 @@ bool IsExpensiveToRepeat(const HloInstruction& instr); // Invariant: After modifying or removing a fusion node, call Invalidate(node). class FusionInfoCache { public: + explicit FusionInfoCache(const se::DeviceDescription& device_info) + : device_info_(device_info) {} // Must be called after modifying or removing a fusion node (or other node // that's part of this cache). void Invalidate(const HloInstruction* instr) { @@ -69,6 +71,7 @@ class FusionInfoCache { int64_t GetNumUnnestedReductions(const HloInstruction& instr); private: + const se::DeviceDescription& device_info_; absl::Mutex mutex_; absl::flat_hash_map shared_memory_usage_; @@ -110,15 +113,18 @@ bool TransposesMinorDimension(const HloInstruction* instr); // Whether `instr` is an input fusion rooted at a reduction-to-vector op or a // multi-output input fusion with at least one reduction-to-vector op root. -bool IsReduceInputFusion(const HloInstruction& instr); +bool IsReduceInputFusion(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is fusible as root of a reduce input fusions, i.e. `instr` // is either an unfused reduction-to-vector op or a reduce input fusion. -bool IsInputFusibleReduction(const HloInstruction& instr); +bool IsInputFusibleReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is a nestable variadic reduction // or a loop fusion rooted with such. -bool IsNestableVariadicReduction(const HloInstruction& instr); +bool IsNestableVariadicReduction(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Whether `instr` is fusible as root of a scatter input fusions, i.e. `instr` // is either an unfused scatter op or a scatter input fusion. @@ -139,18 +145,20 @@ FusionDecision FusionFitsInBudget(const HloInstruction& instr1, // producer has a complex computation per output and consumer calls this // computations multiple times. bool CreatesHeavyComputation(const HloInstruction& producer, - const HloInstruction& consumer); + const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Returns the instruction that determines the emitter used for lowering, // sometimes referred to as "the real hero". const HloInstruction* GetRealHeroForMultiOutputFusion( - const HloInstruction& instr); + const HloInstruction& instr, const se::DeviceDescription& device_info); // Whether 'hero1' and 'hero2' are compatible if the two fusions containing // 'hero1' and 'hero2' are merged together. For example merging two fusions with // a reduction hero and a transpose here, respectively, does not work. -FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, - const HloInstruction* hero2); +FusionDecision FusionHeroesAreCompatible( + const HloInstruction* hero1, const HloInstruction* hero2, + const se::DeviceDescription& device_info); // Whether instruction shapes are compatible for multi-output fusion, i.e. // whether the emitters support lowering the resulting fusion. @@ -160,7 +168,8 @@ FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, // input fusions only. It is up to the caller to ensure the instructions // themselves are fusible! FusionDecision ShapesCompatibleForMultiOutputFusion( - const HloInstruction& instr1, const HloInstruction& instr2); + const HloInstruction& instr1, const HloInstruction& instr2, + const se::DeviceDescription& device_info); // Whether fusing producer into consumer creates a scatter fusion that cannot be // handled by the scatter emitter. @@ -172,19 +181,23 @@ FusionDecision CanEmitInputFusedScatter(const HloInstruction& producer, // they are not library calls. // Used both by instruction fusion and fusion-fusion merging. FusionDecision IsProducerConsumerFusible(const HloInstruction& producer, - const HloInstruction& consumer); + const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Whether the producer is a valid candidate for a multi-output fusion. // That is, the root tuple of the multi-output fusion will contain the results // of both, the producer and consumer. -FusionDecision IsProducerMultiOutputFusible(const HloInstruction& producer); +FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer, const se::DeviceDescription& device_info); // Whether `instr` is a candidate for sibling fusion or as a consumer in // a producer-consumer multi-output fusion. -bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr); +bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr, + const se::DeviceDescription& device_info); // Determines the fusion kind to be used when fusing into `consumer`. -HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer, - const HloInstruction& consumer); +HloInstruction::FusionKind ChooseFusionKind( + const HloInstruction& producer, const HloInstruction& consumer, + const se::DeviceDescription& device_info); // Returns whether `consumer` is the only non-root user of `instr`. bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, @@ -192,7 +205,7 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, // Returns number of instructions in the fusible `instr`. If `instr` is not a // fusion instruction, 1 is returned. -size_t GetInstrCountOfFusible(const HloInstruction& instr); +int64_t GetInstrCountOfFusible(const HloInstruction& instr); // Returns the outputs of the fusible `instr`. absl::InlinedVector GetOutputsOfFusible( diff --git a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc index 735709cbd346f8..88d679e6e30527 100644 --- a/third_party/xla/xla/service/gpu/gpu_fusible_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_fusible_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include "xla/service/gpu/gpu_fusible.h" #include -#include #include #include "absl/strings/str_cat.h" @@ -24,14 +23,71 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/hlo_parser.h" #include "xla/tests/hlo_test_base.h" +#include "xla/service/hlo_runner.h" +#include "xla/service/instruction_fusion.h" +#include "xla/service/platform_util.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla { namespace gpu { - +namespace { using ::testing::ElementsAre; -using GpuFusibleTest = HloTestBase; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class GpuFusibleTest : public NewHloTestBase { + public: + GpuFusibleTest() + : NewHloTestBase(std::make_unique( + PlatformUtil::GetDefaultPlatform().value()), + std::make_unique( + PlatformUtil::GetDefaultPlatform().value())), + device_description_(MakeDeviceDescription()) {} + + bool IsReduceInputFusion(const HloInstruction& instr) const { + return ::xla::gpu::IsReduceInputFusion(instr, device_description_); + } + + bool IsInputFusibleReduction(const HloInstruction& instr) const { + return ::xla::gpu::IsInputFusibleReduction(instr, device_description_); + } + + FusionDecision IsProducerMultiOutputFusible( + const HloInstruction& producer) const { + return ::xla::gpu::IsProducerMultiOutputFusible(producer, + device_description_); + } + + bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) const { + return ::xla::gpu::IsFusibleAsMultiOutputFusionRoot(instr, + device_description_); + } + + FusionDecision FusionHeroesAreCompatible(const HloInstruction* hero1, + const HloInstruction* hero2) const { + return ::xla::gpu::FusionHeroesAreCompatible(hero1, hero2, + device_description_); + } + + FusionDecision ShapesCompatibleForMultiOutputFusion( + const HloInstruction& instr1, const HloInstruction& instr2) const { + return ::xla::gpu::ShapesCompatibleForMultiOutputFusion( + instr1, instr2, device_description_); + } + + const se::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const se::DeviceDescription device_description_; +}; const char kModulePrefix[] = R"( HloModule test_module @@ -1068,7 +1124,8 @@ TEST_F(GpuFusibleTest, ProducerConsumerFusionElementwiseAndReduce) { const HloInstruction* producer = root->operand(1); EXPECT_TRUE(IsProducerMultiOutputFusible(*producer)); EXPECT_TRUE(IsFusibleAsMultiOutputFusionRoot(*consumer)); - EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*producer, *consumer)); + EXPECT_TRUE(ShapesCompatibleForMultiOutputFusion(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, ProducerConsumerFusionTransposeAndLoopFusion) { @@ -1090,7 +1147,8 @@ TEST_F(GpuFusibleTest, ProducerConsumerFusionTransposeAndLoopFusion) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* consumer = root; const HloInstruction* producer = root->operand(1); - EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer)); + EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, ProducerConsumerFusionReduceAndLoopFusion) { @@ -1113,7 +1171,8 @@ TEST_F(GpuFusibleTest, ProducerConsumerFusionReduceAndLoopFusion) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* consumer = root; const HloInstruction* producer = root->operand(1); - EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer)); + EXPECT_TRUE(IsProducerConsumerFusible(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, ProducerConsumerFusionLoopFusionAndReduce) { @@ -1307,9 +1366,11 @@ TEST_F(GpuFusibleTest, NonscalarConstantsNotFused) { const HloInstruction* consumer2 = root->operand(2); const HloInstruction* producer2 = root->operand(3); EXPECT_FALSE( - static_cast(IsProducerConsumerFusible(*producer, *consumer))); + static_cast(IsProducerConsumerFusible(*producer, *consumer, + device_description()))); EXPECT_FALSE( - static_cast(IsProducerConsumerFusible(*producer2, *consumer2))); + static_cast(IsProducerConsumerFusible(*producer2, *consumer2, + device_description()))); } TEST_F(GpuFusibleTest, FuseLayoutChangingOpWithElementwise) { @@ -1326,7 +1387,8 @@ TEST_F(GpuFusibleTest, FuseLayoutChangingOpWithElementwise) { module->entry_computation()->root_instruction(); const HloInstruction* producer = consumer->operand(0); EXPECT_TRUE( - static_cast(IsProducerConsumerFusible(*producer, *consumer))); + static_cast(IsProducerConsumerFusible(*producer, *consumer, + device_description()))); } TEST_F(GpuFusibleTest, FuseReduceWithUnaryElementwise) { @@ -1343,7 +1405,8 @@ TEST_F(GpuFusibleTest, FuseReduceWithUnaryElementwise) { module->entry_computation()->root_instruction(); const HloInstruction* producer = consumer->operand(0); EXPECT_TRUE( - static_cast(IsProducerConsumerFusible(*producer, *consumer))); + static_cast(IsProducerConsumerFusible(*producer, *consumer, + device_description()))); } TEST_F(GpuFusibleTest, DoNotFuseReduceWithRacesWithUnaryElementwise) { @@ -1360,7 +1423,8 @@ TEST_F(GpuFusibleTest, DoNotFuseReduceWithRacesWithUnaryElementwise) { module->entry_computation()->root_instruction(); const HloInstruction* producer = consumer->operand(0); EXPECT_FALSE( - static_cast(IsProducerConsumerFusible(*producer, *consumer))); + static_cast(IsProducerConsumerFusible(*producer, *consumer, + device_description()))); } TEST_F(GpuFusibleTest, CreatesHeavyComputation_NonfusionInstr) { @@ -1383,7 +1447,8 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_NonfusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_NonfusionInstr) { @@ -1404,7 +1469,8 @@ TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_NonfusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, @@ -1427,7 +1493,8 @@ TEST_F(GpuFusibleTest, const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, CreatesHeavyComputation_ReduceWindowGather) { @@ -1450,7 +1517,8 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_ReduceWindowGather) { EXPECT_FALSE(IfFusedReadsElementsMultipleTimes(*reduce_window)); EXPECT_TRUE(IsExpensiveToRepeat(*reduce_window)); EXPECT_TRUE(IfFusedReadsElementsMultipleTimes(*gather)); - EXPECT_TRUE(CreatesHeavyComputation(*reduce_window, *gather)); + EXPECT_TRUE(CreatesHeavyComputation(*reduce_window, *gather, + device_description())); } TEST_F(GpuFusibleTest, CreatesHeavyComputation_FusionInstr) { @@ -1483,7 +1551,8 @@ TEST_F(GpuFusibleTest, CreatesHeavyComputation_FusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_TRUE(CreatesHeavyComputation(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_FusionInstr) { @@ -1516,7 +1585,8 @@ TEST_F(GpuFusibleTest, DoesNotCreateHeavyComputation_FusionInstr) { const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); const HloInstruction* consumer = root->operand(1); - EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer)); + EXPECT_FALSE(CreatesHeavyComputation(*producer, *consumer, + device_description())); } TEST_F(GpuFusibleTest, ChooseFusionKind) { @@ -1532,7 +1602,7 @@ ENTRY computation { .value(); const HloInstruction* root = module->entry_computation()->root_instruction(); const HloInstruction* producer = root->operand(0); - EXPECT_EQ(ChooseFusionKind(*producer, *root), + EXPECT_EQ(ChooseFusionKind(*producer, *root, device_description()), HloInstruction::FusionKind::kInput); } @@ -1773,12 +1843,13 @@ TEST_F(GpuFusibleTest, GetSharedMemoryUsage) { ROOT res = f32[1024,128,2]{2,1,0} fusion(p), kind=kInput, calls=wrapped_transpose })")) .value(); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); - FusionInfoCache cache; + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); + FusionInfoCache cache(device_description()); auto fusion = module->entry_computation()->root_instruction(); EXPECT_EQ(cache.GetSharedMemoryUsage(*fusion), 32 * 33 * 2 * 4); } +} // namspace } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc index 3aa9d79977bd81..61b2a996ea14a6 100644 --- a/third_party/xla/xla/service/gpu/gpu_offloading_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_offloading_test.cc @@ -218,7 +218,9 @@ TEST_F(GpuOffloadingTest, CopyIRCreationTest) { RunHloRematerialization( /*memory_limit_bytes=*/10 * 1024, module.get())); ASSERT_TRUE(changed); - StreamAttributeAnnotator attr_annotator; + stream_executor::StreamExecutor* executor = + backend().default_stream_executor(); + StreamAttributeAnnotator attr_annotator(executor->GetDeviceDescription()); TF_ASSERT_OK_AND_ASSIGN(bool changed_attr, attr_annotator.Run(module.get())); EXPECT_TRUE(changed_attr); // Verify that the stream attribute for a copy-start is annotated diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc index 94b4c55a802c67..c559a6bf3f7a45 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis.cc @@ -256,7 +256,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() std::optional first_reduce_hero; for (auto [root, hero] : llvm::zip(fusion_roots_, fusion_heroes_)) { - if (IsRealReductionHero(root.instruction(), hero.instruction())) { + if (IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { first_reduce_hero = hero; break; } @@ -268,7 +269,8 @@ HloFusionAnalysis::EmitterFusionKind HloFusionAnalysis::GetEmitterFusionKind() if (root == *first_reduce_hero) { continue; } - if (!IsRealReductionHero(root.instruction(), hero.instruction())) { + if (!IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { // Needs to have a compatible shape to the reduce operand (compatible // meaning same number of elements). if (ShapeUtil::ElementsIn(root.shape()) != @@ -322,7 +324,8 @@ const HloInstruction* HloFusionAnalysis::FindHeroReduction() const { // have the same shape and layout as verified by // `IsFusedReductionOutputConsistent()`. for (auto [root, hero] : llvm::zip(roots, fusion_heroes_)) { - if (IsRealReductionHero(root.instruction(), hero.instruction())) { + if (IsRealReductionHero(root.instruction(), hero.instruction(), + *device_info_)) { return &hero.instruction(); } } diff --git a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc index 7328bc6dad0ec9..afaf89176c0c77 100644 --- a/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/hlo_fusion_analysis_test.cc @@ -318,7 +318,8 @@ TEST_F(HloFusionAnalysisTest, InvalidDevice) { stream_executor::GpuDeviceInfoProto device_info_proto; stream_executor::DeviceDescription device_info(device_info_proto); - + device_info.set_threads_per_warp(32); + auto* root = module->entry_computation()->root_instruction(); auto analysis_fused = HloFusionAnalysis::Create(*root->operand(0), *root, device_info); diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.cc b/third_party/xla/xla/service/gpu/ir_emission_utils.cc index 406fcd9534a9dc..a1521ed0b5fbe0 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.cc @@ -82,12 +82,12 @@ bool IsRank1(const Shape& shape, int64_t batch_dimensions_size) { return shape.rank() == batch_dimensions_size + 1; } -bool IsMlirTransposeEmitterEnabled(const HloInstruction& hlo) { - return hlo.GetModule() - ->config() - .debug_options() - .xla_gpu_mlir_emitter_level() >= 3; -} +// bool IsMlirTransposeEmitterEnabled(const HloInstruction& hlo) { +// return hlo.GetModule() +// ->config() +// .debug_options() +// .xla_gpu_mlir_emitter_level() >= 3; +// } } // namespace @@ -273,7 +273,7 @@ llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset, llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {}); return b->CreateCall( - intrinsic, {b->getInt32(-1), value, offset, b->getInt32(WarpSize() - 1)}); + intrinsic, {b->getInt32(-1), value, offset, b->getInt32(32 - 1)}); } // Helper function to emit call to SPIR shfl_down intrinsic. @@ -557,7 +557,40 @@ std::optional GetDescriptionForTiledTransposeEmitter( absl::InlinedVector dimensions(hero.shape().dimensions().begin(), hero.shape().dimensions().end()); int64_t operand_most_minor_dim = hero.operand(0)->shape().dimensions().back(); - if (IsMlirTransposeEmitterEnabled(hero)) { + // if (IsMlirTransposeEmitterEnabled(hero)) { + // if (permutation.back() == dimensions.size() - 1) { + // operand_most_minor_dim = + // hero.operand(0)->shape().dimensions(dimensions.size() - 2); + // auto byte_width = primitive_util::ByteWidth(hero.shape().element_type()); + // if (byte_width * dimensions.back() <= kMaxBytesInMostMinorDimension && + // byte_width * dimensions.back() * + // std::min(operand_most_minor_dim, + // dimensions[dimensions.size() - 2]) >= + // kMinDimensionToTransposeTiled) { + // return TransposeDescription{&hero, dimensions, permutation}; + // } + // } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && + // dimensions.back() >= kMinDimensionToTransposeTiled) || + // (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + // dimensions.back() >= kMinDimensionToTransposeTiled2 && + // operand_most_minor_dim * dimensions.back() >= + // kMinTotalDimensionsToTransposeTiled)) { + // return TransposeDescription{&hero, dimensions, permutation}; + // } + // } else if (permutation == absl::InlinedVector{1, 0} || + // permutation == absl::InlinedVector{0, 2, 1} || + // permutation == absl::InlinedVector{2, 1, 0}) { + // // The old emitter needs a normalization to rank 3. + // if (permutation.size() == 2) { + // permutation = {0, 2, 1}; + // dimensions.insert(dimensions.begin(), 1); + // } + // if ((dimensions.back() >= kMinDimensionToTransposeTiled && + // operand_most_minor_dim >= kMinDimensionToTransposeTiled) || + // (dimensions.back() >= kMinDimensionToTransposeTiled2 && + // operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + // dimensions.back() * operand_most_minor_dim >= + // kMinTotalDimensionsToTransposeTiled)) { if (permutation.back() == dimensions.size() - 1) { operand_most_minor_dim = hero.operand(0)->shape().dimensions(dimensions.size() - 2); @@ -566,33 +599,16 @@ std::optional GetDescriptionForTiledTransposeEmitter( byte_width * dimensions.back() * std::min(operand_most_minor_dim, dimensions[dimensions.size() - 2]) >= - kMinDimensionToTransposeTiled) { - return TransposeDescription{&hero, dimensions, permutation}; - } - } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && - dimensions.back() >= kMinDimensionToTransposeTiled) || - (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && - dimensions.back() >= kMinDimensionToTransposeTiled2 && - operand_most_minor_dim * dimensions.back() >= - kMinTotalDimensionsToTransposeTiled)) { - return TransposeDescription{&hero, dimensions, permutation}; - } - } else if (permutation == absl::InlinedVector{1, 0} || - permutation == absl::InlinedVector{0, 2, 1} || - permutation == absl::InlinedVector{2, 1, 0}) { - // The old emitter needs a normalization to rank 3. - if (permutation.size() == 2) { - permutation = {0, 2, 1}; - dimensions.insert(dimensions.begin(), 1); - } - if ((dimensions.back() >= kMinDimensionToTransposeTiled && - operand_most_minor_dim >= kMinDimensionToTransposeTiled) || - (dimensions.back() >= kMinDimensionToTransposeTiled2 && - operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && - dimensions.back() * operand_most_minor_dim >= - kMinTotalDimensionsToTransposeTiled)) { + kMinDimensionToTransposeTiled) { return TransposeDescription{&hero, dimensions, permutation}; } + } else if ((operand_most_minor_dim >= kMinDimensionToTransposeTiled && + dimensions.back() >= kMinDimensionToTransposeTiled) || + (operand_most_minor_dim >= kMinDimensionToTransposeTiled2 && + dimensions.back() >= kMinDimensionToTransposeTiled2 && + operand_most_minor_dim * dimensions.back() >= + kMinTotalDimensionsToTransposeTiled)) { + return TransposeDescription{&hero, dimensions, permutation}; } return std::nullopt; } diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils.h b/third_party/xla/xla/service/gpu/ir_emission_utils.h index eef3943fcd8ee8..6aa6695aeb7e29 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils.h +++ b/third_party/xla/xla/service/gpu/ir_emission_utils.h @@ -64,7 +64,10 @@ inline constexpr int64_t kMaxBytesInMostMinorDimension = 8; bool IsMatrixMultiplication(const HloInstruction& dot); bool IsMatrixVectorMultiplication(const HloInstruction& dot); -inline constexpr int64_t WarpSize() { return 32; } +inline constexpr int64_t WarpSize( + const se::DeviceDescription& gpu_device_info) { + return gpu_device_info.threads_per_warp(); +} // Fusions that implemented with pre-compiled device kernels have // FusionBackendConfig.kind requel to this string. diff --git a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc index a74bfd3c8eb18d..38ed40cb0a9ad1 100644 --- a/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emission_utils_test.cc @@ -56,8 +56,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); @@ -68,29 +68,29 @@ ENTRY entry { EXPECT_EQ(result->permutation, InlinedVector({1, 0})); } -TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeNoMlirEmitters) { - const char* hlo = R"( -HloModule module +// TEST_F(IrEmissionUtilsTest, FindTiledLogicalTransposeNoMlirEmitters) { +// const char* hlo = R"( +// HloModule module -ENTRY entry { - p = f32[1536,64]{1,0} parameter(0) - ROOT t = f32[64,1536]{1,0} transpose(p), dimensions={1,0} -} -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, - ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(0); +// ENTRY entry { +// p = f32[1536,64]{1,0} parameter(0) +// ROOT t = f32[64,1536]{1,0} transpose(p), dimensions={1,0} +// } +// )"; +// TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, +// ParseAndReturnVerifiedModule(hlo)); +// auto& debug_options = module->mutable_config().mutable_debug_options(); +// debug_options.set_xla_gpu_mlir_emitter_level(0); - HloInstruction* tr = module->entry_computation()->root_instruction(); +// HloInstruction* tr = module->entry_computation()->root_instruction(); - auto result = GetDescriptionForTiledTransposeEmitter(*tr); - EXPECT_TRUE(result.has_value()); - EXPECT_EQ(result->instr, tr); - // If MLIR emitters are disabled, we pad the shape to rank 3. - EXPECT_EQ(result->dimensions, InlinedVector({1, 64, 1536})); - EXPECT_EQ(result->permutation, InlinedVector({0, 2, 1})); -} +// auto result = GetDescriptionForTiledTransposeEmitter(*tr); +// EXPECT_TRUE(result.has_value()); +// EXPECT_EQ(result->instr, tr); +// // If MLIR emitters are disabled, we pad the shape to rank 3. +// EXPECT_EQ(result->dimensions, InlinedVector({1, 64, 1536})); +// EXPECT_EQ(result->permutation, InlinedVector({0, 2, 1})); +// } TEST_F(IrEmissionUtilsTest, FindTiledLogical102Transpose) { const char* hlo = R"( @@ -103,8 +103,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); @@ -126,8 +126,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); @@ -146,8 +146,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); @@ -169,8 +169,8 @@ ENTRY entry { )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(hlo)); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); HloInstruction* tr = module->entry_computation()->root_instruction(); diff --git a/third_party/xla/xla/service/gpu/launch_dimensions.cc b/third_party/xla/xla/service/gpu/launch_dimensions.cc index f9e28995d09960..db060f1eb4b66e 100644 --- a/third_party/xla/xla/service/gpu/launch_dimensions.cc +++ b/third_party/xla/xla/service/gpu/launch_dimensions.cc @@ -39,8 +39,13 @@ LaunchDimensions CalculateLaunchDimensions( const int kWarpSchedulers = 4; int64_t threads_per_block = std::min( gpu_device_info.threads_per_warp() * kWarpSchedulers, num_elements); - int64_t num_blocks = CeilOfRatio(num_elements, threads_per_block); - return LaunchDimensions(se::BlockDim(num_blocks, 1, 1), + int64_t num_blocks_total = CeilOfRatio(num_elements, threads_per_block); + + int64_t num_blocks_y = CeilOfRatio( + num_blocks_total, gpu_device_info.block_dim_limit().x); + int64_t num_blocks_x = CeilOfRatio(num_blocks_total, num_blocks_y); + + return LaunchDimensions(se::BlockDim(num_blocks_x, num_blocks_y, 1), se::ThreadDim(threads_per_block, 1, 1)); } diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc index 9612a5566a3c69..01fc4ca2f28b3c 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.cc @@ -54,6 +54,7 @@ namespace gpu { // producer and consumer are considered as one fusion, otherwise it's only the // producer. bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const se::DeviceDescription& device_info, const HloInstruction* producer, const HloInstruction* consumer) { // Transposing minor dimension breaks coalescing. @@ -91,8 +92,8 @@ bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, } // Fusing two row reductions breaks coalescing. if (fusion_kind == HloFusionAnalysis::EmitterFusionKind::kReduction && - IsInputFusibleReduction(*producer) && consumer && - IsInputFusibleReduction(*consumer)) { + IsInputFusibleReduction(*producer, device_info) && consumer && + IsInputFusibleReduction(*consumer, device_info)) { return false; } return true; @@ -586,7 +587,8 @@ CoalescingAnalysis::CoalescingAnalysis( } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. is_coalesced_computed_by_heuristic_ = - IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), instr); + IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), + fusion_analysis.device_info(), instr); } CoalescingAnalysis::CoalescingAnalysis( @@ -604,7 +606,8 @@ CoalescingAnalysis::CoalescingAnalysis( } // If ComputeCoalescingForAllOperands fails, fallback to using the heuristic. is_coalesced_computed_by_heuristic_ = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer, consumer); + fusion_analysis.GetEmitterFusionKind(), fusion_analysis.device_info(), + producer, consumer); } bool CoalescingAnalysis::ComputeCoalescingForAllOperands( diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h index 5e82b6455afcd7..e097b8cd46b4df 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis.h +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis.h @@ -70,6 +70,7 @@ class CoalescingAnalysis { // producer and consumer are considered as one fusion, otherwise it's only the // producer. bool IsReadCoalescedHeuristic(HloFusionAnalysis::EmitterFusionKind fusion_kind, + const se::DeviceDescription& device_info, const HloInstruction* producer, const HloInstruction* consumer = nullptr); diff --git a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc index 70cca31981174e..6044241b6efe1a 100644 --- a/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc +++ b/third_party/xla/xla/service/gpu/model/coalescing_analysis_test.cc @@ -78,8 +78,8 @@ class CoalescingTest : public HloTestBase { auto module = ParseAndReturnVerifiedModule(hlo_string).value(); HloInstruction* root = module->entry_computation()->root_instruction(); auto analysis = HloFusionAnalysis::Create(*root, device_info_); - return xla::gpu::IsReadCoalescedHeuristic(analysis.GetEmitterFusionKind(), - root->operand(0), root); + return xla::gpu::IsReadCoalescedHeuristic( + analysis.GetEmitterFusionKind(), device_info_, root->operand(0), root); } protected: diff --git a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc index f32ce743b4504b..c63bcebd020ee0 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_hlo_cost_analysis.cc @@ -218,17 +218,17 @@ bool GpuHloCostAnalysis::ProducerConsumerMergedTooLarge( << ", " << IrBasicBlockSplitCount(consumer) << " -> " << n_splits; int64_t merged_ir_size = (IrSize(producer) * producer_replication + IrSize(consumer)); - // The MLIR emitters don't have the problem with cache invalidation, so we - // don't need to evaluate basic block split counts. - if (producer.GetModule() - ->config() - .debug_options() - .xla_gpu_mlir_emitter_level() < 4) { - if (n_splits > kMaxBasicBlockSplitsPerFusion) { - return true; - } - merged_ir_size *= (1 << n_splits); - } + // // The MLIR emitters don't have the problem with cache invalidation, so we + // // don't need to evaluate basic block split counts. + // if (producer.GetModule() + // ->config() + // .debug_options() + // .xla_gpu_mlir_emitter_level() < 4) { + // if (n_splits > kMaxBasicBlockSplitsPerFusion) { + // return true; + // } + // merged_ir_size *= (1 << n_splits); + // } VLOG(5) << "IR sizes: " << IrSize(producer) << ", " << IrSize(consumer) << " -> " << merged_ir_size; return merged_ir_size > kMaxIRSize; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc index 7798b80d17a681..71c791d34a6875 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.cc @@ -314,7 +314,7 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForInstruction( auto fusion_analysis = HloFusionAnalysis::Create(*producer, *device_info_); bool is_coalesced = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer); + fusion_analysis.GetEmitterFusionKind(), *device_info_, producer); return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); } @@ -324,8 +324,8 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForProducerConsumer( auto fusion_analysis = HloFusionAnalysis::Create(*producer, *consumer, *device_info_); - bool is_coalesced = IsReadCoalescedHeuristic( - fusion_analysis.GetEmitterFusionKind(), producer, consumer); + bool is_coalesced = IsReadCoalescedHeuristic(fusion_analysis.GetEmitterFusionKind(), + *device_info_, producer, consumer); return EstimateRunTimeForFusion(fusion_analysis, is_coalesced); } @@ -504,13 +504,14 @@ GpuPerformanceModelWithIndexingAnalysis::EstimateRunTimeForTriton( /*static*/ LaunchDimensions GpuPerformanceModelWithIndexingAnalysis::GetLaunchDimensionsForTiledFusion( - const TiledHloComputation& tiled_hlo_computation) { + const TiledHloComputation& tiled_hlo_computation, + const se::DeviceDescription& device_info) { const auto* tiled_root = tiled_hlo_computation.GetRoot(); int64_t num_blocks = tiled_hlo_computation.num_output_tiles(); int64_t num_warps = GetNumWarps(GetPaddedTileSize(tiled_root->tile_sizes())); return {static_cast(num_blocks), - static_cast(num_warps * WarpSize())}; + static_cast(num_warps * WarpSize(device_info))}; } absl::StatusOr @@ -538,7 +539,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( analysis.ComputeTiledHloInstructions(tiling)); LaunchDimensions launch_dimensions = - GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, *device_info_); TF_ASSIGN_OR_RETURN( EstimateRunTimeData estimate_run_time_data, @@ -552,7 +553,7 @@ GpuPerformanceModelWithIndexingAnalysis::TryFindBestTilingForFusion( block_level_parameters.output_tile_sizes = std::vector(tiling.begin(), tiling.end()); block_level_parameters.num_warps = - launch_dimensions.num_threads_per_block() / WarpSize(); + launch_dimensions.num_threads_per_block() / WarpSize(*device_info_); best_tiled_run_time_data = TiledRunTimeData{estimate_run_time_data, block_level_parameters}; diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h index 499b75ea61bffe..a8e7fed4314a30 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model.h @@ -73,7 +73,8 @@ class GpuPerformanceModelWithIndexingAnalysis : public GpuPerformanceModelBase { // Returns the launch dimensions for the given tiled HLO computation. static LaunchDimensions GetLaunchDimensionsForTiledFusion( - const TiledHloComputation& tiled_hlo_computation); + const TiledHloComputation& tiled_hlo_computation, + const se::DeviceDescription& device_info); EstimateRunTimeData EstimateRunTimeForFusion( const HloFusionAnalysis& fusion_analysis, bool is_coalesced = true); diff --git a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc index dddf7b1d428f9d..2e61dd28df014e 100644 --- a/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc +++ b/third_party/xla/xla/service/gpu/model/gpu_indexing_performance_model_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "xla/service/gpu/model/gpu_indexing_performance_model.h" +#include #include #include #include @@ -72,6 +73,8 @@ class GpuIndexingPerformanceModelTest : public HloTestBase { &device_info_, &fusion_analysis_cache_, ShapeSizeBytesFunction(), &mlir_context_}; + size_t WarpSize() const { return ::xla::gpu::WarpSize(device_info_); } + GpuIndexingPerformanceModelTest() : HloTestBase() {} }; @@ -613,7 +616,7 @@ ENTRY main { .ComputeTiledHloInstructions(/*tile_parameters=*/{9, 9, 9})); LaunchDimensions launch_dimensions = GpuPerformanceModelWithIndexingAnalysis:: - GetLaunchDimensionsForTiledFusion(tiled_hlo_computation); + GetLaunchDimensionsForTiledFusion(tiled_hlo_computation, device_info_); EXPECT_EQ(launch_dimensions.num_blocks(), 1); // Tile size is 9 * 9 * 9 = 729 that corresponds to 2 warps. But we estimate diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index e29420c8bfb8b3..3c619e7f03e6d8 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -569,8 +569,12 @@ NVPTXCompiler::NVPTXCompiler() : GpuCompiler(stream_executor::cuda::kCudaPlatformId, nvptx::TargetTriple(), nvptx::DataLayout()) {} -HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer() const { - return &CanShareBufferHint; +HloDataflowAnalysis::CanShareBuffer NVPTXCompiler::GetCanShareBuffer( + const se::DeviceDescription& device_description) const { + return [&](const HloInstruction* user, const HloInstruction* operand, + const ShapeIndex& user_index) { + return CanShareBufferHint(user, operand, user_index, device_description); + }; } constexpr const uint8_t kPtxPrefix[] = {'P', 'T', 'X', ':', ' '}; diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.h b/third_party/xla/xla/service/gpu/nvptx_compiler.h index 78591bb2c42a7d..5a28591a3dbdb8 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.h +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.h @@ -89,8 +89,9 @@ class NVPTXCompiler : public GpuCompiler { se::StreamExecutor* stream_exec, BinaryMap* dnn_compiled_graphs) override; - HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer() const override; - + HloDataflowAnalysis::CanShareBuffer GetCanShareBuffer( + const se::DeviceDescription& device_description) const override; + absl::StatusOr CompileTargetBinary( const HloModuleConfig& module_config, llvm::Module* llvm_module, se::GpuComputeCapability gpu_version, bool relocatable, diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc index 50c627981ce7c5..6591a369a52735 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.cc @@ -32,14 +32,15 @@ limitations under the License. #include "xla/service/hlo_verifier.h" #include "xla/service/layout_assignment.h" #include "xla/service/loop_schedule_linearizer.h" +#include "xla/stream_executor/device_description.h" #include "xla/xla.pb.h" namespace xla { namespace gpu { HloPassPipeline PrepareHloModuleForIrEmittingPipeline( - HloModule& hlo_module, - HloDataflowAnalysis::CanShareBuffer can_share_buffer) { + HloModule& hlo_module, HloDataflowAnalysis::CanShareBuffer can_share_buffer, + const se::DeviceDescription& device_description) { const DebugOptions& debug_options = hlo_module.config().debug_options(); // In some cases, we have to place the result of an instruction in a temporary @@ -83,8 +84,8 @@ HloPassPipeline PrepareHloModuleForIrEmittingPipeline( auto& sub_pipeline = pipeline.AddPass("horizontal-loop-fusion-for-copy"); // To fuse the copy. - sub_pipeline.AddPass(); - sub_pipeline.AddPass("copy_"); + sub_pipeline.AddPass(device_description); + sub_pipeline.AddPass(device_description, "copy_"); sub_pipeline.AddPass(); pipeline.AddPass(); return pipeline; diff --git a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h index 095907f39794ac..87bca409f09639 100644 --- a/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h +++ b/third_party/xla/xla/service/gpu/prepare_hlo_for_ir_emitting_pipeline.h @@ -19,6 +19,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/service/hlo_dataflow_analysis.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -28,7 +29,8 @@ namespace gpu { // correctness of the input module. HloPassPipeline PrepareHloModuleForIrEmittingPipeline( HloModule& hlo_module, - HloDataflowAnalysis::CanShareBuffer can_share_buffer); + HloDataflowAnalysis::CanShareBuffer can_share_buffer, + const se::DeviceDescription& device_description); } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/reduction_utils.cc b/third_party/xla/xla/service/gpu/reduction_utils.cc index 447c0427bbb07a..fc7e5b48dcfafa 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.cc +++ b/third_party/xla/xla/service/gpu/reduction_utils.cc @@ -16,32 +16,22 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include -#include -#include #include #include #include "absl/algorithm/container.h" -#include "absl/base/const_init.h" #include "absl/strings/str_join.h" -#include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/layout_util.h" #include "xla/service/gpu/ir_emission_utils.h" -#include "xla/service/hlo_module_config.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/stream_executor/semantic_version.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "tsl/platform/logging.h" -#ifdef GOOGLE_CUDA -#include "xla/service/gpu/gpu_asm_opts_util.h" -#include "xla/stream_executor/cuda/cuda_asm_compiler.h" -#endif // GOOGLE_CUDA - namespace xla { namespace gpu { @@ -79,56 +69,6 @@ Vector3 PartitionShapeByMiddleDimensions( } // namespace -int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config) { -#ifdef GOOGLE_CUDA - // The call to `GetAsmCompilerVersion` is expensive, but the result never - // changes during one execution and doesn't really depend on - // `hlo_module_config`. To avoid repeated calls, we cache the result in a - // static variable. - static absl::Mutex mutex(absl::kConstInit); - static std::atomic use_reduced_thread_count_atomic = nullptr; - - bool* use_reduced_thread_count = - use_reduced_thread_count_atomic.load(std::memory_order_acquire); - - if (use_reduced_thread_count == nullptr) { - absl::MutexLock lock(&mutex); - // We might have raced with another thread, so check again! - // Note: We can use relaxed memory ordering here because we hold - // the mutex lock and all updates happen under the same lock. - // When unsure, use release and acquire pairs for stores and loads. - use_reduced_thread_count = - use_reduced_thread_count_atomic.load(std::memory_order_relaxed); - - if (use_reduced_thread_count == nullptr) { - auto ptxas_config = - PtxOptsFromDebugOptions(hlo_module_config.debug_options()); - auto ptxas_version_tuple = - se::GetAsmCompilerVersion(ptxas_config.preferred_cuda_dir); - - use_reduced_thread_count = new bool(false); - - // ptxas versions prior to 12.2 have a very rare bug when very high - // register spilling occurs with some order of instructions, so use less - // threads to reduce register pressure. - if (!ptxas_version_tuple.ok() || - ptxas_version_tuple.value() < - stream_executor::SemanticVersion{12, 2, 0}) { - *use_reduced_thread_count = true; - } - - use_reduced_thread_count_atomic.store(use_reduced_thread_count, - std::memory_order_release); - } - } - - if (*use_reduced_thread_count) { - return 512; - } -#endif // GOOGLE_CUDA - return 1024; -} - Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { if (reduction_dimensions.is_row_reduction) { int64_t tile_z = std::min(reduction_dimensions.dimensions[0], @@ -141,25 +81,27 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions) { } int64_t ReductionDimensionRaceFreeBound( - const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { Vector3 reduction_tiling = GetReductionTiling(reduction_dimensions); if (reduction_dimensions.is_row_reduction) { - return MinThreadsXRowReduction(hlo_module_config) * reduction_tiling[2]; + return MinThreadsXRowReduction() * reduction_tiling[2]; } - return WarpSize() * reduction_tiling[1]; + return WarpSize(device_description) * reduction_tiling[1]; } bool IsUnnestedReductionFasterThanElemental( - const ReductionDimensions& reduction_dimensions) { + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { + const int64_t warp_size = WarpSize(device_description); if (reduction_dimensions.is_row_reduction) { // For row reduction, the tile block is 1 x tile_size_x, and we are reducing // along tile_size_x which needs to be large enough to make the tiling // implementation efficient. // For very small reductions with a power-of-two size, we can fit multiple // reductions inside a single warp, which is more efficient than a loop. - return (reduction_dimensions.dimensions[2] >= WarpSize()) || - ((WarpSize() % reduction_dimensions.dimensions[2]) == 0); + return (reduction_dimensions.dimensions[2] >= warp_size) || + ((warp_size % reduction_dimensions.dimensions[2]) == 0); } // For column reduction, the tile block is tile_size_y x tile_size_x, and we @@ -170,15 +112,16 @@ bool IsUnnestedReductionFasterThanElemental( // Rule generated by sweeping the search space of small column reductions. bool prefer_elemental_emitter = - (major_size < WarpSize()) || - (major_size < 2 * WarpSize() && minor_size < WarpSize()) || - (major_size < 4 * WarpSize() && minor_size < 8) || - (major_size < 8 * WarpSize() && minor_size < 3); + (major_size < warp_size) || + (major_size < 2 * warp_size && minor_size < warp_size) || + (major_size < 4 * warp_size && minor_size < 8) || + (major_size < 8 * warp_size && minor_size < 3); return !prefer_elemental_emitter; } -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { +bool IsReductionFromOrToContiguousDimensions( + const HloInstruction& reduce, const se::DeviceDescription& device_description) { if (reduce.opcode() != HloOpcode::kReduce) { return false; } @@ -201,23 +144,24 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(), dims_to_reduce)) && IsUnnestedReductionFasterThanElemental( - GetReductionKindAndContiguousComponents(reduce)); + GetReductionKindAndContiguousComponents(reduce), + device_description); } -bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions) { +bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description) { if (reduction_dimensions.is_row_reduction) { return reduction_dimensions.dimensions[2] <= - ReductionDimensionRaceFreeBound(hlo_module_config, - reduction_dimensions) && + ReductionDimensionRaceFreeBound(reduction_dimensions, + device_description) && reduction_dimensions.dimensions[0] <= BatchedReductionRaceFreeBound(); } // Column reduction. return reduction_dimensions.dimensions[1] <= - ReductionDimensionRaceFreeBound(hlo_module_config, - reduction_dimensions); + ReductionDimensionRaceFreeBound(reduction_dimensions, + device_description); } std::ostream& operator<<(std::ostream& os, @@ -275,14 +219,14 @@ ReductionDimensions GetReductionKindAndContiguousComponents( return {/*is_row_reduction=*/false, shape_partition}; } -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero) { - if (!IsReductionFromOrToContiguousDimensions(hero)) { +bool IsRealReductionHero(const HloInstruction& root, const HloInstruction& hero, + const se::DeviceDescription& device_description) { + if (!IsReductionFromOrToContiguousDimensions(hero, device_description)) { return false; } return &root == &hero || - ReductionIsRaceFree(hero.GetModule()->config(), - GetReductionKindAndContiguousComponents(hero)); + ReductionIsRaceFree(GetReductionKindAndContiguousComponents(hero), + device_description); } bool AreReductionsMultiOutputFusionCompatible( diff --git a/third_party/xla/xla/service/gpu/reduction_utils.h b/third_party/xla/xla/service/gpu/reduction_utils.h index 7e5e31bc464ce3..5a8c37dd3b0ae4 100644 --- a/third_party/xla/xla/service/gpu/reduction_utils.h +++ b/third_party/xla/xla/service/gpu/reduction_utils.h @@ -21,7 +21,7 @@ limitations under the License. #include "absl/container/inlined_vector.h" #include "xla/hlo/ir/hlo_instruction.h" -#include "xla/service/hlo_module_config.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" namespace xla { @@ -29,7 +29,7 @@ namespace gpu { // Need at least 1024 threads/block for reasonable tree reduction // performance (assuming all data fits). -int64_t MinThreadsXRowReduction(const HloModuleConfig& hlo_module_config); +inline constexpr int64_t MinThreadsXRowReduction() { return 1024; } // When doing batched row reduction, how big the batch dimension could be. inline constexpr int64_t BatchedReductionRaceFreeBound() { return 8; } @@ -79,11 +79,14 @@ std::ostream& operator<<(std::ostream& os, // Returns true if using the reduction emitter is estimated to be faster than // using the elemental emitter. bool IsUnnestedReductionFasterThanElemental( - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Returns true if either the dimensions being reduced or the dimensions being // kept are contiguous in the input of the reduce instruction. -bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce); +bool IsReductionFromOrToContiguousDimensions( + const HloInstruction& reduce, + const se::DeviceDescription& device_description); // Given the input shape and dimensions to reduce for a reduction, returns // ReductionDimensions. @@ -99,17 +102,17 @@ Vector3 GetReductionTiling(const ReductionDimensions& reduction_dimensions); // How big the reduction dimension can be to be race free. int64_t ReductionDimensionRaceFreeBound( - const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions); + const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Returns whether the given reduction can be safely generated without atomics : // that is, at most one block will write to every output element. -bool ReductionIsRaceFree(const HloModuleConfig& hlo_module_config, - const ReductionDimensions& reduction_dimensions); +bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions, + const se::DeviceDescription& device_description); // Whether the instruction is a reduction hero for the given root. -bool IsRealReductionHero(const HloInstruction& root, - const HloInstruction& hero); +bool IsRealReductionHero(const HloInstruction& root, const HloInstruction& hero, + const se::DeviceDescription& device_description); // Whether `reduction_hero` is compatible with `first_reduce`. bool AreReductionsMultiOutputFusionCompatible( diff --git a/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc b/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc new file mode 100644 index 00000000000000..392d08eb63705b --- /dev/null +++ b/third_party/xla/xla/service/gpu/stream_executor_util_kernel_stub.cc @@ -0,0 +1,21 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +namespace xla::gpu::repeat_buffer_kernel { + +// Stub to make CPU build linker find undefined symbol. +void* kernel() { return nullptr; } + +} // namespace xla::gpu::repeat_buffer_kernel diff --git a/third_party/xla/xla/service/gpu/tests/BUILD b/third_party/xla/xla/service/gpu/tests/BUILD index 7dbc7016841c61..4b406a3b4a9bb3 100644 --- a/third_party/xla/xla/service/gpu/tests/BUILD +++ b/third_party/xla/xla/service/gpu/tests/BUILD @@ -411,42 +411,42 @@ xla_test( ], ) -xla_test( - name = "concatenate_emitter_test", - srcs = ["concatenate_emitter_test.cc"], - backends = ["gpu"], - deps = [ - ":gpu_codegen_test", - "//xla:error_spec", - "//xla/tests:hlo_test_base", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) +# xla_test( +# name = "concatenate_emitter_test", +# srcs = ["concatenate_emitter_test.cc"], +# backends = ["gpu"], +# deps = [ +# ":gpu_codegen_test", +# "//xla:error_spec", +# "//xla/tests:hlo_test_base", +# "@local_tsl//tsl/platform:test", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) -xla_test( - name = "transpose_emitter_test", - srcs = ["transpose_emitter_test.cc"], - backends = ["gpu"], - deps = [ - ":gpu_codegen_test", - "//xla:error_spec", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) +# xla_test( +# name = "transpose_emitter_test", +# srcs = ["transpose_emitter_test.cc"], +# backends = ["gpu"], +# deps = [ +# ":gpu_codegen_test", +# "//xla:error_spec", +# "@local_tsl//tsl/platform:test", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) -xla_test( - name = "reduction_emitter_test", - srcs = ["reduction_emitter_test.cc"], - backends = ["gpu"], - deps = [ - ":gpu_codegen_test", - "//xla:error_spec", - "@local_tsl//tsl/platform:test", - "@local_tsl//tsl/platform:test_main", - ], -) +# xla_test( +# name = "reduction_emitter_test", +# srcs = ["reduction_emitter_test.cc"], +# backends = ["gpu"], +# deps = [ +# ":gpu_codegen_test", +# "//xla:error_spec", +# "@local_tsl//tsl/platform:test", +# "@local_tsl//tsl/platform:test_main", +# ], +# ) xla_test( name = "gpu_ldg_test", @@ -623,35 +623,35 @@ lit_test_suite( name = "hlo_lit_tests", srcs = enforce_glob( [ - "add_preds.hlo", + # "add_preds.hlo", "calling_convention.hlo", "dot_bf16.hlo", - "dynamic_update_slice_inplace.hlo", - "fused_scatter.hlo", - "fused_slice.hlo", + # "dynamic_update_slice_inplace.hlo", + # "fused_scatter.hlo", + # "fused_slice.hlo", "kernel_reuse.hlo", "pad_to_static.hlo", - "reduce_atomic_min.hlo", - "reduce_column_layout_change.hlo", - "reduce_f64_column.hlo", - "reduce_large_row_to_scalar.hlo", - "reduce_row_vectorized.hlo", - "reduce_to_scalar_vectorized.hlo", - "reduce_unnested.hlo", - "reduce_variadic_column.hlo", - "reduction_vectorization_sm_all.hlo", + # "reduce_atomic_min.hlo", + # "reduce_column_layout_change.hlo", + # "reduce_f64_column.hlo", + # "reduce_large_row_to_scalar.hlo", + # "reduce_row_vectorized.hlo", + # "reduce_to_scalar_vectorized.hlo", + # "reduce_unnested.hlo", + # "reduce_variadic_column.hlo", + # "reduction_vectorization_sm_all.hlo", "rng_get_and_update_state.hlo", - "scatter.hlo", - "scatter_bf16.hlo", + # "scatter.hlo", + # "scatter_bf16.hlo", "select_and_scatter.hlo", "single_instruction.hlo", "slice_to_dynamic.hlo", "sorting.hlo", - "transpose_021.hlo", - "transpose_021_extra_output.hlo", - "transpose_10.hlo", - "transpose_210.hlo", - "transpose_210_extra_output.hlo", + # "transpose_021.hlo", + # "transpose_021_extra_output.hlo", + # "transpose_10.hlo", + # "transpose_210.hlo", + # "transpose_210_extra_output.hlo", "triton_naming.hlo", ], include = [ diff --git a/third_party/xla/xla/service/gpu/tests/add_preds.hlo b/third_party/xla/xla/service/gpu/tests/add_preds.hlo deleted file mode 100644 index d86113ae2ad603..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/add_preds.hlo +++ /dev/null @@ -1,26 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s - -// CHECK: define{{( amdgpu_kernel)?}} void @fusion({{.*}}%[[ARG0:.*]], {{.*}}%[[ARG1:.*]], -// CHECK: %[[A:.*]] = load {{.*}} ptr %[[ARG0]] -// CHECK: %[[B:.*]] = load {{.*}} ptr %[[ARG1]] -// CHECK: or {{.*}} %[[A]], %[[B]] - -HloModule xla_computation_f.8, is_scheduled=true - -// Since the conversion to MLIR goes through completely different code paths -// depending on whether an op is fused or not, this separately tests pred -// "addition" in fused context. - -%fused_computation (param_0.1: pred[], param_1: pred[]) -> pred[] { - %param_0.1 = pred[] parameter(0) - %param_1 = pred[] parameter(1) - %add.1 = pred[] add(pred[] %param_0.1, pred[] %param_1) - ROOT %not.1 = pred[] not(pred[] %add.1) -} - -ENTRY %xla_computation_f.8 (parameter.0: pred[], parameter.1: pred[]) -> (pred[]) { - %parameter.0 = pred[] parameter(0) - %parameter.1 = pred[] parameter(1) - %fusion = pred[] fusion(pred[] %parameter.0, pred[] %parameter.1), kind=kLoop, calls=%fused_computation - ROOT %tuple.7 = (pred[]) tuple(pred[] %fusion) -} diff --git a/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc b/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc deleted file mode 100644 index 66920a3f182792..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/concatenate_emitter_test.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2024 The OpenXLA Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "xla/error_spec.h" -#include "xla/service/gpu/tests/gpu_codegen_test.h" -#include "xla/tests/hlo_test_base.h" -#include "tsl/platform/test.h" - -namespace xla { -namespace { - -class ConcatenateEmitterTest : public gpu::GpuCodegenTest { - protected: - ConcatenateEmitterTest() = default; - DebugOptions GetDebugOptionsForTest() override { - auto opts = HloTestBase::GetDebugOptionsForTest(); - opts.set_xla_gpu_mlir_emitter_level(0); - return opts; - } -}; - -TEST_F(ConcatenateEmitterTest, Simple) { - const char* const kHloString = R"( - HloModule module - - ENTRY main { - param0 = f32[128] parameter(0) - param1 = f32[128] parameter(1) - ROOT concat = f32[256] concatenate(param0, param1), dimensions={0} - })"; - - auto expected_ir = R"( -; CHECK-DAG: %[[ARG0:.*]] = addrspacecast ptr %arg0 -; CHECK-DAG: %[[ARG1:.*]] = addrspacecast ptr %arg1 -; CHECK-DAG: %[[ARG2:.*]] = addrspacecast ptr %arg2 -; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG0]] -; CHECK-DAG: %[[VAL:.*]] = load float, ptr addrspace(1) %[[PTR]] -; CHECK-DAG: %[[DST:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG2]] -; CHECK: store float %[[VAL]], ptr addrspace(1) %[[DST]] -; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG1]] -; CHECK-DAG: %[[VAL:.*]] = load float, ptr addrspace(1) %[[PTR]] -; CHECK-DAG: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[DST]], i64 512 -; CHECK: store float %[[VAL]], ptr addrspace(1) %[[PTR]] -; CHECK: !"reqntidx", i32 128 -)"; - CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(ConcatenateEmitterTest, PrologueAndEpilogue) { - const char* const kHloString = R"( - HloModule module - - fused_computation { - param0 = f32[128] parameter(0) - negate = f32[128] negate(param0) - param1 = f32[128] parameter(1) - concat = f32[256] concatenate(negate, param1), dimensions={0} - param2 = f32[256] parameter(2) - ROOT add = f32[256] add(concat, param2) - } - - ENTRY main { - param0 = f32[128] parameter(0) - param1 = f32[128] parameter(1) - param2 = f32[256] parameter(2) - ROOT %fusion = f32[256] fusion(param0, param1, param2), kind=kInput, calls=fused_computation - })"; - - auto expected_ir = R"( -; CHECK-DAG: %[[ARG0:.*]] = addrspacecast ptr %arg0 -; CHECK-DAG: %[[ARG1:.*]] = addrspacecast ptr %arg1 -; CHECK-DAG: %[[ARG2:.*]] = addrspacecast ptr %arg2 -; CHECK-DAG: %[[ARG3:.*]] = addrspacecast ptr %arg3 -; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG0]] -; CHECK: %[[RHS:.*]] = load float, ptr addrspace(1) %[[PTR]] -; CHECK: %[[SRC:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG2]] -; CHECK: %[[LHS:.*]] = load float, ptr addrspace(1) %[[SRC]] -; CHECK: %[[VAL:.*]] = fsub float %[[LHS]], %[[RHS]] -; CHECK: %[[DST:.*]] = getelementptr inbounds [256 x float], ptr addrspace(1) %[[ARG3]] -; CHECK: store float %[[VAL]], ptr addrspace(1) %[[DST]] -; CHECK: %[[PTR:.*]] = getelementptr float, ptr addrspace(1) %[[ARG1]] -; CHECK: %[[LHS:.*]] = load float, ptr addrspace(1) %[[PTR]] -; CHECK: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[SRC]], i64 512 -; CHECK: %[[RHS:.*]] = load float, ptr addrspace(1) %[[PTR]] -; CHECK: %[[VAL:.*]] = fadd float %[[LHS]], %[[RHS]] -; CHECK: %[[PTR:.*]] = getelementptr inbounds i8, ptr addrspace(1) %[[DST]], i64 512 -; CHECK: store float %[[VAL]], ptr addrspace(1) %[[PTR]] -)"; - CompileAndVerifyIr(kHloString, MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true, - /*run_optimization_passes=*/false); - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(ConcatenateEmitterTest, MajorDimension) { - const char* const kHloString = R"( - HloModule module - - fused_computation { - param0 = f32[16,16] parameter(0) - param1 = f32[16,16] parameter(1) - ROOT concat = f32[32,16] concatenate(param0, param1), dimensions={0} - } - - ENTRY main { - param0 = f32[16,16] parameter(0) - param1 = f32[16,16] parameter(1) - ROOT %fusion = f32[32,16] fusion(param0, param1), kind=kInput, calls=fused_computation - })"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(ConcatenateEmitterTest, DifferentSizes) { - const char* const kHloString = R"( - HloModule module - - ENTRY main { - param0 = f32[112] parameter(0) - param1 = f32[128] parameter(1) - ROOT concat = f32[240] concatenate(param0, param1), dimensions={0} - })"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(ConcatenateEmitterTest, RepeatedInput) { - const char* const kHloString = R"( - HloModule module - - ENTRY main { - param0 = f32[128] parameter(0) - ROOT concat = f32[256] concatenate(param0, param0), dimensions={0} - })"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -TEST_F(ConcatenateEmitterTest, BitcastEpilogue) { - const char* const kHloString = R"( - HloModule module - - fused_computation { - param0 = f32[128] parameter(0) - param1 = f32[128] parameter(1) - concat = f32[256] concatenate(param0, param1), dimensions={0} - ROOT bitcast = f32[1,16,16] bitcast(concat) - } - - ENTRY main { - param0 = f32[128] parameter(0) - param1 = f32[128] parameter(1) - ROOT %fusion = f32[1,16,16] fusion(param0, param1), kind=kInput, calls=fused_computation - })"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{1e-3})); -} - -} // namespace -} // namespace xla diff --git a/third_party/xla/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo b/third_party/xla/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo deleted file mode 100644 index 8fd8c1cbe667d3..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/dynamic_update_slice_inplace.hlo +++ /dev/null @@ -1,349 +0,0 @@ -// This test explicitly disables MLIR emitters. The unit tests in -// in_place_dynamic_update_slice_mlir_test cover what is tested here - when -// the flag is removed, this test file can be deleted. -// RUN: hlo-opt %s --xla_gpu_mlir_emitter_level=0 --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = load i32, ptr @0, align 4 -// CHECK: %[[VAL_1:.*]] = icmp sge i32 0, %[[VAL_0]] -// CHECK: %[[VAL_2:.*]] = select i1 %[[VAL_1]], i32 0, i32 %[[VAL_0]] -// CHECK: %[[VAL_3:.*]] = icmp sle i32 49, %[[VAL_2]] -// CHECK: %[[VAL_4:.*]] = select i1 %[[VAL_3]], i32 49, i32 %[[VAL_2]] -// CHECK: %[[VAL_5:.*]] = load i32, ptr @0, align 4 -// CHECK: %[[VAL_6:.*]] = icmp sge i32 0, %[[VAL_5]] -// CHECK: %[[VAL_7:.*]] = select i1 %[[VAL_6]], i32 0, i32 %[[VAL_5]] -// CHECK: %[[VAL_8:.*]] = icmp sle i32 0, %[[VAL_7]] -// CHECK: %[[VAL_9:.*]] = select i1 %[[VAL_8]], i32 0, i32 %[[VAL_7]] -// CHECK: %[[VAL_10:.*]] = load i32, ptr @0, align 4 -// CHECK: %[[VAL_11:.*]] = icmp sge i32 0, %[[VAL_10]] -// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_11]], i32 0, i32 %[[VAL_10]] -// CHECK: %[[VAL_13:.*]] = icmp sle i32 0, %[[VAL_12]] -// CHECK: %[[VAL_14:.*]] = select i1 %[[VAL_13]], i32 0, i32 %[[VAL_12]] -// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_16:.*]] = zext i32 %[[VAL_15]] to i64 -// CHECK-PTX: %[[VAL_17:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_17:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_18:.*]] = zext i32 %[[VAL_17]] to i64 -// CHECK-PTX: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 128 -// CHECK-GCN: %[[VAL_19:.*]] = mul nuw nsw i64 %[[VAL_16]], 256 -// CHECK: %[[VAL_20:.*]] = add nuw nsw i64 %[[VAL_19]], %[[VAL_18]] -// CHECK: %[[VAL_21:.*]] = icmp ult i64 %[[VAL_20]], 98304 -// CHECK: call void @llvm.assume(i1 %[[VAL_21]]) -// CHECK: %[[VAL_22:.*]] = add nuw nsw i64 %[[VAL_20]], 0 -// CHECK: %[[VAL_23:.*]] = udiv i64 %[[VAL_22]], 1 -// CHECK: %[[VAL_24:.*]] = urem i64 %[[VAL_23]], 1024 -// CHECK: %[[VAL_25:.*]] = udiv i64 %[[VAL_22]], 1024 -// CHECK: %[[VAL_26:.*]] = urem i64 %[[VAL_25]], 96 -// CHECK: %[[VAL_27:.*]] = udiv i64 %[[VAL_22]], 98304 -// CHECK: %[[VAL_28:.*]] = icmp ult i64 %[[VAL_20]], 98304 -// CHECK: br i1 %[[VAL_28]], label %[[VAL_29:.*]], label %[[VAL_30:.*]] -// CHECK: dynamic-update-slice.in_bounds-after: ; preds = %[[VAL_29]], %[[VAL_31:.*]] -// CHECK: ret void -// CHECK: dynamic-update-slice.in_bounds-true: ; preds = %[[VAL_31]] -// CHECK: %[[VAL_32:.*]] = sext i32 %[[VAL_4]] to i64 -// CHECK: %[[VAL_33:.*]] = add i64 %[[VAL_32]], %[[VAL_27]] -// CHECK: %[[VAL_34:.*]] = sext i32 %[[VAL_9]] to i64 -// CHECK: %[[VAL_35:.*]] = add i64 %[[VAL_34]], %[[VAL_26]] -// CHECK: %[[VAL_36:.*]] = sext i32 %[[VAL_14]] to i64 -// CHECK: %[[VAL_37:.*]] = add i64 %[[VAL_36]], %[[VAL_24]] -// CHECK: %[[VAL_38:.*]] = getelementptr half, ptr %[[VAL_39:.*]], i64 %[[VAL_20]] -// CHECK: %[[VAL_40:.*]] = getelementptr inbounds half, ptr %[[VAL_38]], i64 0 -// CHECK: %[[VAL_41:.*]] = load half, ptr %[[VAL_40]], align 2, !invariant.load -// CHECK: %[[VAL_42:.*]] = getelementptr inbounds [50 x [96 x [1024 x half]]], ptr %[[VAL_43:.*]], i64 0, i64 %[[VAL_33]], i64 %[[VAL_35]], i64 %[[VAL_37]] -// CHECK: store half %[[VAL_41]], ptr %[[VAL_42]], align 2 -// CHECK: br label %[[VAL_30]] - -HloModule TestModule, is_scheduled=true - -fusion.1 { - p.0 = f16[50,96,1024]{2,1,0} parameter(0) - p.1 = f16[1,96,1024]{2,1,0} parameter(1) - c.0 = s32[] constant(0) - ROOT %dynamic-update-slice = f16[50,96,1024]{2,1,0} dynamic-update-slice(p.0, p.1, c.0, c.0, c.0) -} - -ENTRY entry { - p.0 = f16[50,96,1024]{2,1,0} parameter(0) - p.1 = f16[1,96,1024]{2,1,0} parameter(1) - ROOT f1 = f16[50,96,1024] fusion(p.0, p.1), kind=kLoop, calls=fusion.1 -} - -// ----- - -// CHECK-LABEL: @fusion -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]] -// CHECK: call void @llvm.assume -// CHECK: %[[COND:.*]] = icmp ult i32 %[[LINEAR_INDEX:.*]], 6 -// CHECK: br i1 %[[COND]], label %[[FUSION:.*]].in_bounds-true, label %[[FUSION]].in_bounds-after -// CHECK: [[FUSION]].in_bounds-after: -// CHECK: ret void -// CHECK: [[FUSION]].in_bounds-true: -// CHECK: br i1 %{{.*}}, label %[[SLICE:.*]]-true, label %[[SLICE]]-false -// CHECK: [[SLICE]]-after: -// CHECK-PTX: %[[VAL_85:.*]] = load i32, ptr %[[RET_VALUE_ADDR:.*]], align 4 -// CHECK-GCN: %[[VAL_85:.*]] = load i32, ptr addrspace(5) %[[RET_VALUE_ADDR:.*]], align 4 -// CHECK: %[[VAL_86:.*]] = getelementptr i32, ptr %[[ARG2]], i32 %[[LINEAR_INDEX]] -// CHECK: %[[VAL_88:.*]] = getelementptr inbounds i32, ptr %[[VAL_86]], i32 0 -// CHECK: store i32 %[[VAL_85]], ptr %[[VAL_88]], align 4 -// CHECK: br label %[[FUSION]].in_bounds-after -// CHECK: [[SLICE]]-true: -// CHECK: %[[VAL_108:.*]] = getelementptr inbounds [6 x i32], ptr %[[ARG0]], i32 0, i32 %[[VAL_106:.*]] -// CHECK: %[[VAL_110:.*]] = load i32, ptr %[[VAL_108]], align 4, !invariant.load -// CHECK: %[[VAL_111:.*]] = load i32, ptr @1, align 4 -// CHECK: %[[VAL_112:.*]] = add i32 %[[VAL_110]], %[[VAL_111]] -// CHECK-PTX: store i32 %[[VAL_112]], ptr %[[RET_VALUE_ADDR]], align 4 -// CHECK-GCN: store i32 %[[VAL_112]], ptr addrspace(5) %[[RET_VALUE_ADDR]], align 4 -// CHECK: br label %[[SLICE]]-after -// CHECK: [[SLICE]]-false: -// CHECK: %[[VAL_118:.*]] = getelementptr i32, ptr %[[ARG0]], i32 %[[LINEAR_INDEX]] -// CHECK: %[[VAL_119:.*]] = getelementptr inbounds i32, ptr %[[VAL_118]], i32 0 -// CHECK: %[[VAL_120:.*]] = load i32, ptr %[[VAL_119]], align 4, !invariant.load -// CHECK-PTX: store i32 %[[VAL_120]], ptr %[[RET_VALUE_ADDR]], align 4 -// CHECK-GCN: store i32 %[[VAL_120]], ptr addrspace(5) %[[RET_VALUE_ADDR]], align 4 -// CHECK: br label %[[SLICE]]-after - -HloModule fusion, is_scheduled=true - -fused_computation { - param_0.1 = s32[6]{0} parameter(0) - bitcast = s32[2,3]{1,0} bitcast(param_0.1) - zero = s32[] constant(0) - param_1.1 = s32[] parameter(1) - dynamic-slice = s32[1,1]{1,0} dynamic-slice(bitcast, param_1.1, zero), dynamic_slice_sizes={1,1} - one = s32[] constant(1) - bitcasted_one = s32[1,1]{1,0} bitcast(one) - add = s32[1,1] add(dynamic-slice, bitcasted_one) - dynamic-update-slice = s32[2,3]{1,0} dynamic-update-slice(bitcast, add, param_1.1, zero) - ROOT bitcast.1 = s32[6]{0} bitcast(dynamic-update-slice) -} - -ENTRY main { - param_0 = s32[6]{0} parameter(0) - param_1 = s32[] parameter(1) - ROOT fusion = s32[6]{0} fusion(param_0, param_1), kind=kInput, calls=fused_computation -} - -// ----- - -// CHECK-LABEL: @fusion_root_multiple -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG3:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG4:[A-Za-z0-9]*]]) -// CHECK: %[[COND:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND]], label %[[DUS0:.*]].in_bounds-true, label %[[DUS0]].in_bounds-after -// CHECK: [[DUS0]].in_bounds-after: -// CHECK: %[[COND2:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND2]], label %[[DUS1:.*]].in_bounds-true, label %[[DUS1]].in_bounds-after -// CHECK: [[DUS1]].in_bounds-after: -// CHECK-NEXT: ret void -// CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_141:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_141]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_185:.*]], i64 %[[VAL_187:.*]], i64 %[[VAL_189:.*]] -// CHECK: [[DUS1]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_173:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_173]] -// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x bfloat]]], ptr %[[ARG2]], i64 0, i64 %[[VAL_208:.*]], i64 %[[VAL_210:.*]], i64 %[[VAL_212:.*]] - -HloModule MultipleInplaceDus, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } - -fused_computation { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[8,11,12] parameter(2) - p3 = bf16[1,11,12] parameter(3) - p4 = s32[] parameter(4) - c0 = s32[] constant(0) - cmp = pred[] compare(p4, c0), direction=EQ - broadcast = pred[1,11,12] broadcast(cmp), dimensions={} - select = bf16[1,11,12] select(broadcast, p1, p3) - dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) - dus1 = bf16[8,11,12] dynamic-update-slice(p2, select, c0, c0, c0) - ROOT tuple = (bf16[10,11,12], bf16[8,11,12]) tuple(dus0, dus1) -} - -ENTRY main { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[8,11,12] parameter(2) - p3 = bf16[1,11,12] parameter(3) - p4 = s32[] parameter(4) - ROOT fusion_root_multiple = (bf16[10,11,12], bf16[8,11,12]) fusion(p0, p1, p2, p3, p4), kind=kLoop, calls=fused_computation -} - -// ----- - -// CHECK-LABEL: @fusion_root_multiple_transpose_bitcast -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG3:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG4:[A-Za-z0-9]*]]) -// CHECK: %[[COND:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND]], label %[[DUS0:.*]].in_bounds-true, label %[[DUS0]].in_bounds-after -// CHECK: [[DUS0]].in_bounds-after: -// CHECK: %[[COND2:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND2]], label %[[DUS1:.*]].in_bounds-true, label %[[DUS1]].in_bounds-after -// CHECK: [[DUS1]].in_bounds-after: -// CHECK-NEXT: ret void -// CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_247:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_247]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_291:.*]], i64 %[[VAL_293:.*]], i64 %[[VAL_295:.*]] -// CHECK: [[DUS1]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_279:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG3]], i64 %[[VAL_279]] -// CHECK-DAG: getelementptr inbounds [8 x [11 x [12 x bfloat]]], ptr %[[ARG2]], i64 0, i64 %[[VAL_314:.*]], i64 %[[VAL_316:.*]], i64 %[[VAL_318:.*]] - -HloModule MultipleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {0}: (0, {}), {1}: (2, {}) } - -fused_computation { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[8,11,12] parameter(2) - p3 = bf16[1,11,12] parameter(3) - p4 = s32[] parameter(4) - c0 = s32[] constant(0) - cmp = pred[] compare(p4, c0), direction=EQ - broadcast = pred[1,11,12] broadcast(cmp), dimensions={} - select = bf16[1,11,12] select(broadcast, p1, p3) - dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) - bitcasted_dus0 = bf16[11,10,12] bitcast(dus0) - dus1 = bf16[8,11,12] dynamic-update-slice(p2, select, c0, c0, c0) - ROOT tuple = (bf16[11,10,12], bf16[8,11,12]) tuple(bitcasted_dus0, dus1) -} - -ENTRY main { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[8,11,12] parameter(2) - p3 = bf16[1,11,12] parameter(3) - p4 = s32[] parameter(4) - ROOT fusion_root_multiple_transpose_bitcast = (bf16[11,10,12], bf16[8,11,12]) fusion(p0, p1, p2, p3, p4), kind=kLoop, calls=fused_computation -} - -// ----- - -// CHECK-LABEL: @fusion_root_transpose_bitcast -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG3:[A-Za-z0-9]*]]) -// CHECK: %[[COND:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND]], label %[[DUS0:.*]].in_bounds-true, label %[[DUS0]].in_bounds-after -// CHECK: [[DUS0]].in_bounds-after: -// CHECK-NEXT: ret void -// CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_353:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_353]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_366:.*]], i64 %[[VAL_368:.*]], i64 %[[VAL_370:.*]] - -HloModule SingleInplaceDusWithTransposeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } - -single_inplace_dus_with_transpose_bitcast { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - c0 = s32[] constant(0) - cmp = pred[] compare(p3, c0), direction=EQ - broadcast = pred[1,11,12] broadcast(cmp), dimensions={} - select = bf16[1,11,12] select(broadcast, p1, p2) - dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) - ROOT bitcasted_dus0 = bf16[11,10,12] bitcast(dus0) -} - -ENTRY main { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion_root_transpose_bitcast = bf16[11,10,12] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_transpose_bitcast -} - -// ----- - -// CHECK-LABEL: @fusion_root_reshape_bitcast -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG3:[A-Za-z0-9]*]]) -// CHECK: %[[COND:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND]], label %[[DUS0:.*]].in_bounds-true, label %[[DUS0]].in_bounds-after -// CHECK: [[DUS0]].in_bounds-after: -// CHECK-NEXT: ret void -// CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_408:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_408:.*]] -// CHECK-DAG: getelementptr inbounds [10 x [11 x [12 x bfloat]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_421:.*]], i64 %[[VAL_423:.*]], i64 %[[VAL_425:.*]] - -HloModule SingleInplaceDusWithReshapeBitcastToTheRoot, is_scheduled=true, input_output_alias={ {}: (0, {}) } - -single_inplace_dus_with_reshape_bitcast { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - c0 = s32[] constant(0) - cmp = pred[] compare(p3, c0), direction=EQ - broadcast = pred[1,11,12] broadcast(cmp), dimensions={} - select = bf16[1,11,12] select(broadcast, p1, p2) - dus0 = bf16[10,11,12] dynamic-update-slice(p0, select, c0, c0, c0) - ROOT bitcasted_dus0 = bf16[10,11,6,2] bitcast(dus0) -} - -ENTRY main { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion_root_reshape_bitcast = bf16[10,11,6,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_reshape_bitcast -} - -// ----- - -// CHECK-LABEL: @fusion_root_bitcast_both_ways -// CHECK-SAME: %[[ARG0:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG1:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG2:[A-Za-z0-9]*]], -// CHECK-SAME: %[[ARG3:[A-Za-z0-9]*]]) -// CHECK: %[[COND:[A-Za-z0-9]*]] = icmp ult i64 %[[LINEAR_INDEX:.*]], 132 -// CHECK: br i1 %[[COND]], label %[[DUS0:.*]].in_bounds-true, label %[[DUS0]].in_bounds-after -// CHECK: [[DUS0]].in_bounds-after: -// CHECK-NEXT: ret void -// CHECK: [[DUS0]].in_bounds-true -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG1]], i64 %[[VAL_468:.*]] -// CHECK-DAG: getelementptr bfloat, ptr %[[ARG2]], i64 %[[VAL_468]] -// CHECK-DAG: getelementptr inbounds [10 x [6 x [2 x [11 x bfloat]]]], ptr %[[ARG0]], i64 0, i64 %[[VAL_483:.*]], i64 %[[VAL_485:.*]], i64 %[[VAL_487:.*]], i64 %[[VAL_489:.*]] - -HloModule SingleInplaceDusWithBitcastToTheRootAndFromTheParameter, is_scheduled=true, input_output_alias={ {}: (0, {}) } - -single_inplace_dus_with_bitcast_to_the_root_and_from_the_parameter { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - c0 = s32[] constant(0) - cmp = pred[] compare(p3, c0), direction=EQ - broadcast = pred[1,11,12] broadcast(cmp), dimensions={} - select = bf16[1,11,12] select(broadcast, p1, p2) - bitcasted_p0 = bf16[10,6,2,11] bitcast(p0) - bitcasted_select = bf16[1,6,2,11] bitcast(select) - dus0 = bf16[10,6,2,11] dynamic-update-slice(bitcasted_p0, bitcasted_select, c0, c0, c0, c0) - ROOT bitcasted_dus0 = bf16[10,11,6,2] bitcast(dus0) -} - -ENTRY main { - p0 = bf16[10,11,12] parameter(0) - p1 = bf16[1,11,12] parameter(1) - p2 = bf16[1,11,12] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion_root_bitcast_both_ways = bf16[10,11,6,2] fusion(p0, p1, p2, p3), kind=kLoop, calls=single_inplace_dus_with_bitcast_to_the_root_and_from_the_parameter -} diff --git a/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo b/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo deleted file mode 100644 index b63fed6cb439eb..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/fused_scatter.hlo +++ /dev/null @@ -1,85 +0,0 @@ -// RUN: hlo-opt %s --xla_gpu_mlir_emitter_level=0 --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_3:.*]] = mul nuw nsw i32 %[[VAL_1]], 6 -// CHECK: %[[VAL_4:.*]] = add nuw nsw i32 %[[VAL_3]], %[[VAL_2]] -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_5]]) -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_4]], 0 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 3 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_12:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: br i1 %[[VAL_12]], label %[[VAL_13:.*]], label %[[VAL_14:.*]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_15:.*]], %[[VAL_16:.*]] -// CHECK: ret void -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_16]] -// CHECK: %[[VAL_17:.*]] = getelementptr inbounds [2 x [1 x i32]], ptr %[[VAL_18:.*]], i32 0, i32 %[[VAL_11]], i32 0 -// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_17]], align 4, !invariant.load !10 -// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_10]], %[[VAL_19]] -// CHECK: %[[VAL_21:.*]] = icmp ult i32 %[[VAL_19]], 3 -// CHECK: %[[VAL_22:.*]] = and i1 true, %[[VAL_21]] -// CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_15]] -// CHECK: scatter.in_bounds-after3: ; preds = %[[VAL_23]], %[[VAL_13]] -// CHECK: br label %[[VAL_14]] -// CHECK: scatter.in_bounds-true2: ; preds = %[[VAL_13]] -// CHECK: %[[VAL_24:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_25:.*]], i32 0, i32 %[[VAL_20]], i32 %[[VAL_8]] -// CHECK: %[[VAL_26:.*]] = getelementptr i32, ptr %[[VAL_27:.*]], i32 %[[VAL_4]] -// CHECK: %[[VAL_28:.*]] = getelementptr inbounds i32, ptr %[[VAL_26]], i32 0 -// CHECK: %[[VAL_29:.*]] = load i32, ptr %[[VAL_28]], align 4, !invariant.load !10 -// CHECK: store i32 %[[VAL_29]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_30:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: store atomic i32 %[[VAL_30]], ptr %[[VAL_24]] unordered, align 4 -// CHECK: br label %[[VAL_15]] - -HloModule TensorFlowScatterV1, is_scheduled=true - -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - ROOT rhs = s32[] parameter(1) -} - -fused_computation { - param_0 = s32[3,3]{1,0} parameter(0) - ROOT operand.1 = s32[3,3]{1,0} add(param_0, param_0) -} - -fused_computation.1 { - param_0.1 = s32[2,1]{1,0} parameter(0) - ROOT indices.1 = s32[2,1]{1,0} add(param_0.1, param_0.1) -} - -fused_computation.2 { - param_0.2 = s32[2,1,3]{2,1,0} parameter(0) - ROOT updates.1 = s32[2,1,3]{2,1,0} add(param_0.2, param_0.2) -} - -fused_computation.3 { - operand = s32[3,3]{1,0} parameter(0) - indices = s32[2,1]{1,0} parameter(1) - updates = s32[2,1,3]{2,1,0} parameter(2) - ROOT scatter = s32[3,3] scatter(operand, indices, updates), - to_apply=update_s32, - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p1 = s32[2,1] parameter(1) - wrapped_indices = s32[2,1]{1,0} fusion(p1), kind=kLoop, calls=fused_computation.1 - p2 = s32[2,1,3] parameter(2) - wrapped_updates = s32[2,1,3]{2,1,0} fusion(p2), kind=kLoop, calls=fused_computation.2 - p0 = s32[3,3] parameter(0) - wrapped_operand = s32[3,3]{1,0} fusion(p0), kind=kLoop, calls=fused_computation - ROOT wrapped_scatter = s32[3,3] fusion(wrapped_operand, wrapped_indices, wrapped_updates), kind=kInput, calls=fused_computation.3 -} diff --git a/third_party/xla/xla/service/gpu/tests/fused_slice.hlo b/third_party/xla/xla/service/gpu/tests/fused_slice.hlo deleted file mode 100644 index e71cb4880385e7..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/fused_slice.hlo +++ /dev/null @@ -1,106 +0,0 @@ -// RUN: hlo-opt %s --xla_gpu_mlir_emitter_level=0 --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 128 -// CHECK-GCN: %[[VAL_2:.*]] = mul nuw nsw i32 %[[VAL_0]], 256 -// CHECK: %[[VAL_3:.*]] = add nuw nsw i32 %[[VAL_2]], %[[VAL_1]] -// CHECK: %[[VAL_4:.*]] = icmp ult i32 %[[VAL_3]], 2048 -// CHECK: call void @llvm.assume(i1 %[[VAL_4]]) -// CHECK: %[[VAL_5:.*]] = add nuw nsw i32 %[[VAL_3]], 0 -// CHECK: %[[VAL_6:.*]] = udiv i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_3]], 2047 -// CHECK: br i1 %[[VAL_7]], label %[[VAL_8:.*]], label %[[VAL_9:.*]] -// CHECK: fusion.in_bounds-after: ; preds = %[[VAL_10:.*]], %[[VAL_11:.*]] -// CHECK: ret void -// CHECK: fusion.in_bounds-true: ; preds = %[[VAL_11]] -// CHECK: br label %[[VAL_12:.*]] -// CHECK: concat_index_from_operand_id0: ; preds = %[[VAL_13:.*]] -// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_13]] ] -// CHECK: %[[VAL_15:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_14]] -// CHECK: %[[VAL_16:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_17:.*]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_18:.*]] = load half, ptr %[[VAL_16]], align 2, !invariant.load -// CHECK: %[[VAL_19:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_20:.*]], i32 0, i32 %[[VAL_15]] -// CHECK: %[[VAL_21:.*]] = load half, ptr %[[VAL_19]], align 2, !invariant.load -// CHECK: %[[VAL_22:.*]] = fmul half %[[VAL_18]], %[[VAL_21]] -// CHECK: br label %[[VAL_23:.*]] -// CHECK: concat_index_from_operand_id1: ; preds = %[[VAL_24:.*]] -// CHECK: %[[VAL_25:.*]] = phi i32 [ 1024, %[[VAL_24]] ] -// CHECK: %[[VAL_26:.*]] = sub nsw i32 %[[VAL_6]], %[[VAL_25]] -// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_28:.*]], i32 0, i32 %[[VAL_26]] -// CHECK: %[[VAL_29:.*]] = load half, ptr %[[VAL_27]], align 2, !invariant.load -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_26]] -// CHECK: %[[VAL_32:.*]] = load half, ptr %[[VAL_30]], align 2, !invariant.load -// CHECK: %[[VAL_33:.*]] = fadd half %[[VAL_29]], %[[VAL_32]] -// CHECK: br label %[[VAL_23]] -// CHECK: concatenate.pivot.1024.: ; preds = %[[VAL_8]] -// CHECK: %[[VAL_34:.*]] = icmp ult i32 %[[VAL_6]], 1024 -// CHECK: br i1 %[[VAL_34]], label %[[VAL_13]], label %[[VAL_24]] -// CHECK: concatenate.pivot.0.: ; preds = %[[VAL_12]] -// CHECK: br label %[[VAL_35:.*]] -// CHECK: concatenate.pivot.1024.1: ; preds = %[[VAL_12]] -// CHECK: br label %[[VAL_36:.*]] -// CHECK: concat.1.merge: ; preds = %[[VAL_36]], %[[VAL_35]] -// CHECK: %[[VAL_37:.*]] = phi half [ %[[VAL_22]], %[[VAL_35]] ], [ %[[VAL_33]], %[[VAL_36]] ] -// CHECK: %[[VAL_38:.*]] = icmp sge i32 %[[VAL_6]], 0 -// CHECK: %[[VAL_39:.*]] = icmp slt i32 %[[VAL_6]], 1024 -// CHECK: %[[VAL_40:.*]] = and i1 %[[VAL_38]], %[[VAL_39]] -// CHECK: br i1 %[[VAL_40]], label %[[VAL_41:.*]], label %[[VAL_42:.*]] -// CHECK: slice0-after: ; preds = %[[VAL_41]], %[[VAL_23]] -// CHECK: %[[VAL_43:.*]] = icmp sge i32 %[[VAL_6]], 1024 -// CHECK: %[[VAL_44:.*]] = icmp slt i32 %[[VAL_6]], 2047 -// CHECK: %[[VAL_45:.*]] = and i1 %[[VAL_43]], %[[VAL_44]] -// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_47:.*]] -// CHECK: slice1-after: ; preds = %[[VAL_46]], %[[VAL_42]] -// CHECK: %[[VAL_48:.*]] = icmp sge i32 %[[VAL_6]], 2047 -// CHECK: %[[VAL_49:.*]] = icmp slt i32 %[[VAL_6]], 2047 -// CHECK: %[[VAL_50:.*]] = and i1 %[[VAL_48]], %[[VAL_49]] -// CHECK: br i1 %[[VAL_50]], label %[[VAL_51:.*]], label %[[VAL_10]] -// CHECK: slice2-after: ; preds = %[[VAL_51]], %[[VAL_47]] -// CHECK: br label %[[VAL_9]] -// CHECK: slice0-true: ; preds = %[[VAL_23]] -// CHECK: %[[VAL_52:.*]] = sub i32 %[[VAL_6]], 0 -// CHECK: %[[VAL_53:.*]] = getelementptr inbounds [1024 x half], ptr %[[VAL_54:.*]], i32 0, i32 %[[VAL_52]] -// CHECK: store half %[[VAL_37]], ptr %[[VAL_53]], align 2 -// CHECK: br label %[[VAL_42]] -// CHECK: slice1-true: ; preds = %[[VAL_42]] -// CHECK: %[[VAL_55:.*]] = sub i32 %[[VAL_6]], 1024 -// CHECK: %[[VAL_56:.*]] = getelementptr inbounds [1023 x half], ptr %[[VAL_57:.*]], i32 0, i32 %[[VAL_55]] -// CHECK: store half %[[VAL_37]], ptr %[[VAL_56]], align 2 -// CHECK: br label %[[VAL_47]] -// CHECK: slice2-true: ; preds = %[[VAL_47]] -// CHECK: %[[VAL_58:.*]] = sub i32 %[[VAL_6]], 2047 -// CHECK: %[[VAL_59:.*]] = getelementptr inbounds [0 x half], ptr %[[VAL_60:.*]], i32 0, i32 %[[VAL_58]] -// CHECK: store half %[[VAL_37]], ptr %[[VAL_59]], align 2 -// CHECK: br label %[[VAL_10]] - -HloModule input_fusion_with_a_tuple_of_slices, is_scheduled=true - -fused_computation { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - mul.1 = f16[1024]{0} multiply(arg.1, arg.2) - add.1 = f16[1023]{0} add(arg.3, arg.4) - concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0} - slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]} - slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]} - slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]} - ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) - tuple(slice.1, slice.2, slice.3) -} - -ENTRY kernel_entry { - arg.1 = f16[1024]{0} parameter(0) - arg.2 = f16[1024]{0} parameter(1) - arg.3 = f16[1023]{0} parameter(2) - arg.4 = f16[1023]{0} parameter(3) - ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0}) - fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc index e0f399dc017e61..e60e5040f03030 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_int4_test.cc @@ -99,19 +99,20 @@ TEST_F(GpuInt4Test, TestOddElements) { // A conditional branch should check if the index is in bounds within the // unrolled loop - absl::string_view expected_ir; - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() == 0) { - expected_ir = R"( - ; CHECK: {{.*}}.in_bounds-true: - ; CHECK-NEXT: %[[in_bounds:.*]] = icmp ult i32 %linear_index0, 5 - ; CHECK-NEXT: br i1 %{{.*}}, label %[[in_bounds_true:.*unrolled_in_bounds-true]], label %[[in_bounds_after:.*unrolled_in_bounds-after]] - ; - ; CHECK: [[in_bounds_true]]: - ; CHECK: %{{.*}} = load i8, ptr %{{.*}}, align 1 - ; CHECK: store i8 %{{.*}}, ptr %{{.*}}, align 1 - ; CHECK: br label %[[in_bounds_after]])"; - } else { - expected_ir = R"( + // absl::string_view expected_ir; + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() == 0) { + // expected_ir = R"( + // ; CHECK: {{.*}}.in_bounds-true: + // ; CHECK-NEXT: %[[in_bounds:.*]] = icmp ult i32 %linear_index0, 5 + // ; CHECK-NEXT: br i1 %{{.*}}, label %[[in_bounds_true:.*unrolled_in_bounds-true]], label %[[in_bounds_after:.*unrolled_in_bounds-after]] + // ; + // ; CHECK: [[in_bounds_true]]: + // ; CHECK: %{{.*}} = load i8, ptr %{{.*}}, align 1 + // ; CHECK: store i8 %{{.*}}, ptr %{{.*}}, align 1 + // ; CHECK: br label %[[in_bounds_after]])"; + // } else { + // expected_ir = R"( +absl::string_view expected_ir = R"( ; CHECK: %[[in_bounds:.*]] = icmp sle i32 %{{.*}}, 1 ; CHECK-NEXT: br i1 %[[in_bounds]], label %[[in_bounds_true:.*]], label %[[in_bounds_after:.*]] ; CHECK: [[in_bounds_true]]: @@ -120,7 +121,7 @@ TEST_F(GpuInt4Test, TestOddElements) { ; CHECK: br label %[[in_bounds_after]] ; CHECK: [[in_bounds_after]]: ; CHECK-NEXT: ret void)"; - } + //} CompileAndVerifyIr(std::move(hlo_module), MakePlatformSpecificLlvm(expected_ir), /*match_optimized_ir=*/false); diff --git a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc index 132f1524f27b0f..5ca4c0c774dc10 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_kernel_tiling_test.cc @@ -36,7 +36,7 @@ class GpuKernelTilingTest : public GpuCodegenTest { HloModuleConfig ConfigWithLayoutAssignment() { HloModuleConfig config; auto debug_options = HloTestBase::GetDebugOptionsForTest(); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // debug_options.set_xla_gpu_mlir_emitter_level(3); config.set_debug_options(debug_options); return config; } @@ -46,7 +46,7 @@ class GpuKernelTilingTest : public GpuCodegenTest { auto debug_options = HloTestBase::GetDebugOptionsForTest(); // Disable layout_assignment to use the preassigned layouts. debug_options.add_xla_disable_hlo_passes("layout-assignment"); - debug_options.set_xla_gpu_mlir_emitter_level(3); + // debug_options.set_xla_gpu_mlir_emitter_level(3); config.set_debug_options(debug_options); return config; } @@ -358,251 +358,251 @@ TEST_F(GpuKernelTilingTest, ColumnReductionWithLayoutChangeTiled) { EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); } -TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { - const char *const kHloString = R"( - HloModule reduce_with_layout_change - reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) - } - - ENTRY kernel_entry { - arg0 = f32[8,6,64]{2,1,0} parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2}, - to_apply=reduction0 - })"; - - // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} -; CHECK: call SHUFFLE -; CHECK: } -)"; - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); -} - -TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { - const char *const kHloString = R"( - HloModule reduce_with_layout_change - reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) - } - - ENTRY kernel_entry { - arg0 = f32[10000,16]{1,0} parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1}, - to_apply=reduction0 - })"; - - // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and - // a write condition based on the logical thread ID (two writes per warp). - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} -; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() -; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15 -; CHECK: call SHUFFLE -; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 -; CHECK: LCAL -; CHECK: EXTV -; CHECK: BR_CAL -)"; - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); -} - -TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) { - const char *const kHloString = R"( - HloModule reduce_with_layout_change - reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) - } - - ENTRY kernel_entry { - arg0 = f32[10000,8]{1,0} parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1}, - to_apply=reduction0 - })"; - - // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and - // a write condition based on the logical thread ID (four writes per warp). - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} -; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() -; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7 -; CHECK: call SHUFFLE -; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 -)"; - - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); -} - -TEST_F(GpuKernelTilingTest, - ColumnReductionResultTwoPartsWithLayoutChangeTiled) { - const char *const kHloString = R"( - HloModule reduce_with_no_layout_change - reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) - } - - ENTRY kernel_entry { - arg0 = f32[8,64,32]{2,1,0} parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1}, - to_apply=reduction0 - })"; - - // Check that the kernel is tiled by looking for llvm.nvvm.atomic. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - const char *expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} -; CHECK: store float %{{.*}}, ptr addrspace(1) -; CHECK: } -)"; - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); -} - -TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { - const char *const kHloString = R"( - HloModule Test - - scalar_add_computation.1 { - scalar_lhs.1 = f32[] parameter(0) - scalar_rhs.1 = f32[] parameter(1) - ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1) - } - ENTRY Test { - param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) - constant_661 = f16[] constant(0) - broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={} - compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT - param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) - select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695) - convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40) - param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) - copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809) - convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335) - param_0.668 = f32[2]{0} parameter(0) - broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1} - subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687) - multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136) - constant_485 = f32[] constant(0) - reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 - reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 - ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1) - })"; - - // Check that no loop is generated for reduction. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - const char *expected_ir = R"( -; CHECK-NOT: reduce.0.loop_header -; CHECK: } -)"; - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); -} - -TEST_F(GpuKernelTilingTest, - RowReductionWithSmallNonPowerOfTwoDimensionNotTiled) { - const char *const kHloString = R"( - HloModule reduction - reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) - } - - ENTRY kernel_entry { - arg0 = f32[8,6,15]{2,1,0} parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2}, - to_apply=reduction0 - })"; - - // Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down. - auto hlo_module = - ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) - .value(); - auto expected_ir = R"( -; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} -; CHECK-NOT: call SHUFFLE -; CHECK: } -)"; - CompileAndVerifyIr(std::move(hlo_module), - MakePlatformSpecificLlvm(expected_ir), - /*match_optimized_ir=*/true); - - // Check that the kernel runs correctly. - EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); -} - -TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) { - const char *const kHloString = R"( - HloModule LargeReduction - - Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) - } - - ENTRY reduce.1 { - parameter = f32[3048576000] parameter(0) - init_value = f32[] constant(0) - ROOT out = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum - } - )"; - std::unique_ptr hlo_module = - ParseAndReturnVerifiedModule(kHloString).value(); - const char *expected_ir = R"( -; CHECK: i64 - )"; - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); -} +// TEST_F(GpuKernelTilingTest, RowReductionWithLayoutChangeTiled) { +// const char *const kHloString = R"( +// HloModule reduce_with_layout_change +// reduction0 { +// x0 = f32[] parameter(0) +// y0 = f32[] parameter(1) +// ROOT add0 = f32[] add(x0, y0) +// } + +// ENTRY kernel_entry { +// arg0 = f32[8,6,64]{2,1,0} parameter(0) +// constant0 = f32[] constant(0) +// ROOT reduce0 = f32[8,6]{0,1} reduce(arg0, constant0), dimensions={2}, +// to_apply=reduction0 +// })"; + +// // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down. +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// auto expected_ir = R"( +// ; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} +// ; CHECK: call SHUFFLE +// ; CHECK: } +// )"; +// CompileAndVerifyIr(std::move(hlo_module), +// MakePlatformSpecificLlvm(expected_ir), +// /*match_optimized_ir=*/true); + +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{0.001})); +// } + +// TEST_F(GpuKernelTilingTest, RowReductionTwoRowsPerWarp) { +// const char *const kHloString = R"( +// HloModule reduce_with_layout_change +// reduction0 { +// x0 = f32[] parameter(0) +// y0 = f32[] parameter(1) +// ROOT add0 = f32[] add(x0, y0) +// } + +// ENTRY kernel_entry { +// arg0 = f32[10000,16]{1,0} parameter(0) +// constant0 = f32[] constant(0) +// ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1}, +// to_apply=reduction0 +// })"; + +// // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and +// // a write condition based on the logical thread ID (two writes per warp). +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// auto expected_ir = R"( +// ; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} +// ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() +// ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 15 +// ; CHECK: call SHUFFLE +// ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 +// ; CHECK: LCAL +// ; CHECK: EXTV +// ; CHECK: BR_CAL +// )"; +// CompileAndVerifyIr(std::move(hlo_module), +// MakePlatformSpecificLlvm(expected_ir), +// /*match_optimized_ir=*/true); + +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +// } + +// TEST_F(GpuKernelTilingTest, RowReductionFourRowsPerWarp) { +// const char *const kHloString = R"( +// HloModule reduce_with_layout_change +// reduction0 { +// x0 = f32[] parameter(0) +// y0 = f32[] parameter(1) +// ROOT add0 = f32[] add(x0, y0) +// } + +// ENTRY kernel_entry { +// arg0 = f32[10000,8]{1,0} parameter(0) +// constant0 = f32[] constant(0) +// ROOT reduce0 = f32[10000]{0} reduce(arg0, constant0), dimensions={1}, +// to_apply=reduction0 +// })"; + +// // Check that the kernel is tiled by looking for llvm.nvvm.shfl.sync.down and +// // a write condition based on the logical thread ID (four writes per warp). +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// auto expected_ir = R"( +// ; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} +// ; CHECK: %[[TID_X:.*]] = tail call i32 TIDX() +// ; CHECK: %[[TID_LOGICAL:.*]] = and i32 %[[TID_X]], 7 +// ; CHECK: call SHUFFLE +// ; CHECK: %[[LOGICAL_T0:.*]] = icmp eq i32 %[[TID_LOGICAL]], 0 +// )"; + +// CompileAndVerifyIr(std::move(hlo_module), +// MakePlatformSpecificLlvm(expected_ir), +// /*match_optimized_ir=*/true); + +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +// } + +// TEST_F(GpuKernelTilingTest, +// ColumnReductionResultTwoPartsWithLayoutChangeTiled) { +// const char *const kHloString = R"( +// HloModule reduce_with_no_layout_change +// reduction0 { +// x0 = f32[] parameter(0) +// y0 = f32[] parameter(1) +// ROOT add0 = f32[] add(x0, y0) +// } + +// ENTRY kernel_entry { +// arg0 = f32[8,64,32]{2,1,0} parameter(0) +// constant0 = f32[] constant(0) +// ROOT reduce0 = f32[8,32]{0,1} reduce(arg0, constant0), dimensions={1}, +// to_apply=reduction0 +// })"; + +// // Check that the kernel is tiled by looking for llvm.nvvm.atomic. +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// const char *expected_ir = R"( +// ; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} +// ; CHECK: store float %{{.*}}, ptr addrspace(1) +// ; CHECK: } +// )"; +// CompileAndVerifyIr(std::move(hlo_module), +// MakePlatformSpecificLlvm(expected_ir), +// /*match_optimized_ir=*/true); + +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +// } + +// TEST_F(GpuKernelTilingTest, ColumnReductionSmallTileSizeX) { +// const char *const kHloString = R"( +// HloModule Test + +// scalar_add_computation.1 { +// scalar_lhs.1 = f32[] parameter(0) +// scalar_rhs.1 = f32[] parameter(1) +// ROOT add.6 = f32[] add(scalar_lhs.1, scalar_rhs.1) +// } +// ENTRY Test { +// param_3.241 = f16[512,2,9,9]{1,3,2,0} parameter(3) +// constant_661 = f16[] constant(0) +// broadcast.695 = f16[512,2,9,9]{1,3,2,0} broadcast(constant_661), dimensions={} +// compare.42 = pred[512,2,9,9]{1,3,2,0} compare(param_3.241, broadcast.695), direction=GT +// param_2.401 = f16[512,2,9,9]{1,3,2,0} parameter(2) +// select.40 = f16[512,2,9,9]{1,3,2,0} select(compare.42, param_2.401, broadcast.695) +// convert.196 = f32[512,2,9,9]{1,3,2,0} convert(select.40) +// param_1.809 = f16[512,2,9,9]{1,3,2,0} parameter(1) +// copy.335 = f16[512,2,9,9]{1,3,2,0} copy(param_1.809) +// convert.218 = f32[512,2,9,9]{1,3,2,0} convert(copy.335) +// param_0.668 = f32[2]{0} parameter(0) +// broadcast.687 = f32[512,2,9,9]{1,3,2,0} broadcast(param_0.668), dimensions={1} +// subtract.136 = f32[512,2,9,9]{1,3,2,0} subtract(convert.218, broadcast.687) +// multiply.579 = f32[512,2,9,9]{1,3,2,0} multiply(convert.196, subtract.136) +// constant_485 = f32[] constant(0) +// reduce.139 = f32[2]{0} reduce(multiply.579, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 +// reduce.140.clone.1 = f32[2]{0} reduce(convert.196, constant_485), dimensions={0,2,3}, to_apply=scalar_add_computation.1 +// ROOT tuple.102 = (f32[2]{0}, f32[2]{0}) tuple(reduce.139, reduce.140.clone.1) +// })"; + +// // Check that no loop is generated for reduction. +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// const char *expected_ir = R"( +// ; CHECK-NOT: reduce.0.loop_header +// ; CHECK: } +// )"; +// CompileAndVerifyIr(std::move(hlo_module), expected_ir, +// /*match_optimized_ir=*/true); +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompare(kHloString, ErrorSpec{1.0e-5, 1.0e-5})); +// } + +// TEST_F(GpuKernelTilingTest, +// RowReductionWithSmallNonPowerOfTwoDimensionNotTiled) { +// const char *const kHloString = R"( +// HloModule reduction +// reduction0 { +// x0 = f32[] parameter(0) +// y0 = f32[] parameter(1) +// ROOT add0 = f32[] add(x0, y0) +// } + +// ENTRY kernel_entry { +// arg0 = f32[8,6,15]{2,1,0} parameter(0) +// constant0 = f32[] constant(0) +// ROOT reduce0 = f32[8,6]{1,0} reduce(arg0, constant0), dimensions={2}, +// to_apply=reduction0 +// })"; + +// // Check that the kernel is not tiled by looking for llvm.nvvm.shfl.sync.down. +// auto hlo_module = +// ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment()) +// .value(); +// auto expected_ir = R"( +// ; CHECK-LABEL: define KERNEL_ANNOTATION @{{(wrapped_reduce|.*fusion)}} +// ; CHECK-NOT: call SHUFFLE +// ; CHECK: } +// )"; +// CompileAndVerifyIr(std::move(hlo_module), +// MakePlatformSpecificLlvm(expected_ir), +// /*match_optimized_ir=*/true); + +// // Check that the kernel runs correctly. +// EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); +// } + +// TEST_F(GpuKernelTilingTest, RowReductionRequiring64BitIndex) { +// const char *const kHloString = R"( +// HloModule LargeReduction + +// Sum { +// x.1 = f32[] parameter(0) +// y.1 = f32[] parameter(1) +// ROOT add.1 = f32[] add(x.1, y.1) +// } + +// ENTRY reduce.1 { +// parameter = f32[3048576000] parameter(0) +// init_value = f32[] constant(0) +// ROOT out = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum +// } +// )"; +// std::unique_ptr hlo_module = +// ParseAndReturnVerifiedModule(kHloString).value(); +// const char *expected_ir = R"( +// ; CHECK: i64 +// )"; +// CompileAndVerifyIr(std::move(hlo_module), expected_ir, +// /*match_optimized_ir=*/true); +// } TEST_F(GpuKernelTilingTest, Hlo021CopyNoOobAccess) { const char *const kHloString = R"( @@ -629,35 +629,35 @@ ENTRY %primitive_computation_svd.38 (constant_5: f32[841,3], fusion.3: pred[3]) EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0.001})); } -TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { - const char *const kHloString = R"( - HloModule RowReduce - - Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) - } - - ENTRY reduce.1 { - parameter = f32[1048576] parameter(0) - init_value = f32[] constant(0) - ROOT reduce = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum - } - )"; - auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); - auto &debug_options = hlo_module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); - auto expected_ir = is_built_with_rocm_ ? R"( -; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } -; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison - )" - : R"( -; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [4 x [2 x float]] - )"; - CompileAndVerifyIr(std::move(hlo_module), expected_ir, - /*match_optimized_ir=*/true); -} +// TEST_F(GpuKernelTilingTest, RowReductionCorrectShmemUsage) { +// const char *const kHloString = R"( +// HloModule RowReduce + +// Sum { +// x.1 = f32[] parameter(0) +// y.1 = f32[] parameter(1) +// ROOT add.1 = f32[] add(x.1, y.1) +// } + +// ENTRY reduce.1 { +// parameter = f32[1048576] parameter(0) +// init_value = f32[] constant(0) +// ROOT reduce = f32[] reduce(parameter, init_value), dimensions={0}, to_apply=Sum +// } +// )"; +// auto hlo_module = ParseAndReturnVerifiedModule(kHloString).value(); +// auto &debug_options = hlo_module->mutable_config().mutable_debug_options(); +// debug_options.set_xla_gpu_mlir_emitter_level(3); +// auto expected_ir = is_built_with_rocm_ ? R"( +// ; CHECK: %llvm.amdgcn.kernel.input_reduce_fusion.lds.t = type { [4 x [2 x float]] } +// ; CHECK: @llvm.amdgcn.kernel.input_reduce_fusion.lds = internal addrspace(3) global %llvm.amdgcn.kernel.input_reduce_fusion.lds.t poison +// )" +// : R"( +// ; CHECK: shared_cache = private unnamed_addr addrspace({{[0-9]*}}) global [4 x [2 x float]] +// )"; +// CompileAndVerifyIr(std::move(hlo_module), expected_ir, +// /*match_optimized_ir=*/true); +// } TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { const char *const kHloString = R"( @@ -679,8 +679,8 @@ TEST_F(GpuKernelTilingTest, ReductionInputTooLarge) { absl::Status status = CompileToExecutable(std::move(hlo_module)).status(); EXPECT_THAT(status.message(), ::testing::ContainsRegex( - "Kernel '.*' launch needs more blocks [(]4294967296[)] than " - "allowed by hardware [(]2147483647[)]")); + "Kernel '.*' launch needs more blocks [(]4294967296, 1[)] " + "than allowed by hardware [(]2147483647, 65535[)]")); } } // namespace diff --git a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc index bf346ed724cf8b..0cae854c51e303 100644 --- a/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc +++ b/third_party/xla/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -40,13 +40,11 @@ TEST_F(TooManyBlocksTest, FailsWithInvalidStatus) { HloModule primitive_computation_mul.8 ENTRY primitive_computation_mul.8 { - parameter.1 = f32[4,1048576,1,1]{3,2,1,0} parameter(0) - reshape.3 = f32[4,1048576,1]{2,1,0} reshape(parameter.1) - broadcast.4 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.3), dimensions={0,1,3} - parameter.2 = f32[4,1,1048576,1]{3,2,1,0} parameter(1) - reshape.5 = f32[4,1048576,1]{2,1,0} reshape(parameter.2) - broadcast.6 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.5), dimensions={0,2,3} - ROOT multiply.7 = f32[4,1048576,1048576,1]{3,2,1,0} multiply(broadcast.4, broadcast.6) + parameter.1 = f32[16,1048576] parameter(0) + parameter.2 = f32[16,1048576] parameter(1) + broadcast.3 = f32[16,1048576,1048576,65536] broadcast(parameter.1), dimensions={0,1} + broadcast.4 = f32[16,1048576,1048576,65536] broadcast(parameter.2), dimensions={0,2} + ROOT multiply.5 = f32[16,1048576,1048576,65536] multiply(broadcast.3, broadcast.4) } )"; TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, diff --git a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc index 2e0a975bbfca9d..d9411f4d0b1732 100644 --- a/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc +++ b/third_party/xla/xla/service/gpu/tests/parallel_reduction_test.cc @@ -77,19 +77,19 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { CompileAndVerifyIr(std::move(hlo_module), R"(CHECK: switch {{.*}} label {{.*}} [ CHECK-NEXT: label CHECK-NEXT: ])", /*match_optimized_ir=*/false); - } else { - CompileAndVerifyIr(std::move(hlo_module), - R"(CHECK: reduce-group-0 - CHECK: reduce-group-1 - CHECK-NOT: reduce-group-2)", - /*match_optimized_ir=*/false); - } + // } else { + // CompileAndVerifyIr(std::move(hlo_module), + // R"(CHECK: reduce-group-0 + // CHECK: reduce-group-1 + // CHECK-NOT: reduce-group-2)", + // /*match_optimized_ir=*/false); + // } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } @@ -124,65 +124,65 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { CompileAndVerifyIr(std::move(hlo_module), R"(CHECK: switch {{.*}} label {{.*}} [ CHECK-NEXT: label CHECK-NEXT: ])", /*match_optimized_ir=*/false); - } else { - CompileAndVerifyIr(std::move(hlo_module), - R"(CHECK: reduce-group-0 - CHECK: reduce-group-1 - CHECK-NOT: reduce-group-2)", - /*match_optimized_ir=*/false); - } + // } else { + // CompileAndVerifyIr(std::move(hlo_module), + // R"(CHECK: reduce-group-0 + // CHECK: reduce-group-1 + // CHECK-NOT: reduce-group-2)", + // /*match_optimized_ir=*/false); + // } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ParallelReductionTest, - UnnestedReductionWithLoopReductionDifferentShape) { - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { - GTEST_SKIP() - << "reduction does not occur in real pipelines and is not supported"; - } - const char* hlo = R"( - -HloModule module - -max { - a = f32[] parameter(0) - b = f32[] parameter(1) - ROOT c = f32[] maximum(a, b) -} - -fused_computation { - one_1_clone_1 = f32[] constant(1) - one_b.1.clone.1 = f32[100,200]{1,0} broadcast(one_1_clone_1), dimensions={} - param_1.9 = f16[100,200]{1,0} parameter(1) - c.2.clone.1 = f32[100,200]{1,0} convert(param_1.9) - param_0.6 = f32[] parameter(0) - b.2.clone.1 = f32[100,200]{1,0} broadcast(param_0.6), dimensions={} - d.1.clone.1 = f32[100,200]{1,0} divide(c.2.clone.1, b.2.clone.1) - a.2.clone.1 = f32[100,200]{1,0} add(one_b.1.clone.1, d.1.clone.1) - bitcast.1 = f32[20000]{0} bitcast(a.2.clone.1) - z_1 = f32[] constant(0) - r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max - ROOT tuple = (f32[], f32[100,200]{1,0}) tuple(r.1, a.2.clone.1) -} - -ENTRY computation { - input_scale = f32[] parameter(1) - p = f16[100,200]{1,0} parameter(0) - fusion = (f32[], f32[100,200]{1,0}) fusion(input_scale, p), kind=kInput, calls=fused_computation - get-tuple-element.1 = f32[100,200]{1,0} get-tuple-element(fusion), index=1 - get-tuple-element = f32[] get-tuple-element(fusion), index=0 - ROOT out = (f32[100,200]{1,0}, f32[]) tuple(get-tuple-element.1, get-tuple-element) -} - - )"; - EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5})); -} +// TEST_F(ParallelReductionTest, +// UnnestedReductionWithLoopReductionDifferentShape) { +// if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { +// GTEST_SKIP() +// << "reduction does not occur in real pipelines and is not supported"; +// } +// const char* hlo = R"( + +// HloModule module + +// max { +// a = f32[] parameter(0) +// b = f32[] parameter(1) +// ROOT c = f32[] maximum(a, b) +// } + +// fused_computation { +// one_1_clone_1 = f32[] constant(1) +// one_b.1.clone.1 = f32[100,200]{1,0} broadcast(one_1_clone_1), dimensions={} +// param_1.9 = f16[100,200]{1,0} parameter(1) +// c.2.clone.1 = f32[100,200]{1,0} convert(param_1.9) +// param_0.6 = f32[] parameter(0) +// b.2.clone.1 = f32[100,200]{1,0} broadcast(param_0.6), dimensions={} +// d.1.clone.1 = f32[100,200]{1,0} divide(c.2.clone.1, b.2.clone.1) +// a.2.clone.1 = f32[100,200]{1,0} add(one_b.1.clone.1, d.1.clone.1) +// bitcast.1 = f32[20000]{0} bitcast(a.2.clone.1) +// z_1 = f32[] constant(0) +// r.1 = f32[] reduce(bitcast.1, z_1), dimensions={0}, to_apply=max +// ROOT tuple = (f32[], f32[100,200]{1,0}) tuple(r.1, a.2.clone.1) +// } + +// ENTRY computation { +// input_scale = f32[] parameter(1) +// p = f16[100,200]{1,0} parameter(0) +// fusion = (f32[], f32[100,200]{1,0}) fusion(input_scale, p), kind=kInput, calls=fused_computation +// get-tuple-element.1 = f32[100,200]{1,0} get-tuple-element(fusion), index=1 +// get-tuple-element = f32[] get-tuple-element(fusion), index=0 +// ROOT out = (f32[100,200]{1,0}, f32[]) tuple(get-tuple-element.1, get-tuple-element) +// } + +// )"; +// EXPECT_TRUE(RunAndCompare(hlo, ErrorSpec{1e-5, 1e-5})); +// } TEST_F(ParallelReductionTest, UnnestedReductionWithLoopReduction) { const char* hlo_text = R"( @@ -366,19 +366,19 @@ ENTRY %cluster { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr hlo_module, ParseAndReturnVerifiedModule(hlo_text)); - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { CompileAndVerifyIr(std::move(hlo_module), R"(CHECK: switch {{.*}} label {{.*}} [ CHECK-NEXT: label CHECK-NEXT: ])", /*match_optimized_ir=*/false); - } else { - CompileAndVerifyIr(std::move(hlo_module), - R"(CHECK: reduce-group-0 - CHECK: reduce-group-1 - CHECK-NOT: reduce-group-2)", - /*match_optimized_ir=*/false); - } + // } else { + // CompileAndVerifyIr(std::move(hlo_module), + // R"(CHECK: reduce-group-0 + // CHECK: reduce-group-1 + // CHECK-NOT: reduce-group-2)", + // /*match_optimized_ir=*/false); + // } EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } diff --git a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo b/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo deleted file mode 100644 index 474fd86488d3b6..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_atomic_min.hlo +++ /dev/null @@ -1,443 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -// Check that for "min" we are still using atomics (CAS loop). - -HloModule MinReduce, is_scheduled=true - -Min { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT min.1 = f32[] minimum(x.1, y.1) -} - -fused_computation { - param_0 = f32[300000]{0} parameter(0) - param_1 = f32[] parameter(1) - ROOT reduce.1 = f32[] reduce(f32[300000]{0} param_0, f32[] param_1), dimensions={0}, to_apply=Min -} - -ENTRY reduce.1 { - parameter = f32[300000] parameter(0) - init_value = f32[] constant(0) - ROOT wrapped_reduce = f32[] fusion(f32[300000]{0} parameter, f32[] init_value), kind=kInput, calls=fused_computation -} - -// CHECK-LABEL: entry: -// CHECK-PTX: %[[VAL_0:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %[[VAL_0:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_1:.*]] = zext i32 %[[VAL_0]] to i64 -// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 -// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_3:.*]] = zext i32 %[[VAL_2]] to i64 -// CHECK: %[[VAL_4:.*]] = mul nuw nsw i64 %[[VAL_1]], 1 -// CHECK: %[[VAL_5:.*]] = add nuw nsw i64 %[[VAL_4]], %[[VAL_3]] -// CHECK: %[[VAL_6:.*]] = icmp ult i64 %[[VAL_5]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_6]]) -// CHECK: %[[VAL_7:.*]] = add nuw nsw i64 %[[VAL_5]], 0 -// CHECK: %[[VAL_8:.*]] = icmp ult i64 %[[VAL_5]], 1 -// CHECK: br i1 %[[VAL_8]], label %[[VAL_9:.*]], label %[[VAL_10:.*]] -// CHECK: wrapped_reduce.in_bounds-after: ; preds = %[[VAL_9]], %[[VAL_11:.*]] -// CHECK: ret void -// CHECK: wrapped_reduce.in_bounds-true: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_12:.*]] = load float, ptr %[[VAL_13:.*]], align 4, !invariant.load -// CHECK: store float %[[VAL_12]], ptr %[[VAL_14:.*]], align 4 -// CHECK: br label %[[VAL_10]] -// CHECK: entry: -// CHECK: %[[VAL_15:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_16:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_17:.*]] = alloca float, align 4 -// CHECK: %[[VAL_18:.*]] = alloca float, align 4 -// CHECK: %[[VAL_19:.*]] = alloca float, align 4 -// CHECK: %[[VAL_20:.*]] = alloca float, align 4 -// CHECK: %[[VAL_21:.*]] = alloca float, align 4 -// CHECK: %[[VAL_22:.*]] = alloca float, align 4 -// CHECK: %[[VAL_23:.*]] = alloca float, align 4 -// CHECK: %[[VAL_24:.*]] = alloca float, align 4 -// CHECK: %[[VAL_25:.*]] = alloca float, align 4 -// CHECK: %[[VAL_26:.*]] = alloca float, align 4 -// CHECK: %[[VAL_27:.*]] = alloca float, align 4 -// CHECK: %[[VAL_28:.*]] = alloca float, align 4 -// CHECK: %[[VAL_29:.*]] = alloca float, align 4 -// CHECK: %[[VAL_30:.*]] = alloca float, align 4 -// CHECK: %[[VAL_31:.*]] = alloca float, align 4 -// CHECK: %[[VAL_32:.*]] = alloca float, align 4 -// CHECK: %[[VAL_33:.*]] = alloca float, align 4 -// CHECK: %[[VAL_34:.*]] = alloca float, align 4 -// CHECK: %[[VAL_35:.*]] = alloca float, align 4 -// CHECK: %[[VAL_36:.*]] = alloca float, align 4 -// CHECK: %[[VAL_37:.*]] = alloca float, align 4 -// CHECK: %[[VAL_38:.*]] = alloca float, align 4 -// CHECK: %[[LOOP3_I_2:loop[23].invar_address.*]] = alloca i32, align 4 -// CHECK-GCN: %[[VAL_42:return_buffer.*]] = alloca float, align 4 -// CHECK: %[[LOOP2_I_2:loop2.invar_address.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_42:return_buffer.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_40:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_43:.*]] = alloca i32, align 4 -// CHECK: %partial_reduction_result = alloca float, align 4 -// CHECK: %reduction_input_address = alloca float, align 4 -// CHECK-PTX: %[[VAL_47:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !4 -// CHECK-GCN: %[[VAL_47:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0 -// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_51:.*]], %[[VAL_52:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_52]] -// CHECK: %[[VAL_53:.*]] = load float, ptr %[[VAL_54:.*]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_53]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !6 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !7 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 1024 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK-PTX: %[[VAL_63:.*]] = udiv i32 %block.id.x, 1 -// CHECK-PTX: %[[VECTOR_OFFSET:.*]] = urem i32 %[[VAL_63]], 1 -// CHECK: %[[VAL_63_2:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_64:.*]] = urem i32 %[[VAL_63_2]], 19 -// CHECK: %[[VAL_65:.*]] = udiv i32 %block.id.x, 19 -// CHECK: %[[VAL_66:.*]] = urem i32 %[[VAL_65]], 1 -// CHECK: %[[VAL_67:.*]] = udiv i32 %block.id.x, 19 -// CHECK: %[[VAL_68:.*]] = icmp eq i32 %[[VAL_64]], 18 -// CHECK-PTX: %tile_bound.2 = select i1 %[[VAL_68]], i32 2544, i32 8192 -// CHECK-GCN: %tile_bound.2 = select i1 %[[VAL_68]], i32 5088, i32 16384 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_67]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_66]], 1 -// CHECK-PTX: %tile_origin.2 = mul i32 %[[VAL_64]], 8192 -// CHECK-GCN: %tile_origin.2 = mul i32 %[[VAL_64]], 16384 -// CHECK-PTX: %tile_origin.3 = mul i32 %[[VECTOR_OFFSET]], 2 -// CHECK-PTX: %[[VAL_81:.*]] = icmp eq i32 8192, %tile_bound.2 -// CHECK-GCN: %[[VAL_81:.*]] = icmp eq i32 16384, %tile_bound.2 -// CHECK: br i1 %[[VAL_81]], label %[[VAL_82:.*]], label %[[VAL_83:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_84:.*]], %[[VAL_85:.*]] -// CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_87_1:.*]] = bitcast float %[[VAL_86]] to i32 -// CHECK-GCN: %[[VAL_87_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_87_1]], i32 543) -// CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_2]] to float -// CHECK: store float %[[VAL_87]], ptr{{.*}} %[[VAL_37]], align 4 -// CHECK-GCN: %[[VAL_88_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_88_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_88_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_36]] to ptr -// CHECK-PTX: call void @[[MIN:Min.*]](ptr %partial_reduction_result, ptr %[[VAL_37]], ptr %[[VAL_36]]) -// CHECK-GCN: call void @[[MIN:Min.*]](ptr %[[VAL_88_1]], ptr %[[VAL_88_2]], ptr %[[VAL_88_3]]) -// CHECK: %[[VAL_88:.*]] = load float, ptr{{.*}} %[[VAL_36]], align 4 -// CHECK: store float %[[VAL_88]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_89:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %[[VAL_90:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_89]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_90_1:.*]] = bitcast float %[[VAL_89]] to i32 -// CHECK-GCN: %[[VAL_90_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_90_1]], i32 287) -// CHECK-GCN: %[[VAL_90:.*]] = bitcast i32 %[[VAL_90_2]] to float -// CHECK: store float %[[VAL_90]], ptr{{.*}} %[[VAL_35]], align 4 -// CHECK-GCN: %[[VAL_91_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_91_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_35]] to ptr -// CHECK-GCN: %[[VAL_91_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_34]] to ptr -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_35]], ptr %[[VAL_34]]) -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_91_1]], ptr %[[VAL_91_2]], ptr %[[VAL_91_3]]) -// CHECK: %[[VAL_91:.*]] = load float, ptr{{.*}} %[[VAL_34]], align 4 -// CHECK: store float %[[VAL_91]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_93_1:.*]] = bitcast float %[[VAL_92]] to i32 -// CHECK-GCN: %[[VAL_93_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_93_1]], i32 159) -// CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_2]] to float -// CHECK: store float %[[VAL_93]], ptr{{.*}} %[[VAL_33]], align 4 -// CHECK-GCN: %[[VAL_94_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_94_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_33]] to ptr -// CHECK-GCN: %[[VAL_94_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_32]] to ptr -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_33]], ptr %[[VAL_32]]) -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_94_1]], ptr %[[VAL_94_2]], ptr %[[VAL_94_3]]) -// CHECK: %[[VAL_94:.*]] = load float, ptr{{.*}} %[[VAL_32]], align 4 -// CHECK: store float %[[VAL_94]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_95:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %[[VAL_96:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_95]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_96_1:.*]] = bitcast float %[[VAL_95]] to i32 -// CHECK-GCN: %[[VAL_96_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_96_1]], i32 95) -// CHECK-GCN: %[[VAL_96:.*]] = bitcast i32 %[[VAL_96_2]] to float -// CHECK: store float %[[VAL_96]], ptr{{.*}} %[[VAL_31]], align 4 -// CHECK-GCN: %[[VAL_97_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_97_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_31]] to ptr -// CHECK-GCN: %[[VAL_97_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_30]] to ptr -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_31]], ptr %[[VAL_30]]) -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_97_1]], ptr %[[VAL_97_2]], ptr %[[VAL_97_3]]) -// CHECK: %[[VAL_97:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4 -// CHECK: store float %[[VAL_97]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_98:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-PTX: %[[VAL_99:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_98]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_99_1:.*]] = bitcast float %[[VAL_98]] to i32 -// CHECK-GCN: %[[VAL_99_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_99_1]], i32 63) -// CHECK-GCN: %[[VAL_99:.*]] = bitcast i32 %[[VAL_99_2]] to float -// CHECK: store float %[[VAL_99]], ptr{{.*}} %[[VAL_29]], align 4 -// CHECK-GCN: %[[VAL_100_1:.*]] = addrspacecast ptr{{.*}} %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_100_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_29]] to ptr -// CHECK-GCN: %[[VAL_100_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %[[VAL_29]], ptr %[[VAL_28]]) -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_100_1]], ptr %[[VAL_100_2]], ptr %[[VAL_100_3]]) -// CHECK: %[[VAL_100:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: store float %[[VAL_100]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_101:.*]] = udiv i32 %thread.id.2, 32 -// CHECK: br i1 true, label %[[VAL_105:.*]], label %[[VAL_51]] - -// CHECK: thread_in_bounds-after: -// CHECK: br label %[[VAL_50]] -// CHECK: is_full_tile-true: -// CHECK-PTX: store i32 0, ptr{{.*}} %[[VAL_43]], align 4 -// CHECK-GCN: store i32 0, ptr{{.*}} %[[LOOP2_I_2]], align 4 -// CHECK: br label %[[VAL_107:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_108:.*]], %[[VAL_82]] -// CHECK-PTX: %[[VAL_109:.*]] = load i32, ptr %[[VAL_43]], align 4 -// CHECK-GCN: %[[VAL_109:.*]] = load i32, ptr{{.*}} %[[LOOP2_I_2]], align 4 -// CHECK: %[[VAL_110:.*]] = icmp uge i32 %[[VAL_109]], -// CHECK: br i1 %[[VAL_110]], label %loop2.loop_exit, label %loop2.loop_body - -// CHECK: loop2.loop_body: ; preds = %[[VAL_107]] -// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_109]], 1 -// CHECK-PTX: store i32 %[[VAL_111]], ptr %[[VAL_43]], align 4 -// CHECK-GCN: store i32 %[[VAL_111]], ptr{{.*}} %[[LOOP2_I_2]], align 4 -// CHECK: %[[OFFSET_2:.*]] = add i32 %loop2.indvar, %thread.id.2 -// CHECK-GCN: %[[START_0:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[START_1:.*]] = add i32 %tile_origin.1, 0 -// CHECK-GCN: %[[START_2:.*]] = add i32 %tile_origin.2, %[[OFFSET_2]] -// CHECK-GCN: %[[VAL_119:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120:.*]], i32 0, i32 %[[START_2]] -// CHECK-GCN: %[[VAL_121:.*]] = load float, ptr %[[VAL_119]], align 4, !invariant.load !3 -// CHECK-GCN: store float %[[VAL_121]], ptr{{.*}} %reduction_input_address, align 4 -// CHECK-GCN: %[[VAL_123_1:.*]] = addrspacecast ptr addrspace(5) %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_123_2:.*]] = addrspacecast ptr addrspace(5) %reduction_input_address to ptr -// CHECK-GCN: %[[VAL_123_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_42]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_123_1]], ptr %[[VAL_123_2]], ptr %[[VAL_123_3]]) -// CHECK-GCN: %[[VAL_123:.*]] = load float, ptr{{.*}} %[[VAL_42]], align 4 -// CHECK-GCN: store float %[[VAL_123]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-GCN: br label %loop2.loop_header -// CHECK-PTX: store i32 0, ptr %loop3.invar_address, align 4 -// CHECK-PTX: br label %loop3.loop_header - -// CHECK-PTX: loop3.loop_header: -// CHECK-PTX: %loop3.indvar = load i32, ptr %loop3.invar_address, align 4 -// CHECK-PTX: %[[LOOP3_OOB:.*]] = icmp uge i32 %loop3.indvar, 2 -// CHECK-PTX: br i1 %[[LOOP3_OOB]], label %loop3.loop_exit, label %loop3.loop_body -// CHECK-PTX: loop3.loop_body: -// CHECK-PTX: %[[LOOP3_INC:.*]] = add nuw nsw i32 %loop3.indvar, 1 -// CHECK-PTX: store i32 %[[LOOP3_INC]], ptr %loop3.invar_address, align 4 -// CHECK-PTX: %[[START_0:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[START_1:.*]] = add i32 %tile_origin.1, 0 -// CHECK-PTX: %[[START_2:.*]] = add i32 %tile_origin.2, %[[OFFSET_2]] -// CHECK-PTX: %[[START_3:.*]] = add i32 %tile_origin.3, %loop3.indvar -// CHECK-PTX: %[[VAL_113:.*]] = mul nuw nsw i32 %[[START_3]], 1 -// CHECK-PTX: %[[VAL_114:.*]] = add nuw nsw i32 0, %[[VAL_113]] -// CHECK-PTX: %[[VAL_115:.*]] = mul nuw nsw i32 %[[START_2]], 2 -// CHECK-PTX: %[[VAL_116:.*]] = add nuw nsw i32 %[[VAL_114]], %[[VAL_115]] -// CHECK-PTX: %[[VAL_119:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120:.*]], i32 0, i32 %[[VAL_116]] -// CHECK-PTX: %[[VAL_121:.*]] = load float, ptr %[[VAL_119]], align 4, !invariant.load !5 -// CHECK-PTX: store float %[[VAL_121]], ptr %reduction_input_address, align 4 -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_42]]) -// CHECK-PTX: %[[VAL_123:.*]] = load float, ptr %[[VAL_42]], align 4 -// CHECK-PTX: store float %[[VAL_123]], ptr %partial_reduction_result, align 4 -// CHECK-PTX: br label %loop3.loop_header -// CHECK-PTX: loop3.loop_exit: -// CHECK-PTX: br label %loop2.loop_header - -// CHECK: loop2.loop_exit: -// CHECK: br label %is_full_tile-after - -// CHECK: is_full_tile-false: -// CHECK-PTX: store i32 0, ptr %[[LOOP2_I_2]], align 4 -// CHECK-GCN: store i32 0, ptr{{.*}} %[[LOOP3_I_2]], align 4 -// CHECK: br label %[[VAL_134:.*]] - -// CHECK: loop2.loop_header{{(4|3)}}: -// CHECK-PTX: %[[VAL_136:.*]] = load i32, ptr %[[LOOP2_I_2]], align 4 -// CHECK-GCN: %[[VAL_136:.*]] = load i32, ptr{{.*}} %[[LOOP3_I_2]], align 4 -// CHECK: %[[VAL_137:.*]] = icmp uge i32 %[[VAL_136]], {{(8|16384)}} -// CHECK: br i1 %[[VAL_137]], label %[[VAL_84]], label %[[VAL_138:.*]] - -// CHECK: loop2.loop_body{{(5|4)}}: -// CHECK: %[[VAL_139:.*]] = add nuw nsw i32 %[[VAL_136]], 1 -// CHECK-PTX: store i32 %[[VAL_139]], ptr %[[LOOP2_I_2]], align 4 -// CHECK-GCN: store i32 %[[VAL_139]], ptr{{.*}} %[[LOOP3_I_2]], align 4 -// CHECK: %[[VAL_141:.*]] = add i32 %[[VAL_136]], %thread.id.2 -// CHECK: %[[VAL_144:.*]] = icmp ult i32 %[[VAL_141]], %tile_bound.2 -// CHECK: br i1 %[[VAL_144]], label %x_in_tile-true, label %x_in_tile-after - -// CHECK: x_in_tile-after: -// CHECK: br label %loop2.loop_header{{(4|3)}} - -// CHECK: loop2.loop_exit{{(3|2)}}: -// CHECK: br label %is_full_tile-after - -// CHECK: x_in_tile-true: ; preds = %[[VAL_138]] -// CHECK-GCN: %[[IDX0:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[IDX1:.*]] = add i32 %tile_origin.1, 0 -// CHECK-GCN: %[[IDX2:.*]] = add i32 %tile_origin.2, %[[VAL_141]] -// CHECK-GCN: %[[VAL_155:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120]], i32 0, i32 %[[IDX2]] -// CHECK-GCN: %[[VAL_156:.*]] = load float, ptr %[[VAL_155]], align 4, !invariant.load !3 -// CHECK-GCN: store float %[[VAL_156]], ptr{{.*}} %reduction_input_address, align 4 -// CHECK-GCN: %[[VAL_158_1:.*]] = addrspacecast ptr addrspace(5) %partial_reduction_result to ptr -// CHECK-GCN: %[[VAL_158_2:.*]] = addrspacecast ptr addrspace(5) %reduction_input_address to ptr -// CHECK-GCN: %[[VAL_158_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_38]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_158_1]], ptr %[[VAL_158_2]], ptr %[[VAL_158_3]]) -// CHECK-GCN: %[[VAL_158:.*]] = load float, ptr{{.*}} %[[VAL_38]], align 4 -// CHECK-GCN: store float %[[VAL_158]], ptr{{.*}} %partial_reduction_result, align 4 -// CHECK-GCN: br label %x_in_tile-after -// CHECK-PTX: store i32 0, ptr %[[LOOP3_I_2]], align 4 -// CHECK-PTX: br label %loop3.loop_header10 - -// CHECK-PTX: loop3.loop_header10: -// CHECK-PTX: %[[VAL_145:.*]] = load i32, ptr %[[LOOP3_I_2]], align 4 -// CHECK-PTX: %[[VAL_146:.*]] = icmp uge i32 %[[VAL_145]], 2 -// CHECK-PTX: br i1 %[[VAL_146]], label %loop3.loop_exit9, label %loop3.loop_body11 - -// CHECK-PTX: loop3.loop_body11: -// CHECK-PTX: %[[VAL_147:.*]] = add nuw nsw i32 %[[VAL_145]], 1 -// CHECK-PTX: store i32 %[[VAL_147]], ptr %[[LOOP3_I_2]], align 4 -// CHECK-PTX: %[[IDX0:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[IDX1:.*]] = add i32 %tile_origin.1, 0 -// CHECK-PTX: %[[IDX2:.*]] = add i32 %tile_origin.2, %[[VAL_141]] -// CHECK-PTX: %[[IDX3:.*]] = add i32 %tile_origin.3, %[[VAL_145]] -// CHECK-PTX: %[[VAL_148:.*]] = mul nuw nsw i32 %[[IDX3]], 1 -// CHECK-PTX: %[[VAL_149:.*]] = add nuw nsw i32 0, %[[VAL_148]] -// CHECK-PTX: %[[VAL_150:.*]] = mul nuw nsw i32 %[[IDX2]], 2 -// CHECK-PTX: %[[VAL_151:.*]] = add nuw nsw i32 %[[VAL_149]], %[[VAL_150]] -// CHECK-PTX: %[[VAL_155:.*]] = getelementptr inbounds [300000 x float], ptr %[[VAL_120]], i32 0, i32 %[[VAL_151]] -// CHECK-PTX: %[[VAL_156:.*]] = load float, ptr %[[VAL_155]], align 4, !invariant.load !5 -// CHECK-PTX: store float %[[VAL_156]], ptr %reduction_input_address, align 4 -// CHECK-PTX: call void @[[MIN]](ptr %partial_reduction_result, ptr %reduction_input_address, ptr %[[VAL_38]]) -// CHECK-PTX: %[[VAL_158:.*]] = load float, ptr %[[VAL_38]], align 4 -// CHECK-PTX: store float %[[VAL_158]], ptr %partial_reduction_result, align 4 -// CHECK-PTX: br label %loop3.loop_header10 - -// CHECK-PTX: loop3.loop_exit9: -// CHECK-PTX: br label %x_in_tile-after - -// CHECK: thread_in_bounds-true: -// CHECK: %[[VAL_166:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: br i1 %[[VAL_166]], label %[[VAL_167:.*]], label %[[VAL_168:.*]] - -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_167]], %[[VAL_105]] -// CHECK-GCM: fence syncscope("workgroup") seq_cst -// CHECK-GCM: call void @llvm.amdgcn.s.barrier() -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_169:.*]] = icmp eq i32 %[[VAL_101]], 0 -// CHECK: br i1 %[[VAL_169]], label %inter_warp_reduce-true, label %inter_warp_reduce-after -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_171:.*]], %[[VAL_168]] -// CHECK: br label %[[VAL_51]] -// CHECK: intra_warp_reduce_write-true: ; preds = %[[VAL_105]] -// CHECK: %[[VAL_172:.*]] = load float, ptr{{.*}} %partial_reduction_result, align 4 -// CHECK: %[[VAL_173:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_101]] -// CHECK: %[[VAL_174:.*]] = addrspacecast ptr addrspace(3) %[[VAL_173]] to ptr -// CHECK: store float %[[VAL_172]], ptr %[[VAL_174]], align 4 -// CHECK: br label %[[VAL_168]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_168]] -// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [1 x [32 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id -// CHECK: %[[VAL_176:.*]] = addrspacecast ptr addrspace(3) %[[VAL_175]] to ptr -// CHECK-GCN: %[[VAL_176_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_27]] to ptr -// CHECK-GCN: store float %[[VAL_53]], ptr{{.*}} %[[VAL_176_1]], align 4 -// CHECK-PTX: store float %[[VAL_53]], ptr %[[VAL_27]], align 4 -// CHECK: %[[VAL_177:.*]] = icmp ult i32 %thread.id.2, 32 -// CHECK-GCN: %[[VAL_178:.*]] = select i1 %[[VAL_177]], ptr %[[VAL_176]], ptr %[[VAL_176_1]] -// CHECK-PTX: %[[VAL_178:.*]] = select i1 %[[VAL_177]], ptr %[[VAL_176]], ptr %[[VAL_27]] -// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK-GCN: %[[VAL_179_1:.*]] = bitcast float %[[VAL_179]] to i32 -// CHECK-GCN: %[[VAL_180:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_179_1]], i32 543) -// CHECK-GCN: %[[VAL_180_1:.*]] = bitcast i32 %[[VAL_180]] to float -// CHECK-GCN: store float %[[VAL_180_1]], ptr{{.*}} %[[VAL_26]], align 4 -// CHECK-PTX: %[[VAL_180:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_179]], i32 16, i32 31) -// CHECK-PTX: store float %[[VAL_180]], ptr %[[VAL_26]], align 4 -// CHECK-GCN: %[[VAL_181_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_26]] to ptr -// CHECK-GCN: %[[VAL_181_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_25]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_181_2]], ptr %[[VAL_181_3]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_26]], ptr %[[VAL_25]]) -// CHECK: %[[VAL_181:.*]] = load float, ptr{{.*}} %[[VAL_25]], align 4 -// CHECK: store float %[[VAL_181]], ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_182:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK-GCN: %[[VAL_182_1:.*]] = bitcast float %[[VAL_182]] to i32 -// CHECK-GCN: %[[VAL_183:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_182_1]], i32 287) -// CHECK-GCN: %[[VAL_183_1:.*]] = bitcast i32 %[[VAL_183]] to float -// CHECK-GCN: store float %[[VAL_183_1]], ptr{{.*}} %[[VAL_24]], align 4 -// CHECK-PTX: %[[VAL_183:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_182]], i32 8, i32 31) -// CHECK-PTX: store float %[[VAL_183]], ptr %[[VAL_24]], align 4 -// CHECK-GCN: %[[VAL_184_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_24]] to ptr -// CHECK-GCN: %[[VAL_184_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_23]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_184_2]], ptr %[[VAL_184_3]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_24]], ptr %[[VAL_23]]) -// CHECK: %[[VAL_184:.*]] = load float, ptr{{.*}} %[[VAL_23]], align 4 -// CHECK: store float %[[VAL_184]], ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_185:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK-GCN: %[[VAL_185_1:.*]] = bitcast float %[[VAL_185]] to i32 -// CHECK-GCN: %[[VAL_186:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_185_1]], i32 159) -// CHECK-GCN: %[[VAL_186_1:.*]] = bitcast i32 %[[VAL_186]] to float -// CHECK-GCN: store float %[[VAL_186_1]], ptr{{.*}} %[[VAL_22]], align 4 -// CHECK-PTX: %[[VAL_186:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_185]], i32 4, i32 31) -// CHECK-PTX: store float %[[VAL_186]], ptr %[[VAL_22]], align 4 -// CHECK-GCN: %[[VAL_187_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_22]] to ptr -// CHECK-GCN: %[[VAL_187_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_21]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_187_2]], ptr %[[VAL_187_3]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_22]], ptr %[[VAL_21]]) -// CHECK: %[[VAL_187:.*]] = load float, ptr{{.*}} %[[VAL_21]], align 4 -// CHECK: store float %[[VAL_187]], ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_188:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK-GCN: %[[VAL_188_1:.*]] = bitcast float %[[VAL_188]] to i32 -// CHECK-GCN: %[[VAL_189:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_188_1]], i32 95) -// CHECK-GCN: %[[VAL_189_1:.*]] = bitcast i32 %[[VAL_189]] to float -// CHECK-GCN: store float %[[VAL_189_1]], ptr{{.*}} %[[VAL_20]], align 4 -// CHECK-PTX: %[[VAL_189:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_188]], i32 2, i32 31) -// CHECK-PTX: store float %[[VAL_189]], ptr %[[VAL_20]], align 4 -// CHECK-GCN: %[[VAL_190_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_20]] to ptr -// CHECK-GCN: %[[VAL_190_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_19]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_190_2]], ptr %[[VAL_190_3]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK: %[[VAL_190:.*]] = load float, ptr{{.*}} %[[VAL_19]], align 4 -// CHECK: store float %[[VAL_190]], ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_191:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK-GCN: %[[VAL_191_1:.*]] = bitcast float %[[VAL_191]] to i32 -// CHECK-GCN: %[[VAL_192:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_191_1]], i32 63) -// CHECK-GCN: %[[VAL_192_1:.*]] = bitcast i32 %[[VAL_192]] to float -// CHECK-GCN: store float %[[VAL_192_1]], ptr{{.*}} %[[VAL_18]], align 4 -// CHECK-PTX: %[[VAL_192:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_191]], i32 1, i32 31) -// CHECK-PTX: store float %[[VAL_192]], ptr %[[VAL_18]], align 4 -// CHECK-GCN: %[[VAL_193_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_18]] to ptr -// CHECK-GCN: %[[VAL_193_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_17]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_193_2]], ptr %[[VAL_193_3]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_178]], ptr %[[VAL_18]], ptr %[[VAL_17]]) -// CHECK: %[[VAL_193:.*]] = load float, ptr{{.*}} %[[VAL_17]], align 4 -// CHECK: store float %[[VAL_193]], ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_194:.*]] = icmp eq i32 %thread.id.2, 0 -// CHECK: br i1 %[[VAL_194]], label %[[VAL_195:.*]], label %[[VAL_171]] -// CHECK: reduction_write_output-after: -// CHECK: br label %inter_warp_reduce-after -// CHECK: reduction_write_output-true: -// CHECK: %[[VAL_200:.*]] = load float, ptr %[[VAL_178]], align 4 -// CHECK: %[[VAL_201:.*]] = load i32, ptr %[[VAL_202:.*]], align 4 -// CHECK: store i32 %[[VAL_201]], ptr{{.*}} %[[VAL_16]], align 4 -// CHECK: br label %[[VAL_203:.*]] -// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_204:.*]], %[[VAL_203]] -// CHECK: br label %[[VAL_171]] -// CHECK: atomic_op_loop_body: ; preds = %[[VAL_204]], %[[VAL_195]] -// CHECK: %[[VAL_205:.*]] = load i32, ptr{{.*}} %[[VAL_16]], align 4 -// CHECK: store i32 %[[VAL_205]], ptr{{.*}} %[[VAL_15]], align 4 -// CHECK-GCN: %[[VAL_206_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_15]] to ptr -// CHECK-GCN: %[[VAL_206_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_15]] to ptr -// CHECK-GCN: call void @[[MIN]](ptr %[[VAL_206_1]], ptr %[[VAL_178]], ptr %[[VAL_206_2]]) -// CHECK-PTX: call void @[[MIN]](ptr %[[VAL_15]], ptr %[[VAL_178]], ptr %[[VAL_15]]) -// CHECK: %[[VAL_206:.*]] = load i32, ptr{{.*}} %[[VAL_15]], align 4 -// CHECK: %[[VAL_207:.*]] = icmp eq i32 %[[VAL_205]], %[[VAL_206]] -// CHECK: br i1 %[[VAL_207]], label %atomic_op_loop_exit, label %atomic_op_loop_cas -// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_203]] -// CHECK: %[[VAL_208:.*]] = cmpxchg ptr %[[VAL_202]], i32 %[[VAL_205]], i32 %[[VAL_206]]{{.*}} seq_cst seq_cst, align 4 -// CHECK: %[[VAL_209:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 0 -// CHECK: store i32 %[[VAL_209]], ptr{{.*}} %[[VAL_16]], align 4 -// CHECK: %[[VAL_210:.*]] = extractvalue { i32, i1 } %[[VAL_208]], 1 -// CHECK: br i1 %[[VAL_210]], label %atomic_op_loop_exit, label %atomic_op_loop_body -// CHECK: entry: -// CHECK: %[[VAL_211:.*]] = alloca float, align 4 -// CHECK: %[[VAL_212:.*]] = load float, ptr %[[VAL_213:.*]], align 4 -// CHECK: %[[VAL_214:.*]] = load float, ptr %[[VAL_215:.*]], align 4 -// CHECK-PTX: %[[VAL_216:.*]] = call float @llvm.minimum.f32(float %[[VAL_212]], float %[[VAL_214]]) -// CHECK-GCN: %[[VAL_216_1:.*]] = fcmp une float %[[VAL_212]], %[[VAL_212]] -// CHECK-GCN: %[[VAL_216_2:.*]] = fcmp oeq float %[[VAL_214]], %[[VAL_214]] -// CHECK-GCN: %[[VAL_216_3:.*]] = fcmp ole float %[[VAL_212]], %[[VAL_214]] -// CHECK-GCN: %[[VAL_216_4:.*]] = and i1 %[[VAL_216_2]], %[[VAL_216_3]] -// CHECK-GCN: %[[VAL_216_5:.*]] = or i1 %[[VAL_216_1]], %[[VAL_216_4]] -// CHECK-GCN: %[[VAL_216:.*]] = select i1 %[[VAL_216_5]], float %[[VAL_212]], float %[[VAL_214]] -// CHECK: store float %[[VAL_216]], ptr{{.*}} %[[VAL_211]], align 4 -// CHECK: %[[VAL_217:.*]] = load float, ptr{{.*}} %[[VAL_211]], align 4 -// CHECK: store float %[[VAL_217]], ptr %[[VAL_218:.*]], align 4 -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo b/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo deleted file mode 100644 index 50508e5496a535..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_column_layout_change.hlo +++ /dev/null @@ -1,207 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -HloModule reduce_with_layout_change, is_scheduled=true - -reduction0 { - x0 = f32[] parameter(0) - y0 = f32[] parameter(1) - ROOT add0 = f32[] add(x0, y0) -} - -fused_computation { - arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) - constant0 = f32[] constant(0) - ROOT reduce0 = f32[16,32,4,3,12]{1,3,2,0,4} reduce(arg0, constant0), dimensions={0,1,2}, to_apply=reduction0 -} - -ENTRY kernel_entry { - arg0 = f32[12,3,32,16,32,4,3,12] parameter(0) - ROOT fusion = f32[16,32,4,3,12]{1,3,2,0,4} fusion(arg0), kind=kInput, calls=fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca float, align 4 -// CHECK: %[[VAL_1:.*]] = alloca float, align 4 -// CHECK: %[[VAL_2:.*]] = alloca float, align 4 -// CHECK: %[[VAL_3:.*]] = alloca float, align 4 -// CHECK: %[[VAL_4:.*]] = alloca float, align 4 -// CHECK: %[[VAL_5:.*]] = alloca float, align 4 -// CHECK: %[[VAL_6:.*]] = alloca float, align 4 -// CHECK: %[[VAL_7:.*]] = alloca float, align 4 -// CHECK: %[[VAL_8:.*]] = alloca float, align 4 -// CHECK: %[[VAL_9:.*]] = alloca float, align 4 -// CHECK: %[[VAL_10:.*]] = alloca float, align 4 -// CHECK: %[[VAL_11:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_13:.*]] = alloca float, align 4 -// CHECK: %[[VAL_14:.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_16:.*]] = icmp eq i32 %[[VAL_15]], 0 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] -// CHECK: %[[VAL_21:.*]] = load float, ptr @0, align 4 -// CHECK: store float %[[VAL_21]], ptr{{( addrspace\(5\))?}} %[[VAL_13]], align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 2304 -// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 2304 -// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 -// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 2304 -// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 -// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1152, i32 4096 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 -// CHECK: store i32 %thread.id.1, ptr{{( addrspace\(5\))?}} %[[VAL_12]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr{{( addrspace\(5\))?}} %[[VAL_12]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 -// CHECK: store i32 %[[VAL_35]], ptr{{( addrspace\(5\))?}} %[[VAL_12]], align 4 -// CHECK: store i32 0, ptr{{( addrspace\(5\))?}} %[[VAL_11]], align 4 -// CHECK: br label %[[VAL_37:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] -// CHECK: %[[VAL_39:.*]] = load i32, ptr{{( addrspace\(5\))?}} %[[VAL_11]], align 4 -// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 -// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] -// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 -// CHECK: store i32 %[[VAL_42]], ptr{{( addrspace\(5\))?}} %[[VAL_11]], align 4 -// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 -// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 -// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] -// CHECK: br label %[[VAL_37]], !llvm.loop -// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] -// CHECK: br label %[[VAL_29]], !llvm.loop -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_47:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_13]], align 4 -// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 -// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr -// CHECK: store float %[[VAL_47]], ptr %[[VAL_49]], align 4 -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 -// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr -// CHECK: %[[VAL_52:.*]] = load float, ptr %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_53:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_52]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_53_:.*]] = call i32 @llvm.amdgcn.ds.swizzle -// CHECK-GCN: %[[VAL_53:.*]] = bitcast i32 -// CHECK: store float %[[VAL_53]], ptr{{( addrspace\(5\))?}} %[[VAL_9]], align 4 -// CHECK-PTX: call void @[[REDUCTION0:reduction0.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) -// CHECK: %[[VAL_54:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_8]], align 4 -// CHECK: store float %[[VAL_54]], ptr %[[VAL_51]], align 4 -// CHECK: %[[VAL_55:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_56:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_55]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_56_1_:.*]] = call i32 @llvm.amdgcn.ds.swizzle -// CHECK-GCN: %[[VAL_56:.*]] = bitcast i32 -// CHECK: store float %[[VAL_56]], ptr{{( addrspace\(5\))?}} %[[VAL_7]], align 4 -// CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) -// CHECK: %[[VAL_57:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_6]], align 4 -// CHECK: store float %[[VAL_57]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK: %[[VAL_58:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_59:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_58]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_59_:.*]] = call i32 @llvm.amdgcn.ds.swizzle -// CHECK-GCN: %[[VAL_59:.*]] = bitcast i32 -// CHECK: store float %[[VAL_59]], ptr{{( addrspace\(5\))?}} %[[VAL_5]], align 4 -// CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_60:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_4]], align 4 -// CHECK: store float %[[VAL_60]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK: %[[VAL_61:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_62:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_61]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_62_:.*]] = call i32 @llvm.amdgcn.ds.swizzle -// CHECK-GCN: %[[VAL_62:.*]] = bitcast i32 -// CHECK: store float %[[VAL_62]], ptr{{( addrspace\(5\))?}} %[[VAL_3]], align 4 -// CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) -// CHECK: %[[VAL_63:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_2]], align 4 -// CHECK: store float %[[VAL_63]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK: %[[VAL_64:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_65:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_64]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_65_:.*]] = call i32 @llvm.amdgcn.ds.swizzle -// CHECK-GCN: %[[VAL_65:.*]] = bitcast i32 -// CHECK: store float %[[VAL_65]], ptr{{( addrspace\(5\))?}} %[[VAL_1]], align 4 -// CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_66:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_0]], align 4 -// CHECK: store float %[[VAL_66]], ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK-PTX: %[[VAL_67:.*]] = icmp ult i32 %thread.id.1, 32 -// CHECK-PTX: %[[VAL_68:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 -// CHECK-GCN: %[[VAL_68:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 -// CHECK-GCN: %[[VAL_67:.*]] = icmp ult i32 %thread.id.1, 32 -// CHECK: %[[VAL_69:.*]] = and i1 %[[VAL_67]], %[[VAL_68]] -// CHECK: %[[VAL_70:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: %[[VAL_71:.*]] = and i1 %[[VAL_69]], %[[VAL_70]] -// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_19]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_72]], %[[VAL_33]] -// CHECK: br label %[[VAL_18]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] -// CHECK: %[[VAL_73:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_74:.*]] = add i32 %tile_origin.1, %[[VAL_31]] -// CHECK: %[[VAL_75:.*]] = add i32 %tile_origin.2, %[[VAL_44]] -// CHECK: %[[VAL_76:.*]] = mul nuw nsw i32 %[[VAL_75]], 1 -// CHECK: %[[VAL_77:.*]] = add nuw nsw i32 0, %[[VAL_76]] -// CHECK: %[[VAL_78:.*]] = urem i32 %[[VAL_77]], 12 -// CHECK: %[[VAL_79:.*]] = udiv i32 %[[VAL_77]], 12 -// CHECK: %[[VAL_80:.*]] = urem i32 %[[VAL_79]], 3 -// CHECK: %[[VAL_81:.*]] = udiv i32 %[[VAL_79]], 3 -// CHECK: %[[VAL_82:.*]] = urem i32 %[[VAL_81]], 4 -// CHECK: %[[VAL_83:.*]] = udiv i32 %[[VAL_81]], 4 -// CHECK: %[[VAL_84:.*]] = urem i32 %[[VAL_83]], 32 -// CHECK: %[[VAL_85:.*]] = udiv i32 %[[VAL_83]], 32 -// CHECK: %[[VAL_86:.*]] = udiv i32 %[[VAL_85]], 16 -// CHECK: %[[VAL_87:.*]] = mul nuw nsw i32 %[[VAL_74]], 1 -// CHECK: %[[VAL_88:.*]] = add nuw nsw i32 0, %[[VAL_87]] -// CHECK: %[[VAL_89:.*]] = urem i32 %[[VAL_88]], 32 -// CHECK: %[[VAL_90:.*]] = udiv i32 %[[VAL_88]], 32 -// CHECK: %[[VAL_91:.*]] = urem i32 %[[VAL_90]], 3 -// CHECK: %[[VAL_92:.*]] = udiv i32 %[[VAL_90]], 3 -// CHECK: %[[VAL_93:.*]] = udiv i32 %[[VAL_92]], 12 -// CHECK: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_73]], 1 -// CHECK: %[[VAL_95:.*]] = add nuw nsw i32 0, %[[VAL_94]] -// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [12 x [3 x [32 x [16 x [32 x [4 x [3 x [12 x float]]]]]]]], ptr %[[VAL_97:.*]], i32 0, i32 %[[VAL_92]], i32 %[[VAL_91]], i32 %[[VAL_89]], i32 %[[VAL_85]], i32 %[[VAL_84]], i32 %[[VAL_82]], i32 %[[VAL_80]], i32 %[[VAL_78]] -// CHECK: %[[VAL_98:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_96]], align 4, !invariant.load -// CHECK: store float %[[VAL_98]], ptr{{( addrspace\(5\))?}} %[[VAL_14]], align 4 -// CHECK-PTX: call void @[[REDUCTION0]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) -// CHECK: %[[VAL_99:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_10]], align 4 -// CHECK: store float %[[VAL_99]], ptr{{( addrspace\(5\))?}} %[[VAL_13]], align 4 -// CHECK: br label %[[VAL_38]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] -// CHECK: %[[VAL_100:.*]] = add i32 %tile_origin.2, %thread.id.1 -// CHECK: %[[VAL_101:.*]] = mul nuw nsw i32 %[[VAL_100]], 1 -// CHECK: %[[VAL_102:.*]] = add nuw nsw i32 0, %[[VAL_101]] -// CHECK: %[[VAL_103:.*]] = urem i32 %[[VAL_102]], 12 -// CHECK: %[[VAL_104:.*]] = udiv i32 %[[VAL_102]], 12 -// CHECK: %[[VAL_105:.*]] = urem i32 %[[VAL_104]], 3 -// CHECK: %[[VAL_106:.*]] = udiv i32 %[[VAL_104]], 3 -// CHECK: %[[VAL_107:.*]] = urem i32 %[[VAL_106]], 4 -// CHECK: %[[VAL_108:.*]] = udiv i32 %[[VAL_106]], 4 -// CHECK: %[[VAL_109:.*]] = urem i32 %[[VAL_108]], 32 -// CHECK: %[[VAL_110:.*]] = udiv i32 %[[VAL_108]], 32 -// CHECK: %[[VAL_111:.*]] = udiv i32 %[[VAL_110]], 16 -// CHECK: %[[VAL_112:.*]] = mul nuw nsw i32 %tile_origin.0, 1 -// CHECK: %[[VAL_113:.*]] = add nuw nsw i32 0, %[[VAL_112]] -// CHECK: %[[VAL_114:.*]] = getelementptr inbounds [12 x [16 x [4 x [3 x [32 x float]]]]], ptr %[[VAL_115:.*]], i32 0, i32 %[[VAL_103]], i32 %[[VAL_110]], i32 %[[VAL_107]], i32 %[[VAL_105]], i32 %[[VAL_109]] -// CHECK: %[[VAL_116:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_51]], align 4 -// CHECK: store float %[[VAL_116]], ptr{{( addrspace\(5\))?}} %[[VAL_114]], align 4 -// CHECK: br label %[[VAL_19]] -// CHECK: entry: -// CHECK: %[[VAL_117:.*]] = alloca float, align 4 -// CHECK: %[[VAL_118:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_119:.*]], align 4 -// CHECK: %[[VAL_120:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_121:.*]], align 4 -// CHECK: %[[VAL_122:.*]] = fadd float %[[VAL_118]], %[[VAL_120]] -// CHECK: store float %[[VAL_122]], ptr{{( addrspace\(5\))?}} %[[VAL_117]], align 4 -// CHECK: %[[VAL_123:.*]] = load float, ptr{{( addrspace\(5\))?}} %[[VAL_117]], align 4 -// CHECK: store float %[[VAL_123]], ptr{{( addrspace\(5\))?}} %[[VAL_124:.*]], align 4 -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo deleted file mode 100644 index f1f7cc198b6541..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_f64_column.hlo +++ /dev/null @@ -1,254 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -HloModule m, is_scheduled=true - -add { - a = f64[] parameter(0) - b = f64[] parameter(1) - ROOT out = f64[] add(a, b) -} - -fused_computation { - p1 = f64[1024,1024]{1,0} parameter(0) - p2 = f64[1024,1024]{1,0} parameter(1) - s = pred[1024,1024]{1,0} parameter(2) - p = f64[1024,1024]{1,0} select(s, p1, p2) - z = f64[] constant(0) - ROOT out = f64[1024]{0} reduce(p, z), to_apply=add, dimensions={0} -} - -ENTRY e { - p1 = f64[1024,1024]{1,0} parameter(0) - p2 = f64[1024,1024]{1,0} parameter(1) - s = pred[1024,1024]{1,0} parameter(2) - ROOT f = f64[1024]{0} fusion(p1, p2, s), kind=kInput, calls=fused_computation -} - -// CHECK: @shared_cache = private addrspace(3) global [32 x [33 x double]] - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca double, align 8 -// CHECK: %[[VAL_1:.*]] = alloca double, align 8 -// CHECK: %[[VAL_2:.*]] = alloca double, align 8 -// CHECK: %[[VAL_3:.*]] = alloca double, align 8 -// CHECK: %[[VAL_4:.*]] = alloca double, align 8 -// CHECK: %[[VAL_5:.*]] = alloca double, align 8 -// CHECK: %[[VAL_6:.*]] = alloca double, align 8 -// CHECK: %[[VAL_7:.*]] = alloca double, align 8 -// CHECK: %[[VAL_8:.*]] = alloca double, align 8 -// CHECK: %[[VAL_9:.*]] = alloca double, align 8 -// CHECK: %[[VAL_10:.*]] = alloca double, align 8 -// CHECK: %[[VAL_11:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_12:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_13:.*]] = alloca double, align 8 -// CHECK: %[[VAL_14:.*]] = alloca double, align 8 -// CHECK-PTX: %[[VAL_15:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_15:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_16:.*]] = icmp eq i32 %[[VAL_15]], 0 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: reduce-group-0-after: ; preds = %[[VAL_19:.*]], %[[VAL_20:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_20]] -// CHECK: %[[VAL_21:.*]] = load double, ptr @0, align 8 -// CHECK: store double %[[VAL_21]], ptr{{.*}}%[[VAL_13]], align 8 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_22:.*]] = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.1 = urem i32 %[[VAL_22]], 32 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_23:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_24:.*]] = urem i32 %[[VAL_23]], 32 -// CHECK: %[[VAL_25:.*]] = udiv i32 %block.id.x, 32 -// CHECK: %[[VAL_26:.*]] = urem i32 %[[VAL_25]], 1 -// CHECK: %[[VAL_27:.*]] = udiv i32 %block.id.x, 32 -// CHECK: %[[VAL_28:.*]] = icmp eq i32 %[[VAL_26]], 0 -// CHECK: %tile_bound.1 = select i1 %[[VAL_28]], i32 1024, i32 4096 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_27]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_26]], 4096 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_24]], 32 -// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_12]], align 4 -// CHECK: br label %[[VAL_29:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_30:.*]], %[[VAL_17]] -// CHECK: %[[VAL_31:.*]] = load i32, ptr{{.*}}%[[VAL_12]], align 4 -// CHECK: %[[VAL_32:.*]] = icmp uge i32 %[[VAL_31]], %tile_bound.1 -// CHECK: br i1 %[[VAL_32]], label %[[VAL_33:.*]], label %[[VAL_34:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_35:.*]] = add nuw nsw i32 %[[VAL_31]], 32 -// CHECK: store i32 %[[VAL_35]], ptr{{.*}}%[[VAL_12]], align 4 -// CHECK: store i32 0, ptr{{.*}}%[[VAL_11]], align 4 -// CHECK: br label %[[VAL_37:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_38:.*]], %[[VAL_34]] -// CHECK: %[[VAL_39:.*]] = load i32, ptr{{.*}}%[[VAL_11]], align 4 -// CHECK: %[[VAL_40:.*]] = icmp uge i32 %[[VAL_39]], 32 -// CHECK: br i1 %[[VAL_40]], label %[[VAL_30]], label %[[VAL_41:.*]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_37]] -// CHECK: %[[VAL_42:.*]] = add nuw nsw i32 %[[VAL_39]], 32 -// CHECK: store i32 %[[VAL_42]], ptr{{.*}}%[[VAL_11]], align 4 -// CHECK: %[[VAL_44:.*]] = add i32 %[[VAL_39]], %thread.id.2 -// CHECK: %[[VAL_45:.*]] = icmp ult i32 %[[VAL_44]], 32 -// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_38]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_46]], %[[VAL_41]] -// CHECK: br label %[[VAL_37]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_37]] -// CHECK: br label %[[VAL_29]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit: ; preds = %[[VAL_29]] -// CHECK: %[[VAL_47:.*]] = load double, ptr{{.*}}%[[VAL_13]], align 8 -// CHECK: %[[VAL_48:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.2, i32 %thread.id.1 -// CHECK: %[[VAL_49:.*]] = addrspacecast ptr addrspace(3) %[[VAL_48]] to ptr -// CHECK: store double %[[VAL_47]], ptr{{.*}}%[[VAL_49]], align 8 -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: %[[VAL_50:.*]] = getelementptr inbounds [32 x [33 x double]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %thread.id.2 -// CHECK: %[[VAL_51:.*]] = addrspacecast ptr addrspace(3) %[[VAL_50]] to ptr -// CHECK: %[[VAL_52:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_53:.*]] = bitcast double %[[VAL_52]] to i64 -// CHECK: %[[VAL_54:.*]] = bitcast i64 %[[VAL_53]] to <2 x i32> -// CHECK: %[[VAL_55:.*]] = extractelement <2 x i32> %[[VAL_54]], i64 0 -// CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_55]], i32 543) -// CHECK: %[[VAL_57:.*]] = insertelement <2 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 0 -// CHECK: %[[VAL_58:.*]] = extractelement <2 x i32> %[[VAL_57]], i64 1 -// CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_59:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_58]], i32 543) -// CHECK: %[[VAL_60:.*]] = insertelement <2 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 1 -// CHECK: %[[VAL_61:.*]] = bitcast <2 x i32> %[[VAL_60]] to i64 -// CHECK: %[[VAL_62:.*]] = bitcast i64 %[[VAL_61]] to double -// CHECK: store double %[[VAL_62]], ptr{{.*}}%[[VAL_9]], align 8 -// CHECK-PTX: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9]], ptr %[[VAL_8]]) -// CHECK-GCN: %[[VAL_9_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr -// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr -// CHECK-GCN: call void @[[ADD:add.*]](ptr %[[VAL_51]], ptr %[[VAL_9_1]], ptr %[[VAL_8_1]]) -// CHECK: %[[VAL_63:.*]] = load double, ptr{{.*}}%[[VAL_8]], align 8 -// CHECK: store double %[[VAL_63]], ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_64:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_65:.*]] = bitcast double %[[VAL_64]] to i64 -// CHECK: %[[VAL_66:.*]] = bitcast i64 %[[VAL_65]] to <2 x i32> -// CHECK: %[[VAL_67:.*]] = extractelement <2 x i32> %[[VAL_66]], i64 0 -// CHECK-PTX: %[[VAL_68:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_67]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_68:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_67]], i32 287) -// CHECK: %[[VAL_69:.*]] = insertelement <2 x i32> %[[VAL_66]], i32 %[[VAL_68]], i64 0 -// CHECK: %[[VAL_70:.*]] = extractelement <2 x i32> %[[VAL_69]], i64 1 -// CHECK-PTX: %[[VAL_71:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_70]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_71:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_70]], i32 287) -// CHECK: %[[VAL_72:.*]] = insertelement <2 x i32> %[[VAL_69]], i32 %[[VAL_71]], i64 1 -// CHECK: %[[VAL_73:.*]] = bitcast <2 x i32> %[[VAL_72]] to i64 -// CHECK: %[[VAL_74:.*]] = bitcast i64 %[[VAL_73]] to double -// CHECK: store double %[[VAL_74]], ptr{{.*}}%[[VAL_7]], align 8 -// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7]], ptr %[[VAL_6]]) -// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr -// CHECK-GCN: %[[VAL_6_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr -// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_7_1]], ptr %[[VAL_6_1]]) -// CHECK: %[[VAL_75:.*]] = load double, ptr{{.*}}%[[VAL_6]], align 8 -// CHECK: store double %[[VAL_75]], ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_76:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_77:.*]] = bitcast double %[[VAL_76]] to i64 -// CHECK: %[[VAL_78:.*]] = bitcast i64 %[[VAL_77]] to <2 x i32> -// CHECK: %[[VAL_79:.*]] = extractelement <2 x i32> %[[VAL_78]], i64 0 -// CHECK-PTX: %[[VAL_80:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_79]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_80:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_79]], i32 159) -// CHECK: %[[VAL_81:.*]] = insertelement <2 x i32> %[[VAL_78]], i32 %[[VAL_80]], i64 0 -// CHECK: %[[VAL_82:.*]] = extractelement <2 x i32> %[[VAL_81]], i64 1 -// CHECK-PTX: %[[VAL_83:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_82]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_83:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_82]], i32 159) -// CHECK: %[[VAL_84:.*]] = insertelement <2 x i32> %[[VAL_81]], i32 %[[VAL_83]], i64 1 -// CHECK: %[[VAL_85:.*]] = bitcast <2 x i32> %[[VAL_84]] to i64 -// CHECK: %[[VAL_86:.*]] = bitcast i64 %[[VAL_85]] to double -// CHECK: store double %[[VAL_86]], ptr{{.*}}%[[VAL_5]], align 8 -// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK-GCN: %[[VAL_5_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr -// CHECK-GCN: %[[VAL_4_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr -// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_5_1]], ptr %[[VAL_4_1]]) -// CHECK: %[[VAL_87:.*]] = load double, ptr{{.*}}%[[VAL_4]], align 8 -// CHECK: store double %[[VAL_87]], ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_88:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_89:.*]] = bitcast double %[[VAL_88]] to i64 -// CHECK: %[[VAL_90:.*]] = bitcast i64 %[[VAL_89]] to <2 x i32> -// CHECK: %[[VAL_91:.*]] = extractelement <2 x i32> %[[VAL_90]], i64 0 -// CHECK-PTX: %[[VAL_92:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_91]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_92:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_91]], i32 95) -// CHECK: %[[VAL_93:.*]] = insertelement <2 x i32> %[[VAL_90]], i32 %[[VAL_92]], i64 0 -// CHECK: %[[VAL_94:.*]] = extractelement <2 x i32> %[[VAL_93]], i64 1 -// CHECK-PTX: %[[VAL_95:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_94]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_95:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_94]], i32 95) -// CHECK: %[[VAL_96:.*]] = insertelement <2 x i32> %[[VAL_93]], i32 %[[VAL_95]], i64 1 -// CHECK: %[[VAL_97:.*]] = bitcast <2 x i32> %[[VAL_96]] to i64 -// CHECK: %[[VAL_98:.*]] = bitcast i64 %[[VAL_97]] to double -// CHECK: store double %[[VAL_98]], ptr{{.*}}%[[VAL_3]], align 8 -// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3]], ptr %[[VAL_2]]) -// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr -// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr -// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_3_1]], ptr %[[VAL_2_1]]) -// CHECK: %[[VAL_99:.*]] = load double, ptr{{.*}}%[[VAL_2]], align 8 -// CHECK: store double %[[VAL_99]], ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_100:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: %[[VAL_101:.*]] = bitcast double %[[VAL_100]] to i64 -// CHECK: %[[VAL_102:.*]] = bitcast i64 %[[VAL_101]] to <2 x i32> -// CHECK: %[[VAL_103:.*]] = extractelement <2 x i32> %[[VAL_102]], i64 0 -// CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_104:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_103]], i32 63) -// CHECK: %[[VAL_105:.*]] = insertelement <2 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 0 -// CHECK: %[[VAL_106:.*]] = extractelement <2 x i32> %[[VAL_105]], i64 1 -// CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_107:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_106]], i32 63) -// CHECK: %[[VAL_108:.*]] = insertelement <2 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 1 -// CHECK: %[[VAL_109:.*]] = bitcast <2 x i32> %[[VAL_108]] to i64 -// CHECK: %[[VAL_110:.*]] = bitcast i64 %[[VAL_109]] to double -// CHECK: store double %[[VAL_110]], ptr{{.*}}%[[VAL_1]], align 8 -// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1]], ptr %[[VAL_0]]) -// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr -// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr -// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_51]], ptr %[[VAL_1_1]], ptr %[[VAL_0_1]]) -// CHECK: %[[VAL_111:.*]] = load double, ptr{{.*}}%[[VAL_0]], align 8 -// CHECK: store double %[[VAL_111]], ptr{{.*}}%[[VAL_51]], align 8 -// CHECK-PTX: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 -// CHECK-PTX: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 -// CHECK-GCN: %[[VAL_113:.*]] = icmp ult i32 %thread.id.2, %tile_bound.1 -// CHECK-GCN: %[[VAL_112:.*]] = icmp ult i32 %thread.id.1, 32 -// CHECK: %[[VAL_114:.*]] = and i1 %[[VAL_112]], %[[VAL_113]] -// CHECK: %[[VAL_115:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: %[[VAL_116:.*]] = and i1 %[[VAL_114]], %[[VAL_115]] -// CHECK: br i1 %[[VAL_116]], label %[[VAL_117:.*]], label %[[VAL_19]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_117]], %[[VAL_33]] -// CHECK: br label %[[VAL_18]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_41]] -// CHECK: %[[VAL_118:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_119:.*]] = add i32 %tile_origin.1, %[[VAL_31]] -// CHECK: %[[VAL_120:.*]] = add i32 %tile_origin.2, %[[VAL_44]] -// CHECK: %[[VAL_121:.*]] = getelementptr inbounds [1024 x [1024 x i8]], ptr{{.*}}%[[VAL_122:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] -// CHECK: %[[VAL_123:.*]] = load i8, ptr{{.*}}%[[VAL_121]], align 1, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_124:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_125:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] -// CHECK: %[[VAL_126:.*]] = load double, ptr{{.*}}%[[VAL_124]], align 8, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_127:.*]] = getelementptr inbounds [1024 x [1024 x double]], ptr{{.*}}%[[VAL_128:.*]], i32 0, i32 %[[VAL_119]], i32 %[[VAL_120]] -// CHECK: %[[VAL_129:.*]] = load double, ptr{{.*}}%[[VAL_127]], align 8, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_130:.*]] = trunc i8 %[[VAL_123]] to i1 -// CHECK: %[[VAL_131:.*]] = select i1 %[[VAL_130]], double %[[VAL_126]], double %[[VAL_129]] -// CHECK: store double %[[VAL_131]], ptr{{.*}}%[[VAL_14]], align 8 -// CHECK-PTX: call void @[[ADD]](ptr %[[VAL_13]], ptr %[[VAL_14]], ptr %[[VAL_10]]) -// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr -// CHECK-GCN: %[[VAL_14_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr -// CHECK-GCN: %[[VAL_10_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr -// CHECK-GCN: call void @[[ADD]](ptr %[[VAL_13_1]], ptr %[[VAL_14_1]], ptr %[[VAL_10_1]]) -// CHECK: %[[VAL_132:.*]] = load double, ptr{{.*}}%[[VAL_10]], align 8 -// CHECK: store double %[[VAL_132]], ptr{{.*}}%[[VAL_13]], align 8 -// CHECK: br label %[[VAL_38]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_33]] -// CHECK: %[[VAL_135:.*]] = add i32 %tile_origin.2, %thread.id.1 -// CHECK: %[[VAL_139:.*]] = getelementptr inbounds [1024 x double], ptr{{.*}}%[[VAL_140:.*]], i32 0, i32 %[[VAL_135]] -// CHECK: %[[VAL_141:.*]] = load double, ptr{{.*}}%[[VAL_51]], align 8 -// CHECK: store double %[[VAL_141]], ptr{{.*}}%[[VAL_139]], align 8 -// CHECK: br label %[[VAL_19]] -// CHECK: entry: -// CHECK: %[[VAL_142:.*]] = alloca double, align 8 -// CHECK: %[[VAL_143:.*]] = load double, ptr{{.*}}%[[VAL_144:.*]], align 8 -// CHECK: %[[VAL_145:.*]] = load double, ptr{{.*}}%[[VAL_146:.*]], align 8 -// CHECK: %[[VAL_147:.*]] = fadd double %[[VAL_143]], %[[VAL_145]] -// CHECK: store double %[[VAL_147]], ptr{{.*}}%[[VAL_142]], align 8 -// CHECK: %[[VAL_148:.*]] = load double, ptr{{.*}}%[[VAL_142]], align 8 -// CHECK: store double %[[VAL_148]], ptr{{.*}}%[[VAL_149:.*]], align 8 -// CHECK: ret void - -// CHECK-PTX: !3 = !{i32 0, i32 1024} -// CHECK-PTX: !4 = !{i32 0, i32 32} diff --git a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo b/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo deleted file mode 100644 index ac932291805d1d..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_large_row_to_scalar.hlo +++ /dev/null @@ -1,554 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -HloModule LargeReduction, is_scheduled=true - -Sum { - x.1 = c128[] parameter(0) - y.1 = c128[] parameter(1) - ROOT add.1 = c128[] add(x.1, y.1) -} - -fused_computation { - param_0 = c128[10000]{0} parameter(0) - param_1 = c128[] parameter(1) - ROOT out1.1 = c128[] reduce(c128[10000]{0} param_0, c128[] param_1), dimensions={0}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter = c128[10000] parameter(0) - init_value = c128[] constant((0, 0)) - ROOT wrapped_out1 = c128[] fusion(c128[10000]{0} parameter, c128[] init_value), kind=kInput, calls=fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca %[[VAL_1:.*]], align 8 -// CHECK: %[[VAL_2:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_3:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_4:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_5:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_6:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_7:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_8:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_9:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_10:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_11:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_12:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_13:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_14:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_15:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_16:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_17:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_18:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_19:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_20:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_21:.*]] = alloca %[[VAL_1]], align 8 -// CHECK: %[[VAL_22:.*]] = alloca %[[VAL_1]], align 8 -// CHECK-PTX: %[[VAL_23:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_24:.*]] = alloca i32, align 4 -// CHECK-DAG: %[[VAL_25:.*]] = alloca %[[VAL_1]], align 8 -// CHECK-DAG: %[[VAL_26:.*]] = alloca i32, align 4 -// CHECK-DAG: %[[VAL_27:.*]] = alloca i32, align 4 -// CHECK-DAG: %[[VAL_28:.*]] = alloca %[[VAL_1]], align 8 -// CHECK-DAG: %[[VAL_29:.*]] = alloca %[[VAL_1]], align 8 -// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 -// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] -// CHECK: %[[VAL_35:.*]] = load %[[VAL_1]], ptr %[[VAL_36:.*]], align 1, !invariant.load !{{[0-9]}} -// CHECK: store %[[VAL_1]] %[[VAL_35]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 640 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 -// CHECK: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 -// CHECK-PTX: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 -// CHECK-PTX: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 1 -// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 1 -// CHECK-PTX: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_40]], 0 -// CHECK-GCN: %[[VAL_44:.*]] = icmp eq i32 %[[VAL_38]], 0 -// CHECK-PTX: %tile_bound.2 = select i1 %[[VAL_44]], i32 5000, i32 5120 -// CHECK-GCN: %tile_bound.2 = select i1 %[[VAL_44]], i32 10000, i32 10240 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 -// CHECK-PTX: %tile_origin.1 = mul i32 %[[VAL_42]], 1 -// CHECK-GCN: %tile_origin.1 = mul i32 %[[VAL_40]], 1 -// CHECK-PTX: %tile_origin.2 = mul i32 %[[VAL_40]], 5120 -// CHECK-GCN: %tile_origin.2 = mul i32 %[[VAL_38]], 10240 -// CHECK-PTX: %tile_origin.3 = mul i32 %[[VAL_38]], 2 -// CHECK-PTX: %[[VAL_45:.*]] = icmp eq i32 5120, %tile_bound.2 -// CHECK-GCN: %[[VAL_45:.*]] = icmp eq i32 10240, %tile_bound.2 -// CHECK: br i1 %[[VAL_45]], label %[[VAL_46:.*]], label %[[VAL_47:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_48:.*]], %[[VAL_49:.*]] -// CHECK: %[[VAL_50:.*]] = load i128, ptr{{.*}} %[[VAL_28]], align {{(16|8)}} -// CHECK: %[[VAL_51:.*]] = bitcast i128 %[[VAL_50]] to <4 x i32> -// CHECK: %[[VAL_52:.*]] = extractelement <4 x i32> %[[VAL_51]], i64 0 -// CHECK-PTX: %[[VAL_53:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_52]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_53:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_52]], i32 543) -// CHECK: %[[VAL_54:.*]] = insertelement <4 x i32> %[[VAL_51]], i32 %[[VAL_53]], i64 0 -// CHECK: %[[VAL_55:.*]] = extractelement <4 x i32> %[[VAL_54]], i64 1 -// CHECK-PTX: %[[VAL_56:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_55]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_56:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_55]], i32 543) -// CHECK: %[[VAL_57:.*]] = insertelement <4 x i32> %[[VAL_54]], i32 %[[VAL_56]], i64 1 -// CHECK: %[[VAL_58:.*]] = extractelement <4 x i32> %[[VAL_57]], i64 2 -// CHECK-PTX: %[[VAL_59:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_58]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_59:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_58]], i32 543) -// CHECK: %[[VAL_60:.*]] = insertelement <4 x i32> %[[VAL_57]], i32 %[[VAL_59]], i64 2 -// CHECK: %[[VAL_61:.*]] = extractelement <4 x i32> %[[VAL_60]], i64 3 -// CHECK-PTX: %[[VAL_62:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_61]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_62:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_61]], i32 543) -// CHECK: %[[VAL_63:.*]] = insertelement <4 x i32> %[[VAL_60]], i32 %[[VAL_62]], i64 3 -// CHECK: %[[VAL_64:.*]] = bitcast <4 x i32> %[[VAL_63]] to i128 -// CHECK: store i128 %[[VAL_64]], ptr{{.*}} %[[VAL_21]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_65_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_65_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_21]] to ptr -// CHECK-GCN: %[[VAL_65_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_20]] to ptr -// CHECK-GCN: call void @[[SUM:Sum.*]](ptr %[[VAL_65_1]], ptr %[[VAL_65_2]], ptr %[[VAL_65_3]]) -// CHECK-PTX: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_21]], ptr %[[VAL_20]]) -// CHECK: %[[VAL_65:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_20]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_65]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_66:.*]] = load i128, ptr{{.*}} %[[VAL_28]], align {{(16|8)}} -// CHECK: %[[VAL_67:.*]] = bitcast i128 %[[VAL_66]] to <4 x i32> -// CHECK: %[[VAL_68:.*]] = extractelement <4 x i32> %[[VAL_67]], i64 0 -// CHECK-PTX: %[[VAL_69:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_68]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_69:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_68]], i32 287) -// CHECK: %[[VAL_70:.*]] = insertelement <4 x i32> %[[VAL_67]], i32 %[[VAL_69]], i64 0 -// CHECK: %[[VAL_71:.*]] = extractelement <4 x i32> %[[VAL_70]], i64 1 -// CHECK-PTX: %[[VAL_72:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_71]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_72:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_71]], i32 287) -// CHECK: %[[VAL_73:.*]] = insertelement <4 x i32> %[[VAL_70]], i32 %[[VAL_72]], i64 1 -// CHECK: %[[VAL_74:.*]] = extractelement <4 x i32> %[[VAL_73]], i64 2 -// CHECK-PTX: %[[VAL_75:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_74]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_75:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_74]], i32 287) -// CHECK: %[[VAL_76:.*]] = insertelement <4 x i32> %[[VAL_73]], i32 %[[VAL_75]], i64 2 -// CHECK: %[[VAL_77:.*]] = extractelement <4 x i32> %[[VAL_76]], i64 3 -// CHECK-PTX: %[[VAL_78:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_77]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_78:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_77]], i32 287) -// CHECK: %[[VAL_79:.*]] = insertelement <4 x i32> %[[VAL_76]], i32 %[[VAL_78]], i64 3 -// CHECK: %[[VAL_80:.*]] = bitcast <4 x i32> %[[VAL_79]] to i128 -// CHECK: store i128 %[[VAL_80]], ptr{{.*}} %[[VAL_19]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_81_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_81_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_19]] to ptr -// CHECK-GCN: %[[VAL_81_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_18]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_81_1]], ptr %[[VAL_81_2]], ptr %[[VAL_81_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_19]], ptr %[[VAL_18]]) -// CHECK: %[[VAL_81:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_18]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_81]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_82:.*]] = load i128, ptr{{.*}} %[[VAL_28]], align {{(16|8)}} -// CHECK: %[[VAL_83:.*]] = bitcast i128 %[[VAL_82]] to <4 x i32> -// CHECK: %[[VAL_84:.*]] = extractelement <4 x i32> %[[VAL_83]], i64 0 -// CHECK-PTX: %[[VAL_85:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_84]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_85:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_84]], i32 159) -// CHECK: %[[VAL_86:.*]] = insertelement <4 x i32> %[[VAL_83]], i32 %[[VAL_85]], i64 0 -// CHECK: %[[VAL_87:.*]] = extractelement <4 x i32> %[[VAL_86]], i64 1 -// CHECK-PTX: %[[VAL_88:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_87]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_88:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_87]], i32 159) -// CHECK: %[[VAL_89:.*]] = insertelement <4 x i32> %[[VAL_86]], i32 %[[VAL_88]], i64 1 -// CHECK: %[[VAL_90:.*]] = extractelement <4 x i32> %[[VAL_89]], i64 2 -// CHECK-PTX: %[[VAL_91:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_90]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_91:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_90]], i32 159) -// CHECK: %[[VAL_92:.*]] = insertelement <4 x i32> %[[VAL_89]], i32 %[[VAL_91]], i64 2 -// CHECK: %[[VAL_93:.*]] = extractelement <4 x i32> %[[VAL_92]], i64 3 -// CHECK-PTX: %[[VAL_94:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_93]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_94:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_93]], i32 159) -// CHECK: %[[VAL_95:.*]] = insertelement <4 x i32> %[[VAL_92]], i32 %[[VAL_94]], i64 3 -// CHECK: %[[VAL_96:.*]] = bitcast <4 x i32> %[[VAL_95]] to i128 -// CHECK: store i128 %[[VAL_96]], ptr{{.*}} %[[VAL_17]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_98_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_98_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_17]] to ptr -// CHECK-GCN: %[[VAL_98_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_16]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_98_1]], ptr %[[VAL_98_2]], ptr %[[VAL_98_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_17]], ptr %[[VAL_16]]) -// CHECK: %[[VAL_97:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_16]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_97]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_98:.*]] = load i128, ptr{{.*}} %[[VAL_28]], align {{(16|8)}} -// CHECK: %[[VAL_99:.*]] = bitcast i128 %[[VAL_98]] to <4 x i32> -// CHECK: %[[VAL_100:.*]] = extractelement <4 x i32> %[[VAL_99]], i64 0 -// CHECK-PTX: %[[VAL_101:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_100]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_101:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_100]], i32 95) -// CHECK: %[[VAL_102:.*]] = insertelement <4 x i32> %[[VAL_99]], i32 %[[VAL_101]], i64 0 -// CHECK: %[[VAL_103:.*]] = extractelement <4 x i32> %[[VAL_102]], i64 1 -// CHECK-PTX: %[[VAL_104:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_103]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_104:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_103]], i32 95) -// CHECK: %[[VAL_105:.*]] = insertelement <4 x i32> %[[VAL_102]], i32 %[[VAL_104]], i64 1 -// CHECK: %[[VAL_106:.*]] = extractelement <4 x i32> %[[VAL_105]], i64 2 -// CHECK-PTX: %[[VAL_107:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_106]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_107:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_106]], i32 95) -// CHECK: %[[VAL_108:.*]] = insertelement <4 x i32> %[[VAL_105]], i32 %[[VAL_107]], i64 2 -// CHECK: %[[VAL_109:.*]] = extractelement <4 x i32> %[[VAL_108]], i64 3 -// CHECK-PTX: %[[VAL_110:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_109]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_110:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_109]], i32 95) -// CHECK: %[[VAL_111:.*]] = insertelement <4 x i32> %[[VAL_108]], i32 %[[VAL_110]], i64 3 -// CHECK: %[[VAL_112:.*]] = bitcast <4 x i32> %[[VAL_111]] to i128 -// CHECK: store i128 %[[VAL_112]], ptr{{.*}} %[[VAL_15]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_113_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_113_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_15]] to ptr -// CHECK-GCN: %[[VAL_113_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_14]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_113_1]], ptr %[[VAL_113_2]], ptr %[[VAL_113_3]]) -// CHECK_PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_15]], ptr %[[VAL_14]]) -// CHECK: %[[VAL_113:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_14]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_113]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_114:.*]] = load i128, ptr{{.*}} %[[VAL_28]], align {{(16|8)}} -// CHECK: %[[VAL_115:.*]] = bitcast i128 %[[VAL_114]] to <4 x i32> -// CHECK: %[[VAL_116:.*]] = extractelement <4 x i32> %[[VAL_115]], i64 0 -// CHECK-PTX: %[[VAL_117:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_116]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_117:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_116]], i32 63) -// CHECK: %[[VAL_118:.*]] = insertelement <4 x i32> %[[VAL_115]], i32 %[[VAL_117]], i64 0 -// CHECK: %[[VAL_119:.*]] = extractelement <4 x i32> %[[VAL_118]], i64 1 -// CHECK-PTX: %[[VAL_120:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_119]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_120:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_119]], i32 63) -// CHECK: %[[VAL_121:.*]] = insertelement <4 x i32> %[[VAL_118]], i32 %[[VAL_120]], i64 1 -// CHECK: %[[VAL_122:.*]] = extractelement <4 x i32> %[[VAL_121]], i64 2 -// CHECK-PTX: %[[VAL_123:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_122]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_123:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_122]], i32 63) -// CHECK: %[[VAL_124:.*]] = insertelement <4 x i32> %[[VAL_121]], i32 %[[VAL_123]], i64 2 -// CHECK: %[[VAL_125:.*]] = extractelement <4 x i32> %[[VAL_124]], i64 3 -// CHECK-PTX: %[[VAL_126:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_125]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_126:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_125]], i32 63) -// CHECK: %[[VAL_127:.*]] = insertelement <4 x i32> %[[VAL_124]], i32 %[[VAL_126]], i64 3 -// CHECK: %[[VAL_128:.*]] = bitcast <4 x i32> %[[VAL_127]] to i128 -// CHECK: store i128 %[[VAL_128]], ptr{{.*}} %[[VAL_13]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_129_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_129_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_13]] to ptr -// CHECK-GCN: %[[VAL_129_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_12]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_129_1]], ptr %[[VAL_129_2]], ptr %[[VAL_129_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_13]], ptr %[[VAL_12]]) -// CHECK: %[[VAL_129:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_12]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_129]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_130:.*]] = udiv i32 %thread.id.2, 32 -// CHECK: br i1 true, label %thread_in_bounds-true, label %thread_in_bounds-after - -// CHECK: thread_in_bounds-after: ; preds = %[[VAL_131:.*]], %[[VAL_132:.*]] -// CHECK: br label %[[VAL_33]] - -// CHECK: is_full_tile-true: ; preds = %[[VAL_32]] -// CHECK: store i32 0, ptr{{.*}} %[[VAL_27]], align 4 -// CHECK: br label %[[VAL_133:.*]] - -// CHECK: loop2.loop_header: ; preds = %[[VAL_134:.*]], %[[VAL_46]] -// CHECK: %[[VAL_135:.*]] = load i32, ptr{{.*}} %[[VAL_27]], align 4 -// CHECK-PTX: %[[VAL_136:.*]] = icmp uge i32 %[[VAL_135]], 5120 -// CHECK-GCN: %[[VAL_136:.*]] = icmp uge i32 %[[VAL_135]], 10240 -// CHECK: br i1 %[[VAL_136]], label %[[VAL_49]], label %[[VAL_137:.*]] - -// CHECK: loop2.loop_body: ; preds = %[[VAL_133]] -// CHECK: %[[VAL_138:.*]] = add nuw nsw i32 %[[VAL_135]], 640 -// CHECK: store i32 %[[VAL_138]], ptr{{.*}} %[[VAL_27]], align 4 -// CHECK: %[[VAL_140:.*]] = add i32 %[[VAL_135]], %thread.id.2 -// CHECK-GCN: %[[VAL_147:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[VAL_148:.*]] = add i32 %tile_origin.1, 0 -// CHECK-GCN: %[[VAL_149:.*]] = add i32 %tile_origin.2, %[[VAL_140]] -// CHECK-GCN: %[[VAL_160:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161:.*]], i32 0, i32 %[[VAL_149]] -// CHECK-GCN: %[[VAL_162:.*]] = load %[[VAL_1]], ptr %[[VAL_160]], align 1, !invariant.load !2 -// CHECK-GCN: store %[[VAL_1]] %[[VAL_162]], ptr{{.*}} %[[VAL_29]], align 1 -// CHECK-GCN: %[[VAL_163_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_163_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_29]] to ptr -// CHECK-GCN: %[[VAL_163_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_25]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_163_1]], ptr %[[VAL_163_2]], ptr %[[VAL_163_3]]) -// CHECK-GCN: %[[VAL_163:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_25]], align 1 -// CHECK-GCN: store %[[VAL_1]] %[[VAL_163]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK-PTX: store i32 0, ptr %[[VAL_26]], align 4 -// CHECK: br label %[[VAL_141:.*]] - -// CHECK-PTX: loop3.loop_header: ; preds = %[[VAL_142:.*]], %[[VAL_137]] -// CHECK-PTX: %[[VAL_143:.*]] = load i32, ptr %[[VAL_26]], align 4 -// CHECK-PTX: %[[VAL_144:.*]] = icmp uge i32 %[[VAL_143]], 2 -// CHECK-PTX: br i1 %[[VAL_144]], label %[[VAL_134]], label %[[VAL_142]] - -// CHECK-PTX: loop3.loop_body: ; preds = %[[VAL_141]] -// CHECK-PTX: %[[VAL_145:.*]] = add nuw nsw i32 %[[VAL_143]], 1 -// CHECK-PTX: store i32 %[[VAL_145]], ptr %[[VAL_26]], align 4 -// CHECK-PTX: %[[VAL_147:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[VAL_148:.*]] = add i32 %tile_origin.1, 0 -// CHECK-PTX: %[[VAL_149:.*]] = add i32 %tile_origin.2, %[[VAL_140]] -// CHECK-PTX: %[[VAL_150:.*]] = add i32 %tile_origin.3, %[[VAL_143]] -// CHECK-PTX: %[[VAL_151:.*]] = mul nuw nsw i32 %[[VAL_150]], 1 -// CHECK-PTX: %[[VAL_152:.*]] = add nuw nsw i32 0, %[[VAL_151]] -// CHECK-PTX: %[[VAL_153:.*]] = mul nuw nsw i32 %[[VAL_149]], 2 -// CHECK-PTX: %[[VAL_154:.*]] = add nuw nsw i32 %[[VAL_152]], %[[VAL_153]] -// CHECK-PTX: %[[VAL_155:.*]] = udiv i32 %[[VAL_154]], 10000 -// CHECK-PTX: %[[VAL_156:.*]] = mul nuw nsw i32 %[[VAL_148]], 1 -// CHECK-PTX: %[[VAL_157:.*]] = add nuw nsw i32 0, %[[VAL_156]] -// CHECK-PTX: %[[VAL_158:.*]] = mul nuw nsw i32 %[[VAL_147]], 1 -// CHECK-PTX: %[[VAL_159:.*]] = add nuw nsw i32 0, %[[VAL_158]] -// CHECK-PTX: %[[VAL_160:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161:.*]], i32 0, i32 %[[VAL_154]] -// CHECK-PTX: %[[VAL_162:.*]] = load %[[VAL_1]], ptr %[[VAL_160]], align 1, !invariant.load !3 -// CHECK-PTX: store %[[VAL_1]] %[[VAL_162]], ptr %[[VAL_29]], align 1 -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_25]]) -// CHECK-PTX: %[[VAL_163:.*]] = load %[[VAL_1]], ptr %[[VAL_25]], align 1 -// CHECK-PTX: store %[[VAL_1]] %[[VAL_163]], ptr %[[VAL_28]], align 1 -// CHECK-PTX: br label %[[VAL_141]], !llvm.loop !5 - -// CHECK-PTX: loop3.loop_exit: ; preds = %[[VAL_141]] -// CHECK-PTX: br label %[[VAL_133]], !llvm.loop !7 - -// CHECK: loop2.loop_exit: ; preds = %[[VAL_133]] -// CHECK: br label %[[VAL_132]] - -// CHECK: is_full_tile-false: ; preds = %[[VAL_32]] -// CHECK-PTX: store i32 0, ptr %[[VAL_24]], align 4 -// CHECK-GCN: store i32 0, ptr{{.*}} %[[VAL_26]], align 4 -// CHECK: br label %[[VAL_164:.*]] - -// CHECK: loop2.loop_header{{4|3}}: ; preds = %[[VAL_165:.*]], %[[VAL_47]] -// CHECK-PTX: %[[VAL_166:.*]] = load i32, ptr %[[VAL_24]], align 4 -// CHECK-PTX: %[[VAL_167:.*]] = icmp uge i32 %[[VAL_166]], 5120 -// CHECK-GCN: %[[VAL_166:.*]] = load i32, ptr{{.*}} %[[VAL_26]], align 4 -// CHECK-GCN: %[[VAL_167:.*]] = icmp uge i32 %[[VAL_166]], 10240 -// CHECK: br i1 %[[VAL_167]], label %[[VAL_48]], label %[[VAL_168:.*]] - -// CHECK: loop2.loop_body{{5|4}}: ; preds = %[[VAL_164]] -// CHECK: %[[VAL_169:.*]] = add nuw nsw i32 %[[VAL_166]], 640 -// CHECK-PTX: store i32 %[[VAL_169]], ptr %[[VAL_24]], align 4 -// CHECK-GCN: store i32 %[[VAL_169]], ptr{{.*}} %[[VAL_26]], align 4 -// CHECK: %[[VAL_171:.*]] = add i32 %[[VAL_166]], %thread.id.2 -// CHECK: %[[VAL_172:.*]] = icmp ult i32 %[[VAL_171]], %tile_bound.2 -// CHECK: br i1 %[[VAL_172]], label %[[VAL_173:.*]], label %[[VAL_165]] - -// CHECK: x_in_tile-after: ; preds = %[[VAL_174:.*]], %[[VAL_168]] -// CHECK: br label %[[VAL_164]], !llvm.loop !{{9|7}} - -// CHECK: loop2.loop_exit{{3|2}}: ; preds = %[[VAL_164]] -// CHECK: br label %[[VAL_132]] - -// CHECK: x_in_tile-true: ; preds = %[[VAL_168]] -// CHECK-GCN: %[[VAL_181:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[VAL_182:.*]] = add i32 %tile_origin.1, 0 -// CHECK-GCN: %[[VAL_183:.*]] = add i32 %tile_origin.2, %[[VAL_171]] -// CHECK-GCN: %[[VAL_194:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161]], i32 0, i32 %[[VAL_183]] -// CHECK-GCN: %[[VAL_195:.*]] = load %[[VAL_1]], ptr %[[VAL_194]], align 1, !invariant.load !2 -// CHECK-GCN: store %[[VAL_1]] %[[VAL_195]], ptr{{.*}} %[[VAL_29]], align 1 -// CHECK-GCN: %[[VAL_196_1:.*]] = addrspacecast ptr{{.*}} %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_196_2:.*]] = addrspacecast ptr{{.*}} %[[VAL_29]] to ptr -// CHECK-GCN: %[[VAL_196_3:.*]] = addrspacecast ptr{{.*}} %[[VAL_22]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_196_1]], ptr %[[VAL_196_2]], ptr %[[VAL_196_3]]) -// CHECK-GCN: %[[VAL_196:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_22]], align 1 -// CHECK-GCN: store %[[VAL_1]] %[[VAL_196]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK-PTX: store i32 0, ptr %[[VAL_23]], align 4 -// CHECK: br label %[[VAL_175:.*]] - -// CHECK-PTX: loop3.loop_header10: ; preds = %[[VAL_176:.*]], %[[VAL_173]] -// CHECK-PTX: %[[VAL_177:.*]] = load i32, ptr %[[VAL_23]], align 4 -// CHECK-PTX: %[[VAL_178:.*]] = icmp uge i32 %[[VAL_177]], 2 -// CHECK-PTX: br i1 %[[VAL_178]], label %[[VAL_174]], label %[[VAL_176]] -// CHECK-PTX: loop3.loop_body11: ; preds = %[[VAL_175]] -// CHECK-PTX: %[[VAL_179:.*]] = add nuw nsw i32 %[[VAL_177]], 1 -// CHECK-PTX: store i32 %[[VAL_179]], ptr %[[VAL_23]], align 4 -// CHECK-PTX: %[[VAL_181:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[VAL_182:.*]] = add i32 %tile_origin.1, 0 -// CHECK-PTX: %[[VAL_183:.*]] = add i32 %tile_origin.2, %[[VAL_171]] -// CHECK-PTX: %[[VAL_184:.*]] = add i32 %tile_origin.3, %[[VAL_177]] -// CHECK-PTX: %[[VAL_185:.*]] = mul nuw nsw i32 %[[VAL_184]], 1 -// CHECK-PTX: %[[VAL_186:.*]] = add nuw nsw i32 0, %[[VAL_185]] -// CHECK-PTX: %[[VAL_187:.*]] = mul nuw nsw i32 %[[VAL_183]], 2 -// CHECK-PTX: %[[VAL_188:.*]] = add nuw nsw i32 %[[VAL_186]], %[[VAL_187]] -// CHECK-PTX: %[[VAL_189:.*]] = udiv i32 %[[VAL_188]], 10000 -// CHECK-PTX: %[[VAL_190:.*]] = mul nuw nsw i32 %[[VAL_182]], 1 -// CHECK-PTX: %[[VAL_191:.*]] = add nuw nsw i32 0, %[[VAL_190]] -// CHECK-PTX: %[[VAL_192:.*]] = mul nuw nsw i32 %[[VAL_181]], 1 -// CHECK-PTX: %[[VAL_193:.*]] = add nuw nsw i32 0, %[[VAL_192]] -// CHECK-PTX: %[[VAL_194:.*]] = getelementptr inbounds [10000 x %[[VAL_1]]], ptr %[[VAL_161]], i32 0, i32 %[[VAL_188]] -// CHECK-PTX: %[[VAL_195:.*]] = load %[[VAL_1]], ptr %[[VAL_194]], align 1, !invariant.load !3 -// CHECK-PTX: store %[[VAL_1]] %[[VAL_195]], ptr %[[VAL_29]], align 1 -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_22]]) -// CHECK-PTX: %[[VAL_196:.*]] = load %[[VAL_1]], ptr %[[VAL_22]], align 1 -// CHECK-PTX: store %[[VAL_1]] %[[VAL_196]], ptr %[[VAL_28]], align 1 -// CHECK-PTX: br label %[[VAL_175]], !llvm.loop !10 -// CHECK-PTX: loop3.loop_exit9: ; preds = %[[VAL_175]] -// CHECK-PTX: br label %[[VAL_165]] - -// CHECK: thread_in_bounds-true: ; preds = %[[VAL_132]] -// CHECK: %[[VAL_197:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: br i1 %[[VAL_197]], label %[[VAL_198:.*]], label %[[VAL_199:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_198]], %thread_in_bounds-true -// CHECK-GCN: fence syncscope("workgroup") seq_cst -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK: %[[VAL_200:.*]] = icmp eq i32 %[[VAL_130]], 0 -// CHECK: br i1 %[[VAL_200]], label %[[VAL_201:.*]], label %[[VAL_131]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_202:.*]], %[[VAL_199]] -// CHECK: br label %thread_in_bounds-after -// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true -// CHECK: %[[VAL_203:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_28]], align 1 -// CHECK: %[[VAL_204:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %[[VAL_130]] -// CHECK: %[[VAL_205:.*]] = addrspacecast ptr addrspace(3) %[[VAL_204]] to ptr -// CHECK: store %[[VAL_1]] %[[VAL_203]], ptr %[[VAL_205]], align 1 -// CHECK: br label %[[VAL_199]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_199]] -// CHECK: %[[VAL_206:.*]] = getelementptr inbounds [1 x [20 x %[[VAL_1]]]], ptr addrspace(3) @shared_cache, i32 0, i32 0, i32 %lane_id -// CHECK: %[[VAL_207:.*]] = addrspacecast ptr addrspace(3) %[[VAL_206]] to ptr -// CHECK-GCN: %[[VAL_207_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_11]] to ptr -// CHECK-GCN: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_207_1]], align 1 -// CHECK-PTX: store %[[VAL_1]] %[[VAL_35]], ptr %[[VAL_11]], align 1 -// CHECK: %[[VAL_208:.*]] = icmp ult i32 %thread.id.2, 20 -// CHECK-GCN: %[[VAL_209:.*]] = select i1 %[[VAL_208]], ptr %[[VAL_207]], ptr %[[VAL_207_1]] -// CHECK-PTX: %[[VAL_209:.*]] = select i1 %[[VAL_208]], ptr %[[VAL_207]], ptr %[[VAL_11]] -// CHECK: %[[VAL_210:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} -// CHECK: %[[VAL_211:.*]] = bitcast i128 %[[VAL_210]] to <4 x i32> -// CHECK: %[[VAL_212:.*]] = extractelement <4 x i32> %[[VAL_211]], i64 0 -// CHECK-GCN: %[[VAL_213:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_212]], i32 543) -// CHECK-PTX: %[[VAL_213:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_212]], i32 16, i32 31) -// CHECK: %[[VAL_214:.*]] = insertelement <4 x i32> %[[VAL_211]], i32 %[[VAL_213]], i64 0 -// CHECK: %[[VAL_215:.*]] = extractelement <4 x i32> %[[VAL_214]], i64 1 -// CHECK-GCN: %[[VAL_216:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_215]], i32 543) -// CHECK-PTX: %[[VAL_216:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_215]], i32 16, i32 31) -// CHECK: %[[VAL_217:.*]] = insertelement <4 x i32> %[[VAL_214]], i32 %[[VAL_216]], i64 1 -// CHECK: %[[VAL_218:.*]] = extractelement <4 x i32> %[[VAL_217]], i64 2 -// CHECK-GCN: %[[VAL_219:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_218]], i32 543) -// CHECK-PTX: %[[VAL_219:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_218]], i32 16, i32 31) -// CHECK: %[[VAL_220:.*]] = insertelement <4 x i32> %[[VAL_217]], i32 %[[VAL_219]], i64 2 -// CHECK: %[[VAL_221:.*]] = extractelement <4 x i32> %[[VAL_220]], i64 3 -// CHECK-GCN: %[[VAL_222:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_221]], i32 543) -// CHECK-PTX: %[[VAL_222:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_221]], i32 16, i32 31) -// CHECK: %[[VAL_223:.*]] = insertelement <4 x i32> %[[VAL_220]], i32 %[[VAL_222]], i64 3 -// CHECK: %[[VAL_224:.*]] = bitcast <4 x i32> %[[VAL_223]] to i128 -// CHECK: store i128 %[[VAL_224]], ptr{{.*}} %[[VAL_10]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_225_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_10]] to ptr -// CHECK-GCN: %[[VAL_225_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_9]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_225_1]], ptr %[[VAL_225_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_10]], ptr %[[VAL_9]]) -// CHECK: %[[VAL_225:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_9]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_225]], ptr %[[VAL_209]], align 1 -// CHECK: %[[VAL_226:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} -// CHECK: %[[VAL_227:.*]] = bitcast i128 %[[VAL_226]] to <4 x i32> -// CHECK: %[[VAL_228:.*]] = extractelement <4 x i32> %[[VAL_227]], i64 0 -// CHECK-GCN: %[[VAL_229:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_228]], i32 287) -// CHECK-PTX: %[[VAL_229:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_228]], i32 8, i32 31) -// CHECK: %[[VAL_230:.*]] = insertelement <4 x i32> %[[VAL_227]], i32 %[[VAL_229]], i64 0 -// CHECK: %[[VAL_231:.*]] = extractelement <4 x i32> %[[VAL_230]], i64 1 -// CHECK-GCN: %[[VAL_232:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_231]], i32 287) -// CHECK-PTX: %[[VAL_232:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_231]], i32 8, i32 31) -// CHECK: %[[VAL_233:.*]] = insertelement <4 x i32> %[[VAL_230]], i32 %[[VAL_232]], i64 1 -// CHECK: %[[VAL_234:.*]] = extractelement <4 x i32> %[[VAL_233]], i64 2 -// CHECK-GCN: %[[VAL_235:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_234]], i32 287) -// CHECK-PTX: %[[VAL_235:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_234]], i32 8, i32 31) -// CHECK: %[[VAL_236:.*]] = insertelement <4 x i32> %[[VAL_233]], i32 %[[VAL_235]], i64 2 -// CHECK: %[[VAL_237:.*]] = extractelement <4 x i32> %[[VAL_236]], i64 3 -// CHECK-GCN: %[[VAL_238:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_237]], i32 287) -// CHECK-PTX: %[[VAL_238:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_237]], i32 8, i32 31) -// CHECK: %[[VAL_239:.*]] = insertelement <4 x i32> %[[VAL_236]], i32 %[[VAL_238]], i64 3 -// CHECK: %[[VAL_240:.*]] = bitcast <4 x i32> %[[VAL_239]] to i128 -// CHECK: store i128 %[[VAL_240]], ptr{{.*}} %[[VAL_8]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_241_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_8]] to ptr -// CHECK-GCN: %[[VAL_241_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_7]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_241_1]], ptr %[[VAL_241_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_8]], ptr %[[VAL_7]]) -// CHECK: %[[VAL_241:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_7]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_241]], ptr %[[VAL_209]], align 1 -// CHECK: %[[VAL_242:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} -// CHECK: %[[VAL_243:.*]] = bitcast i128 %[[VAL_242]] to <4 x i32> -// CHECK: %[[VAL_244:.*]] = extractelement <4 x i32> %[[VAL_243]], i64 0 -// CHECK-GCN: %[[VAL_245:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_244]], i32 159) -// CHECK-PTX: %[[VAL_245:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_244]], i32 4, i32 31) -// CHECK: %[[VAL_246:.*]] = insertelement <4 x i32> %[[VAL_243]], i32 %[[VAL_245]], i64 0 -// CHECK: %[[VAL_247:.*]] = extractelement <4 x i32> %[[VAL_246]], i64 1 -// CHECK-GCN: %[[VAL_248:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_247]], i32 159) -// CHECK-PTX: %[[VAL_248:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_247]], i32 4, i32 31) -// CHECK: %[[VAL_249:.*]] = insertelement <4 x i32> %[[VAL_246]], i32 %[[VAL_248]], i64 1 -// CHECK: %[[VAL_250:.*]] = extractelement <4 x i32> %[[VAL_249]], i64 2 -// CHECK-GCN: %[[VAL_251:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_250]], i32 159) -// CHECK-PTX: %[[VAL_251:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_250]], i32 4, i32 31) -// CHECK: %[[VAL_252:.*]] = insertelement <4 x i32> %[[VAL_249]], i32 %[[VAL_251]], i64 2 -// CHECK: %[[VAL_253:.*]] = extractelement <4 x i32> %[[VAL_252]], i64 3 -// CHECK-GCN: %[[VAL_254:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_253]], i32 159) -// CHECK-PTX: %[[VAL_254:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_253]], i32 4, i32 31) -// CHECK: %[[VAL_255:.*]] = insertelement <4 x i32> %[[VAL_252]], i32 %[[VAL_254]], i64 3 -// CHECK: %[[VAL_256:.*]] = bitcast <4 x i32> %[[VAL_255]] to i128 -// CHECK: store i128 %[[VAL_256]], ptr{{.*}} %[[VAL_6]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_257_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_6]] to ptr -// CHECK-GCN: %[[VAL_257_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_5]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_257_1]], ptr %[[VAL_257_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_6]], ptr %[[VAL_5]]) -// CHECK: %[[VAL_257:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_5]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_257]], ptr{{.*}} %[[VAL_209]], align 1 -// CHECK: %[[VAL_258:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} -// CHECK: %[[VAL_259:.*]] = bitcast i128 %[[VAL_258]] to <4 x i32> -// CHECK: %[[VAL_260:.*]] = extractelement <4 x i32> %[[VAL_259]], i64 0 -// CHECK-GCN: %[[VAL_261:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_260]], i32 95) -// CHECK-PTX: %[[VAL_261:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_260]], i32 2, i32 31) -// CHECK: %[[VAL_262:.*]] = insertelement <4 x i32> %[[VAL_259]], i32 %[[VAL_261]], i64 0 -// CHECK: %[[VAL_263:.*]] = extractelement <4 x i32> %[[VAL_262]], i64 1 -// CHECK-GCN: %[[VAL_264:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_263]], i32 95) -// CHECK-PTX: %[[VAL_264:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_263]], i32 2, i32 31) -// CHECK: %[[VAL_265:.*]] = insertelement <4 x i32> %[[VAL_262]], i32 %[[VAL_264]], i64 1 -// CHECK: %[[VAL_266:.*]] = extractelement <4 x i32> %[[VAL_265]], i64 2 -// CHECK-GCN: %[[VAL_267:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_266]], i32 95) -// CHECK-PTX: %[[VAL_267:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_266]], i32 2, i32 31) -// CHECK: %[[VAL_268:.*]] = insertelement <4 x i32> %[[VAL_265]], i32 %[[VAL_267]], i64 2 -// CHECK: %[[VAL_269:.*]] = extractelement <4 x i32> %[[VAL_268]], i64 3 -// CHECK-GCN: %[[VAL_270:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_269]], i32 95) -// CHECK-PTX: %[[VAL_270:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_269]], i32 2, i32 31) -// CHECK: %[[VAL_271:.*]] = insertelement <4 x i32> %[[VAL_268]], i32 %[[VAL_270]], i64 3 -// CHECK: %[[VAL_272:.*]] = bitcast <4 x i32> %[[VAL_271]] to i128 -// CHECK: store i128 %[[VAL_272]], ptr{{.*}} %[[VAL_4]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_273_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_4]] to ptr -// CHECK-GCN: %[[VAL_273_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_3]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_273_1]], ptr %[[VAL_273_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_4]], ptr %[[VAL_3]]) -// CHECK: %[[VAL_273:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_3]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_273]], ptr %[[VAL_209]], align 1 -// CHECK: %[[VAL_274:.*]] = load i128, ptr %[[VAL_209]], align {{(16|8)}} -// CHECK: %[[VAL_275:.*]] = bitcast i128 %[[VAL_274]] to <4 x i32> -// CHECK: %[[VAL_276:.*]] = extractelement <4 x i32> %[[VAL_275]], i64 0 -// CHECK-GCN: %[[VAL_277:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_276]], i32 63) -// CHECK-PTX: %[[VAL_277:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_276]], i32 1, i32 31) -// CHECK: %[[VAL_278:.*]] = insertelement <4 x i32> %[[VAL_275]], i32 %[[VAL_277]], i64 0 -// CHECK: %[[VAL_279:.*]] = extractelement <4 x i32> %[[VAL_278]], i64 1 -// CHECK-GCN: %[[VAL_280:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_279]], i32 63) -// CHECK-PTX: %[[VAL_280:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_279]], i32 1, i32 31) -// CHECK: %[[VAL_281:.*]] = insertelement <4 x i32> %[[VAL_278]], i32 %[[VAL_280]], i64 1 -// CHECK: %[[VAL_282:.*]] = extractelement <4 x i32> %[[VAL_281]], i64 2 -// CHECK-GCN: %[[VAL_283:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_282]], i32 63) -// CHECK-PTX: %[[VAL_283:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_282]], i32 1, i32 31) -// CHECK: %[[VAL_284:.*]] = insertelement <4 x i32> %[[VAL_281]], i32 %[[VAL_283]], i64 2 -// CHECK: %[[VAL_285:.*]] = extractelement <4 x i32> %[[VAL_284]], i64 3 -// CHECK-GCN: %[[VAL_286:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_285]], i32 63) -// CHECK-PTX: %[[VAL_286:.*]] = call i32 @llvm.nvvm.shfl.sync.down.i32(i32 -1, i32 %[[VAL_285]], i32 1, i32 31) -// CHECK: %[[VAL_287:.*]] = insertelement <4 x i32> %[[VAL_284]], i32 %[[VAL_286]], i64 3 -// CHECK: %[[VAL_288:.*]] = bitcast <4 x i32> %[[VAL_287]] to i128 -// CHECK: store i128 %[[VAL_288]], ptr{{.*}} %[[VAL_2]], align {{(16|8)}} -// CHECK-GCN: %[[VAL_289_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_2]] to ptr -// CHECK-GCN: %[[VAL_289_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_0]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_289_1]], ptr %[[VAL_289_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_209]], ptr %[[VAL_2]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_289:.*]] = load %[[VAL_1]], ptr{{.*}} %[[VAL_0]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_289]], ptr %[[VAL_209]], align 1 -// CHECK: %[[VAL_290:.*]] = icmp eq i32 %thread.id.2, 0 -// CHECK: br i1 %[[VAL_290]], label %[[VAL_291:.*]], label %[[VAL_202]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_291]], %[[VAL_201]] -// CHECK: br label %[[VAL_131]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_201]] -// CHECK: %[[VAL_293:.*]] = add i32 %tile_origin.1, 0 -// CHECK: %[[VAL_296:.*]] = load %[[VAL_1]], ptr %[[VAL_209]], align 1 -// CHECK: store %[[VAL_1]] %[[VAL_296]], ptr %[[VAL_297:.*]], align 1 -// CHECK: br label %[[VAL_202]] -// CHECK: entry: -// CHECK: %[[VAL_298:.*]] = alloca %[[VAL_299:.*]], align 8 -// CHECK: %[[VAL_300:.*]] = load %[[VAL_299]], ptr %[[VAL_301:.*]], align 1 -// CHECK: %[[VAL_302:.*]] = load %[[VAL_299]], ptr %[[VAL_303:.*]], align 1 -// CHECK-GCN: %[[VAL_304:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 1 -// CHECK-GCN: %[[VAL_305:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 1 -// CHECK-PTX: %[[VAL_304:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 0 -// CHECK-PTX: %[[VAL_305:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 0 -// CHECK-GCN: %[[VAL_306:.*]] = fadd double %[[VAL_305]], %[[VAL_304]] -// CHECK-PTX: %[[VAL_306:.*]] = fadd double %[[VAL_304]], %[[VAL_305]] -// CHECK-GCN: %[[VAL_307:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 0 -// CHECK-GCN: %[[VAL_308:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 0 -// CHECK-PTX: %[[VAL_307:.*]] = extractvalue %[[VAL_299]] %[[VAL_300]], 1 -// CHECK-PTX: %[[VAL_308:.*]] = extractvalue %[[VAL_299]] %[[VAL_302]], 1 -// CHECK-GCN: %[[VAL_309:.*]] = fadd double %[[VAL_308]], %[[VAL_307]] -// CHECK-PTX: %[[VAL_309:.*]] = fadd double %[[VAL_307]], %[[VAL_308]] -// CHECK-GCN: %[[VAL_310:.*]] = insertvalue %[[VAL_299]] zeroinitializer, double %[[VAL_309]], 0 -// CHECK-GCN: %[[VAL_311:.*]] = insertvalue %[[VAL_299]] %[[VAL_310]], double %[[VAL_306]], 1 -// CHECK-PTX: %[[VAL_310:.*]] = insertvalue %[[VAL_299]] zeroinitializer, double %[[VAL_306]], 0 -// CHECK-PTX: %[[VAL_311:.*]] = insertvalue %[[VAL_299]] %[[VAL_310]], double %[[VAL_309]], 1 -// CHECK: store %[[VAL_299]] %[[VAL_311]], ptr{{.*}} %[[VAL_298]], align 1 -// CHECK: %[[VAL_312:.*]] = load %[[VAL_299]], ptr{{.*}} %[[VAL_298]], align 1 -// CHECK: store %[[VAL_299]] %[[VAL_312]], ptr %[[VAL_313:.*]], align 1 -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo b/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo deleted file mode 100644 index ba3c28957c1138..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_row_vectorized.hlo +++ /dev/null @@ -1,419 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -HloModule RowReductionVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[131072,1024] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1024] parameter(0) - ROOT fusion_row_reduction_vectorized = f32[131072] fusion( - f32[131072,1024] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca float, align 4 -// CHECK: %[[VAL_1:.*]] = alloca float, align 4 -// CHECK: %[[VAL_2:.*]] = alloca float, align 4 -// CHECK: %[[VAL_3:.*]] = alloca float, align 4 -// CHECK: %[[VAL_4:.*]] = alloca float, align 4 -// CHECK: %[[VAL_5:.*]] = alloca float, align 4 -// CHECK: %[[VAL_6:.*]] = alloca float, align 4 -// CHECK: %[[VAL_7:.*]] = alloca float, align 4 -// CHECK: %[[VAL_8:.*]] = alloca float, align 4 -// CHECK: %[[VAL_9:.*]] = alloca float, align 4 -// CHECK: %[[VAL_10:.*]] = alloca float, align 4 -// CHECK: %[[VAL_11:.*]] = alloca float, align 4 -// CHECK: %[[VAL_12:.*]] = alloca float, align 4 -// CHECK: %[[VAL_13:.*]] = alloca float, align 4 -// CHECK: %[[VAL_14:.*]] = alloca float, align 4 -// CHECK: %[[VAL_15:.*]] = alloca float, align 4 -// CHECK: %[[VAL_16:.*]] = alloca float, align 4 -// CHECK: %[[VAL_17:.*]] = alloca float, align 4 -// CHECK: %[[VAL_18:.*]] = alloca float, align 4 -// CHECK: %[[VAL_19:.*]] = alloca float, align 4 -// CHECK: %[[VAL_20:.*]] = alloca float, align 4 -// CHECK: %[[VAL_21:.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_22:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_23:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_24:.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_25:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_26:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_27:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_28:.*]] = alloca float, align 4 -// CHECK: %[[VAL_29:.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_30:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_30:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0 -// CHECK: br i1 %[[VAL_31]], label %[[VAL_32:.*]], label %[[VAL_33:.*]] -// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_34:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_34]] -// CHECK: %[[VAL_35:.*]] = load float, ptr @0, align 4 -// CHECK: store float %[[VAL_35]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !4 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_36:.*]] = udiv i32 %thread.id.x, 64 -// CHECK: %thread.id.1 = urem i32 %[[VAL_36]], 4 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 64 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_37:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_38:.*]] = urem i32 %[[VAL_37]], 1 -// CHECK-PTX: %[[VAL_39:.*]] = udiv i32 %block.id.x, 1 -// CHECK-PTX: %[[VAL_40:.*]] = urem i32 %[[VAL_39]], 1 -// CHECK: %[[VAL_41:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_42:.*]] = urem i32 %[[VAL_41]], 32768 -// CHECK: %[[VAL_43:.*]] = udiv i32 %block.id.x, 32768 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_43]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_42]], 4 -// CHECK-PTX: %tile_origin.2 = mul i32 %[[VAL_40]], 512 -// CHECK-GCN: %tile_origin.2 = mul i32 %[[VAL_38]], 1024 -// CHECK-PTX: %tile_origin.3 = mul i32 %[[VAL_38]], 2 -// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_27]], align 4 -// CHECK: br label %[[VAL_44:.*]] - -// CHECK: loop1.loop_header: ; preds = %[[VAL_45:.*]], %[[VAL_32]] -// CHECK: %[[VAL_46:.*]] = load i32, ptr{{.*}} %[[VAL_27]], align 4 -// CHECK: %[[VAL_47:.*]] = icmp uge i32 %[[VAL_46]], 4 -// CHECK: br i1 %[[VAL_47]], label %[[VAL_48:.*]], label %[[VAL_49:.*]] - -// CHECK: loop1.loop_body: ; preds = %[[VAL_44]] -// CHECK: %[[VAL_50:.*]] = add nuw nsw i32 %[[VAL_46]], 4 -// CHECK: store i32 %[[VAL_50]], ptr{{.*}} %[[VAL_27]], align 4 -// CHECK: br i1 true, label %[[VAL_52:.*]], label %[[VAL_53:.*]] - -// CHECK: is_full_tile-after: ; preds = %[[VAL_54:.*]], %[[VAL_55:.*]] -// CHECK: br label %[[VAL_44]], !llvm.loop !{{(5|4)}} - -// CHECK: loop1.loop_exit: ; preds = %[[VAL_44]] -// CHECK: %[[VAL_56:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-GCN: %[[VAL_57_1:.*]] = bitcast float %[[VAL_56]] to i32 -// CHECK-GCN: %[[VAL_57_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_57_1]], i32 543) -// CHECK-GCN: %[[VAL_57:.*]] = bitcast i32 %[[VAL_57_2]] to float -// CHECK-PTX: %[[VAL_57:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_56]], i32 16, i32 31) -// CHECK: store float %[[VAL_57]], ptr{{.*}} %[[VAL_20]], align 4 -// CHECK-GCN: %[[VAL_58_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_58_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_20]] to ptr -// CHECK-GCN: %[[VAL_58_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_19]] to ptr -// CHECK-GCN: call void @[[SUM:Sum.*]](ptr %[[VAL_58_1]], ptr %[[VAL_58_2]], ptr %[[VAL_58_3]]) -// CHECK-PTX: call void @[[SUM:Sum.*]](ptr %[[VAL_28]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK: %[[VAL_58:.*]] = load float, ptr{{.*}} %[[VAL_19]], align 4 -// CHECK: store float %[[VAL_58]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-GCN: %[[VAL_60_1:.*]] = bitcast float %[[VAL_59]] to i32 -// CHECK-GCN: %[[VAL_60_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_60_1]], i32 287) -// CHECK-GCN: %[[VAL_60:.*]] = bitcast i32 %[[VAL_60_2]] to float -// CHECK-PTX: %[[VAL_60:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_59]], i32 8, i32 31) -// CHECK: store float %[[VAL_60]], ptr{{.*}} %[[VAL_18]], align 4 -// CHECK-GCN: %[[VAL_61_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_61_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_18]] to ptr -// CHECK-GCN: %[[VAL_61_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_17]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_61_1]], ptr %[[VAL_61_2]], ptr %[[VAL_61_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_18]], ptr %[[VAL_17]]) -// CHECK: %[[VAL_61:.*]] = load float, ptr{{.*}} %[[VAL_17]], align 4 -// CHECK: store float %[[VAL_61]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_62:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-GCN: %[[VAL_63_1:.*]] = bitcast float %[[VAL_62]] to i32 -// CHECK-GCN: %[[VAL_63_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_63_1]], i32 159) -// CHECK-GCN: %[[VAL_63:.*]] = bitcast i32 %[[VAL_63_2]] to float -// CHECK-PTX: %[[VAL_63:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_62]], i32 4, i32 31) -// CHECK: store float %[[VAL_63]], ptr{{.*}} %[[VAL_16]], align 4 -// CHECK-GCN: %[[VAL_64_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_64_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_16]] to ptr -// CHECK-GCN: %[[VAL_64_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_15]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_64_1]], ptr %[[VAL_64_2]], ptr %[[VAL_64_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_16]], ptr %[[VAL_15]]) -// CHECK: %[[VAL_64:.*]] = load float, ptr{{.*}} %[[VAL_15]], align 4 -// CHECK: store float %[[VAL_64]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_65:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-GCN: %[[VAL_66_1:.*]] = bitcast float %[[VAL_65]] to i32 -// CHECK-GCN: %[[VAL_66_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_66_1]], i32 95) -// CHECK-GCN: %[[VAL_66:.*]] = bitcast i32 %[[VAL_66_2]] to float -// CHECK-PTX: %[[VAL_66:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_65]], i32 2, i32 31) -// CHECK: store float %[[VAL_66]], ptr{{.*}} %[[VAL_14]], align 4 -// CHECK-GCN: %[[VAL_67_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_67_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_14]] to ptr -// CHECK-GCN: %[[VAL_67_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_13]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_67_1]], ptr %[[VAL_67_2]], ptr %[[VAL_67_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_14]], ptr %[[VAL_13]]) -// CHECK: %[[VAL_67:.*]] = load float, ptr{{.*}} %[[VAL_13]], align 4 -// CHECK: store float %[[VAL_67]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_68:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-GCN: %[[VAL_69_1:.*]] = bitcast float %[[VAL_68]] to i32 -// CHECK-GCN: %[[VAL_69_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_69_1]], i32 63) -// CHECK-GCN: %[[VAL_69:.*]] = bitcast i32 %[[VAL_69_2]] to float -// CHECK-PTX: %[[VAL_69:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_68]], i32 1, i32 31) -// CHECK: store float %[[VAL_69]], ptr{{.*}} %[[VAL_12]], align 4 -// CHECK-GCN: %[[VAL_70_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_70_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_12]] to ptr -// CHECK-GCN: %[[VAL_70_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_11]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_70_1]], ptr %[[VAL_70_2]], ptr %[[VAL_70_3]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_12]], ptr %[[VAL_11]]) -// CHECK: %[[VAL_70:.*]] = load float, ptr{{.*}} %[[VAL_11]], align 4 -// CHECK: store float %[[VAL_70]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_71:.*]] = udiv i32 %thread.id.2, 32 -// CHECK: %[[VAL_72:.*]] = icmp ult i32 %thread.id.1, 4 -// CHECK: br i1 %[[VAL_72]], label %thread_in_bounds-true, label %thread_in_bounds-after - -// CHECK: thread_in_bounds-after: ; preds = %[[VAL_73:.*]], %[[VAL_48]] -// CHECK: br label %[[VAL_33]] - -// CHECK: is_full_tile-true: ; preds = %[[VAL_49]] -// CHECK: store i32 0, ptr{{.*}} %[[VAL_26]], align 4 -// CHECK: br label %[[VAL_74:.*]] - -// CHECK: loop2.loop_header: ; preds = %[[VAL_75:.*]], %[[VAL_52]] -// CHECK: %[[VAL_76:.*]] = load i32, ptr{{.*}} %[[VAL_26]], align 4 -// CHECK-PTX: %[[VAL_77:.*]] = icmp uge i32 %[[VAL_76]], 512 -// CHECK-GCN: %[[VAL_77:.*]] = icmp uge i32 %[[VAL_76]], 1024 -// CHECK: br i1 %[[VAL_77]], label %[[VAL_55]], label %[[VAL_78:.*]] - -// CHECK: loop2.loop_body: ; preds = %[[VAL_74]] -// CHECK: %[[VAL_79:.*]] = add nuw nsw i32 %[[VAL_76]], 64 -// CHECK: store i32 %[[VAL_79]], ptr{{.*}} %[[VAL_26]], align 4 -// CHECK: %[[VAL_81:.*]] = add i32 %[[VAL_76]], %thread.id.2 -// CHECK-GCN: %[[VAL_88:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[VAL_89:.*]] = add i32 %tile_origin.1, %[[VAL_46]] -// CHECK-GCN: %[[VAL_90:.*]] = add i32 %tile_origin.2, %[[VAL_81]] -// CHECK-GCN: %[[VAL_102:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103:.*]], i32 0, i32 %[[VAL_89]], i32 %[[VAL_90]] -// CHECK-GCN: %[[VAL_104:.*]] = load float, ptr %[[VAL_102]], align 4, !invariant.load !6 -// CHECK-GCN: store float %[[VAL_104]], ptr{{.*}} %[[VAL_29]], align 4 -// CHECK-GCN: %[[VAL_105_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_105_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_29]] to ptr -// CHECK-GCN: %[[VAL_105_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_24]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_105_1]], ptr %[[VAL_105_2]], ptr %[[VAL_105_3]]) -// CHECK-GCN: %[[VAL_105:.*]] = load float, ptr{{.*}} %[[VAL_24]], align 4 -// CHECK-GCN: store float %[[VAL_105]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-PTX: store i32 0, ptr %[[VAL_25]], align 4 -// CHECK: br label %[[VAL_82:.*]] - -// CHECK-PTX: loop3.loop_header: ; preds = %[[VAL_83:.*]], %[[VAL_78]] -// CHECK-PTX: %[[VAL_84:.*]] = load i32, ptr %[[VAL_25]], align 4 -// CHECK-PTX: %[[VAL_85:.*]] = icmp uge i32 %[[VAL_84]], 2 -// CHECK-PTX: br i1 %[[VAL_85]], label %[[VAL_75]], label %[[VAL_83]] - -// CHECK-PTX: loop3.loop_body: ; preds = %[[VAL_82]] -// CHECK-PTX: %[[VAL_86:.*]] = add nuw nsw i32 %[[VAL_84]], 1 -// CHECK-PTX: store i32 %[[VAL_86]], ptr %[[VAL_25]], align 4 -// CHECK-PTX: %[[VAL_88:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[VAL_89:.*]] = add i32 %tile_origin.1, %[[VAL_46]] -// CHECK-PTX: %[[VAL_90:.*]] = add i32 %tile_origin.2, %[[VAL_81]] -// CHECK-PTX: %[[VAL_91:.*]] = add i32 %tile_origin.3, %[[VAL_84]] -// CHECK-PTX: %[[VAL_92:.*]] = mul nuw nsw i32 %[[VAL_91]], 1 -// CHECK-PTX: %[[VAL_93:.*]] = add nuw nsw i32 0, %[[VAL_92]] -// CHECK-PTX: %[[VAL_94:.*]] = mul nuw nsw i32 %[[VAL_90]], 2 -// CHECK-PTX: %[[VAL_95:.*]] = add nuw nsw i32 %[[VAL_93]], %[[VAL_94]] -// CHECK-PTX: %[[VAL_96:.*]] = udiv i32 %[[VAL_95]], 1024 -// CHECK-PTX: %[[VAL_97:.*]] = mul nuw nsw i32 %[[VAL_89]], 1 -// CHECK-PTX: %[[VAL_98:.*]] = add nuw nsw i32 0, %[[VAL_97]] -// CHECK-PTX: %[[VAL_99:.*]] = udiv i32 %[[VAL_98]], 131072 -// CHECK-PTX: %[[VAL_100:.*]] = mul nuw nsw i32 %[[VAL_88]], 1 -// CHECK-PTX: %[[VAL_101:.*]] = add nuw nsw i32 0, %[[VAL_100]] -// CHECK-PTX: %[[VAL_102:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103:.*]], i32 0, i32 %[[VAL_98]], i32 %[[VAL_95]] -// CHECK-PTX: %[[VAL_104:.*]] = load float, ptr %[[VAL_102]], align 4, !invariant.load !7 -// CHECK-PTX: store float %[[VAL_104]], ptr %[[VAL_29]], align 4 -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_24]]) -// CHECK-PTX: %[[VAL_105:.*]] = load float, ptr %[[VAL_24]], align 4 -// CHECK-PTX: store float %[[VAL_105]], ptr %[[VAL_28]], align 4 -// CHECK-PTX: br label %[[VAL_82]], !llvm.loop !8 - -// CHECK-PTX: loop3.loop_exit: ; preds = %[[VAL_82]] -// CHECK-PTX: br label %[[VAL_74]], !llvm.loop !9 - -// CHECK: loop2.loop_exit: ; preds = %[[VAL_74]] -// CHECK: br label %[[VAL_45]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_49]] -// CHECK: store i32 0, ptr{{.*}} %[[VAL_23]], align 4 -// CHECK: br label %[[VAL_106:.*]] - -// CHECK: loop2.loop_header{{(5|4)}}: ; preds = %[[VAL_107:.*]], %[[VAL_53]] -// CHECK: %[[VAL_108:.*]] = load i32, ptr{{.*}} %[[VAL_23]], align 4 -// CHECK-PTX: %[[VAL_109:.*]] = icmp uge i32 %[[VAL_108]], 512 -// CHECK-GCN: %[[VAL_109:.*]] = icmp uge i32 %[[VAL_108]], 1024 -// CHECK: br i1 %[[VAL_109]], label %[[VAL_54]], label %[[VAL_110:.*]] - -// CHECK: loop2.loop_body{{(6|5)}}: ; preds = %[[VAL_106]] -// CHECK: %[[VAL_111:.*]] = add nuw nsw i32 %[[VAL_108]], 64 -// CHECK: store i32 %[[VAL_111]], ptr{{.*}} %[[VAL_23]], align 4 -// CHECK: %[[VAL_113:.*]] = add i32 %[[VAL_108]], %thread.id.2 -// CHECK-PTX: %[[VAL_114:.*]] = icmp ult i32 %[[VAL_113]], 512 -// CHECK-GCN: %[[VAL_114:.*]] = icmp ult i32 %[[VAL_113]], 1024 -// CHECK: br i1 %[[VAL_114]], label %[[VAL_115:.*]], label %[[VAL_107]] - -// CHECK: x_in_tile-after: ; preds = %[[VAL_116:.*]], %[[VAL_110]] -// CHECK: br label %[[VAL_106]], !llvm.loop !{{(11|9)}} - -// CHECK: loop2.loop_exit{{(4|3)}}: ; preds = %[[VAL_106]] -// CHECK: br label %[[VAL_45]] - -// CHECK: x_in_tile-true: ; preds = %[[VAL_110]] -// CHECK-GCN: %[[VAL_123:.*]] = add i32 %tile_origin.0, 0 -// CHECK-GCN: %[[VAL_124:.*]] = add i32 %tile_origin.1, %[[VAL_46]] -// CHECK-GCN: %[[VAL_125:.*]] = add i32 %tile_origin.2, %[[VAL_113]] -// CHECK-GCN: %[[VAL_137:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103]], i32 0, i32 %[[VAL_124]], i32 %[[VAL_125]] -// CHECK-GCN: %[[VAL_138:.*]] = load float, ptr %[[VAL_137]], align 4, !invariant.load !6 -// CHECK-GCN: store float %[[VAL_138]], ptr{{.*}} %[[VAL_29]], align 4 -// CHECK-GCN: %[[VAL_139_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_28]] to ptr -// CHECK-GCN: %[[VAL_139_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_29]] to ptr -// CHECK-GCN: %[[VAL_139_3:.*]] = addrspacecast ptr addrspace(5) %[[VAL_21]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_139_1]], ptr %[[VAL_139_2]], ptr %[[VAL_139_3]]) -// CHECK-GCN: %[[VAL_139:.*]] = load float, ptr{{.*}} %[[VAL_21]], align 4 -// CHECK-GCN: store float %[[VAL_139]], ptr{{.*}} %[[VAL_28]], align 4 -// CHECK-PTX: store i32 0, ptr %[[VAL_22]], align 4 -// CHECK: br label %[[VAL_117:.*]] - -// CHECK-PTX: loop3.loop_header11: ; preds = %[[VAL_118:.*]], %[[VAL_115]] -// CHECK-PTX: %[[VAL_119:.*]] = load i32, ptr %[[VAL_22]], align 4 -// CHECK-PTX: %[[VAL_120:.*]] = icmp uge i32 %[[VAL_119]], 2 -// CHECK-PTX: br i1 %[[VAL_120]], label %[[VAL_116]], label %[[VAL_118]] - -// CHECK-PTX: loop3.loop_body12: ; preds = %[[VAL_117]] -// CHECK-PTX: %[[VAL_121:.*]] = add nuw nsw i32 %[[VAL_119]], 1 -// CHECK-PTX: store i32 %[[VAL_121]], ptr %[[VAL_22]], align 4 -// CHECK-PTX: %[[VAL_123:.*]] = add i32 %tile_origin.0, 0 -// CHECK-PTX: %[[VAL_124:.*]] = add i32 %tile_origin.1, %[[VAL_46]] -// CHECK-PTX: %[[VAL_125:.*]] = add i32 %tile_origin.2, %[[VAL_113]] -// CHECK-PTX: %[[VAL_126:.*]] = add i32 %tile_origin.3, %[[VAL_119]] -// CHECK-PTX: %[[VAL_127:.*]] = mul nuw nsw i32 %[[VAL_126]], 1 -// CHECK-PTX: %[[VAL_128:.*]] = add nuw nsw i32 0, %[[VAL_127]] -// CHECK-PTX: %[[VAL_129:.*]] = mul nuw nsw i32 %[[VAL_125]], 2 -// CHECK-PTX: %[[VAL_130:.*]] = add nuw nsw i32 %[[VAL_128]], %[[VAL_129]] -// CHECK-PTX: %[[VAL_131:.*]] = udiv i32 %[[VAL_130]], 1024 -// CHECK-PTX: %[[VAL_132:.*]] = mul nuw nsw i32 %[[VAL_124]], 1 -// CHECK-PTX: %[[VAL_133:.*]] = add nuw nsw i32 0, %[[VAL_132]] -// CHECK-PTX: %[[VAL_134:.*]] = udiv i32 %[[VAL_133]], 131072 -// CHECK-PTX: %[[VAL_135:.*]] = mul nuw nsw i32 %[[VAL_123]], 1 -// CHECK-PTX: %[[VAL_136:.*]] = add nuw nsw i32 0, %[[VAL_135]] -// CHECK-PTX: %[[VAL_137:.*]] = getelementptr inbounds [131072 x [1024 x float]], ptr %[[VAL_103]], i32 0, i32 %[[VAL_133]], i32 %[[VAL_130]] -// CHECK-PTX: %[[VAL_138:.*]] = load float, ptr %[[VAL_137]], align 4, !invariant.load !7 -// CHECK-PTX: store float %[[VAL_138]], ptr %[[VAL_29]], align 4 -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_28]], ptr %[[VAL_29]], ptr %[[VAL_21]]) -// CHECK-PTX: %[[VAL_139:.*]] = load float, ptr %[[VAL_21]], align 4 -// CHECK-PTX: store float %[[VAL_139]], ptr %[[VAL_28]], align 4 -// CHECK-PTX: br label %[[VAL_117]], !llvm.loop !12 - -// CHECK-PTX: loop3.loop_exit10: ; preds = %[[VAL_117]] -// CHECK-PTX: br label %[[VAL_107]] - -// CHECK: thread_in_bounds-true: ; preds = %[[VAL_48]] -// CHECK: %[[VAL_140:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: br i1 %[[VAL_140]], label %[[VAL_141:.*]], label %[[VAL_142:.*]] - -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_141]], %thread_in_bounds-true -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: fence syncscope("workgroup") seq_cst -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: %[[VAL_143:.*]] = icmp eq i32 %[[VAL_71]], 0 -// CHECK: br i1 %[[VAL_143]], label %[[VAL_144:.*]], label %[[VAL_73]] - -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_145:.*]], %[[VAL_142]] -// CHECK: br label %thread_in_bounds-after - -// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true -// CHECK: %[[VAL_146:.*]] = load float, ptr{{.*}} %[[VAL_28]], align 4 -// CHECK: %[[VAL_147:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_71]] -// CHECK: %[[VAL_148:.*]] = addrspacecast ptr addrspace(3) %[[VAL_147]] to ptr -// CHECK: store float %[[VAL_146]], ptr %[[VAL_148]], align 4 -// CHECK: br label %[[VAL_142]] - -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_142]] -// CHECK: %[[VAL_149:.*]] = getelementptr inbounds [4 x [2 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id -// CHECK: %[[VAL_150:.*]] = addrspacecast ptr addrspace(3) %[[VAL_149]] to ptr -// CHECK-GCN: %[[VAL_150_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_10]] to ptr -// CHECK-GCN: store float %[[VAL_35]], ptr %[[VAL_150_1]], align 4 -// CHECK-PTX: store float %[[VAL_35]], ptr %[[VAL_10]], align 4 -// CHECK: %[[VAL_151:.*]] = icmp ult i32 %thread.id.2, 2 -// CHECK-GCN: %[[VAL_152:.*]] = select i1 %[[VAL_151]], ptr %[[VAL_150]], ptr %[[VAL_150_1]] -// CHECK-PTX: %[[VAL_152:.*]] = select i1 %[[VAL_151]], ptr %[[VAL_150]], ptr %[[VAL_10]] -// CHECK: %[[VAL_153:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK-GCN: %[[VAL_154_1:.*]] = bitcast float %[[VAL_153]] to i32 -// CHECK-GCN: %[[VAL_154_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_154_1]], i32 543) -// CHECK-GCN: %[[VAL_154:.*]] = bitcast i32 %[[VAL_154_2]] to float -// CHECK-PTX: %[[VAL_154:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_153]], i32 16, i32 31) -// CHECK: store float %[[VAL_154]], ptr{{.*}} %[[VAL_9]], align 4 -// CHECK-GCN: %[[VAL_155_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_9]] to ptr -// CHECK-GCN: %[[VAL_155_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_8]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_155_1]], ptr %[[VAL_155_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_9]], ptr %[[VAL_8]]) -// CHECK: %[[VAL_155:.*]] = load float, ptr{{.*}} %[[VAL_8]], align 4 -// CHECK: store float %[[VAL_155]], ptr %[[VAL_152]], align 4 -// CHECK: %[[VAL_156:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK-GCN: %[[VAL_157_1:.*]] = bitcast float %[[VAL_156]] to i32 -// CHECK-GCN: %[[VAL_157_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_157_1]], i32 287) -// CHECK-GCN: %[[VAL_157:.*]] = bitcast i32 %[[VAL_157_2]] to float -// CHECK-PTX: %[[VAL_157:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_156]], i32 8, i32 31) -// CHECK: store float %[[VAL_157]], ptr{{.*}} %[[VAL_7]], align 4 -// CHECK-GCN: %[[VAL_158_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_7]] to ptr -// CHECK-GCN: %[[VAL_158_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_6]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_158_1]], ptr %[[VAL_158_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_7]], ptr %[[VAL_6]]) -// CHECK: %[[VAL_158:.*]] = load float, ptr{{.*}} %[[VAL_6]], align 4 -// CHECK: store float %[[VAL_158]], ptr %[[VAL_152]], align 4 -// CHECK: %[[VAL_159:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK-GCN: %[[VAL_160_1:.*]] = bitcast float %[[VAL_159]] to i32 -// CHECK-GCN: %[[VAL_160_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_160_1]], i32 159) -// CHECK-GCN: %[[VAL_160:.*]] = bitcast i32 %[[VAL_160_2]] to float -// CHECK-PTX: %[[VAL_160:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_159]], i32 4, i32 31) -// CHECK: store float %[[VAL_160]], ptr{{.*}} %[[VAL_5]], align 4 -// CHECK-GCN: %[[VAL_161_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_5]] to ptr -// CHECK-GCN: %[[VAL_161_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_4]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_161_1]], ptr %[[VAL_161_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK: %[[VAL_161:.*]] = load float, ptr{{.*}} %[[VAL_4]], align 4 -// CHECK: store float %[[VAL_161]], ptr %[[VAL_152]], align 4 -// CHECK: %[[VAL_162:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK-GCN: %[[VAL_163_1:.*]] = bitcast float %[[VAL_162]] to i32 -// CHECK-GCN: %[[VAL_163_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_163_1]], i32 95) -// CHECK-GCN: %[[VAL_163:.*]] = bitcast i32 %[[VAL_163_2]] to float -// CHECK-PTX: %[[VAL_163:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_162]], i32 2, i32 31) -// CHECK: store float %[[VAL_163]], ptr{{.*}} %[[VAL_3]], align 4 -// CHECK-GCN: %[[VAL_164_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_3]] to ptr -// CHECK-GCN: %[[VAL_164_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_2]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_164_1]], ptr %[[VAL_164_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_3]], ptr %[[VAL_2]]) -// CHECK: %[[VAL_164:.*]] = load float, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: store float %[[VAL_164]], ptr %[[VAL_152]], align 4 -// CHECK: %[[VAL_165:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK-GCN: %[[VAL_166_1:.*]] = bitcast float %[[VAL_165]] to i32 -// CHECK-GCN: %[[VAL_166_2:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_166_1]], i32 63) -// CHECK-GCN: %[[VAL_166:.*]] = bitcast i32 %[[VAL_166_2]] to float -// CHECK-PTX: %[[VAL_166:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_165]], i32 1, i32 31) -// CHECK: store float %[[VAL_166]], ptr{{.*}} %[[VAL_1]], align 4 -// CHECK-GCN: %[[VAL_167_1:.*]] = addrspacecast ptr addrspace(5) %[[VAL_1]] to ptr -// CHECK-GCN: %[[VAL_167_2:.*]] = addrspacecast ptr addrspace(5) %[[VAL_0]] to ptr -// CHECK-GCN: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_167_1]], ptr %[[VAL_167_2]]) -// CHECK-PTX: call void @[[SUM]](ptr %[[VAL_152]], ptr %[[VAL_1]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_167:.*]] = load float, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: store float %[[VAL_167]], ptr %[[VAL_152]], align 4 -// CHECK: %[[VAL_168:.*]] = icmp eq i32 %thread.id.2, 0 -// CHECK: br i1 %[[VAL_168]], label %[[VAL_169:.*]], label %[[VAL_145]] - -// CHECK: reduction_write_output-after: ; preds = %[[VAL_169]], %[[VAL_144]] -// CHECK: br label %[[VAL_73]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_144]] -// CHECK: %[[VAL_171:.*]] = add i32 %tile_origin.1, %thread.id.1 -// CHECK: %[[VAL_175:.*]] = getelementptr inbounds [131072 x float], ptr %[[VAL_176:.*]], i32 0, i32 %[[VAL_171]] -// CHECK: %[[VAL_177:.*]] = load float, ptr %[[VAL_152]], align 4 -// CHECK: store float %[[VAL_177]], ptr %[[VAL_175]], align 4 -// CHECK: br label %[[VAL_145]] -// CHECK: entry: -// CHECK: %[[VAL_178:.*]] = alloca float, align 4 -// CHECK: %[[VAL_179:.*]] = load float, ptr %[[VAL_180:.*]], align 4 -// CHECK: %[[VAL_181:.*]] = load float, ptr %[[VAL_182:.*]], align 4 -// CHECK: %[[VAL_183:.*]] = fadd float %[[VAL_179]], %[[VAL_181]] -// CHECK: store float %[[VAL_183]], ptr{{.*}} %[[VAL_178]], align 4 -// CHECK: %[[VAL_184:.*]] = load float, ptr{{.*}} %[[VAL_178]], align 4 -// CHECK: store float %[[VAL_184]], ptr %[[VAL_185:.*]], align 4 -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/reduce_to_scalar_vectorized.hlo b/third_party/xla/xla/service/gpu/tests/reduce_to_scalar_vectorized.hlo deleted file mode 100644 index 7fcf676da8f94a..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_to_scalar_vectorized.hlo +++ /dev/null @@ -1,28 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=3 --stage=llvm-after-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK - -HloModule ReductionToScalarVectorized, is_scheduled=true - -region_0.7 { - Arg_1.9 = pred[] parameter(1) - Arg_0.8 = pred[] parameter(0) - ROOT and.1 = pred[] and(Arg_0.8, Arg_1.9) -} - -fused_reduce { - param_1.8 = s8[2,3,4]{2,1,0} parameter(1) - convert.2.3 = s16[2,3,4]{2,1,0} convert(param_1.8) - param_0.5 = u8[2,3,4]{2,1,0} parameter(0) - convert.3.3 = s16[2,3,4]{2,1,0} convert(param_0.5) - compare.1.3 = pred[2,3,4]{2,1,0} compare(convert.2.3, convert.3.3), direction=EQ - bitcast.26.1 = pred[24]{0} bitcast(compare.1.3) - constant_3_1 = pred[] constant(true) - ROOT reduce.11.1 = pred[] reduce(bitcast.26.1, constant_3_1), dimensions={0}, to_apply=region_0.7 -} // fused_reduce - -ENTRY main.12 { - Arg_1.2.0 = u8[2,3,4]{2,1,0} parameter(1) - Arg_0.1.0 = s8[2,3,4]{2,1,0} parameter(0) - ROOT loop_reduce_fusion = pred[] fusion(Arg_1.2.0, Arg_0.1.0), kind=kLoop, calls=fused_reduce -} - -// CHECK: load <16 x i8>, ptr addrspace(1) %{{.*}}, align 16, !invariant.load !4 diff --git a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo b/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo deleted file mode 100644 index c662094fecb9cb..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_unnested.hlo +++ /dev/null @@ -1,82 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s - -// CHECK: define void @fusion_row_reduction_too_small( -// CHECK-COUNT-1: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionNotVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[131072,512] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,512] parameter(0) - ROOT fusion_row_reduction_too_small = f32[131072] fusion( - f32[131072,512] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} - -// ----- - -// CHECK: define void @fusion_row_reduction_odd_dimx( -// CHECK-COUNT-1: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionNotVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_vectorized { - a = f32[131072,1025] parameter(0) - init = f32[] constant(0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1025] parameter(0) - ROOT fusion_row_reduction_odd_dimx = f32[131072] fusion( - f32[131072,1025] parameter0 - ), kind=kLoop, calls=fusion_vectorized -} - - -// ----- - -// CHECK: define void @fusion_row_reduction_sin_prevents_vectorization( -// CHECK-COUNT-1: {{^x_in_tile-true}} -// CHECK-NOT: {{^x_in_tile-true}} - -HloModule RowReductionNotVectorized, is_scheduled=true - -Sum { - x.1 = f32[] parameter(0) - y.1 = f32[] parameter(1) - ROOT add.1 = f32[] add(x.1, y.1) -} - -fusion_not_vectorized { - a = f32[131072,1024] parameter(0) - c0 = f32[] constant(0) - init = f32[] sine(c0) - ROOT reduce = f32[131072] reduce(a, init), dimensions={1}, to_apply=Sum -} - -ENTRY reduce.1 { - parameter0 = f32[131072,1024] parameter(0) - ROOT fusion_row_reduction_sin_prevents_vectorization = f32[131072] fusion( - f32[131072,1024] parameter0 - ), kind=kLoop, calls=fusion_not_vectorized -} diff --git a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo b/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo deleted file mode 100644 index 26eace7068ef2d..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduce_variadic_column.hlo +++ /dev/null @@ -1,460 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck %s --check-prefixes=CHECK,CHECK-%{PTX} - -HloModule Test, is_scheduled=true - -Add { - scalar_lhs.0 = f32[] parameter(0) - scalar_rhs.0 = f32[] parameter(1) - scalar_lhs.1 = f32[] parameter(2) - scalar_rhs.1 = f32[] parameter(3) - add.0 = f32[] add(scalar_lhs.0, scalar_lhs.1) - add.1 = f32[] add(scalar_rhs.0, scalar_rhs.1) - ROOT t = (f32[], f32[]) tuple(add.0, add.1) -} - -fused_computation { - param_0 = f32[5,200,300]{2,1,0} parameter(0) - param_1 = f32[5,200,300]{2,1,0} parameter(1) - param_2 = f32[] parameter(2) - ROOT d.1 = (f32[200]{0}, f32[200]{0}) reduce(f32[5,200,300]{2,1,0} param_0, f32[5,200,300]{2,1,0} %param_1, f32[] param_2, f32[] param_2), dimensions={0,2}, to_apply=Add -} - -ENTRY main { - a = f32[5, 200, 300]{2,1,0} parameter(0) - b = f32[5, 200, 300]{2,1,0} parameter(1) - c = f32[] constant(0) - ROOT wrapped_d = (f32[200]{0}, f32[200]{0}) fusion(f32[5,200,300]{2,1,0} a, f32[5,200,300]{2,1,0} b, f32[] c), kind=kInput, calls=fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca float, align 4 -// CHECK: %[[VAL_1:.*]] = alloca float, align 4 -// CHECK: %[[VAL_2:.*]] = alloca float, align 4 -// CHECK: %[[VAL_3:.*]] = alloca float, align 4 -// CHECK: %[[VAL_4:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_5:.*]] = alloca float, align 4 -// CHECK: %[[VAL_6:.*]] = alloca float, align 4 -// CHECK: %[[VAL_7:.*]] = alloca float, align 4 -// CHECK: %[[VAL_8:.*]] = alloca float, align 4 -// CHECK: %[[VAL_9:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_10:.*]] = alloca float, align 4 -// CHECK: %[[VAL_11:.*]] = alloca float, align 4 -// CHECK: %[[VAL_12:.*]] = alloca float, align 4 -// CHECK: %[[VAL_13:.*]] = alloca float, align 4 -// CHECK: %[[VAL_14:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_15:.*]] = alloca float, align 4 -// CHECK: %[[VAL_16:.*]] = alloca float, align 4 -// CHECK: %[[VAL_17:.*]] = alloca float, align 4 -// CHECK: %[[VAL_18:.*]] = alloca float, align 4 -// CHECK: %[[VAL_19:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_20:.*]] = alloca float, align 4 -// CHECK: %[[VAL_21:.*]] = alloca float, align 4 -// CHECK: %[[VAL_22:.*]] = alloca float, align 4 -// CHECK: %[[VAL_23:.*]] = alloca float, align 4 -// CHECK: %[[VAL_24:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_25:.*]] = alloca float, align 4 -// CHECK: %[[VAL_26:.*]] = alloca float, align 4 -// CHECK: %[[VAL_27:.*]] = alloca float, align 4 -// CHECK: %[[VAL_28:.*]] = alloca float, align 4 -// CHECK: %[[VAL_29:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_30:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_31:.*]] = alloca float, align 4 -// CHECK: %[[VAL_32:.*]] = alloca float, align 4 -// CHECK: %[[VAL_33:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_34:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_35:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_36:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_37:.*]] = alloca float, align 4 -// CHECK: %[[VAL_38:.*]] = alloca float, align 4 -// CHECK: %[[VAL_39:.*]] = alloca float, align 4 -// CHECK: %[[VAL_40:.*]] = alloca float, align 4 -// CHECK-PTX: %[[VAL_41:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !range !2 -// CHECK-GCN: %[[VAL_41:.*]] = call i32 @llvm.amdgcn.workgroup.id.y -// CHECK: %[[VAL_42:.*]] = icmp eq i32 %[[VAL_41]], 0 -// CHECK: br i1 %[[VAL_42]], label %[[VAL_43:.*]], label %[[VAL_44:.*]] -// CHECK: reduce-group-0-after: ; preds = %thread_in_bounds-after, %[[VAL_45:.*]] -// CHECK: ret void -// CHECK: reduce-group-0-true: ; preds = %[[VAL_45]] -// CHECK: %[[VAL_46:.*]] = load float, ptr{{.*}}%[[VAL_47:.*]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_46]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: %[[VAL_48:.*]] = load float, ptr{{.*}}%[[VAL_47]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_48]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !4 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !5 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_49:.*]] = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.1 = urem i32 %[[VAL_49]], 8 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_50:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_51:.*]] = urem i32 %[[VAL_50]], 1 -// CHECK: %[[VAL_52:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_53:.*]] = urem i32 %[[VAL_52]], 25 -// CHECK: %[[VAL_54:.*]] = udiv i32 %block.id.x, 25 -// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_51]], 0 -// CHECK: %tile_bound.2 = select i1 %[[VAL_55]], i32 300, i32 512 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_54]], 5 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_53]], 8 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_51]], 512 -// CHECK: store i32 0, ptr{{.*}}%[[VAL_36]], align 4 -// CHECK: br label %[[VAL_56:.*]] -// CHECK: loop0.loop_header: ; preds = %[[VAL_57:.*]], %[[VAL_43]] -// CHECK: %[[VAL_58:.*]] = load i32, ptr{{.*}}%[[VAL_36]], align 4 -// CHECK: %[[VAL_59:.*]] = icmp uge i32 %[[VAL_58]], 5 -// CHECK: br i1 %[[VAL_59]], label %[[VAL_60:.*]], label %[[VAL_61:.*]] -// CHECK: loop0.loop_body: ; preds = %[[VAL_56]] -// CHECK: %[[VAL_62:.*]] = add nuw nsw i32 %[[VAL_58]], 1 -// CHECK: store i32 %[[VAL_62]], ptr{{.*}}%[[VAL_36]], align 4 -// CHECK: store i32 %thread.id.1, ptr{{.*}}%[[VAL_35]], align 4 -// CHECK: br label %[[VAL_64:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_65:.*]], %[[VAL_61]] -// CHECK: %[[VAL_66:.*]] = load i32, ptr{{.*}}%[[VAL_35]], align 4 -// CHECK: %[[VAL_67:.*]] = icmp uge i32 %[[VAL_66]], 8 -// CHECK: br i1 %[[VAL_67]], label %[[VAL_57]], label %[[VAL_68:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_64]] -// CHECK: %[[VAL_69:.*]] = add nuw nsw i32 %[[VAL_66]], 8 -// CHECK: store i32 %[[VAL_69]], ptr{{.*}}%[[VAL_35]], align 4 -// CHECK: %[[VAL_71:.*]] = icmp eq i32 512, %tile_bound.2 -// CHECK: br i1 %[[VAL_71]], label %[[VAL_72:.*]], label %[[VAL_73:.*]] -// CHECK: is_full_tile-after: ; preds = %[[VAL_74:.*]], %[[VAL_75:.*]] -// CHECK: br label %[[VAL_64]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit: ; preds = %[[VAL_64]] -// CHECK: br label %[[VAL_56]], !llvm.loop !{{[0-9]}} -// CHECK: loop0.loop_exit: ; preds = %[[VAL_56]] -// CHECK: %[[VAL_76:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK-PTX: %[[VAL_77:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_76]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_76_1:.*]] = bitcast float %[[VAL_76]] to i32 -// CHECK-GCN: %[[VAL_77_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_76_1:.*]], i32 543) -// CHECK-GCN: %[[VAL_77:.*]] = bitcast i32 %[[VAL_77_1:.*]] to float -// CHECK: store float %[[VAL_77]], ptr{{.*}}%[[VAL_26]], align 4 -// CHECK: %[[VAL_78:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_79:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_78]], i32 16, i32 31) -// CHECK-GCN: %[[VAL_78_1:.*]] = bitcast float %[[VAL_78]] to i32 -// CHECK-GCN: %[[VAL_79_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_78_1:.*]], i32 543) -// CHECK-GCN: %[[VAL_79:.*]] = bitcast i32 %[[VAL_79_1:.*]] to float -// CHECK: store float %[[VAL_79]], ptr{{.*}}%[[VAL_25]], align 4 -// CHECK-GCN: %[[VAL_22_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_22]] to ptr -// CHECK: %[[VAL_80:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_22]], ptr %[[VAL_80]], align 8 -// CHECK-GCN: store ptr %[[VAL_22_1]], ptr{{.*}}%[[VAL_80]], align 8 -// CHECK-GCN: %[[VAL_23_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_23]] to ptr -// CHECK: %[[VAL_81:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_24]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_23]], ptr %[[VAL_81]], align 8 -// CHECK-GCN: store ptr %[[VAL_23_1]], ptr{{.*}}%[[VAL_81]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_26]], ptr %[[VAL_25]], ptr %[[VAL_24]]) -// CHECK-GCN: %[[VAL_39_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_26_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_26]] to ptr -// CHECK-GCN: %[[VAL_25_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_25]] to ptr -// CHECK-GCN: %[[VAL_24_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_24]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_1]], ptr %[[VAL_37_1]], ptr %[[VAL_26_1]], ptr %[[VAL_25_1]], ptr %[[VAL_24_1]]) -// CHECK: %[[VAL_82:.*]] = load float, ptr{{.*}}%[[VAL_22]], align 4 -// CHECK: %[[VAL_83:.*]] = load float, ptr{{.*}}%[[VAL_23]], align 4 -// CHECK: store float %[[VAL_82]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_83]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_84:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK-PTX: %[[VAL_85:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_84]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_84_1:.*]] = bitcast float %[[VAL_84]] to i32 -// CHECK-GCN: %[[VAL_85_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_84_1:.*]], i32 287) -// CHECK-GCN: %[[VAL_85:.*]] = bitcast i32 %[[VAL_85_1:.*]] to float -// CHECK: store float %[[VAL_85]], ptr{{.*}}%[[VAL_21]], align 4 -// CHECK: %[[VAL_86:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_87:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_86]], i32 8, i32 31) -// CHECK-GCN: %[[VAL_86_1:.*]] = bitcast float %[[VAL_86]] to i32 -// CHECK-GCN: %[[VAL_87_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_86_1:.*]], i32 287) -// CHECK-GCN: %[[VAL_87:.*]] = bitcast i32 %[[VAL_87_1:.*]] to float -// CHECK: store float %[[VAL_87]], ptr{{.*}}%[[VAL_20]], align 4 -// CHECK-GCN: %[[VAL_17_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_17]] to ptr -// CHECK: %[[VAL_88:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_17]], ptr %[[VAL_88]], align 8 -// CHECK-GCN: store ptr %[[VAL_17_1]], ptr{{.*}}%[[VAL_88]], align 8 -// CHECK-GCN: %[[VAL_18_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_18]] to ptr -// CHECK: %[[VAL_89:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_19]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_18]], ptr %[[VAL_89]], align 8 -// CHECK-GCN: store ptr %[[VAL_18_1]], ptr{{.*}}%[[VAL_89]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_21]], ptr %[[VAL_20]], ptr %[[VAL_19]]) -// CHECK-GCN: %[[VAL_39_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_21_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_21]] to ptr -// CHECK-GCN: %[[VAL_20_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_20]] to ptr -// CHECK-GCN: %[[VAL_19_2:.*]] = addrspacecast ptr{{.*}}%[[VAL_19]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_2]], ptr %[[VAL_37_2]], ptr %[[VAL_21_2]], ptr %[[VAL_20_2]], ptr %[[VAL_19_2]]) -// CHECK: %[[VAL_90:.*]] = load float, ptr{{.*}}%[[VAL_17]], align 4 -// CHECK: %[[VAL_91:.*]] = load float, ptr{{.*}}%[[VAL_18]], align 4 -// CHECK: store float %[[VAL_90]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_91]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_92:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK-PTX: %[[VAL_93:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_92]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_92_1:.*]] = bitcast float %[[VAL_92]] to i32 -// CHECK-GCN: %[[VAL_93_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_92_1:.*]], i32 159) -// CHECK-GCN: %[[VAL_93:.*]] = bitcast i32 %[[VAL_93_1:.*]] to float -// CHECK: store float %[[VAL_93]], ptr{{.*}}%[[VAL_16]], align 4 -// CHECK: %[[VAL_94:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_95:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_94]], i32 4, i32 31) -// CHECK-GCN: %[[VAL_94_1:.*]] = bitcast float %[[VAL_94]] to i32 -// CHECK-GCN: %[[VAL_95_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_94_1:.*]], i32 159) -// CHECK-GCN: %[[VAL_95:.*]] = bitcast i32 %[[VAL_95_1:.*]] to float -// CHECK: store float %[[VAL_95]], ptr{{.*}}%[[VAL_15]], align 4 -// CHECK-GCN: %[[VAL_12_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_12]] to ptr -// CHECK: %[[VAL_96:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_12]], ptr %[[VAL_96]], align 8 -// CHECK-GCN: store ptr %[[VAL_12_1]], ptr{{.*}}%[[VAL_96]], align 8 -// CHECK-GCN: %[[VAL_13_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_13]] to ptr -// CHECK: %[[VAL_97:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_14]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_13]], ptr %[[VAL_97]], align 8 -// CHECK-GCN: store ptr %[[VAL_13_1]], ptr{{.*}}%[[VAL_97]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_16]], ptr %[[VAL_15]], ptr %[[VAL_14]]) -// CHECK-GCN: %[[VAL_39_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_16_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_16]] to ptr -// CHECK-GCN: %[[VAL_15_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_15]] to ptr -// CHECK-GCN: %[[VAL_14_3:.*]] = addrspacecast ptr{{.*}}%[[VAL_14]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_3]], ptr %[[VAL_37_3]], ptr %[[VAL_16_3]], ptr %[[VAL_15_3]], ptr %[[VAL_14_3]]) -// CHECK: %[[VAL_98:.*]] = load float, ptr{{.*}}%[[VAL_12]], align 4 -// CHECK: %[[VAL_99:.*]] = load float, ptr{{.*}}%[[VAL_13]], align 4 -// CHECK: store float %[[VAL_98]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_99]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_100:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK-PTX: %[[VAL_101:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_100]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_100_1:.*]] = bitcast float %[[VAL_100]] to i32 -// CHECK-GCN: %[[VAL_101_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_100_1:.*]], i32 95) -// CHECK-GCN: %[[VAL_101:.*]] = bitcast i32 %[[VAL_101_1:.*]] to float -// CHECK: store float %[[VAL_101]], ptr{{.*}}%[[VAL_11]], align 4 -// CHECK: %[[VAL_102:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_103:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_102]], i32 2, i32 31) -// CHECK-GCN: %[[VAL_102_1:.*]] = bitcast float %[[VAL_102]] to i32 -// CHECK-GCN: %[[VAL_103_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_102_1:.*]], i32 95) -// CHECK-GCN: %[[VAL_103:.*]] = bitcast i32 %[[VAL_103_1:.*]] to float -// CHECK: store float %[[VAL_103]], ptr{{.*}}%[[VAL_10]], align 4 -// CHECK-GCN: %[[VAL_7_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_7]] to ptr -// CHECK: %[[VAL_104:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_7]], ptr %[[VAL_104]], align 8 -// CHECK-GCN: store ptr %[[VAL_7_1]], ptr{{.*}}%[[VAL_104]], align 8 -// CHECK-GCN: %[[VAL_8_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_8]] to ptr -// CHECK: %[[VAL_105:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_9]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_8]], ptr %[[VAL_105]], align 8 -// CHECK-GCN: store ptr %[[VAL_8_1]], ptr{{.*}}%[[VAL_105]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_11]], ptr %[[VAL_10]], ptr %[[VAL_9]]) -// CHECK-GCN: %[[VAL_39_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_11_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_11]] to ptr -// CHECK-GCN: %[[VAL_10_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_10]] to ptr -// CHECK-GCN: %[[VAL_9_4:.*]] = addrspacecast ptr{{.*}}%[[VAL_9]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_4]], ptr %[[VAL_37_4]], ptr %[[VAL_11_4]], ptr %[[VAL_10_4]], ptr %[[VAL_9_4]]) -// CHECK: %[[VAL_106:.*]] = load float, ptr{{.*}}%[[VAL_7]], align 4 -// CHECK: %[[VAL_107:.*]] = load float, ptr{{.*}}%[[VAL_8]], align 4 -// CHECK: store float %[[VAL_106]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_107]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_108:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK-PTX: %[[VAL_109:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_108]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_108_1:.*]] = bitcast float %[[VAL_108]] to i32 -// CHECK-GCN: %[[VAL_109_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_108_1:.*]], i32 63) -// CHECK-GCN: %[[VAL_109:.*]] = bitcast i32 %[[VAL_109_1:.*]] to float -// CHECK: store float %[[VAL_109]], ptr{{.*}}%[[VAL_6]], align 4 -// CHECK: %[[VAL_110:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK-PTX: %[[VAL_111:.*]] = call float @llvm.nvvm.shfl.sync.down.f32(i32 -1, float %[[VAL_110]], i32 1, i32 31) -// CHECK-GCN: %[[VAL_110_1:.*]] = bitcast float %[[VAL_110]] to i32 -// CHECK-GCN: %[[VAL_111_1:.*]] = call i32 @llvm.amdgcn.ds.swizzle(i32 %[[VAL_110_1:.*]], i32 63) -// CHECK-GCN: %[[VAL_111:.*]] = bitcast i32 %[[VAL_111_1:.*]] to float -// CHECK: store float %[[VAL_111]], ptr{{.*}}%[[VAL_5]], align 4 -// CHECK-GCN: %[[VAL_2_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_2]] to ptr -// CHECK: %[[VAL_112:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_2]], ptr %[[VAL_112]], align 8 -// CHECK-GCN: store ptr %[[VAL_2_1]], ptr{{.*}}%[[VAL_112]], align 8 -// CHECK-GCN: %[[VAL_3_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_3]] to ptr -// CHECK: %[[VAL_113:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_4]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_3]], ptr %[[VAL_113]], align 8 -// CHECK-GCN: store ptr %[[VAL_3_1]], ptr{{.*}}%[[VAL_113]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_6]], ptr %[[VAL_5]], ptr %[[VAL_4]]) -// CHECK-GCN: %[[VAL_39_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_6_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_6]] to ptr -// CHECK-GCN: %[[VAL_5_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_5]] to ptr -// CHECK-GCN: %[[VAL_4_5:.*]] = addrspacecast ptr{{.*}}%[[VAL_4]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_5]], ptr %[[VAL_37_5]], ptr %[[VAL_6_5]], ptr %[[VAL_5_5]], ptr %[[VAL_4_5]]) -// CHECK: %[[VAL_114:.*]] = load float, ptr{{.*}}%[[VAL_2]], align 4 -// CHECK: %[[VAL_115:.*]] = load float, ptr{{.*}}%[[VAL_3]], align 4 -// CHECK: store float %[[VAL_114]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_115]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_116:.*]] = udiv i32 %thread.id.2, 32 -// CHECK: %[[VAL_117:.*]] = icmp ult i32 %thread.id.1, 8 -// CHECK: br i1 %[[VAL_117]], label %thread_in_bounds-true, label %thread_in_bounds-after -// CHECK: thread_in_bounds-after: ; preds = %[[VAL_118:.*]], %[[VAL_60]] -// CHECK: br label %[[VAL_44]] -// CHECK: is_full_tile-true: ; preds = %[[VAL_68]] -// CHECK: store i32 0, ptr{{.*}}%[[VAL_34]], align 4 -// CHECK: br label %[[VAL_119:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_120:.*]], %[[VAL_72]] -// CHECK: %[[VAL_121:.*]] = load i32, ptr{{.*}}%[[VAL_34]], align 4 -// CHECK: %[[VAL_122:.*]] = icmp uge i32 %[[VAL_121]], 512 -// CHECK: br i1 %[[VAL_122]], label %[[VAL_75]], label %[[VAL_120]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_119]] -// CHECK: %[[VAL_123:.*]] = add nuw nsw i32 %[[VAL_121]], 32 -// CHECK: store i32 %[[VAL_123]], ptr{{.*}}%[[VAL_34]], align 4 -// CHECK: %[[VAL_125:.*]] = add i32 %[[VAL_121]], %thread.id.2 -// CHECK: %[[VAL_126:.*]] = add i32 %tile_origin.0, %[[VAL_58]] -// CHECK: %[[VAL_127:.*]] = add i32 %tile_origin.1, %[[VAL_66]] -// CHECK: %[[VAL_128:.*]] = add i32 %tile_origin.2, %[[VAL_125]] -// CHECK: %[[VAL_129:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] -// CHECK: %[[VAL_131:.*]] = load float, ptr{{.*}}%[[VAL_129]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_131]], ptr{{.*}}%[[VAL_40]], align 4 -// CHECK: %[[VAL_132:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133:.*]], i32 0, i32 %[[VAL_126]], i32 %[[VAL_127]], i32 %[[VAL_128]] -// CHECK: %[[VAL_134:.*]] = load float, ptr{{.*}}%[[VAL_132]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_134]], ptr{{.*}}%[[VAL_38]], align 4 -// CHECK-GCN: %[[VAL_31_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_31]] to ptr -// CHECK: %[[VAL_135:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_31]], ptr %[[VAL_135]], align 8 -// CHECK-GCN: store ptr %[[VAL_31_1]], ptr{{.*}}%[[VAL_135]], align 8 -// CHECK-GCN: %[[VAL_32_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_32]] to ptr -// CHECK: %[[VAL_136:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_33]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_32]], ptr %[[VAL_136]], align 8 -// CHECK-GCN: store ptr %[[VAL_32_1]], ptr{{.*}}%[[VAL_136]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_33]]) -// CHECK-GCN: %[[VAL_39_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_40_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr -// CHECK-GCN: %[[VAL_38_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr -// CHECK-GCN: %[[VAL_33_6:.*]] = addrspacecast ptr{{.*}}%[[VAL_33]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_6]], ptr %[[VAL_37_6]], ptr %[[VAL_40_6]], ptr %[[VAL_38_6]], ptr %[[VAL_33_6]]) -// CHECK: %[[VAL_137:.*]] = load float, ptr{{.*}}%[[VAL_31]], align 4 -// CHECK: %[[VAL_138:.*]] = load float, ptr{{.*}}%[[VAL_32]], align 4 -// CHECK: store float %[[VAL_137]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_138]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: br label %[[VAL_119]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_119]] -// CHECK: br label %[[VAL_65]] -// CHECK: is_full_tile-false: ; preds = %[[VAL_68]] -// CHECK: store i32 0, ptr{{.*}}%[[VAL_30]], align 4 -// CHECK: br label %[[VAL_139:.*]] -// CHECK: loop2.loop_header9: ; preds = %[[VAL_140:.*]], %[[VAL_73]] -// CHECK: %[[VAL_141:.*]] = load i32, ptr{{.*}}%[[VAL_30]], align 4 -// CHECK: %[[VAL_142:.*]] = icmp uge i32 %[[VAL_141]], 512 -// CHECK: br i1 %[[VAL_142]], label %[[VAL_74]], label %[[VAL_143:.*]] -// CHECK: loop2.loop_body10: ; preds = %[[VAL_139]] -// CHECK: %[[VAL_144:.*]] = add nuw nsw i32 %[[VAL_141]], 32 -// CHECK: store i32 %[[VAL_144]], ptr{{.*}}%[[VAL_30]], align 4 -// CHECK: %[[VAL_146:.*]] = add i32 %[[VAL_141]], %thread.id.2 -// CHECK: %[[VAL_147:.*]] = icmp ult i32 %[[VAL_146]], %tile_bound.2 -// CHECK: br i1 %[[VAL_147]], label %[[VAL_148:.*]], label %[[VAL_140]] -// CHECK: x_in_tile-after: ; preds = %[[VAL_148]], %[[VAL_143]] -// CHECK: br label %[[VAL_139]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit8: ; preds = %[[VAL_139]] -// CHECK: br label %[[VAL_65]] -// CHECK: x_in_tile-true: ; preds = %[[VAL_143]] -// CHECK: %[[VAL_149:.*]] = add i32 %tile_origin.0, %[[VAL_58]] -// CHECK: %[[VAL_150:.*]] = add i32 %tile_origin.1, %[[VAL_66]] -// CHECK: %[[VAL_151:.*]] = add i32 %tile_origin.2, %[[VAL_146]] -// CHECK: %[[VAL_152:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_130]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] -// CHECK: %[[VAL_153:.*]] = load float, ptr{{.*}}%[[VAL_152]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_153]], ptr{{.*}}%[[VAL_40]], align 4 -// CHECK: %[[VAL_154:.*]] = getelementptr inbounds [5 x [200 x [300 x float]]], ptr{{.*}}%[[VAL_133]], i32 0, i32 %[[VAL_149]], i32 %[[VAL_150]], i32 %[[VAL_151]] -// CHECK: %[[VAL_155:.*]] = load float, ptr{{.*}}%[[VAL_154]], align 4, !invariant.load !{{[0-9]}} -// CHECK: store float %[[VAL_155]], ptr{{.*}}%[[VAL_38]], align 4 -// CHECK-GCN: %[[VAL_27_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_27]] to ptr -// CHECK: %[[VAL_156:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_27]], ptr %[[VAL_156]], align 8 -// CHECK-GCN: store ptr %[[VAL_27_1]], ptr{{.*}}%[[VAL_156]], align 8 -// CHECK-GCN: %[[VAL_28_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_28]] to ptr -// CHECK: %[[VAL_157:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_29]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_28]], ptr %[[VAL_157]], align 8 -// CHECK-GCN: store ptr %[[VAL_28_1]], ptr{{.*}}%[[VAL_157]], align 8 -// CHECK-PTX: call void @[[ADD:Add.*]](ptr %[[VAL_39]], ptr %[[VAL_37]], ptr %[[VAL_40]], ptr %[[VAL_38]], ptr %[[VAL_29]]) -// CHECK-GCN: %[[VAL_39_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_39]] to ptr -// CHECK-GCN: %[[VAL_37_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_37]] to ptr -// CHECK-GCN: %[[VAL_40_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_40]] to ptr -// CHECK-GCN: %[[VAL_38_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_38]] to ptr -// CHECK-GCN: %[[VAL_29_7:.*]] = addrspacecast ptr{{.*}}%[[VAL_29]] to ptr -// CHECK-GCN: call void @[[ADD:Add.*]](ptr %[[VAL_39_7]], ptr %[[VAL_37_7]], ptr %[[VAL_40_7]], ptr %[[VAL_38_7]], ptr %[[VAL_29_7]]) -// CHECK: %[[VAL_158:.*]] = load float, ptr{{.*}}%[[VAL_27]], align 4 -// CHECK: %[[VAL_159:.*]] = load float, ptr{{.*}}%[[VAL_28]], align 4 -// CHECK: store float %[[VAL_158]], ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: store float %[[VAL_159]], ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: br label %[[VAL_140]] -// CHECK: thread_in_bounds-true: ; preds = %[[VAL_60]] -// CHECK: %[[VAL_160:.*]] = icmp eq i32 %lane_id, 0 -// CHECK: br i1 %[[VAL_160]], label %[[VAL_161:.*]], label %[[VAL_162:.*]] -// CHECK: intra_warp_reduce_write-after: ; preds = %[[VAL_161]], %thread_in_bounds-true -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: %[[VAL_163:.*]] = icmp eq i32 %[[VAL_116]], 0 -// CHECK: br i1 %[[VAL_163]], label %[[VAL_164:.*]], label %[[VAL_118]] -// CHECK: inter_warp_reduce-after: ; preds = %[[VAL_165:.*]], %[[VAL_162]] -// CHECK: br label %thread_in_bounds-after -// CHECK: intra_warp_reduce_write-true: ; preds = %thread_in_bounds-true -// CHECK: %[[VAL_166:.*]] = load float, ptr{{.*}}%[[VAL_39]], align 4 -// CHECK: %[[VAL_167:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] -// CHECK: %[[VAL_168:.*]] = addrspacecast ptr addrspace(3) %[[VAL_167]] to ptr -// CHECK: store float %[[VAL_166]], ptr{{.*}}%[[VAL_168]], align 4 -// CHECK: %[[VAL_169:.*]] = load float, ptr{{.*}}%[[VAL_37]], align 4 -// CHECK: %[[VAL_170:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %[[VAL_116]] -// CHECK: %[[VAL_171:.*]] = addrspacecast ptr addrspace(3) %[[VAL_170]] to ptr -// CHECK: store float %[[VAL_169]], ptr{{.*}}%[[VAL_171]], align 4 -// CHECK: br label %[[VAL_162]] -// CHECK: inter_warp_reduce-true: ; preds = %[[VAL_162]] -// CHECK: %[[VAL_172:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache, i32 0, i32 %thread.id.1, i32 %lane_id -// CHECK: %[[VAL_173:.*]] = addrspacecast ptr addrspace(3) %[[VAL_172]] to ptr -// CHECK-GCN: %[[VAL_1_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_1]] to ptr -// CHECK-PTX: store float %[[VAL_46]], ptr %[[VAL_1]], align 4 -// CHECK-GCN: store float %[[VAL_46]], ptr %[[VAL_1_1]], align 4 -// CHECK: %[[VAL_174:.*]] = icmp ult i32 %thread.id.2, 1 -// CHECK-PTX: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1]] -// CHECK-GCN: %[[VAL_175:.*]] = select i1 %[[VAL_174]], ptr %[[VAL_173]], ptr %[[VAL_1_1]] -// CHECK: %[[VAL_176:.*]] = getelementptr inbounds [8 x [1 x float]], ptr addrspace(3) @shared_cache{{.*}}, i32 0, i32 %thread.id.1, i32 %lane_id -// CHECK: %[[VAL_177:.*]] = addrspacecast ptr addrspace(3) %[[VAL_176]] to ptr -// CHECK-GCN: %[[VAL_0_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_0]] to ptr -// CHECK-PTX: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0]], align 4 -// CHECK-GCN: store float %[[VAL_48]], ptr{{.*}}%[[VAL_0_1]], align 4 -// CHECK: %[[VAL_178:.*]] = icmp ult i32 %thread.id.2, 1 -// CHECK-PTX: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0]] -// CHECK-GCN: %[[VAL_179:.*]] = select i1 %[[VAL_178]], ptr{{.*}}%[[VAL_177]], ptr %[[VAL_0_1]] -// CHECK: %[[VAL_180:.*]] = icmp eq i32 %thread.id.2, 0 -// CHECK: br i1 %[[VAL_180]], label %[[VAL_181:.*]], label %[[VAL_165]] -// CHECK: reduction_write_output-after: ; preds = %[[VAL_181]], %[[VAL_164]] -// CHECK: br label %[[VAL_118]] -// CHECK: reduction_write_output-true: ; preds = %[[VAL_164]] -// CHECK: %[[VAL_183:.*]] = add i32 %tile_origin.1, %thread.id.1 -// CHECK: %[[VAL_186:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_187:.*]], i32 0, i32 %[[VAL_183]] -// CHECK: %[[VAL_188:.*]] = load float, ptr{{.*}}%[[VAL_175]], align 4 -// CHECK: store float %[[VAL_188]], ptr{{.*}}%[[VAL_186]], align 4 -// CHECK: %[[VAL_190:.*]] = add i32 %tile_origin.1, %thread.id.1 -// CHECK: %[[VAL_193:.*]] = getelementptr inbounds [200 x float], ptr{{.*}}%[[VAL_194:.*]], i32 0, i32 %[[VAL_190]] -// CHECK: %[[VAL_195:.*]] = load float, ptr{{.*}}%[[VAL_179]], align 4 -// CHECK: store float %[[VAL_195]], ptr{{.*}}%[[VAL_193]], align 4 -// CHECK: br label %[[VAL_165]] -// CHECK: entry: -// CHECK: %[[VAL_196:.*]] = alloca float, align 4 -// CHECK: %[[VAL_197:.*]] = alloca float, align 4 -// CHECK: %[[VAL_198:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_199:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_200:.*]] = alloca [2 x ptr], align 8 -// CHECK: %[[VAL_201:.*]] = load float, ptr{{.*}}%[[VAL_202:.*]], align 4 -// CHECK: %[[VAL_203:.*]] = load float, ptr{{.*}}%[[VAL_204:.*]], align 4 -// CHECK: %[[VAL_205:.*]] = fadd float %[[VAL_201]], %[[VAL_203]] -// CHECK: store float %[[VAL_205]], ptr{{.*}}%[[VAL_197]], align 4 -// CHECK: %[[VAL_206:.*]] = load float, ptr{{.*}}%[[VAL_207:.*]], align 4 -// CHECK: %[[VAL_208:.*]] = load float, ptr{{.*}}%[[VAL_209:.*]], align 4 -// CHECK: %[[VAL_210:.*]] = fadd float %[[VAL_206]], %[[VAL_208]] -// CHECK: store float %[[VAL_210]], ptr{{.*}}%[[VAL_196]], align 4 -// CHECK-GCN: %[[VAL_197_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_197]] to ptr -// CHECK: %[[VAL_211:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 -// CHECK-PTX: store ptr %[[VAL_197]], ptr %[[VAL_211]], align 8 -// CHECK-GCN: store ptr %[[VAL_197_1]], ptr{{.*}}%[[VAL_211]], align 8 -// CHECK-GCN: %[[VAL_196_1:.*]] = addrspacecast ptr{{.*}}%[[VAL_196]] to ptr -// CHECK: %[[VAL_212:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 -// CHECK-PTX: store ptr %[[VAL_196]], ptr %[[VAL_212]], align 8 -// CHECK-GCN: store ptr %[[VAL_196_1]], ptr{{.*}}%[[VAL_212]], align 8 -// CHECK: %[[VAL_213:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214:.*]], i64 0, i64 0 -// CHECK: %[[VAL_215:.*]] = load ptr, ptr{{.*}}%[[VAL_213]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} -// CHECK: %[[VAL_216:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 0 -// CHECK: %[[VAL_217:.*]] = load ptr, ptr{{.*}}%[[VAL_216]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} -// CHECK: %[[VAL_218:.*]] = load float, ptr{{.*}}%[[VAL_217]], align 4 -// CHECK: store float %[[VAL_218]], ptr{{.*}}%[[VAL_215]], align 4 -// CHECK: %[[VAL_219:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_214]], i64 0, i64 1 -// CHECK: %[[VAL_220:.*]] = load ptr, ptr{{.*}}%[[VAL_219]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} -// CHECK: %[[VAL_221:.*]] = getelementptr inbounds [2 x ptr], ptr{{.*}}%[[VAL_200]], i64 0, i64 1 -// CHECK: %[[VAL_222:.*]] = load ptr, ptr{{.*}}%[[VAL_221]], align 8, !dereferenceable !{{[0-9]*}}, !align !{{[0-9]*}} -// CHECK: %[[VAL_223:.*]] = load float, ptr{{.*}}%[[VAL_222]], align 4 -// CHECK: store float %[[VAL_223]], ptr{{.*}}%[[VAL_220]], align 4 -// CHECK: ret void - diff --git a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo deleted file mode 100644 index baeb614b18d6e1..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_sm_all.hlo +++ /dev/null @@ -1,210 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck %s -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/p100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM60 -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/v100.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM70 -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a6000.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM86 - -// CHECK-LABEL: .entry wrapped_reduce_odd_row -// CHECK-NOT: ld.global.nc.v2.f32 -// CHECK-NOT: ld.global.nc.v4.f32 -// CHECK-NOT: ld.global.nc.u64 -// CHECK-NOT: ld.global.u64 - -HloModule ReduceOddRowSize, is_scheduled=true - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(%x, %y) -} - -%fused_computation { - %param_0.1 = f32[5,4071]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.odd_row.1 = f32[5]{0} reduce(f32[5,4071]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%max_ -} - -ENTRY %main { - %param_0 = f32[5,4071] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %wrapped_reduce.odd_row = f32[5]{0} fusion(f32[5,4071]{1,0} %param_0, f32[] %constant.3), kind=kInput, calls=%fused_computation -} - -// ----- - -// CHECK-SM86-LABEL: .entry wrapped_reduce_small_row -// CHECK-SM86: .reqntid 256, 1, 1 - -HloModule ReduceSmallRow, is_scheduled=true - -addition { - x = f32[] parameter(0) - y = f32[] parameter(1) - ROOT out = add(%x, %y) -} - -%fused_computation { - %param_0 = f32[700000,32]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce_small_row.1 = f32[700000]{0} reduce(f32[700000,32]{1,0} %param_0, f32[] %param_1), dimensions={1}, to_apply=%addition -} - -ENTRY main { - p = f32[700000,32] parameter(0) - zero = f32[] constant(0) - ROOT %wrapped_reduce_small_row = f32[700000]{0} fusion(f32[700000,32]{1,0} %p, f32[] %zero), kind=kInput, calls=%fused_computation -} - -// ----- - -// CHECK-LABEL: .entry wrapped_reduce_sine -// CHECK-COUNT-7: ld.global.nc.v2.f32 - -HloModule DisableSin, is_scheduled=true - -%add_float { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) -} - -%fused_computation { - %param_0 = f32[5,3584]{1,0} parameter(0) - ROOT %sine.1 = f32[5,3584]{1,0} sine(f32[5,3584]{1,0} %param_0) -} - -%fused_computation.1 { - %param_0.1 = f32[5,3584]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.sine.1 = f32[5]{0} reduce(f32[5,3584]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%add_float -} - -ENTRY %main { - %arg0.1 = f32[5,3584] parameter(0) - %wrapped_sine = f32[5,3584]{1,0} fusion(f32[5,3584]{1,0} %arg0.1), kind=kLoop, calls=%fused_computation - %constant.0 = f32[] constant(0) - ROOT %wrapped_reduce.sine = f32[5]{0} fusion(f32[5,3584]{1,0} %wrapped_sine, f32[] %constant.0), kind=kInput, calls=%fused_computation.1 -} - -// ----- - -// SM dependent tests - -// CHECK-SM60: .entry wrapped_exp -// CHECK-SM60-LABEL: .entry wrapped_reduce_exp -// CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 - -// CHECK-SM70: .entry wrapped_exp -// CHECK-SM70-LABEL: .entry wrapped_reduce_exp -// CHECK-SM70-COUNT-8: ld.global.nc.v2.f32 - -HloModule Exp, is_scheduled=true - -%add_float { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add.17 = f32[] add(f32[] %x, f32[] %y) -} - -%fused_computation { - %param_0 = f32[5,3584]{1,0} parameter(0) - ROOT %exp.1 = f32[5,3584]{1,0} exponential(f32[5,3584]{1,0} %param_0) -} - -%fused_computation.1 { - %param_0.1 = f32[5,3584]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.exp.1 = f32[5]{0} reduce(f32[5,3584]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%add_float -} - -ENTRY %main { - %arg0.1 = f32[5,3584] parameter(0) - %wrapped_exp = f32[5,3584]{1,0} fusion(f32[5,3584]{1,0} %arg0.1), kind=kLoop, calls=%fused_computation - %constant.0 = f32[] constant(0) - ROOT %wrapped_reduce.exp = f32[5]{0} fusion(f32[5,3584]{1,0} %wrapped_exp, f32[] %constant.0), kind=kInput, calls=%fused_computation.1 -} - -// ----- - -HloModule ReduceTileFit, is_scheduled=true - -// CHECK-SM60-LABEL: .entry wrapped_reduce_tile_fit -// CHECK-SM60-COUNT-8: ld.global.nc.v2.f32 - -// CHECK-SM70-LABEL: .entry wrapped_reduce_tile_fit -// CHECK-SM70-COUNT-4: ld.global.nc.v2.f32 - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -%fused_computation { - %param_0.1 = f32[5,3584]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.tile_fit.1 = f32[5]{0} reduce(f32[5,3584]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%max_ -} - -ENTRY %main { - %param_0 = f32[5,3584] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %wrapped_reduce.tile_fit = f32[5]{0} fusion(f32[5,3584]{1,0} %param_0, f32[] %constant.3), kind=kInput, calls=%fused_computation -} - -// ----- - -HloModule ReducePower2, is_scheduled=true - -// CHECK-SM60-LABEL: .entry wrapped_reduce_pow_2 -// CHECK-SM60-COUNT-4: ld.global.nc.v2.f32 - -// CHECK-SM70-LABEL: .entry wrapped_reduce_pow_2 -// CHECK-SM70-COUNT-4: ld.global.nc.v2.f32 - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -%fused_computation { - %param_0.1 = f32[5,4096]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.pow_2.1 = f32[5]{0} reduce(f32[5,4096]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%max_ -} - -ENTRY %main { - %param_0 = f32[5,4096] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %wrapped_reduce.pow_2 = f32[5]{0} fusion(f32[5,4096]{1,0} %param_0, f32[] %constant.3), kind=kInput, calls=%fused_computation -} - -// ----- - -HloModule ReduceEvenColumns, is_scheduled=true - -// CHECK-SM60-LABEL: .entry wrapped_reduce_even_col -// CHECK-SM60-NOT: ld.global.nc.f32 -// CHECK-SM60-COUNT-8: ld.global.nc.f32 - -// CHECK-SM70-LABEL: .entry wrapped_reduce_even_col -// CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 -// CHECK-SM70-COUNT-2: ld.global.nc.v2.f32 - -%max_ { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %maximum.7 = f32[] maximum(f32[] %x, f32[] %y) -} - -%fused_computation { - %param_0.1 = f32[5,4070]{1,0} parameter(0) - %param_1 = f32[] parameter(1) - ROOT %reduce.even_col.1 = f32[5]{0} reduce(f32[5,4070]{1,0} %param_0.1, f32[] %param_1), dimensions={1}, to_apply=%max_ -} - -ENTRY %main { - %param_0 = f32[5,4070] parameter(0) - %constant.3 = f32[] constant(0) - ROOT %wrapped_reduce.even_col = f32[5]{0} fusion(f32[5,4070]{1,0} %param_0, f32[] %constant.3), kind=kInput, calls=%fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc index 680391c2fa7db6..7d78d1ccb12a16 100644 --- a/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc +++ b/third_party/xla/xla/service/gpu/tests/reduction_vectorization_test.cc @@ -106,60 +106,60 @@ CHECK: st.global.v2.f32 EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); } -TEST_F(ReductionVectorizationTest, NoVectorizationForBlockSmallerThanWarpSize) { - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { - GTEST_SKIP() << "MLIR emitters can vectorize this"; - } - const char* hlo_text = R"( -HloModule SlowModule - -%search_fn (x: f32[], y: f32[]) -> f32[] { - %x = f32[] parameter(0) - %y = f32[] parameter(1) - ROOT %add0 = f32[] add(f32[] %x, f32[] %y) -} - -ENTRY %fused_computation.371 (param_0: f32[6400,4,8,32]) -> f32[6400,4,8] { - %param_0 = f32[6400,4,8,32]{3,2,1,0} parameter(0) - %constant_0 = f32[] constant(0.0) - ROOT %reduce.277 = f32[6400,4,8]{2,1,0} reduce(f32[6400,4,8,32]{3,2,1,0} %param_0, f32[] %constant_0), dimensions={3}, to_apply=%search_fn -} -)"; - - std::string expected_optimized_llvm_ir = R"( -CHECK: %[[thread_id:.*]] = tail call i32 X_THREAD -CHECK: %[[masked_thread_id:.*]] = and i32 %[[thread_id]], 31 -// Verify that there is no comparison masking half the warp. -CHECK-NOT: icmp ult i32 %[[masked_thread_id]], 16 -// Verify that we only do one warp reducton by checking that there are 6 -// shfl.sync corresponding to 1 declaration and 5 shuffle instructions. The -// second warp reduction was originally produced for inter-warp reduction -// which we have now optimized away. -CHECK-COUNT-6: SHUFFLE -CHECK-NOT: SHUFFLE -)"; - - expected_optimized_llvm_ir = absl::StrReplaceAll( - expected_optimized_llvm_ir, - {{"X_THREAD", is_built_with_rocm_ ? "@llvm.amdgcn.workitem.id.x" - : "@llvm.nvvm.read.ptx.sreg.tid.x"}, - {"SHUFFLE", is_built_with_rocm_ ? "@llvm.amdgcn.ds.swizzle" - : "llvm.nvvm.shfl.sync.down.f32"}}); - - CompileAndVerifyIr(hlo_text, expected_optimized_llvm_ir, true); - - // Check that there is a single scalar load. - const char* expected_ptx = R"( -CHECK: ld.global.nc.f32 -CHECK: shfl.sync.down -CHECK-NOT: ld.global.nc.f32 -CHECK-NOT: ld.global.v2.f32 -)"; - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, - ParseAndReturnVerifiedModule(hlo_text)); - CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); - EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); -} +// TEST_F(ReductionVectorizationTest, NoVectorizationForBlockSmallerThanWarpSize) { +// if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() >= 4) { +// GTEST_SKIP() << "MLIR emitters can vectorize this"; +// } +// const char* hlo_text = R"( +// HloModule SlowModule + +// %search_fn (x: f32[], y: f32[]) -> f32[] { +// %x = f32[] parameter(0) +// %y = f32[] parameter(1) +// ROOT %add0 = f32[] add(f32[] %x, f32[] %y) +// } + +// ENTRY %fused_computation.371 (param_0: f32[6400,4,8,32]) -> f32[6400,4,8] { +// %param_0 = f32[6400,4,8,32]{3,2,1,0} parameter(0) +// %constant_0 = f32[] constant(0.0) +// ROOT %reduce.277 = f32[6400,4,8]{2,1,0} reduce(f32[6400,4,8,32]{3,2,1,0} %param_0, f32[] %constant_0), dimensions={3}, to_apply=%search_fn +// } +// )"; + +// std::string expected_optimized_llvm_ir = R"( +// CHECK: %[[thread_id:.*]] = tail call i32 X_THREAD +// CHECK: %[[masked_thread_id:.*]] = and i32 %[[thread_id]], 31 +// // Verify that there is no comparison masking half the warp. +// CHECK-NOT: icmp ult i32 %[[masked_thread_id]], 16 +// // Verify that we only do one warp reducton by checking that there are 6 +// // shfl.sync corresponding to 1 declaration and 5 shuffle instructions. The +// // second warp reduction was originally produced for inter-warp reduction +// // which we have now optimized away. +// CHECK-COUNT-6: SHUFFLE +// CHECK-NOT: SHUFFLE +// )"; + +// expected_optimized_llvm_ir = absl::StrReplaceAll( +// expected_optimized_llvm_ir, +// {{"X_THREAD", is_built_with_rocm_ ? "@llvm.amdgcn.workitem.id.x" +// : "@llvm.nvvm.read.ptx.sreg.tid.x"}, +// {"SHUFFLE", is_built_with_rocm_ ? "@llvm.amdgcn.ds.swizzle" +// : "llvm.nvvm.shfl.sync.down.f32"}}); + +// CompileAndVerifyIr(hlo_text, expected_optimized_llvm_ir, true); + +// // Check that there is a single scalar load. +// const char* expected_ptx = R"( +// CHECK: ld.global.nc.f32 +// CHECK: shfl.sync.down +// CHECK-NOT: ld.global.nc.f32 +// CHECK-NOT: ld.global.v2.f32 +// )"; +// TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, +// ParseAndReturnVerifiedModule(hlo_text)); +// CompileAndOptionallyVerifyPtx(std::move(optimized_module), expected_ptx); +// EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{1e-5, 1e-5})); +// } } // namespace } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/tests/scatter.hlo b/third_party/xla/xla/service/gpu/tests/scatter.hlo deleted file mode 100644 index 21dfb77611ef0b..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/scatter.hlo +++ /dev/null @@ -1,300 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py - -// RUN: hlo-opt %s --xla_gpu_mlir_emitter_level=0 --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -// CHECK-LABEL: ModuleID = 'TensorFlowScatterV1' -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_3:.*]] = mul nuw nsw i32 %[[VAL_1]], 6 -// CHECK: %[[VAL_4:.*]] = add nuw nsw i32 %[[VAL_3]], %[[VAL_2]] -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_5]]) -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_4]], 0 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 3 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_12:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: br i1 %[[VAL_12]], label %[[VAL_13:.*]], label %[[VAL_14:.*]] -// CHECK: scatter_TensorFlowScatterV1.in_bounds-after: ; preds = %[[VAL_15:.*]], %[[VAL_16:.*]] -// CHECK: ret void -// CHECK: scatter_TensorFlowScatterV1.in_bounds-true: ; preds = %[[VAL_16]] -// CHECK: %[[VAL_17:.*]] = getelementptr inbounds [2 x [1 x i32]], ptr %[[VAL_18:.*]], i32 0, i32 %[[VAL_11]], i32 0 -// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_17]], align 4, !invariant.load !4 -// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_10]], %[[VAL_19]] -// CHECK: %[[VAL_21:.*]] = icmp ult i32 %[[VAL_19]], 3 -// CHECK: %[[VAL_22:.*]] = and i1 true, %[[VAL_21]] -// CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_15]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_23]], %[[VAL_13]] -// CHECK: br label %[[VAL_14]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_13]] -// CHECK: %[[VAL_24:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_25:.*]], i32 0, i32 %[[VAL_20]], i32 %[[VAL_8]] -// CHECK: %[[VAL_26:.*]] = getelementptr i32, ptr %[[VAL_27:.*]], i32 %[[VAL_4]] -// CHECK: %[[VAL_28:.*]] = getelementptr inbounds i32, ptr %[[VAL_26]], i32 0 -// CHECK: %[[VAL_29:.*]] = load i32, ptr %[[VAL_28]], align 4, !invariant.load !4 -// CHECK: store i32 %[[VAL_29]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_30:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: store atomic i32 %[[VAL_30]], ptr %[[VAL_24]] unordered, align 4 -// CHECK: br label %[[VAL_15]] - -HloModule TensorFlowScatterV1, is_scheduled=true - -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - ROOT rhs = s32[] parameter(1) -} - -fused_computation { - operand = s32[3,3] parameter(0) - indices = s32[2,1] parameter(1) - updates = s32[2,1,3] parameter(2) - ROOT scatter_TensorFlowScatterV1 = s32[3,3] scatter(operand, indices, updates), - to_apply=update_s32, - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p0 = s32[3,3] parameter(0) - p1 = s32[2,1] parameter(1) - p2 = s32[2,1,3] parameter(2) - ROOT wrapped_scatter = s32[3,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation -} - - -// ----- - -// CHECK-LABEL: ModuleID = 'TensorFlowScatter_Mul' -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_3:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_3:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_4:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_4:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_5:.*]] = mul nuw nsw i32 %[[VAL_3]], 6 -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_5]], %[[VAL_4]] -// CHECK: %[[VAL_7:.*]] = icmp ult i32 %[[VAL_6]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_7]]) -// CHECK: %[[VAL_8:.*]] = add nuw nsw i32 %[[VAL_6]], 0 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_8]], 1 -// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 3 -// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_8]], 3 -// CHECK: %[[VAL_12:.*]] = urem i32 %[[VAL_11]], 1 -// CHECK: %[[VAL_13:.*]] = udiv i32 %[[VAL_8]], 3 -// CHECK: %[[VAL_14:.*]] = icmp ult i32 %[[VAL_6]], 6 -// CHECK: br i1 %[[VAL_14]], label %[[VAL_15:.*]], label %[[VAL_16:.*]] -// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-after: ; preds = %[[VAL_17:.*]], %[[VAL_18:.*]] -// CHECK: ret void -// CHECK: scatter_TensorFlowScatter_Mul.in_bounds-true: ; preds = %[[VAL_18]] -// CHECK: %[[VAL_19:.*]] = getelementptr inbounds [2 x [1 x i32]], ptr %[[VAL_20:.*]], i32 0, i32 %[[VAL_13]], i32 0 -// CHECK: %[[VAL_21:.*]] = load i32, ptr %[[VAL_19]], align 4, !invariant.load !4 -// CHECK: %[[VAL_22:.*]] = add i32 %[[VAL_12]], %[[VAL_21]] -// CHECK: %[[VAL_23:.*]] = icmp ult i32 %[[VAL_21]], 3 -// CHECK: %[[VAL_24:.*]] = and i1 true, %[[VAL_23]] -// CHECK: br i1 %[[VAL_24]], label %[[VAL_25:.*]], label %[[VAL_17]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_26:.*]], %[[VAL_15]] -// CHECK: br label %[[VAL_16]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_15]] -// CHECK: %[[VAL_27:.*]] = getelementptr inbounds [3 x [3 x i32]], ptr %[[VAL_28:.*]], i32 0, i32 %[[VAL_22]], i32 %[[VAL_10]] -// CHECK: %[[VAL_29:.*]] = getelementptr i32, ptr %[[VAL_30:.*]], i32 %[[VAL_6]] -// CHECK: %[[VAL_31:.*]] = getelementptr inbounds i32, ptr %[[VAL_29]], i32 0 -// CHECK: %[[VAL_32:.*]] = load i32, ptr %[[VAL_31]], align 4, !invariant.load !4 -// CHECK: store i32 %[[VAL_32]], ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_33:.*]] = load i32, ptr %[[VAL_2]], align 4 -// CHECK: %[[VAL_34:.*]] = load i32, ptr %[[VAL_27]], align 4 -// CHECK: store i32 %[[VAL_34]], ptr %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_35:.*]] -// CHECK: atomic_op_loop_exit: ; preds = %[[VAL_36:.*]], %[[VAL_35]] -// CHECK: br label %[[VAL_17]] -// CHECK: atomic_op_loop_body: ; preds = %[[VAL_36]], %[[VAL_25]] -// CHECK: %[[VAL_37:.*]] = load i32, ptr %[[VAL_1]], align 4 -// CHECK: store i32 %[[VAL_37]], ptr %[[VAL_0]], align 4 -// CHECK: call void @mul_s32_{{.*}}(ptr %[[VAL_0]], ptr %[[VAL_2]], ptr %[[VAL_0]]) -// CHECK: %[[VAL_38:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_39:.*]] = icmp eq i32 %[[VAL_37]], %[[VAL_38]] -// CHECK: br i1 %[[VAL_39]], label %[[VAL_26]], label %[[VAL_36]] -// CHECK: atomic_op_loop_cas: ; preds = %[[VAL_35]] -// CHECK: %[[VAL_40:.*]] = cmpxchg ptr %[[VAL_27]], i32 %[[VAL_37]], i32 %[[VAL_38]] seq_cst seq_cst, align 4 -// CHECK: %[[VAL_41:.*]] = extractvalue { i32, i1 } %[[VAL_40]], 0 -// CHECK: store i32 %[[VAL_41]], ptr %[[VAL_1]], align 4 -// CHECK: %[[VAL_42:.*]] = extractvalue { i32, i1 } %[[VAL_40]], 1 -// CHECK: br i1 %[[VAL_42]], label %[[VAL_26]], label %[[VAL_35]] -// CHECK: entry: -// CHECK: %[[VAL_43:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_44:.*]] = load i32, ptr %[[VAL_45:.*]], align 4 -// CHECK: %[[VAL_46:.*]] = load i32, ptr %[[VAL_47:.*]], align 4 -// CHECK: %[[VAL_48:.*]] = mul i32 %[[VAL_44]], %[[VAL_46]] -// CHECK: store i32 %[[VAL_48]], ptr %[[VAL_43]], align 4 -// CHECK: %[[VAL_49:.*]] = load i32, ptr %[[VAL_43]], align 4 -// CHECK: store i32 %[[VAL_49]], ptr %[[VAL_50:.*]], align 4 -// CHECK: ret void - - -HloModule TensorFlowScatter_Mul, is_scheduled=true - -mul_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - rhs = s32[] parameter(1) - ROOT mul = s32[] multiply(s32[] lhs, s32[] rhs) -} - -fused_computation { - operand = s32[3,3] parameter(0) - indices = s32[2,1] parameter(1) - updates = s32[2,1,3] parameter(2) - ROOT scatter_TensorFlowScatter_Mul = s32[3,3] scatter(operand, indices, updates), - to_apply=mul_s32, - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p0 = s32[3,3] parameter(0) - p1 = s32[2,1] parameter(1) - p2 = s32[2,1,3] parameter(2) - ROOT wrapped_scatter = s32[3,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation -} - -// ----- - - -// CHECK-LABEL: ModuleID = 'ScalarUpdate' -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_3:.*]] = mul nuw nsw i32 %[[VAL_1]], 1 -// CHECK: %[[VAL_4:.*]] = add nuw nsw i32 %[[VAL_3]], %[[VAL_2]] -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_4]], 1 -// CHECK: call void @llvm.assume(i1 %[[VAL_5]]) -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_4]], 0 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_4]], 1 -// CHECK: br i1 %[[VAL_10]], label %[[VAL_11:.*]], label %[[VAL_12:.*]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] -// CHECK: ret void -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_14]] -// CHECK: %[[VAL_15:.*]] = getelementptr inbounds [1 x [1 x i32]], ptr %[[VAL_16:.*]], i32 0, i32 0, i32 0 -// CHECK: %[[VAL_17:.*]] = load i32, ptr %[[VAL_15]], align 4, !invariant.load !3 -// CHECK: %[[VAL_18:.*]] = add i32 %[[VAL_8]], %[[VAL_17]] -// CHECK: %[[VAL_19:.*]] = icmp ult i32 %[[VAL_17]], 4 -// CHECK: %[[VAL_20:.*]] = and i1 true, %[[VAL_19]] -// CHECK: br i1 %[[VAL_20]], label %[[VAL_21:.*]], label %[[VAL_13]] -// CHECK: scatter.in_bounds-after3: ; preds = %[[VAL_21]], %[[VAL_11]] -// CHECK: br label %[[VAL_12]] -// CHECK: scatter.in_bounds-true2: ; preds = %[[VAL_11]] -// CHECK: %[[VAL_22:.*]] = getelementptr inbounds [4 x i32], ptr %[[VAL_23:.*]], i32 0, i32 %[[VAL_18]] -// CHECK: %[[VAL_24:.*]] = getelementptr i32, ptr %[[VAL_25:.*]], i32 %[[VAL_4]] -// CHECK: %[[VAL_26:.*]] = getelementptr inbounds i32, ptr %[[VAL_24]], i32 0 -// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_26]], align 4, !invariant.load !3 -// CHECK: store i32 %[[VAL_27]], ptr %[[VAL_0]], align 4 -// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_0]], align 4 -// CHECK: store atomic i32 %[[VAL_28]], ptr %[[VAL_22]] unordered, align 4 -// CHECK: br label %[[VAL_13]] - -HloModule ScalarUpdate, is_scheduled=true - -update_s32 (lhs: s32[], rhs: s32[]) -> s32[] { - lhs = s32[] parameter(0) - ROOT rhs = s32[] parameter(1) -} - -fused_computation { - operand = s32[4]{0} parameter(0) - index = s32[1,1] parameter(1) - updates = s32[1,1] parameter(2) - ROOT scatter = s32[4]{0} scatter(operand, index, updates), - to_apply=update_s32, - update_window_dims={1}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p0 = s32[4]{0} parameter(0) - p1 = s32[1,1] parameter(1) - p2 = s32[1,1] parameter(2) - ROOT wrapped_scatter = s32[4] fusion(p0, p1, p2), kind=kInput, calls=fused_computation -} - -// ----- - - -// CHECK-LABEL: ModuleID = 'TensorFlowScatter_Add' -// CHECK: %[[VAL_0:.*]] = alloca half, align 2 -// CHECK-PTX: %[[VAL_1:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x -// CHECK-GCN: %[[VAL_1:.*]] = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK-PTX: %[[VAL_2:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x -// CHECK-GCN: %[[VAL_2:.*]] = call i32 @llvm.amdgcn.workitem.id.x -// CHECK: %[[VAL_3:.*]] = mul nuw nsw i32 %[[VAL_1]], 6 -// CHECK: %[[VAL_4:.*]] = add nuw nsw i32 %[[VAL_3]], %[[VAL_2]] -// CHECK: %[[VAL_5:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: call void @llvm.assume(i1 %[[VAL_5]]) -// CHECK: %[[VAL_6:.*]] = add nuw nsw i32 %[[VAL_4]], 0 -// CHECK: %[[VAL_7:.*]] = udiv i32 %[[VAL_6]], 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 3 -// CHECK: %[[VAL_9:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 1 -// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_6]], 3 -// CHECK: %[[VAL_12:.*]] = icmp ult i32 %[[VAL_4]], 6 -// CHECK: br i1 %[[VAL_12]], label %[[VAL_13:.*]], label %[[VAL_14:.*]] -// CHECK: scatter_TensorFlowScatter_Add.in_bounds-after: ; preds = %[[VAL_15:.*]], %[[VAL_16:.*]] -// CHECK: ret void -// CHECK: scatter_TensorFlowScatter_Add.in_bounds-true: ; preds = %[[VAL_16]] -// CHECK: %[[VAL_17:.*]] = getelementptr inbounds [2 x [1 x i32]], ptr %[[VAL_18:.*]], i32 0, i32 %[[VAL_11]], i32 0 -// CHECK: %[[VAL_19:.*]] = load i32, ptr %[[VAL_17]], align 4, !invariant.load !4 -// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_10]], %[[VAL_19]] -// CHECK: %[[VAL_21:.*]] = icmp ult i32 %[[VAL_19]], 3 -// CHECK: %[[VAL_22:.*]] = and i1 true, %[[VAL_21]] -// CHECK: br i1 %[[VAL_22]], label %[[VAL_23:.*]], label %[[VAL_15]] -// CHECK: scatter.in_bounds-after: ; preds = %[[VAL_23]], %[[VAL_13]] -// CHECK: br label %[[VAL_14]] -// CHECK: scatter.in_bounds-true: ; preds = %[[VAL_13]] -// CHECK: %[[VAL_24:.*]] = getelementptr inbounds [3 x [3 x half]], ptr %[[VAL_25:.*]], i32 0, i32 %[[VAL_20]], i32 %[[VAL_8]] -// CHECK: %[[VAL_26:.*]] = getelementptr half, ptr %[[VAL_27:.*]], i32 %[[VAL_4]] -// CHECK: %[[VAL_28:.*]] = getelementptr inbounds half, ptr %[[VAL_26]], i32 0 -// CHECK: %[[VAL_29:.*]] = load half, ptr %[[VAL_28]], align 2, !invariant.load !4 -// CHECK: store half %[[VAL_29]], ptr %[[VAL_0]], align 2 -// CHECK: %[[VAL_30:.*]] = load half, ptr %[[VAL_0]], align 2 -// CHECK: %[[VAL_31:.*]] = atomicrmw fadd ptr %[[VAL_24]], half %[[VAL_30]] seq_cst, align 2 -// CHECK: br label %[[VAL_15]] - -HloModule TensorFlowScatter_Add, is_scheduled=true - -add_f16 (lhs: f16[], rhs: f16[]) -> f16[] { - lhs = f16[] parameter(0) - rhs = f16[] parameter(1) - ROOT add = f16[] add(f16[] lhs, f16[] rhs) -} - -fused_computation { - operand = f16[3,3] parameter(0) - indices = s32[2,1] parameter(1) - updates = f16[2,1,3] parameter(2) - ROOT scatter_TensorFlowScatter_Add = f16[3,3] scatter(operand, indices, updates), - to_apply=add_f16, - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p0 = f16[3,3] parameter(0) - p1 = s32[2,1] parameter(1) - p2 = f16[2,1,3] parameter(2) - ROOT wrapped_scatter = f16[3,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/scatter_bf16.hlo b/third_party/xla/xla/service/gpu/tests/scatter_bf16.hlo deleted file mode 100644 index 59e54e48ee8d4b..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/scatter_bf16.hlo +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/a6000.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM86 -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/h100_sxm.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-SM90 -// RUN: hlo-opt %s --platform=gpu --stage=ptx --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/h100_sxm.txtpb --split-input-file | FileCheck %s --check-prefixes=CHECK-PTX-SM90 - -HloModule TensorFlowScatter_Add, is_scheduled=true - -add_bf16 (lhs: bf16[], rhs: bf16[]) -> bf16[] { - lhs = bf16[] parameter(0) - rhs = bf16[] parameter(1) - ROOT add = bf16[] add(bf16[] lhs, bf16[] rhs) -} - -fused_computation { - operand = bf16[3,3] parameter(0) - indices = s32[2,1] parameter(1) - updates = bf16[2,1,3] parameter(2) - ROOT scatter_TensorFlowScatter_Mul = bf16[3,3] scatter(operand, indices, updates), - to_apply=add_bf16, - update_window_dims={1,2}, - inserted_window_dims={}, - scatter_dims_to_operand_dims={0}, - index_vector_dim=1 -} - -ENTRY main { - p0 = bf16[3,3] parameter(0) - p1 = s32[2,1] parameter(1) - p2 = bf16[2,1,3] parameter(2) - ROOT wrapped_scatter = bf16[3,3] fusion(p0, p1, p2), kind=kInput, calls=fused_computation -} - -// CHECK-SM86-NOT: atomicrmw fadd -// CHECK-SM90: atomicrmw fadd -// CHECK-PTX-SM90: atom.global.add.noftz.bf16 diff --git a/third_party/xla/xla/service/gpu/tests/transpose_021.hlo b/third_party/xla/xla/service/gpu/tests/transpose_021.hlo deleted file mode 100644 index d36d1a6c22e29a..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_021.hlo +++ /dev/null @@ -1,103 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -HloModule Transpose, is_scheduled=true - -%fused_computation { - %p0 = f32[2,16,17]{2,1,0} parameter(0) - ROOT %transpose = f32[2,17,16]{2,1,0} transpose(%p0), dimensions={0,2,1} -} - -ENTRY main { - %param = f32[2,16,17]{2,1,0} parameter(0) - ROOT %fusion = f32[2,17,16] fusion(%param), kind=kInput, calls=%fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 -// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 -// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 -// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 -// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 -// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_12:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] -// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 -// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_21:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] -// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 -// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 -// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] -// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] -// CHECK: %[[VAL_30:.*]] = getelementptr{{.*}} inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_33:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] -// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr -// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 -// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_35:.*]] -// CHECK: loop1.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] -// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 -// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] -// CHECK: loop1.loop_body5: ; preds = %[[VAL_35]] -// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 -// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_43:.*]] -// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] -// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.1 -// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] -// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] -// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 -// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.2, %[[VAL_37]] -// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.1, %[[VAL_45]] -// CHECK: %[[VAL_52:.*]] = getelementptr{{.*}} inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_45]], i32 %[[VAL_37]] -// CHECK: %[[VAL_53:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_52]] to ptr -// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 -// CHECK: %[[VAL_55:.*]] = getelementptr{{.*}} inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] -// CHECK: store float %[[VAL_54]], ptr{{.*}} %[[VAL_55]], align 4 -// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] -// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit3: ; preds = %[[VAL_35]] -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo b/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo deleted file mode 100644 index bcbe49c04a8a7e..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_021_extra_output.hlo +++ /dev/null @@ -1,111 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -HloModule Transpose, is_scheduled=true - -%fused_computation { - %p0 = f32[2,16,17] parameter(0) - %neg = f32[2,16,17] negate(%p0) - %transpose = f32[2,17,16] transpose(%p0), dimensions={0,2,1} - ROOT %tuple = (f32[2,16,17], f32[2,17,16]) tuple(%neg, %transpose) -} - -ENTRY main { - %param = f32[2,16,17] parameter(0) - ROOT %fusion = (f32[2,16,17], f32[2,17,16]) fusion(%param), kind=kInput, calls=%fused_computation -} - - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %[[VAL_4:.*]] = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.1 = urem i32 %[[VAL_4]], 4 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 1 -// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 1 -// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_8]], 0 -// CHECK: %tile_bound.1 = select i1 %[[VAL_10]], i32 16, i32 32 -// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 0 -// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 17, i32 32 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 1 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 32 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 -// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_12:.*]] -// CHECK: loop1.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] -// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.1 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: loop1.loop_body: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 -// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_21:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] -// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 -// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 -// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, %[[VAL_15]] -// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_15]], i32 %[[VAL_23]] -// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr -// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 -// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] -// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [2 x [16 x [17 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 -// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit: ; preds = %[[VAL_12]] -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: store i32 %thread.id.1, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_40:.*]] -// CHECK: loop1.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] -// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 -// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] -// CHECK: loop1.loop_body7: ; preds = %[[VAL_40]] -// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 -// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_48:.*]] -// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] -// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.1 -// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] -// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] -// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 -// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.0, 0 -// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.2, %[[VAL_42]] -// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.1, %[[VAL_50]] -// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [1 x [32 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 0, i32 %[[VAL_50]], i32 %[[VAL_42]] -// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr -// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 -// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [2 x [17 x [16 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] -// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 -// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] -// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} -// CHECK: loop1.loop_exit5: ; preds = %[[VAL_40]] -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/transpose_10.hlo b/third_party/xla/xla/service/gpu/tests/transpose_10.hlo deleted file mode 100644 index 28b765be3d2a86..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_10.hlo +++ /dev/null @@ -1,17 +0,0 @@ -// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb | FileCheck --check-prefixes=CHECK-%{PTX} %s - -// CHECK-PTX: call void @llvm.nvvm.barrier0 -// CHECK-GCN: call void @llvm.amdgcn.s.barrier - -HloModule Test, is_scheduled=true - - -fused_computation { - param_0 = f32[100,200]{1,0} parameter(0) - ROOT b.1 = f32[200,100]{1,0} transpose(f32[100,200]{1,0} param_0), dimensions={1,0} -} - -ENTRY main { - a = f32[100, 200]{1,0} parameter(0) - ROOT wrapped_b = f32[200,100]{1,0} fusion(f32[100,200]{1,0} a), kind=kInput, calls=fused_computation -} diff --git a/third_party/xla/xla/service/gpu/tests/transpose_210.hlo b/third_party/xla/xla/service/gpu/tests/transpose_210.hlo deleted file mode 100644 index 306aee10ad20d5..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_210.hlo +++ /dev/null @@ -1,102 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -HloModule Transpose, is_scheduled=true - -%fused_computation { - %p0 = f32[33,49,65]{2,1,0} parameter(0) - ROOT %transpose = f32[65,49,33]{2,1,0} transpose(%p0), dimensions={2,1,0} -} - -ENTRY main { - %param = f32[33,49,65]{2,1,0} parameter(0) - ROOT %fusion = f32[65,49,33] fusion(%param), kind=kInput, calls=%fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 -// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 -// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 -// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 -// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 -// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 -// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 -// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_12:.*]] -// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] -// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 -// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_21:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] -// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 -// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 -// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] -// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 -// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = load float, ptr %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_34:.*]] = addrspacecast ptr addrspace(3) %[[VAL_33]] to ptr -// CHECK: store float %[[VAL_32]], ptr %[[VAL_34]], align 4 -// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} -// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_35:.*]] -// CHECK: loop0.loop_header4: ; preds = %[[VAL_36:.*]], %[[VAL_17]] -// CHECK: %[[VAL_37:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: %[[VAL_38:.*]] = icmp uge i32 %[[VAL_37]], %tile_bound.2 -// CHECK: br i1 %[[VAL_38]], label %[[VAL_39:.*]], label %[[VAL_40:.*]] -// CHECK: loop0.loop_body5: ; preds = %[[VAL_35]] -// CHECK: %[[VAL_41:.*]] = add nuw nsw i32 %[[VAL_37]], 4 -// CHECK: store i32 %[[VAL_41]], ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_43:.*]] -// CHECK: loop2.loop_header10: ; preds = %[[VAL_44:.*]], %[[VAL_40]] -// CHECK: %[[VAL_45:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_46:.*]] = icmp uge i32 %[[VAL_45]], %tile_bound.0 -// CHECK: br i1 %[[VAL_46]], label %[[VAL_36]], label %[[VAL_44]] -// CHECK: loop2.loop_body11: ; preds = %[[VAL_43]] -// CHECK: %[[VAL_47:.*]] = add nuw nsw i32 %[[VAL_45]], 32 -// CHECK: store i32 %[[VAL_47]], ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_49:.*]] = add i32 %tile_origin.2, %[[VAL_37]] -// CHECK: %[[VAL_50:.*]] = add i32 %tile_origin.1, 0 -// CHECK: %[[VAL_51:.*]] = add i32 %tile_origin.0, %[[VAL_45]] -// CHECK: %[[VAL_52:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_45]], i32 0, i32 %[[VAL_37]] -// CHECK: %[[VAL_53:.*]] = addrspacecast ptr addrspace(3) %[[VAL_52]] to ptr -// CHECK: %[[VAL_54:.*]] = load float, ptr{{.*}} %[[VAL_53]], align 4 -// CHECK: %[[VAL_55:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr %[[VAL_56:.*]], i32 0, i32 %[[VAL_49]], i32 %[[VAL_50]], i32 %[[VAL_51]] -// CHECK: store float %[[VAL_54]], ptr %[[VAL_55]], align 4 -// CHECK: br label %[[VAL_43]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit9: ; preds = %[[VAL_43]] -// CHECK: br label %[[VAL_35]], !llvm.loop !{{[0-9]}} -// CHECK: loop0.loop_exit3: ; preds = %[[VAL_35]] -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo b/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo deleted file mode 100644 index 550b824b79a4a1..00000000000000 --- a/third_party/xla/xla/service/gpu/tests/transpose_210_extra_output.hlo +++ /dev/null @@ -1,109 +0,0 @@ -// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/../../../tools/hlo_opt/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s - -HloModule Transpose, is_scheduled=true - -%fused_computation { - %p0 = f32[33,49,65] parameter(0) - %neg = f32[33,49,65] negate(%p0) - %transpose = f32[65,49,33] transpose(%p0), dimensions={2,1,0} - ROOT %tuple = (f32[33,49,65], f32[65,49,33]) tuple(%neg, %transpose) -} - -ENTRY main { - %param = f32[33,49,65]{2,1,0} parameter(0) - ROOT %fusion = (f32[33,49,65], f32[65,49,33]) fusion(%param), kind=kInput, calls=%fused_computation -} - -// CHECK-LABEL: entry: -// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_2:.*]] = alloca i32, align 4 -// CHECK: %[[VAL_3:.*]] = alloca i32, align 4 -// CHECK-PTX: %thread.id.x = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 -// CHECK-GCN: %thread.id.x = call i32 @llvm.amdgcn.workitem.id.x -// CHECK-PTX: %block.id.x = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !3 -// CHECK-GCN: %block.id.x = call i32 @llvm.amdgcn.workgroup.id.x -// CHECK: %thread.id.0 = udiv i32 %thread.id.x, 32 -// CHECK: %thread.id.2 = urem i32 %thread.id.x, 32 -// CHECK: %lane_id = urem i32 %thread.id.x, 32 -// CHECK: %[[VAL_5:.*]] = udiv i32 %block.id.x, 1 -// CHECK: %[[VAL_6:.*]] = urem i32 %[[VAL_5]], 3 -// CHECK: %[[VAL_7:.*]] = udiv i32 %block.id.x, 3 -// CHECK: %[[VAL_8:.*]] = urem i32 %[[VAL_7]], 49 -// CHECK: %[[VAL_9:.*]] = udiv i32 %block.id.x, 147 -// CHECK: %[[VAL_10:.*]] = icmp eq i32 %[[VAL_9]], 1 -// CHECK: %tile_bound.0 = select i1 %[[VAL_10]], i32 1, i32 32 -// CHECK: %[[VAL_11:.*]] = icmp eq i32 %[[VAL_6]], 2 -// CHECK: %tile_bound.2 = select i1 %[[VAL_11]], i32 1, i32 32 -// CHECK: %tile_origin.0 = mul i32 %[[VAL_9]], 32 -// CHECK: %tile_origin.1 = mul i32 %[[VAL_8]], 1 -// CHECK: %tile_origin.2 = mul i32 %[[VAL_6]], 32 -// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: br label %[[VAL_12:.*]] -// CHECK: loop0.loop_header: ; preds = %[[VAL_13:.*]], %[[VAL_14:.*]] -// CHECK: %[[VAL_15:.*]] = load i32, ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: %[[VAL_16:.*]] = icmp uge i32 %[[VAL_15]], %tile_bound.0 -// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]] -// CHECK: loop0.loop_body: ; preds = %[[VAL_12]] -// CHECK: %[[VAL_19:.*]] = add nuw nsw i32 %[[VAL_15]], 4 -// CHECK: store i32 %[[VAL_19]], ptr{{.*}} %[[VAL_3]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: br label %[[VAL_21:.*]] -// CHECK: loop2.loop_header: ; preds = %[[VAL_22:.*]], %[[VAL_18]] -// CHECK: %[[VAL_23:.*]] = load i32, ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_24:.*]] = icmp uge i32 %[[VAL_23]], %tile_bound.2 -// CHECK: br i1 %[[VAL_24]], label %[[VAL_13]], label %[[VAL_22]] -// CHECK: loop2.loop_body: ; preds = %[[VAL_21]] -// CHECK: %[[VAL_25:.*]] = add nuw nsw i32 %[[VAL_23]], 32 -// CHECK: store i32 %[[VAL_25]], ptr{{.*}} %[[VAL_2]], align 4 -// CHECK: %[[VAL_27:.*]] = add i32 %tile_origin.0, %[[VAL_15]] -// CHECK: %[[VAL_28:.*]] = add i32 %tile_origin.1, 0 -// CHECK: %[[VAL_29:.*]] = add i32 %tile_origin.2, %[[VAL_23]] -// CHECK: %[[VAL_30:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_32:.*]] = load float, ptr{{.*}} %[[VAL_30]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_33:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_15]], i32 0, i32 %[[VAL_23]] -// CHECK: %[[VAL_34:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_33]] to ptr -// CHECK: store float %[[VAL_32]], ptr{{.*}} %[[VAL_34]], align 4 -// CHECK: %[[VAL_35:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_31]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: %[[VAL_36:.*]] = load float, ptr{{.*}} %[[VAL_35]], align 4, !invariant.load !{{[0-9]}} -// CHECK: %[[VAL_37:.*]] = fneg float %[[VAL_36]] -// CHECK: %[[VAL_38:.*]] = getelementptr inbounds [33 x [49 x [65 x float]]], ptr{{.*}} %[[VAL_39:.*]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_29]] -// CHECK: store float %[[VAL_37]], ptr{{.*}} %[[VAL_38]], align 4 -// CHECK: br label %[[VAL_21]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit: ; preds = %[[VAL_21]] -// CHECK: br label %[[VAL_12]], !llvm.loop !{{[0-9]}} -// CHECK: loop0.loop_exit: ; preds = %[[VAL_12]] -// CHECK-PTX: call void @llvm.nvvm.barrier0() -// CHECK-GCN: call void @llvm.amdgcn.s.barrier() -// CHECK: store i32 %thread.id.0, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: br label %[[VAL_40:.*]] -// CHECK: loop0.loop_header6: ; preds = %[[VAL_41:.*]], %[[VAL_17]] -// CHECK: %[[VAL_42:.*]] = load i32, ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: %[[VAL_43:.*]] = icmp uge i32 %[[VAL_42]], %tile_bound.2 -// CHECK: br i1 %[[VAL_43]], label %[[VAL_44:.*]], label %[[VAL_45:.*]] -// CHECK: loop0.loop_body7: ; preds = %[[VAL_40]] -// CHECK: %[[VAL_46:.*]] = add nuw nsw i32 %[[VAL_42]], 4 -// CHECK: store i32 %[[VAL_46]], ptr{{.*}} %[[VAL_1]], align 4 -// CHECK: store i32 %thread.id.2, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: br label %[[VAL_48:.*]] -// CHECK: loop2.loop_header12: ; preds = %[[VAL_49:.*]], %[[VAL_45]] -// CHECK: %[[VAL_50:.*]] = load i32, ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_51:.*]] = icmp uge i32 %[[VAL_50]], %tile_bound.0 -// CHECK: br i1 %[[VAL_51]], label %[[VAL_41]], label %[[VAL_49]] -// CHECK: loop2.loop_body13: ; preds = %[[VAL_48]] -// CHECK: %[[VAL_52:.*]] = add nuw nsw i32 %[[VAL_50]], 32 -// CHECK: store i32 %[[VAL_52]], ptr{{.*}} %[[VAL_0]], align 4 -// CHECK: %[[VAL_54:.*]] = add i32 %tile_origin.2, %[[VAL_42]] -// CHECK: %[[VAL_55:.*]] = add i32 %tile_origin.1, 0 -// CHECK: %[[VAL_56:.*]] = add i32 %tile_origin.0, %[[VAL_50]] -// CHECK: %[[VAL_57:.*]] = getelementptr inbounds [32 x [1 x [33 x float]]], ptr{{.*}} addrspace(3) @tr_tile_0, i32 0, i32 %[[VAL_50]], i32 0, i32 %[[VAL_42]] -// CHECK: %[[VAL_58:.*]] = addrspacecast ptr{{.*}} addrspace(3) %[[VAL_57]] to ptr -// CHECK: %[[VAL_59:.*]] = load float, ptr{{.*}} %[[VAL_58]], align 4 -// CHECK: %[[VAL_60:.*]] = getelementptr inbounds [65 x [49 x [33 x float]]], ptr{{.*}} %[[VAL_61:.*]], i32 0, i32 %[[VAL_54]], i32 %[[VAL_55]], i32 %[[VAL_56]] -// CHECK: store float %[[VAL_59]], ptr{{.*}} %[[VAL_60]], align 4 -// CHECK: br label %[[VAL_48]], !llvm.loop !{{[0-9]}} -// CHECK: loop2.loop_exit11: ; preds = %[[VAL_48]] -// CHECK: br label %[[VAL_40]], !llvm.loop !{{[0-9]}} -// CHECK: loop0.loop_exit5: ; preds = %[[VAL_40]] -// CHECK: ret void diff --git a/third_party/xla/xla/service/gpu/transforms/BUILD b/third_party/xla/xla/service/gpu/transforms/BUILD index 31c99ac6984708..dc97630546573e 100644 --- a/third_party/xla/xla/service/gpu/transforms/BUILD +++ b/third_party/xla/xla/service/gpu/transforms/BUILD @@ -641,6 +641,7 @@ cc_library( "//xla/service/gpu:hlo_traversal", "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -658,6 +659,7 @@ xla_cc_test( "//xla/hlo/ir:hlo", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", "@com_google_absl//absl/strings", @@ -1547,6 +1549,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -1857,6 +1860,7 @@ cc_library( "//xla/service:hlo_creation_utils", "//xla/service:sub_byte_normalization", "//xla/service/gpu:gpu_fusible", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", @@ -2367,6 +2371,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service/gpu:reduction_utils", + "//xla/stream_executor:device_description", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", @@ -2388,6 +2393,7 @@ xla_cc_test( "//xla/service:hlo_parser", "//xla/service:pattern_matcher", "//xla/service:pattern_matcher_gmock", + "//xla/stream_executor:device_description", "//xla/tests:hlo_test_base", "//xla/tests:xla_internal_test_main", ], @@ -2702,6 +2708,7 @@ cc_library( "//xla/service/gpu:backend_configs_cc", "//xla/service/gpu:gpu_fusible", "//xla/service/gpu/runtime:thunk", + "//xla/stream_executor:device_description", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status:statusor", @@ -2719,6 +2726,7 @@ xla_cc_test( ":stream_attribute_annotator", "//xla/hlo/ir:hlo", "//xla/service/gpu:backend_configs_cc", + "//xla/stream_executor:device_description", "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "@com_google_absl//absl/algorithm:container", @@ -2960,6 +2968,7 @@ cc_library( "//xla/hlo/ir:hlo", "//xla/hlo/pass:hlo_pass", "//xla/service:executable", + "//xla/service:hlo_cost_analysis", "//xla/service:hlo_module_config", "//xla/service:shaped_buffer", "//xla/service/gpu:backend_configs_cc", @@ -2967,6 +2976,10 @@ cc_library( "//xla/service/gpu:ir_emission_utils", "//xla/service/gpu/autotuning:autotuner_compile_util", "//xla/service/gpu/autotuning:autotuner_util", + "//xla/service/gpu/transforms:fusion_wrapper", + "//xla/service/gpu/transforms:priority_fusion", + "//xla/service/gpu/transforms:tree_reduction_rewriter", + "//xla/stream_executor:device_description", "//xla/stream_executor:stream", "//xla/tools:hlo_decomposer_lib", "@com_google_absl//absl/container:flat_hash_set", @@ -2982,6 +2995,7 @@ cc_library( xla_test( name = "triton_fusion_numerics_verifier_test", + timeout = "short", srcs = ["triton_fusion_numerics_verifier_test.cc"], backends = [ "gpu_a100", diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc index eb43ca2364f0c8..80b924de98b36d 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.cc @@ -75,7 +75,7 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { continue; } HloInstruction* root = fused_computation->root_instruction(); - if (IsReductionFromOrToContiguousDimensions(*root) || + if (IsReductionFromOrToContiguousDimensions(*root, device_description_) || root->opcode() == HloOpcode::kScatter || (hlo->IsMultiOutputFusion() && absl::c_all_of(root->operands(), [](const HloInstruction* slice) { @@ -89,7 +89,8 @@ absl::StatusOr CopyFusion::DoCopyFusion(HloComputation* computation) { if (copy_user->opcode() == HloOpcode::kGetTupleElement && copy_user->user_count() == 1) { if (IsReductionFromOrToContiguousDimensions( - *(root->operand(copy_user->tuple_index())))) { + *(root->operand(copy_user->tuple_index())), + device_description_)) { other_users.push_back(user); continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h index 8350935c8982d5..b56f98f799be66 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -30,7 +31,8 @@ namespace gpu { // those copies to the fusion, replacing the copies with get_tuple_elements. class CopyFusion : public HloModulePass { public: - CopyFusion() = default; + explicit CopyFusion(const se::DeviceDescription& device_description) + : device_description_(device_description) {} absl::string_view name() const override { return "copy_fusion"; } @@ -41,6 +43,8 @@ class CopyFusion : public HloModulePass { private: absl::StatusOr DoCopyFusion(HloComputation* computation); + + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc index 1bd2d11fe7ddc7..a812bb614f2f22 100644 --- a/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/copy_fusion_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/hlo_test_base.h" namespace xla { @@ -28,8 +29,18 @@ namespace gpu { namespace m = ::xla::match; +auto MakeDeviceDescriptor() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + class CopyFusionTest : public HloTestBase { public: + CopyFusionTest() + : device_description_(MakeDeviceDescriptor()), cf_(device_description_) {} + const stream_executor::DeviceDescription device_description_; CopyFusion cf_; }; diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc index d132cf6f3ae682..5e246016cea911 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_merger.cc @@ -55,6 +55,7 @@ class FusionInstructionMerger { : computation_(computation), shape_size_function_(shape_size_function), gpu_device_info_(gpu_device_info), + fusion_info_cache_(gpu_device_info_), dump_fusion_visualization_(computation->parent() ->config() .debug_options() @@ -113,7 +114,8 @@ absl::Status FusionInstructionMerger::FuseIntoAllUsers( HloInstruction* consumer = user; if (consumer->opcode() != HloOpcode::kFusion) { consumer = computation_->AddInstruction(HloInstruction::CreateFusion( - user->shape(), ChooseFusionKind(*producer, *user), user)); + user->shape(), ChooseFusionKind(*producer, *user, gpu_device_info_), + user)); TF_CHECK_OK(computation_->ReplaceInstruction(user, consumer)); } @@ -223,7 +225,8 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { return FusionDecision::Forbid("not a loop fusion"); } - auto producer_hero = GetRealHeroForMultiOutputFusion(*producer); + auto producer_hero = GetRealHeroForMultiOutputFusion(*producer, + gpu_device_info_); bool has_reduction_user = false; for (const HloInstruction* user : producer->users()) { @@ -235,19 +238,22 @@ FusionDecision FusionInstructionMerger::ShouldFuse(HloInstruction* producer) { ++num_fail_merge_all_users_; return FusionDecision::Forbid("not fusing custom fusions"); } - auto consumer_hero = GetRealHeroForMultiOutputFusion(*user); + auto consumer_hero = GetRealHeroForMultiOutputFusion(*user, + gpu_device_info_); if (auto compatible = - FusionHeroesAreCompatible(producer_hero, consumer_hero); + FusionHeroesAreCompatible(producer_hero, consumer_hero, + gpu_device_info_); !compatible) { return compatible; } - FusionDecision fusible = IsProducerConsumerFusible(*producer, *user); + FusionDecision fusible = IsProducerConsumerFusible(*producer, *user, + gpu_device_info_); if (!fusible) { ++num_fail_merge_all_users_; VLOG(9) << user->ToString(); return fusible; } - if (IsInputFusibleReduction(*user)) { + if (IsInputFusibleReduction(*user, gpu_device_info_)) { has_reduction_user = true; } } diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc index 16957f80d370e0..013c48228631a3 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.cc @@ -117,7 +117,9 @@ absl::StatusOr FusionWrapper::Run( auto* fusion_instruction = computation->AddInstruction(HloInstruction::CreateFusion( instruction->shape(), - ChooseFusionKind(*instruction, *instruction), instruction)); + ChooseFusionKind(*instruction, *instruction, + device_description_), + instruction)); const absl::string_view wrapped_opcode = HloOpcodeString(instruction->opcode()); module->SetAndUniquifyInstrName( diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h index 1e9a085fbb0b26..fec5e13424d79f 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper.h @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -28,12 +29,17 @@ namespace gpu { // have no LHLO equivalent in fusions containing just that instruction. class FusionWrapper : public HloModulePass { public: + explicit FusionWrapper(const se::DeviceDescription& device_description) + : device_description_(device_description) {} absl::string_view name() const override { return "fusion-wrapper"; } using HloPassInterface::Run; absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc index be5f0d7dfd49c6..8d95298e383a53 100644 --- a/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/fusion_wrapper_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "xla/service/gpu/transforms/fusion_wrapper.h" +#include #include #include @@ -23,7 +24,25 @@ namespace xla { namespace gpu { namespace { -class FusionWrapperTest : public HloTestBase {}; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class FusionWrapperTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + const stream_executor::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const stream_executor::DeviceDescription device_description_{ + MakeDeviceDescription()}; +}; TEST_F(FusionWrapperTest, ConvolutionWorks) { RunAndFilecheckHloRewrite(R"(HloModule TestModule @@ -33,7 +52,7 @@ ENTRY TestComputation { kernel = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) ROOT conv = f32[15,1,9,1,7,5]{5,4,3,2,1,0} convolution(input, kernel), dim_labels=0123bf_i0123o->f0123b, window={size=1x2x1x4} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_convolution_computation (param_0: f32[1,10,1,10,5,20], param_1: f32[20,1,2,1,4,15]) -> f32[15,1,9,1,7,5] { // CHECK: %param_0 = f32[1,10,1,10,5,20]{5,4,3,2,1,0} parameter(0) // CHECK: %param_1 = f32[20,1,2,1,4,15]{5,4,3,2,1,0} parameter(1) @@ -56,7 +75,7 @@ TEST_F(FusionWrapperTest, SimpleOp) { p1 = f16[30,41] parameter(1) ROOT result = f16[60, 41] concatenate(p0, p1), dimensions={0} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_concatenate_computation (param_0: f16[30,41], param_1: f16[30,41]) -> f16[60,41] { // CHECK: %param_0 = f16[30,41]{1,0} parameter(0) // CHECK: %param_1 = f16[30,41]{1,0} parameter(1) @@ -90,7 +109,7 @@ TEST_F(FusionWrapperTest, Scatter) { index_vector_dim=0, to_apply=update_s32 })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: wrapped_scatter_computation // CHECK: %[[param_0:.*]] = s32[] parameter(0) // CHECK: %[[param_1:.*]] = s32[0]{0} parameter(1) @@ -119,7 +138,7 @@ TEST_F(FusionWrapperTest, ControlDependency) { constant_one = f32[] constant(1) ROOT add = f32[] add(param, constant_one), control-predecessors={fusion} })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: ROOT %wrapped_add = f32[] fusion(%param.1, %constant_one), // CHECK-SAME: control-predecessors={%fusion})"); } @@ -146,7 +165,7 @@ TEST_F(FusionWrapperTest, While) { %tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3) ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body })", - FusionWrapper(), R"( + FusionWrapper(device_description()), R"( // CHECK: %wrapped_broadcast_computation {{.*}} { // CHECK: %param_0.1 = f32[] parameter(0) // CHECK: ROOT %broadcast.0 = f32[5]{0} broadcast(%param_0.1), dimensions={} @@ -200,7 +219,7 @@ TEST_F(FusionWrapperTest, WhileInFusion) { %parameter.1 = f32[5]{0} parameter(0) ROOT %fusion = (f32[5]{0}) fusion(f32[5]{0} %parameter.1), kind=kLoop, calls=%fusion })", - FusionWrapper(), + FusionWrapper(device_description()), // No change std::nullopt); } diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc index befe869ac072df..89b7bf3e082e8f 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include +#include #include #include "absl/container/flat_hash_set.h" @@ -42,9 +44,11 @@ namespace gpu { namespace { // Gets the representative input shape of the multi-output fusion. -Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { +Shape GetInputShapeForMultiOutputFusion( + const HloInstruction& instr, const se::DeviceDescription& device_info) { // Get the HLO that determines the emitter used for lowering. - const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); + const HloInstruction* real_hero = + GetRealHeroForMultiOutputFusion(instr, device_info); if (real_hero->operands().empty()) { // Simply return an empty shape if the representative node has no input // operands. @@ -71,23 +75,23 @@ class HorizontalInputFusionImpl { // Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to // right. -bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, - const Shape& shape_b) { - if (shape_a.rank() != shape_b.rank()) { - return shape_a.rank() < shape_b.rank(); - } - auto dims_a = shape_a.dimensions(); - auto dims_b = shape_b.dimensions(); - for (size_t i = 0; i < dims_a.size(); ++i) { - if (dims_a[i] != dims_b[i]) { - return dims_a[i] < dims_b[i]; - } - } - return true; -} +// bool CompareShapeDimsFromLeftToRight(const Shape& shape_a, +// const Shape& shape_b) { +// if (shape_a.rank() != shape_b.rank()) { +// return shape_a.rank() < shape_b.rank(); +// } +// auto dims_a = shape_a.dimensions(); +// auto dims_b = shape_b.dimensions(); +// for (size_t i = 0; i < dims_a.size(); ++i) { +// if (dims_a[i] != dims_b[i]) { +// return dims_a[i] < dims_b[i]; +// } +// } +// return true; +// } std::vector FindAndSortFusionCandidates( - HloInstruction* consumer) { + HloInstruction* consumer, const se::DeviceDescription& device_info) { absl::flat_hash_set fusion_instr_set; std::vector fusion_instrs; for (HloInstruction* opnd : consumer->operands()) { @@ -95,7 +99,7 @@ std::vector FindAndSortFusionCandidates( // Find out the input fusion instructions whose only consumer is `consumer`. // This guarantees that fusing these candidates will never create cycles, as // there is no back edge. - if (IsInputFusibleReduction(*predecessor) && + if (IsInputFusibleReduction(*predecessor, device_info) && IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) { if (fusion_instr_set.insert(predecessor).second) { fusion_instrs.push_back(predecessor); @@ -105,16 +109,22 @@ std::vector FindAndSortFusionCandidates( std::sort(fusion_instrs.begin(), fusion_instrs.end(), [&](const HloInstruction* a, const HloInstruction* b) { - Shape shape_a = GetInputShapeForMultiOutputFusion(*a); - Shape shape_b = GetInputShapeForMultiOutputFusion(*b); - if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) { + Shape shape_a = + GetInputShapeForMultiOutputFusion(*a, device_info); + Shape shape_b = + GetInputShapeForMultiOutputFusion(*b, device_info); + auto tuple_for_op = [](const Shape& shape, + const HloInstruction* op){ // Sort shapes according to dimensions, so that the same input // shapes will be placed adjacent each other. - return CompareShapeDimsFromLeftToRight(shape_a, shape_b); - } + // return CompareShapeDimsFromLeftToRight(shape_a, shape_b); + return std::tuple{shape.rank(), shape.dimensions(), + GetInstrCountOfFusible(*op), op->unique_id()}; + }; + return tuple_for_op(shape_a, a) < tuple_for_op(shape_b, b); // Sort `fusion_instrs` according to instruction counts, because // we'd like to fuse together computations of similar sizes. - return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b); + // return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b); }); return fusion_instrs; @@ -128,7 +138,7 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { std::vector def_to_use_order = computation_->MakeInstructionPostOrder(); for (HloInstruction* consumer : def_to_use_order) { - auto candidates = FindAndSortFusionCandidates(consumer); + auto candidates = FindAndSortFusionCandidates(consumer, device_info_); if (candidates.size() <= 1) { continue; } @@ -149,7 +159,8 @@ absl::StatusOr HorizontalInputFusionImpl::Run() { for (size_t j = 1; j < candidates.size(); ++j) { HloInstruction* fusion_anchor = candidates[fusion_anchor_id]; HloInstruction* fused = candidates[j]; - if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && + if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused, + device_info_) && FusionFitsInBudget(*fusion_anchor, *fused, device_info_)) { VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString(); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc index 5fc1a54acd8d53..70fdd8adf42362 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_input_fusion_test.cc @@ -143,19 +143,19 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) { builder.AddInstruction(HloInstruction::CreateTuple(var_outs)); module->AddEntryComputation(builder.Build()); - // Verify that horizontal fusion is kicked in. Check that there are multiple - // `reduce` instructions fused into the same fusion. - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) { - // 6 is just a randomly picked number as we don't exactly know how large the - // fusion will be created due to the `FusionFitsInBudget` constraint. - CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", - /*match_optimized_ir=*/false); - } else { + // // Verify that horizontal fusion is kicked in. Check that there are multiple + // // `reduce` instructions fused into the same fusion. + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() < 4) { + // // 6 is just a randomly picked number as we don't exactly know how large the + // // fusion will be created due to the `FusionFitsInBudget` constraint. + // CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", + // /*match_optimized_ir=*/false); + // } else { // Verify that we produced a multi-output reduction with independent groups. CompileAndVerifyIr(module->Clone(), R"(CHECK: switch {{.*}} label {{.*}} [ CHECK-NEXT: label)", /*match_optimized_ir=*/false); - } + // } // Testing with the entire gpu optimization pipeline. EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5})); diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc index 0a3d705103c416..7decdd6a8b925f 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -43,6 +44,7 @@ limitations under the License. #include "xla/service/sub_byte_normalization.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -70,9 +72,12 @@ PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) { class HorizontalLoopFusionImpl { public: - explicit HorizontalLoopFusionImpl(HloComputation* computation, - absl::string_view prefix) - : computation_(computation), prefix_(prefix) {} + explicit HorizontalLoopFusionImpl( + HloComputation* computation, + const se::DeviceDescription& device_description, absl::string_view prefix) + : computation_(computation), + device_description_(device_description), + prefix_(prefix) {} ~HorizontalLoopFusionImpl() = default; @@ -116,18 +121,20 @@ class HorizontalLoopFusionImpl { class FusionCandidates { public: explicit FusionCandidates(HloInstruction* consumer, - bool sliced_input_fusion) + bool sliced_input_fusion, + const se::DeviceDescription& device_description) : fusible_instrs_(), pos_(0), sliced_input_fusion_(sliced_input_fusion) { - Initialize(consumer); + Initialize(consumer, device_description); } // Gets a span of fusions to be fused. absl::Span GetNextSpanOfFusions(); private: - void Initialize(HloInstruction*); + void Initialize(HloInstruction* consumer, + const se::DeviceDescription& device_description); std::vector fusible_instrs_; // `pos_` points to the start position of the next span. @@ -138,17 +145,19 @@ class HorizontalLoopFusionImpl { }; HloComputation* computation_; + const se::DeviceDescription& device_description_; std::string prefix_; }; // HorizontalLoopFusionImpl -bool IsFusibleCandidate(const HloInstruction& instr) { +bool IsFusibleCandidate(const HloInstruction& instr, + const se::DeviceDescription& device_description) { // For now, we do not support fusing instruction with control flow. if (!instr.control_successors().empty() || !instr.control_predecessors().empty()) { return false; } - if (IsNestableVariadicReduction(instr)) { + if (IsNestableVariadicReduction(instr, device_description)) { return false; } @@ -266,7 +275,7 @@ bool AnyOpndIsParamSharedAmongFusions( } void HorizontalLoopFusionImpl::FusionCandidates::Initialize( - HloInstruction* consumer) { + HloInstruction* consumer, const se::DeviceDescription& device_description) { // First, find out all potential target candidates. We will filter out // unsupported/non-profitable cases below. absl::flat_hash_set fusible_candidates; @@ -275,7 +284,7 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( HloInstruction* predecessor = opnd->LatestNonGteAncestor(); // We support kLoop fusion and element-wise HLOs now. We may extend the // support list if needs arise. - if (IsFusibleCandidate(*predecessor)) { + if (IsFusibleCandidate(*predecessor, device_description)) { if (fusible_candidates.insert(predecessor).second) { // Add unseen fusion to ordered list. ordered_fusible_candidates.push_back(predecessor); @@ -321,22 +330,33 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize( // the fused instructions to have the same number/type of outputs and also the // same output shape. We did a sort here so the fusion candidates is // populating a continuous span. - std::stable_sort( - fusible_instrs_.begin(), fusible_instrs_.end(), - [&](const HloInstruction* a, const HloInstruction* b) { - if (GetUniqueOutputTypeOfFusible(*a) != - GetUniqueOutputTypeOfFusible(*b)) { - return GetUniqueOutputTypeOfFusible(*a) < - GetUniqueOutputTypeOfFusible(*b); - } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) { - return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b); - } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) { - return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b); - } else { - return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) < - ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape()); - } - }); + // std::stable_sort( + // fusible_instrs_.begin(), fusible_instrs_.end(), + // [&](const HloInstruction* a, const HloInstruction* b) { + // if (GetUniqueOutputTypeOfFusible(*a) != + // GetUniqueOutputTypeOfFusible(*b)) { + // return GetUniqueOutputTypeOfFusible(*a) < + // GetUniqueOutputTypeOfFusible(*b); + // } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) { + // return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b); + // } else if (GetInstrCountOfFusible(*a) != GetInstrCountOfFusible(*b)) { + // return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b); + // } else { + // return ShapeUtil::ElementsIn(GetOutputsOfFusible(*a)[0]->shape()) < + // ShapeUtil::ElementsIn(GetOutputsOfFusible(*b)[0]->shape()); + // } + // }); + std::sort(fusible_instrs_.begin(), fusible_instrs_.end(), + [&](const HloInstruction* a, const HloInstruction* b) { + auto make_tuple_for_op = [](const HloInstruction* op) { + return std::tuple{ + GetUniqueOutputTypeOfFusible(*op), + GetOutputSizeOfFusible(*op), GetInstrCountOfFusible(*op), + ShapeUtil::ElementsIn(GetOutputsOfFusible(*op)[0]->shape()), + op->unique_id()}; + }; + return make_tuple_for_op(a) < make_tuple_for_op(b); + }); } // Gets a next span of fusion instructions to be fused. @@ -423,7 +443,8 @@ absl::StatusOr HorizontalLoopFusionImpl::FuseConsumerOperands( HloInstruction* consumer, bool sliced_input_fusion, std::vector& to_fuse_candidates) { bool changed = false; - FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion); + FusionCandidates loop_fusion_candidates(consumer, sliced_input_fusion, + device_description_); while (true) { auto fusibles = loop_fusion_candidates.GetNextSpanOfFusions(); if (fusibles.empty()) { @@ -715,7 +736,8 @@ absl::StatusOr HorizontalLoopFusionImpl::Run() { absl::StatusOr HorizontalLoopFusion::RunOnComputation( HloComputation* computation) { - HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_); + HorizontalLoopFusionImpl horizontal_fusion_impl(computation, + device_description_, prefix_); return horizontal_fusion_impl.Run(); } diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h index a602a516b724b6..b6df65b92727c2 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion.h @@ -25,6 +25,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla { namespace gpu { @@ -124,8 +125,9 @@ namespace gpu { // Note, reshapes are added only if the tensors isn't already a vector. class HorizontalLoopFusion : public HloModulePass { public: - HorizontalLoopFusion() = default; - explicit HorizontalLoopFusion(absl::string_view prefix) : prefix_(prefix) {} + explicit HorizontalLoopFusion(const se::DeviceDescription& device_description, + absl::string_view prefix = "") + : device_description_(device_description), prefix_(prefix) {} absl::string_view name() const override { return "horizontal_loop_fusion"; } @@ -136,6 +138,8 @@ class HorizontalLoopFusion : public HloModulePass { private: absl::StatusOr RunOnComputation(HloComputation*); + + const se::DeviceDescription& device_description_; std::string prefix_; }; diff --git a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc index d3fb82e9d4b05f..bd6f17b86ae7cf 100644 --- a/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/horizontal_loop_fusion_test.cc @@ -47,11 +47,19 @@ namespace { namespace m = ::xla::match; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + class HorizontalLoopFusionTest : public HloTestBase { public: static bool IsFusion(const HloInstruction* instr) { return instr->opcode() == HloOpcode::kFusion; } + const se::DeviceDescription device_description_{MakeDeviceDescription()}; }; TEST_F(HorizontalLoopFusionTest, BasicTest) { @@ -85,7 +93,8 @@ TEST_F(HorizontalLoopFusionTest, BasicTest) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE( + HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -136,7 +145,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForCycle) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { @@ -172,7 +181,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForIncompatibleTypes) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { @@ -259,7 +268,7 @@ TEST_F(HorizontalLoopFusionTest, FusingIntoKLoopAndKInputTogether) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); int input_fusion_count = 0; int loop_fusion_count = 0; @@ -308,7 +317,7 @@ TEST_F(HorizontalLoopFusionTest, HorizontalLoopFusionAfterVerticalFusion) { fusion.AddPass(/*may_duplicate=*/true, device_info); EXPECT_TRUE(fusion.Run(module.get()).value()); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); VLOG(2) << "Dump after horizontal fusion:"; @@ -415,7 +424,7 @@ TEST_F(HorizontalLoopFusionTest, FusingDifferentOutputs) { )") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -545,7 +554,7 @@ TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) { })") .value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); TF_ASSERT_OK(verifier().Run(module.get()).status()); EXPECT_FALSE(HloDCE().Run(module.get()).value()); @@ -586,7 +595,7 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) { )") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { @@ -627,7 +636,7 @@ TEST_F(HorizontalLoopFusionTest, IterativeHorizontalFusion) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(device_description_); iterative_h_fusion.AddPass(); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); @@ -699,7 +708,7 @@ TEST_F(HorizontalLoopFusionTest, TraversalOrder) { .value(); HloPassFix iterative_h_fusion("iterative_h_fusion"); - iterative_h_fusion.AddPass(); + iterative_h_fusion.AddPass(device_description_); EXPECT_TRUE(iterative_h_fusion.Run(module.get()).value()); // Verify that the total number of fusion instructions is 2 so that we @@ -773,7 +782,7 @@ ENTRY main { )"; auto module = ParseAndReturnUnverifiedModule(hlo_text).value(); - EXPECT_TRUE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_TRUE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); VLOG(2) << module->ToString(); @@ -843,7 +852,7 @@ TEST_F(HorizontalLoopFusionTest, DoNotMergeVariadicReductions) { })") .value(); - EXPECT_FALSE(HorizontalLoopFusion().Run(module.get()).value()); + EXPECT_FALSE(HorizontalLoopFusion{device_description_}.Run(module.get()).value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc index bfd8c5bbb6b0a9..0fbefe9731738d 100644 --- a/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/instruction_fusion.cc @@ -108,15 +108,16 @@ FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks( // Do not fuse into fusions if the resulting kernel would suffer from // uncoalesced reads due to a transposed memory access pattern. - if (IsInputFusibleReduction(*consumer) && + if (IsInputFusibleReduction(*consumer, device_info_) && IsPhysicallyTransposing(*producer)) { return FusionDecision::Forbid( "fusing the producer would break read coalescing"); } - RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer)); + RETURN_IF_NOT_FUSIBLE(IsProducerConsumerFusible(*producer, *consumer, + device_info_)); - if (CreatesHeavyComputation(*producer, *consumer)) { + if (CreatesHeavyComputation(*producer, *consumer, device_info_)) { return FusionDecision::Forbid( "the fusion would create a heavy computation"); } @@ -160,7 +161,7 @@ FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer, HloInstruction::FusionKind GpuInstructionFusion::ChooseKind( const HloInstruction* producer, const HloInstruction* consumer) { - return ChooseFusionKind(*producer, *consumer); + return ChooseFusionKind(*producer, *consumer, device_info_); } HloInstruction* GpuInstructionFusion::FuseInstruction( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc index 9af3e7e04d4d47..8847322cfc26fa 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.cc @@ -586,7 +586,8 @@ bool GpuLayoutAssignment::PropagateReductionLayoutToOperand( } int64_t kept_dimension_size = ShapeUtil::ElementsIn(user->shape()); return IsUnnestedReductionFasterThanElemental( - {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}}); + {/*is_row_reduction=*/true, {1, kept_dimension_size, reduction_size}}, + device_description_); } bool GpuLayoutAssignment::InstructionCanChangeLayoutInstance( diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h index efa58f3f8c3c72..dec76b7f141426 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment.h +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment.h @@ -39,10 +39,12 @@ class GpuLayoutAssignment : public LayoutAssignment { ComputationLayout* entry_computation_layout, const se::GpuComputeCapability& gpu_version, const se::dnn::VersionInfo& dnn_version, + const se::DeviceDescription& device_description, ChannelLayoutConstraints* channel_constraints = nullptr) : LayoutAssignment(entry_computation_layout, channel_constraints), gpu_version_(gpu_version), - dnn_version_(dnn_version) {} + dnn_version_(dnn_version), + device_description_ (device_description) {} ~GpuLayoutAssignment() override = default; protected: @@ -73,6 +75,7 @@ class GpuLayoutAssignment : public LayoutAssignment { const se::GpuComputeCapability gpu_version_; const se::dnn::VersionInfo dnn_version_; + const se::DeviceDescription& device_description_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc index 4dbd453e1d4850..3cab3479b00818 100644 --- a/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/layout_assignment_test.cc @@ -50,28 +50,26 @@ namespace m = ::xla::match; using ::tsl::testing::IsOkAndHolds; class LayoutAssignmentTest : public HloTestBase { - public: - se::CudaComputeCapability GetCudaComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .cuda_compute_capability(); - } - - se::GpuComputeCapability GetGpuComputeCapability() { - return backend() - .default_stream_executor() - ->GetDeviceDescription() - .gpu_compute_capability(); - } - - se::dnn::VersionInfo GetDnnVersion() { - // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but - // none of the tests trigger this heuristic. - return GetDnnVersionInfoOrDefault(backend().default_stream_executor(), - se::dnn::VersionInfo{8, 3, 0}); - } -}; + public: + se::DeviceDescription GetDeviceDescription() { + return backend().default_stream_executor()->GetDeviceDescription(); + } + + se::CudaComputeCapability GetCudaComputeCapability() { + return GetDeviceDescription().cuda_compute_capability(); + } + + se::GpuComputeCapability GetGpuComputeCapability() { + return GetDeviceDescription().gpu_compute_capability(); + } + + se::dnn::VersionInfo GetDnnVersion() { + // GpuLayoutAssignment has a special case heuristic for cudnn <= 7.3, but + // none of the tests trigger this heuristic. + return GetDnnVersionInfoOrDefault(backend().default_stream_executor(), + se::dnn::VersionInfo{8, 3, 0}); + } + }; TEST_F(LayoutAssignmentTest, Elementwise) { Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); @@ -110,7 +108,8 @@ TEST_F(LayoutAssignmentTest, Elementwise) { ShapeLayout(result_shape_with_layout); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); for (const HloInstruction* operand : add->operands()) { @@ -140,7 +139,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutUnchangedIfValid) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), GmockMatch(m::Dot(m::Op().WithShape(F32, {5, 2, 3}, {1, 2, 0}), @@ -166,7 +166,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutSetToDefaultIfDefaultValid) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -193,7 +194,8 @@ TEST_F(LayoutAssignmentTest, DotOperandLayoutSetToBatchRowsColsOtherwise) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -219,7 +221,8 @@ TEST_F(LayoutAssignmentTest, DotOperandInconsistentDimLayouts) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -247,7 +250,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT( @@ -280,7 +284,8 @@ TEST_F(LayoutAssignmentTest, TransposedDotOfDotLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); // The transpose layout is not supported by dot.2. Also, we need a copy @@ -316,7 +321,8 @@ TEST_F(LayoutAssignmentTest, DotLayoutS8) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -351,7 +357,8 @@ TEST_F(LayoutAssignmentTest, SortLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); @@ -394,7 +401,8 @@ TEST_F(LayoutAssignmentTest, TopKLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); @@ -421,7 +429,8 @@ TEST_F(LayoutAssignmentTest, FftLayout) { module->entry_computation()->ComputeProgramShape(), /*ignore_layouts=*/false); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(module.get()), IsOkAndHolds(true)); EXPECT_THAT(module->entry_computation()->root_instruction(), @@ -449,7 +458,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -481,7 +491,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -509,7 +520,8 @@ ENTRY entry { m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); @@ -617,7 +629,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -647,7 +660,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -683,7 +697,8 @@ ENTRY main { ComputationLayout computation_layout( m->entry_computation()->ComputeProgramShape()); GpuLayoutAssignment layout_assignment( - &computation_layout, GetGpuComputeCapability(), GetDnnVersion()); + &computation_layout, GetGpuComputeCapability(), GetDnnVersion(), + GetDeviceDescription()); EXPECT_THAT(layout_assignment.Run(m.get()), IsOkAndHolds(true)); auto reduce = m->entry_computation()->root_instruction(); @@ -744,7 +759,7 @@ ENTRY %main { RunAndFilecheckHloRewrite( hlo, GpuLayoutAssignment{&computation_layout, GetGpuComputeCapability(), - GetDnnVersion()}, + GetDnnVersion(), GetDeviceDescription()}, R"( // CHECK: (f32[100,100]{1,0}, u32[], token[]) recv // CHECK: (f32[100,100]{1,0}, token[]) recv-done diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc index 9ab9729b3b9202..41898a4977e84d 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion.cc @@ -189,13 +189,13 @@ FusionDecision ProducerCandidateIsFusible( const HloDfsReachability& reachability, FusionInfoCache* fusion_info_cache, const se::DeviceDescription& device_info, GpuHloCostAnalysis* cost_analysis) { - if (!IsFusibleAsMultiOutputFusionRoot(consumer)) { + if (!IsFusibleAsMultiOutputFusionRoot(consumer, device_info)) { return FusionDecision::Forbid( "consumer not eligible as multi-output fusion root."); } RETURN_IF_NOT_FUSIBLE( - ShapesCompatibleForMultiOutputFusion(consumer, producer)); + ShapesCompatibleForMultiOutputFusion(consumer, producer, device_info)); RETURN_IF_NOT_FUSIBLE( OperandReachableFromProducer(producer, consumer, reachability)); @@ -233,7 +233,7 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( // If the producer is not a valid candidate for MOF, no need to check any of // its users. - if (!IsProducerMultiOutputFusible(*producer)) { + if (!IsProducerMultiOutputFusible(*producer, device_info)) { return fusion_candidates; } @@ -265,9 +265,11 @@ std::vector GetProducerConsumerMultiOutputFusionCandidates( return fusion_candidates; } -bool IsSiblingFusionCandidate(const HloInstruction* instr) { - if (instr->users().empty() || !IsFusibleAsMultiOutputFusionRoot(*instr) || - IsNestableVariadicReduction(*instr)) { +bool IsSiblingFusionCandidate(const HloInstruction* instr, + const se::DeviceDescription& device_info) { + if (instr->users().empty() || + !IsFusibleAsMultiOutputFusionRoot(*instr, device_info) || + IsNestableVariadicReduction(*instr, device_info)) { return false; } // Check if the users of multioutput fusion is not a get-tuple-element. @@ -292,7 +294,7 @@ FusionDecision CanFuseSiblings(const HloInstruction& sibling_consumer_1, } RETURN_IF_NOT_FUSIBLE(ShapesCompatibleForMultiOutputFusion( - sibling_consumer_1, sibling_consumer_2)); + sibling_consumer_1, sibling_consumer_2, device_info)); // Technically, this check is order-dependent (e.g. siblings A, B, C where // {A, B} and {B, C} overlap, but {A, C} do not. If the priority order is @@ -331,7 +333,9 @@ bool MultiOutputFusion::FuseSiblings(HloInstruction* parent, std::vector siblings; // Only consider siblings that are fusion candidates. absl::c_copy_if(parent->users(), std::back_inserter(siblings), - IsSiblingFusionCandidate); + [&](const HloInstruction* instr) { + return IsSiblingFusionCandidate(instr, device_info_); + }); // Sort the siblings such that multi-output fusion ops occur first, followed // by fusion ops, followed by unfused ops. absl::c_stable_sort(siblings, @@ -418,7 +422,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { std::vector defs_before_uses = computation_->MakeInstructionPostOrder(); - FusionInfoCache fusion_info_cache; + FusionInfoCache fusion_info_cache(device_info_); // Traverse the HLO in uses-before-defs order. for (auto it = defs_before_uses.rbegin(); it != defs_before_uses.rend(); ++it) { @@ -467,7 +471,7 @@ absl::StatusOr MultiOutputFusion::DoMultiOutputFusion() { } else { input_fusion = computation_->AddInstruction(HloInstruction::CreateFusion( consumer_for_fusion->shape(), - ChooseFusionKind(*producer, *consumer_for_fusion), + ChooseFusionKind(*producer, *consumer_for_fusion, device_info_), consumer_for_fusion)); VLOG(2) << "Fuse producer " << producer->name() << " and its consumer " << consumer_for_fusion->name() << " into " diff --git a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc index abd45a15538959..1a8b55bc3b1dac 100644 --- a/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/multi_output_fusion_test.cc @@ -1529,9 +1529,10 @@ ENTRY main { } )") .value(); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); - EXPECT_FALSE(mof_.Run(module.get()).value()); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); + // EXPECT_FALSE(mof_.Run(module.get()).value()); + EXPECT_TRUE(mof_.Run(module.get()).value()); } TEST_F(MultiOutputFusionTest, DoNotFuseRoot) { @@ -1765,8 +1766,8 @@ class TransposeMultiOutputFusionTest : public MultiOutputFusionTest { DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = MultiOutputFusionTest::GetDebugOptionsForTest(); - // Only the MLIR transpose emitter supports unpadded 2D transposes. - debug_options.set_xla_gpu_mlir_emitter_level(3); + // // Only the MLIR transpose emitter supports unpadded 2D transposes. + // debug_options.set_xla_gpu_mlir_emitter_level(3); return debug_options; } }; diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc index c0e86818d36fe8..d09798c0c78e37 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion.cc @@ -161,6 +161,7 @@ class PriorityFusionQueue { mlir_context_(mlir_context), fusion_analysis_cache_(fusion_analysis_cache), fusion_deduplication_cache_(fusion_deduplication_cache), + fusion_info_cache_(*device_info_), triton_softmax_priority_fusion_enabled_( triton_softmax_priority_fusion_enabled) { VLOG(2) << "Running full HLO cost analysis for " << computation_->name(); diff --git a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc index 2a0254f55294e4..c7ab7660a9604a 100644 --- a/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/priority_fusion_test.cc @@ -896,9 +896,10 @@ TEST_F(PriorityFusionTest, DoNotFuseProducerConsumerMergedTooLarge) { ROOT fusion2 = pred[6]{0} fusion(fusion1), kind=kInput, calls=fused_computation.2 } )"); - auto& debug_options = module->mutable_config().mutable_debug_options(); - debug_options.set_xla_gpu_mlir_emitter_level(3); - EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); + // auto& debug_options = module->mutable_config().mutable_debug_options(); + // debug_options.set_xla_gpu_mlir_emitter_level(3); + // EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(false)); + EXPECT_THAT(priority_fusion_.Run(module.get()), IsOkAndHolds(true)); } TEST_F(PriorityFusionWithTritonEnabledTest, diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc index dce9288888a8a5..c394ed88ed4d31 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.cc @@ -35,6 +35,7 @@ limitations under the License. #include "xla/service/gpu/reduction_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "tsl/platform/statusor.h" namespace xla { @@ -42,14 +43,16 @@ namespace gpu { class ReductionSplitterVisitor : public DfsHloRewriteVisitor { public: - explicit ReductionSplitterVisitor(bool ignore_small_dims) - : ignore_small_dims_(ignore_small_dims) {} + ReductionSplitterVisitor(const se::DeviceDescription &device_description, + bool ignore_small_dims) + : device_description_(device_description), + ignore_small_dims_(ignore_small_dims) {} absl::Status HandleReduce(HloInstruction *reduce) override { VLOG(4) << "Input: " << reduce->ToString(); // Reductions with contiguous dimensions are lowered to efficient code. No // need to split such ops. - if (IsReductionFromOrToContiguousDimensions(*reduce)) { + if (IsReductionFromOrToContiguousDimensions(*reduce, device_description_)) { VLOG(4) << "Reduction with contiguous dimensions. Return."; return absl::OkStatus(); } @@ -124,15 +127,17 @@ class ReductionSplitterVisitor : public DfsHloRewriteVisitor { } private: + const se::DeviceDescription &device_description_; bool ignore_small_dims_; }; absl::StatusOr ReductionSplitter::Run( HloModule *module, const absl::flat_hash_set &execution_threads) { - TF_ASSIGN_OR_RETURN(bool changed, - ReductionSplitterVisitor(ignore_small_dims_) - .RunOnModule(module, execution_threads)); + TF_ASSIGN_OR_RETURN( + bool changed, + ReductionSplitterVisitor(device_description_, ignore_small_dims_) + .RunOnModule(module, execution_threads)); return changed; } diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h index 74a5a1d6f31a71..b9aac3590b4568 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter.h @@ -36,12 +36,15 @@ namespace gpu { // fixpoint to split reduce ops along multiple dimensions. // // Precondition: ReductionDimensionGrouper has been run and adjacent reduce -// dimentsions have been grouped. Reduction layouts have been normalized. +// dimensions have been grouped. Reduction layouts have been normalized. class ReductionSplitter : public HloModulePass { public: - explicit ReductionSplitter(bool ignore_small_dims) - : ignore_small_dims_(ignore_small_dims) {} + ReductionSplitter(const se::DeviceDescription& device_description, + bool ignore_small_dims) + : device_description_(device_description), + ignore_small_dims_(ignore_small_dims) {} + absl::string_view name() const override { return "reduction-splitter"; } using HloPassInterface::Run; @@ -50,7 +53,8 @@ class ReductionSplitter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - bool ignore_small_dims_; + const se::DeviceDescription& device_description_; + const bool ignore_small_dims_; }; } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc index 4b9f6fb130ed0f..d75ec4361ba812 100644 --- a/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/reduction_splitter_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "xla/service/pattern_matcher.h" #include "xla/service/pattern_matcher_gmock.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -32,7 +33,26 @@ namespace { namespace m = ::xla::match; -class ReductionSplitterTest : public HloTestBase {}; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class ReductionSplitterTest : public HloTestBase { + public: + using HloTestBase::HloTestBase; + + auto MakeReductionSplitter(bool ignore_small_dims) const { + return ReductionSplitter{device_description_, + /*ignore_small_dims=*/ignore_small_dims}; + } + + private: + const stream_executor::DeviceDescription device_description_{ + MakeDeviceDescription()}; +}; TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { auto module = ParseAndReturnVerifiedModule(R"( @@ -54,8 +74,9 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { } )") .value(); - ASSERT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); + ASSERT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/true) + .Run(module.get()) + .value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -86,8 +107,9 @@ TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) { } )") .value(); - ASSERT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); + ASSERT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); SCOPED_TRACE(module->ToString()); const HloInstruction* root_reduction = module->entry_computation()->root_instruction(); @@ -119,11 +141,13 @@ TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) { } )") .value(); - EXPECT_FALSE( - ReductionSplitter(/*ignore_small_dims=*/true).Run(module.get()).value()); - EXPECT_TRUE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); -} + EXPECT_FALSE(MakeReductionSplitter(/*ignore_small_dims=*/true) + .Run(module.get()) + .value()); + EXPECT_TRUE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); + } TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { auto module = ParseAndReturnVerifiedModule(R"( @@ -143,8 +167,9 @@ TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { } )") .value(); - EXPECT_FALSE( - ReductionSplitter(/*ignore_small_dims=*/false).Run(module.get()).value()); + EXPECT_FALSE(MakeReductionSplitter(/*ignore_small_dims=*/false) + .Run(module.get()) + .value()); } } // namespace diff --git a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc index 4b2e12c1ce36b8..3c42961f95ce73 100644 --- a/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/transforms/softmax_rewriter_triton.cc @@ -438,9 +438,9 @@ absl::Status RunFusionPipeline( // transform reductions. reduction_pipeline.AddPass(); reduction_pipeline.AddPass>( + device_info, /*ignore_small_reduce_dims=*/false); - reduction_pipeline.AddPass>( - device_info.gpu_compute_capability()); + reduction_pipeline.AddPass>(device_info); TF_RETURN_IF_ERROR(reduction_pipeline.Run(module).status()); diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc index 68805b1ddc3c0c..866b8a11fd7d6c 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.cc @@ -30,6 +30,7 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/gpu_fusible.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/stream_executor/device_description.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" @@ -105,12 +106,14 @@ absl::StatusOr AnnotateStreamAttributesForCopyStart( absl::StatusOr WrapIntoFusionAndAnnotateStreamAttributes( HloInstruction* instruction, int64_t channel_id, - GpuBackendConfig& instr_gpu_config) { + GpuBackendConfig& instr_gpu_config, + const se::DeviceDescription& device_description) { auto* computation = instruction->parent(); auto* module = computation->parent(); auto* fusion_instruction = computation->AddInstruction(HloInstruction::CreateFusion( - instruction->shape(), ChooseFusionKind(*instruction, *instruction), + instruction->shape(), + ChooseFusionKind(*instruction, *instruction, device_description), instruction)); const absl::string_view wrapped_opcode = HloOpcodeString(instruction->opcode()); @@ -206,7 +209,8 @@ absl::StatusOr StreamAttributeAnnotator::Run( instr->opcode() == HloOpcode::kDynamicUpdateSlice)) { TF_ASSIGN_OR_RETURN(bool comp_result, WrapIntoFusionAndAnnotateStreamAttributes( - instr, channel_id, instr_gpu_config.value())); + instr, channel_id, instr_gpu_config.value(), + device_description_)); changed |= comp_result; continue; } diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h index 84428f359491fc..d773d9a4f3d6e4 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator.h @@ -22,6 +22,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/pass/hlo_pass_interface.h" +#include "xla/stream_executor/device_description.h" namespace xla::gpu { @@ -45,6 +46,10 @@ namespace xla::gpu { class StreamAttributeAnnotator : public HloModulePass { public: + explicit StreamAttributeAnnotator( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} + absl::string_view name() const override { return "stream-attribute-annotator"; } @@ -53,6 +58,9 @@ class StreamAttributeAnnotator : public HloModulePass { absl::StatusOr Run( HloModule* module, const absl::flat_hash_set& execution_threads) override; + + private: + const se::DeviceDescription& device_description_; }; } // namespace xla::gpu diff --git a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc index c7d2ca59cff0e9..f4e40afa475048 100644 --- a/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/stream_attribute_annotator_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/service/gpu/backend_configs.pb.h" +#include "xla/stream_executor/device_description.h" #include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "tsl/platform/statusor.h" @@ -33,7 +34,22 @@ limitations under the License. namespace xla::gpu { namespace { -using StreamAttributeAnnotatorTest = HloTestBase; +auto MakeDeviceDescription() { + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); + return device_description; +} + +class StreamAttributeAnnotatorTest : public HloTestBase { + public: + const se::DeviceDescription& device_description() const { + return device_description_; + } + + private: + const se::DeviceDescription device_description_{MakeDeviceDescription()}; +}; TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) { constexpr absl::string_view kHloString = R"( @@ -53,7 +69,7 @@ TEST_F(StreamAttributeAnnotatorTest, AllUsersAreAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -85,7 +101,7 @@ TEST_F(StreamAttributeAnnotatorTest, MultipleStreamsAreCombined) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -122,7 +138,7 @@ TEST_F(StreamAttributeAnnotatorTest, GTEUserIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -154,7 +170,7 @@ TEST_F(StreamAttributeAnnotatorTest, FusionIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -195,7 +211,7 @@ TEST_F(StreamAttributeAnnotatorTest, CopyStartIsAnnotated) { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kHloString)); - StreamAttributeAnnotator attr_annotator; + StreamAttributeAnnotator attr_annotator{device_description()}; bool changed; TF_ASSERT_OK_AND_ASSIGN(changed, attr_annotator.Run(module.get())); EXPECT_TRUE(changed); @@ -232,7 +248,7 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicUpdateSliceWrappedAndAnnotated) { EXPECT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN(bool changed, - StreamAttributeAnnotator().Run(module.get())); + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-update-slice instruction is wrapped in a fusion @@ -295,7 +311,7 @@ TEST_F(StreamAttributeAnnotatorTest, DynamicSliceWrappedAndAnnotated) { EXPECT_TRUE(module->has_schedule()); TF_ASSERT_OK_AND_ASSIGN(bool changed, - StreamAttributeAnnotator().Run(module.get())); + StreamAttributeAnnotator{device_description()}.Run(module.get())); EXPECT_TRUE(changed); // Check that the dynamic-slice instruction is wrapped in a fusion diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc index fb023fc8cc693f..3766e9e16e873a 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.cc @@ -71,8 +71,9 @@ bool IsMinMaxReduction(HloReduceInstruction *reduce) { class ReductionRewriterVisitor : public DfsHloRewriteVisitor { public: - explicit ReductionRewriterVisitor(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit ReductionRewriterVisitor( + const se::DeviceDescription &device_description) + : device_description_(device_description) {} absl::Status HandleReduce(HloInstruction *hlo) override { auto *reduce = Cast(hlo); @@ -84,7 +85,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { } ReductionDimensions reduction_dims = GetReductionKindAndContiguousComponents(*hlo); - if (ReductionIsRaceFree(config, reduction_dims)) { + if (ReductionIsRaceFree(reduction_dims, device_description_)) { VLOG(3) << "Base case: dimensions fit"; return absl::OkStatus(); } @@ -110,18 +111,19 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { bool MatchReductionForSplit(HloReduceInstruction *reduce, const HloModuleConfig &config) { // MLIR emitters only support race-free reductions. - // TODO(jreiffers: Verify performance and implement atomics for reductions + // TODO(jreiffers): Verify performance and implement atomics for reductions // if needed. - bool reductions_via_mlir_disabled = - config.debug_options().xla_gpu_mlir_emitter_level() < 4; - if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) { - // TODO(cheshire): Also enable for integers. - VLOG(1) << "Not performing tree expansion on min/max-reduction: " - << reduce->ToString() - << " since min/max operations are associative"; - return false; - } - if (!IsReductionFromOrToContiguousDimensions(*reduce)) { + // bool reductions_via_mlir_disabled = + // config.debug_options().xla_gpu_mlir_emitter_level() < 4; + // if (reductions_via_mlir_disabled && IsMinMaxReduction(reduce)) { + // // TODO(cheshire): Also enable for integers. + // VLOG(1) << "Not performing tree expansion on min/max-reduction: " + // << reduce->ToString() + // << " since min/max operations are associative"; + // return false; + // } + if (!IsReductionFromOrToContiguousDimensions(*reduce, + device_description_)) { VLOG(3) << "Is not a reduction from or to contiguous dimensions"; return false; } @@ -136,7 +138,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { uint64_t n, int64_t race_free_bound, bool is_row_reduction) { - CHECK(k1 >= k2); + CHECK_GE(k1, k2); // Keep inner reduction as race free. if (k1 > race_free_bound) { return false; @@ -200,8 +202,8 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { // will have a power of 2 in that range. uint64_t k2 = static_cast(std::floor(std::sqrt(reduced_dim_size))); - int64_t race_free_bound = ReductionDimensionRaceFreeBound( - reduce->GetModule()->config(), reduction_dims); + int64_t race_free_bound = + ReductionDimensionRaceFreeBound(reduction_dims, device_description_); if (k2 > race_free_bound) { // This means we need more than one split. It is best to limit the n/k // dimension to the maximum size that doesn't require further splitting. @@ -371,7 +373,7 @@ class ReductionRewriterVisitor : public DfsHloRewriteVisitor { return ReplaceWithNewInstruction(hlo, std::move(out)); } - se::GpuComputeCapability gpu_version_; + const se::DeviceDescription &device_description_; }; absl::StatusOr TreeReductionRewriter::Run( @@ -379,7 +381,7 @@ absl::StatusOr TreeReductionRewriter::Run( const absl::flat_hash_set &execution_threads) { VLOG(5) << "Rewriter input: " << module->ToString(); TF_ASSIGN_OR_RETURN(bool changed, - ReductionRewriterVisitor(gpu_version_) + ReductionRewriterVisitor(device_description_) .RunOnModule(module, execution_threads)); VLOG(5) << "Rewriter output: " << module->ToString(); return changed; diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h index 4002ca94d585f2..19abdbf645cb4b 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter.h @@ -75,8 +75,9 @@ namespace gpu { // class TreeReductionRewriter : public HloModulePass { public: - explicit TreeReductionRewriter(se::GpuComputeCapability gpu_version) - : gpu_version_(gpu_version) {} + explicit TreeReductionRewriter( + const se::DeviceDescription& device_description) + : device_description_(device_description) {} ~TreeReductionRewriter() override = default; absl::string_view name() const override { return "tree-reduction-rewriter"; } @@ -87,7 +88,7 @@ class TreeReductionRewriter : public HloModulePass { const absl::flat_hash_set& execution_threads) override; private: - se::GpuComputeCapability gpu_version_; + const se::DeviceDescription& device_description_; }; } // end namespace gpu diff --git a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc index 91f4481a202885..639acc185f9b21 100644 --- a/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/tree_reduction_rewriter_test.cc @@ -30,22 +30,11 @@ class TreeReductionRewriterTest : public HloTestBase { public: void CheckTreeRewriter(absl::string_view hlo, std::optional expected) { -#if GOOGLE_CUDA + stream_executor::DeviceDescription device_description{ + stream_executor::GpuDeviceInfoProto{}}; + device_description.set_threads_per_warp(32); RunAndFilecheckHloRewrite( - hlo, -#if TENSORFLOW_USE_ROCM - gpu::TreeReductionRewriter{se::RocmComputeCapability { - "908" - }}, -#else - gpu::TreeReductionRewriter{se::CudaComputeCapability{8, 1}}, -#endif - expected); -#elif TENSORFLOW_USE_ROCM - RunAndFilecheckHloRewrite( - hlo, gpu::GpuTreeReductionRewriter{se::RocmComputeCapability{"908"}}, - expected); -#endif + hlo, gpu::TreeReductionRewriter{device_description}, expected); } }; diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc index b39a50bde50203..70d8887fca49db 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.cc @@ -36,10 +36,15 @@ limitations under the License. #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/buffer_comparator.h" #include "xla/service/gpu/ir_emission_utils.h" +#include "xla/service/gpu/transforms/fusion_wrapper.h" +#include "xla/service/gpu/transforms/priority_fusion.h" +#include "xla/service/gpu/transforms/tree_reduction_rewriter.h" +#include "xla/service/hlo_cost_analysis.h" #include "xla/service/hlo_module_config.h" #include "xla/service/shaped_buffer.h" #include "xla/shape.h" #include "xla/status_macros.h" +#include "xla/stream_executor/device_description.h" #include "xla/stream_executor/stream.h" #include "xla/tools/hlo_decomposer.h" #include "xla/util.h" @@ -70,16 +75,41 @@ absl::StatusOr AsTritonFusion( return nullptr; } -std::unique_ptr NewHloModuleFromFusion( +// Extracts the fusion, disables Triton, and re-runs the fusion pass in order +// to make sure that the fusions are suitable for the MLIR emitters and will be +// reasonably fast. Without this the generated code can be extremely slow (e.g. +// days instead of milliseconds). +absl::StatusOr> NewHloModuleWithoutTritonFromFusion( const HloFusionInstruction& fusion, const DebugOptions& debug_opts, - bool clear_backend_config) { + const se::DeviceDescription& gpu_device_info) { std::unique_ptr new_module = ExtractInstructionIntoNewModule(fusion); - if (clear_backend_config) { - new_module->entry_computation()->root_instruction()->clear_backend_config(); - } + new_module->mutable_config().set_debug_options(debug_opts); + new_module->mutable_config() + .mutable_debug_options() + .clear_xla_gpu_experimental_enable_triton_softmax_priority_fusion(); + + TreeReductionRewriter tree_reduction_rewriter(gpu_device_info); + TF_RETURN_IF_ERROR(tree_reduction_rewriter.Run(new_module.get()).status()); + PriorityFusion fusion_pass( + /*thread_pool=*/nullptr, gpu_device_info, HloCostAnalysis::Options{}); + TF_RETURN_IF_ERROR(fusion_pass.Run(new_module.get()).status()); + + // If the priority fusion pass above skipped some instructions, turn them + // into fusions. + FusionWrapper fusion_wrapper(gpu_device_info); + TF_RETURN_IF_ERROR(fusion_wrapper.Run(new_module.get()).status()); + + return new_module; +} + +std::unique_ptr NewHloModuleWithTritonFromFusion( + const HloFusionInstruction& fusion, const DebugOptions& debug_opts) { + std::unique_ptr new_module = + ExtractInstructionIntoNewModule(fusion); + new_module->mutable_config().set_debug_options(debug_opts); return new_module; } @@ -90,12 +120,16 @@ namespace triton_fusion_numerics_pass_internal { absl::StatusOr CompileAndRunFusion( AutotunerCompileUtil& util, const HloFusionInstruction& fusion, const AutotuneConfig& config, const DebugOptions& debug_opts, - bool clear_backend_config) { - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - util.Compile([&](const DebugOptions& opts) { - return NewHloModuleFromFusion(fusion, opts, - clear_backend_config); - })); + bool disable_triton) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + util.Compile([&](const DebugOptions& opts) { + return disable_triton + ? NewHloModuleWithoutTritonFromFusion( + fusion, opts, + config.GetExecutor()->GetDeviceDescription()) + : NewHloModuleWithTritonFromFusion(fusion, opts); + })); if (executable == nullptr) { return Internal("Failed to compile Triton fusion."); } @@ -157,11 +191,11 @@ absl::Status VerifyTritonFusion(AutotunerCompileUtil& util, TF_ASSIGN_OR_RETURN(auto triton_result, triton_fusion_numerics_pass_internal::CompileAndRunFusion( util, fusion, config, debug_opts, - /*clear_backend_config=*/false)); + /*disable_triton=*/false)); TF_ASSIGN_OR_RETURN(auto emitters_result, triton_fusion_numerics_pass_internal::CompileAndRunFusion( util, fusion, config, debug_opts, - /*clear_backend_config=*/true)); + /*disable_triton=*/true)); TF_ASSIGN_OR_RETURN(auto stream, config.GetStream()); auto status = triton_fusion_numerics_pass_internal::CompareBuffers( diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h index f23a90bff8e4b7..9329bc3a045bf0 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier.h @@ -58,7 +58,7 @@ namespace triton_fusion_numerics_pass_internal { absl::StatusOr CompileAndRunFusion( AutotunerCompileUtil& util, const HloFusionInstruction& fusion, const AutotuneConfig& config, const DebugOptions& debug_opts, - bool clear_backend_config); + bool disable_triton); absl::Status CompareBuffers(const ScopedShapedBuffer& current, const ScopedShapedBuffer& expected, const Shape& shape, const HloModuleConfig& config, diff --git a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc index dd562e07d38aa8..56e71d1dac79d4 100644 --- a/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc +++ b/third_party/xla/xla/service/gpu/transforms/triton_fusion_numerics_verifier_test.cc @@ -233,7 +233,7 @@ ENTRY main { auto compilation_result = triton_fusion_numerics_pass_internal::CompileAndRunFusion( compile_util, *fusion, autotune_config, GetDebugOptionsForTest(), - /*clear_backend_config=*/false); + /*disable_triton=*/false); // Verify that the compilation with default flags fails. The compilation // fails, because the kernel will spill registers, but the error is @@ -245,6 +245,47 @@ ENTRY main { ::testing::HasSubstr("Failed to compile Triton fusion")); } +TEST_F(TritonFusionNumericsVerifierTest, VerifyThatDisablingTritonIsFast) { + // This computation results in a single Triton fusion. If that fusion is + // compiled without Triton and without rerunning the fusion pass, the + // resulting kernel is extremely slow and the test will timeout. This test + // ensures that the fusion pass is rerun. + absl::string_view hlo_text = R"( +max { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT max = f32[] maximum(p0, p1) +} + +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT add = f32[] add(p0, p1) +} + +ENTRY computation { + p0 = f32[16384,16384] parameter(0) + reshape1 = f32[1,1,16384,16384] reshape(p0) + reshape2 = f32[1,16384,16384] reshape(p0) + constant3 = f32[] constant(-inf) + reduce0 = f32[1,16384] reduce(reshape2, constant3), dimensions={2}, to_apply=max + broadcast3 = f32[1,1,16384,16384] broadcast(reduce0), dimensions={1,2} + sub = f32[1,1,16384,16384] subtract(reshape1, broadcast3) + exp = f32[1,1,16384,16384] exponential(sub) + reshape3 = f32[1,16384,16384] reshape(exp) + constant4 = f32[] constant(0) + reduce1 = f32[1,16384] reduce(reshape3, constant4), dimensions={2}, to_apply=add + broadcast4 = f32[1,1,16384,16384] broadcast(reduce1), dimensions={1,2} + ROOT div = f32[1,1,16384,16384] divide(exp, broadcast4) +} + )"; + auto module = Module(hlo_text, ""); + + EXPECT_TRUE(HloPassHasRun(*module, TritonFusionNumericsVerifier::Name())); + auto fusion = TritonFusion(*module); + EXPECT_NE(fusion, nullptr); +} + INSTANTIATE_TEST_SUITE_P(TritonFusionNumericsVerifierTestSuite, TritonFusionNumericsVerifierTest, ::testing::Values(F32, F16, BF16)); diff --git a/third_party/xla/xla/tests/multioutput_fusion_test.cc b/third_party/xla/xla/tests/multioutput_fusion_test.cc index 97ee8b70575426..e57ae6e9974662 100644 --- a/third_party/xla/xla/tests/multioutput_fusion_test.cc +++ b/third_party/xla/xla/tests/multioutput_fusion_test.cc @@ -245,9 +245,10 @@ XLA_TEST_F(MultiOutputFusionTest, XLA_TEST_F(MultiOutputFusionTest, MultiOutputLoopFeedingMap) { #ifdef XLA_TEST_BACKEND_GPU - if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() > 0) { - GTEST_SKIP() << "Nested fusions not supported on GPU with MLIR emitters."; - } + // if (GetDebugOptionsForTest().xla_gpu_mlir_emitter_level() > 0) { + // GTEST_SKIP() << "Nested fusions not supported on GPU with MLIR emitters."; + // } + GTEST_SKIP() << "Nested fusions not supported on GPU with MLIR emitters."; #endif const char* testcase = R"( HloModule m, is_scheduled=true diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo index e323b0c1930dfa..a4a136d5bd0d7a 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_llvm.hlo @@ -1,4 +1,4 @@ -// RUN: hlo-opt %s --platform=gpu --xla_gpu_mlir_emitter_level=0 --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb --split-input-file | FileCheck --check-prefixes=CHECK,CHECK-%{PTX} %s HloModule m @@ -10,7 +10,7 @@ add { // CHECK-LABEL: fusion -// CHECK: 2 x half +// CHECK: 4 x half ENTRY e { p1 = f16[1048576] parameter(0) i = f16[] constant(0) diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo index 51b89db1007350..a36dd2a2271989 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo +++ b/third_party/xla/xla/tools/hlo_opt/gpu_hlo_unoptimized_llvm.hlo @@ -1,6 +1,7 @@ -// RUN: hlo-opt %s --xla_gpu_mlir_emitter_level=0 --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s +// RUN: hlo-opt %s --platform=gpu --stage=llvm-before-optimizations --xla_gpu_target_config_filename=%S/gpu_specs/%{GPU}.txtpb | FileCheck %s -// CHECK: fusion.in_bounds-true: +// CHECK: define void @fusion +// CHECK: br i1 // CHECK: br label // CHECK: concat_index_from_operand_id0: diff --git a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc index 095cd94dc60ac7..5644a3f4ddc96f 100644 --- a/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc +++ b/third_party/xla/xla/tools/hlo_opt/gpu_opt.cc @@ -119,7 +119,8 @@ class GpuOptProvider : public OptProvider { xla::gpu::CompileModuleToLlvmIr( optimized_module, &llvm_context, gpu_compiler->GetTargetTriple(), gpu_compiler->GetDataLayout(), platform->Name(), platform->id(), - target_config.device_description, gpu_compiler->GetCanShareBuffer(), + target_config.device_description, + gpu_compiler->GetCanShareBuffer(target_config.device_description), gpu_compiler->BufferSizeBytesFunction())); return llvm_ir::DumpToString(results.llvm_module.get()); } diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 74eaf1166e459e..3150967231dac7 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -805,13 +805,14 @@ message DebugOptions { // 2: + Loop-like emitters // 3: + Transpose // 4: + Reduce - int64 xla_gpu_mlir_emitter_level = 303; + // int64 xla_gpu_mlir_emitter_level = 303; // The maximum number of kernels to emit with MLIR. Unlimited if 0. reserved 281; // was xla_gpu_max_mlir_kernels // The number of initial kernels to not emit with MLIR. Only supported kernels // are counted. reserved 282; // was xla_gpu_skip_mlir_kernels - + reserved 303; // was xla_gpu_mlir_emitter_level + // Threshold to rewrite matmul to cuBLAS or Triton (minimum combined number of // elements of both matrices in non-batch dimensions to be considered for a // rewrite).