From 838e153bf259becc26d9eabcbb258f0bd35bcc92 Mon Sep 17 00:00:00 2001 From: Xiang Gao Date: Wed, 17 Dec 2025 23:58:04 -0800 Subject: [PATCH] Check consistency of multiplier --- csrc/multidevice/execution_utils.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/csrc/multidevice/execution_utils.cpp b/csrc/multidevice/execution_utils.cpp index a7a7da703e2..5caee19c070 100644 --- a/csrc/multidevice/execution_utils.cpp +++ b/csrc/multidevice/execution_utils.cpp @@ -68,6 +68,9 @@ std::vector unshardedSizes( "Producing logical axis not found for ", sharded_id); + // Global map to track extent -> multiplier relationships + static std::unordered_map extent_to_multiplier_map; + auto multiplier = [&]() -> int64_t { if (parallel_type == ParallelType::Stream) { // TODO(#5525): hack for MultiDeviceExecutor. MultiDeviceExecutor looks @@ -101,6 +104,22 @@ std::vector unshardedSizes( NVF_THROW("Unexpected parallel type: ", parallel_type); }(); + + // Check consistency: for the same extent, we should always get the same multiplier + Val* extent = sharded_id->extent(); + auto it = extent_to_multiplier_map.find(extent); + if (it != extent_to_multiplier_map.end()) { + NVF_ERROR( + it->second == multiplier, + "Inconsistent multiplier for extent ", + extent->toString(), + ": expected ", + it->second, + " but got ", + multiplier); + } else { + extent_to_multiplier_map[extent] = multiplier; + } unsharded_sizes.at(sharded_axis) *= multiplier; }