-
Notifications
You must be signed in to change notification settings - Fork 66
Add Premul Sum test #2690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Premul Sum test #2690
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds test coverage for the Premul Sum operation in the XCCL distributed communication backend. The Premul Sum operation performs element-wise multiplication by a factor before reduction, combining two operations into one for better performance.
- Adds Premul Sum tests to
test_allreduce_opscovering multiple data types (half, float, double) and factor types (scalar and tensor) - Includes commented-out test code for reduce operations with a note that Premul Sum is not supported for reduce ops
- Adds Premul Sum testing to
test_reduce_scatter_opscomparing results against separate multiplication and sum operations
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for factor in (3.0, torch.tensor([5.0], device=self.rank)): | ||
| if isinstance(factor, torch.Tensor): | ||
| factor_ref = factor.cpu().item() | ||
| else: | ||
| factor_ref = factor | ||
| output = [t.float() for t in output] | ||
| tensor_lists = [[t.float() for t in tl] for tl in tensor_lists] | ||
| output_ref = [t.float() for t in output] | ||
| tensor_lists_ref = [ | ||
| [t.float() * factor_ref for t in tl] for tl in tensor_lists |
Copilot
AI
Jan 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables output and tensor_lists are being modified in-place within the loop iteration. This means subsequent iterations will operate on already-converted float tensors, potentially causing errors or unexpected behavior. These conversions should either be moved outside the loop or use fresh copies for each iteration.
| for factor in (3.0, torch.tensor([5.0], device=self.rank)): | |
| if isinstance(factor, torch.Tensor): | |
| factor_ref = factor.cpu().item() | |
| else: | |
| factor_ref = factor | |
| output = [t.float() for t in output] | |
| tensor_lists = [[t.float() for t in tl] for tl in tensor_lists] | |
| output_ref = [t.float() for t in output] | |
| tensor_lists_ref = [ | |
| [t.float() * factor_ref for t in tl] for tl in tensor_lists | |
| # Create baseline float tensors once, and clone them inside the loop to | |
| # avoid accumulating mutations across iterations. | |
| base_output = [t.float() for t in output] | |
| base_tensor_lists = [[t.float() for t in tl] for tl in tensor_lists] | |
| for factor in (3.0, torch.tensor([5.0], device=self.rank)): | |
| if isinstance(factor, torch.Tensor): | |
| factor_ref = factor.cpu().item() | |
| else: | |
| factor_ref = factor | |
| # Fresh copies for this iteration | |
| output = [t.clone() for t in base_output] | |
| tensor_lists = [[t.clone() for t in tl] for tl in base_tensor_lists] | |
| output_ref = [t.clone() for t in base_output] | |
| tensor_lists_ref = [ | |
| [t.clone() * factor_ref for t in tl] for tl in base_tensor_lists |
| if isinstance(factor, torch.Tensor): | ||
| factor_ref = factor.cpu().item() | ||
| else: | ||
| factor_ref = factor |
Copilot
AI
Jan 6, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pattern of extracting a scalar reference from a factor (lines 773-776) is duplicated from earlier in the test file. Consider extracting this into a helper function to reduce code duplication and improve maintainability.
No description provided.