diff --git a/.gitignore b/.gitignore index 1d407a7..6ea5777 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,27 @@ cython_debug/ # Use wildcards as well *~ *.o +# Miscallenous files generated by DGraph data processing +skbuild/ +.vscode/ +logs/ +torchrun_* +*.png +rdvz +*.pt +*.core +*.graph +*.out +*.gz +data_processed +*.zip +cache +graph_cache +*.nsys-rep +*.nsys +*.pth +*.pyc +*.npy +*.npz +*.sqlite +*.csv \ No newline at end of file diff --git a/DGraph/CommunicatorBase.py b/DGraph/CommunicatorBase.py index 502c841..28e5049 100644 --- a/DGraph/CommunicatorBase.py +++ b/DGraph/CommunicatorBase.py @@ -11,21 +11,40 @@ # https://github.com/LBANN and https://github.com/LLNL/LBANN. # # SPDX-License-Identifier: (Apache-2.0) -class CommunicatorBase: - _is_initialized = False +from abc import ABC, abstractmethod + + +class CommunicatorBase(ABC): + _is_initialized: bool = False def __init__(self): self.backend = "" pass + @abstractmethod def init_process_group(self, backend: str, **kwargs): raise NotImplementedError + @abstractmethod def get_rank(self) -> int: raise NotImplementedError + @abstractmethod def get_world_size(self) -> int: raise NotImplementedError - def barrier(self): + @abstractmethod + def barrier(self) -> None: + raise NotImplementedError + + @abstractmethod + def scatter(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def gather(self, *args, **kwargs): + raise NotImplementedError + + @abstractmethod + def destroy(self) -> None: raise NotImplementedError diff --git a/DGraph/__init__.py b/DGraph/__init__.py index abd700a..1d77cc6 100644 --- a/DGraph/__init__.py +++ b/DGraph/__init__.py @@ -11,6 +11,8 @@ # https://github.com/LBANN and https://github.com/LLNL/LBANN. # # SPDX-License-Identifier: (Apache-2.0) +from DGraph.Communicator import Communicator +from DGraph.CommunicatorBase import CommunicatorBase """DGraph. diff --git a/DGraph/distributed/Engine.py b/DGraph/distributed/Engine.py index 19e7774..547aada 100644 --- a/DGraph/distributed/Engine.py +++ b/DGraph/distributed/Engine.py @@ -50,7 +50,7 @@ def scatter( output_size: int, rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError @@ -60,7 +60,7 @@ def gather( indices: Union[torch.Tensor, torch.LongTensor], rank_mappings: Optional[torch.Tensor] = None, *args, - **kwargs + **kwargs, ) -> torch.Tensor: raise NotImplementedError diff --git a/DGraph/distributed/RankLocalOps.py b/DGraph/distributed/RankLocalOps.py index c4b6de0..b7302f1 100644 --- a/DGraph/distributed/RankLocalOps.py +++ b/DGraph/distributed/RankLocalOps.py @@ -16,9 +16,15 @@ """ import torch +import torch.distributed as dist try: - from DGraph.torch_local import local_masked_gather, local_masked_scatter + from DGraph.torch_local import ( + local_masked_gather, + local_masked_scatter, + local_masked_scatter_gather, + local_masked_scatter_add_gather, + ) _LOCAL_OPT_KERNELS_AVAILABLE = True except ImportError: @@ -81,6 +87,93 @@ def OptimizedRankLocalMaskedGather( return output +def OptimizedLocalScatterGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] = src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + output[dst_indices] = src[src_indices] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + ) + return output + + +def OptimizedLocalScatterSumGather( + src: torch.Tensor, + src_indices: torch.Tensor, + dst_indices: torch.Tensor, + output: torch.Tensor, +): + """ + Performs the operation + + for i in range(len(src_indices)): + output[dst_indices[i]] += src[src_indices[i]] + Args: + src (torch.Tensor): Source tensor + src_indices (torch.Tensor): Source indices + dst_indices (torch.Tensor): Destination indices + output (torch.Tensor): Output tensor + Returns: + torch.Tensor: Output tensor after scatter-gather + """ + + if not _LOCAL_OPT_KERNELS_AVAILABLE: + warnings.warn( + "Optimized local kernels are not available. Falling back to the default implementation." + ) + for i in range(src_indices.shape[0]): + output[:, dst_indices[i], :] += src[:, src_indices[i], :] + else: + bs = src.shape[0] + num_src_rows = src.shape[1] + num_features = src.shape[-1] + num_output_rows = output.shape[1] + local_masked_scatter_add_gather( + src, + src_indices.cuda(), + dst_indices.cuda(), + output, + bs, + num_src_rows, + num_features, + num_output_rows, + ) + return output + + def OutOfPlaceRankLocalMaskedGather( _src: torch.Tensor, indices: torch.Tensor, rank_mapping: torch.Tensor, rank: int ) -> torch.Tensor: @@ -140,7 +233,9 @@ def RankLocalRenumberingWithMapping(_indices, rank_mapping): unique_indices, inverse_indices = torch.unique(_indices, return_inverse=True) rank_mapping = rank_mapping.to(_indices.device) renumbered_indices = inverse_indices - unique_rank_mapping = torch.zeros_like(unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device) + unique_rank_mapping = torch.zeros_like( + unique_indices, dtype=rank_mapping.dtype, device=rank_mapping.device + ) unique_rank_mapping.scatter_(0, inverse_indices, rank_mapping) return renumbered_indices, unique_indices, unique_rank_mapping diff --git a/DGraph/distributed/csrc/local_data_kernels.cuh b/DGraph/distributed/csrc/local_data_kernels.cuh index f12ca4a..1b2ea2b 100644 --- a/DGraph/distributed/csrc/local_data_kernels.cuh +++ b/DGraph/distributed/csrc/local_data_kernels.cuh @@ -251,4 +251,144 @@ namespace Local } } } + + + + template + struct FloatAtomicAddOp + { + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + atomicAdd(cur_addr, new_val); + } + }; + + template + struct FloatSetOp + { + __device__ __forceinline__ void operator()(T *cur_addr, const T new_val) + { + *cur_addr = new_val; + } + }; + + + /** + * + * Masked Gather Kernel operation that performs the operation: + Y [mask[i]] = Op(Y [mask[i]], X [indices[i]]) + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + + template + __global__ void Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + Op op; + + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols * num_indices; + const auto output_offset = mb_i * num_cols * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + const auto output_row = mask[mask_offset + row]; + const auto input_row = indices[ind_offset + row]; + + for (size_t col = gidx; col < num_cols; col += nthreadsx) + { + auto *output_addr = &output[output_offset + output_row * num_cols + col]; + const auto input_val = values[values_offset + input_row * num_cols + col]; + op(output_addr, input_val); + } + } + } + } + + /* + * + Optimized masked scatter gather kernel that performs the operation: + Y [mask[i]] = X [indices[i]] + + This kernel is optimized for the case where the num_cols is a multiple of 4. + + where Y is the output matrix, X is the input matrix, indices is the index matrix, and mask is the mask matrix. + */ + template + __global__ void Optimized_Masked_Scatter_Gather_Kernel( + const float *__restrict__ values, + const long *__restrict__ indices, + const long *__restrict__ mask, + float *__restrict__ output, + const int mini_batch_size, + const int num_indices, + const int num_cols, + const int num_output_rows) + { + const size_t gidx = threadIdx.x + blockIdx.x * blockDim.x; + const size_t gidy = threadIdx.y + blockIdx.y * blockDim.y; + const size_t gidz = threadIdx.z + blockIdx.z * blockDim.z; + + const size_t nthreadsx = gridDim.x * blockDim.x; + const size_t nthreadsy = gridDim.y * blockDim.y; + const size_t nthreadsz = gridDim.z * blockDim.z; + + // Grid-stride loop over mini-batches + + Op binary_operator; + for (size_t mb_i = gidz; mb_i < mini_batch_size; mb_i += nthreadsz) + { + const auto values_offset = mb_i * num_cols / 4 * num_indices; + const auto output_offset = mb_i * num_cols / 4 * num_output_rows; + const auto ind_offset = mb_i * num_indices; + const auto mask_offset = mb_i * num_indices; + + // Grid-stride loop over rows + for (size_t row = gidy; row < num_indices; row += nthreadsy) + { + long output_row, input_row; + + if (threadIdx.x == 0) + { + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + } + + output_row = __shfl_sync(0xFFFFFFFF, output_row, 0); + input_row = __shfl_sync(0xFFFFFFFF, input_row, 0); + + output_row = mask[mask_offset + row]; + input_row = indices[ind_offset + row]; + + size_t col = gidx; + + for (; col < num_cols / 4; col += nthreadsx) + { + const float4 values_vec = reinterpret_cast(values)[values_offset + input_row * num_cols / 4 + col]; + float4* output_addr = &reinterpret_cast(output)[output_offset + output_row * num_cols / 4 + col]; + binary_operator(output_addr, values_vec); + } + } + } + } + } // namespace Local \ No newline at end of file diff --git a/DGraph/distributed/csrc/torch_local_bindings.cpp b/DGraph/distributed/csrc/torch_local_bindings.cpp index a91f516..fe685b6 100644 --- a/DGraph/distributed/csrc/torch_local_bindings.cpp +++ b/DGraph/distributed/csrc/torch_local_bindings.cpp @@ -21,4 +21,6 @@ PYBIND11_MODULE(torch_local, m) { m.def("local_masked_gather", &local_masked_gather, "Masked Gather"); m.def("local_masked_scatter", &local_masked_scatter, "Masked Scatter"); + m.def("local_masked_scatter_gather", &local_masked_scatter_gather, "Masked Scatter Gather"); + m.def("local_masked_scatter_add_gather", &local_masked_scatter_add_gather, "Masked Scatter Add Gather"); } diff --git a/DGraph/distributed/csrc/torch_local_kernels.cu b/DGraph/distributed/csrc/torch_local_kernels.cu index b70bf36..896050f 100644 --- a/DGraph/distributed/csrc/torch_local_kernels.cu +++ b/DGraph/distributed/csrc/torch_local_kernels.cu @@ -114,4 +114,120 @@ torch::Tensor local_masked_scatter(torch::Tensor input, rank); CUDACHECK(cudaGetLastError()); return output; +} + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else + { + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; +} + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor mask, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows) +{ + CHECK_INPUT(input); + CHECK_INPUT(indices); + CHECK_INPUT(mask); + CHECK_INPUT(output); + + const float *input_ptr = input.data_ptr(); + const long *indices_ptr = indices.data_ptr(); + const long *mask_ptr = mask.data_ptr(); + float *output_ptr = output.data_ptr(); + + dim3 block_dims, grid_dims; + block_dims.x = 32; + block_dims.y = 32; + block_dims.z = 1; + + const auto num_grids_needed = (num_output_rows + block_dims.y - 1) / block_dims.y; + const auto num_col_grids_needed = (num_cols + block_dims.x - 1) / block_dims.x; + grid_dims.x = num_col_grids_needed < 65535 ? num_col_grids_needed : 65535; + grid_dims.y = num_grids_needed < 65535 ? num_grids_needed : 65535; + grid_dims.z = 1; + + at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream(input.device().index()); + + if (num_cols % 4 != 0) + { + Local::Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + else + { + Local::Optimized_Masked_Scatter_Gather_Kernel><<>>(input_ptr, + indices_ptr, + mask_ptr, + output_ptr, + num_batches, + num_values_rows, + num_cols, + num_output_rows); + } + CUDACHECK(cudaGetLastError()); + return output; } \ No newline at end of file diff --git a/DGraph/distributed/include/torch_local.hpp b/DGraph/distributed/include/torch_local.hpp index f780160..7a4a258 100644 --- a/DGraph/distributed/include/torch_local.hpp +++ b/DGraph/distributed/include/torch_local.hpp @@ -19,4 +19,22 @@ torch::Tensor local_masked_scatter(torch::Tensor input, const int num_values_rows, const int num_cols, const int num_output_rows, - const int rank); \ No newline at end of file + const int rank); + +torch::Tensor local_masked_scatter_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); + +torch::Tensor local_masked_scatter_add_gather(torch::Tensor input, + torch::Tensor indices, + torch::Tensor rank_local_placement, + torch::Tensor output, + const int num_batches, + const int num_values_rows, + const int num_cols, + const int num_output_rows); \ No newline at end of file diff --git a/DGraph/distributed/nccl/NCCLBackendEngine.py b/DGraph/distributed/nccl/NCCLBackendEngine.py index b3ea11a..b8d2fd0 100644 --- a/DGraph/distributed/nccl/NCCLBackendEngine.py +++ b/DGraph/distributed/nccl/NCCLBackendEngine.py @@ -16,488 +16,22 @@ import torch import torch.distributed as dist from DGraph.distributed.Engine import BackendEngine -from DGraph.distributed.nccl._indices_utils import ( - _generate_local_rank_mapping, - _get_local_unique_recv_placement, -) -from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache -from DGraph.distributed.nccl.alltoallv_impl import ( - _nccl_alltoall_v, - _nccl_alltoallv_with_dict, -) -from DGraph.distributed.RankLocalOps import ( - RankLocalMaskedGather, - RankLocalMaskedScatter, - RankLocalRenumberingWithMapping, - OptimizedRankLocalMaskedGather, +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan +from DGraph.distributed.nccl._torch_func_impl import ( + GatherFunction, + ScatterFunction, + CommPlan_ScatterFunction, + CommPlan_GatherFunction, ) + from torch.autograd import Function from DGraph.utils import largest_split +from typing import overload TIMINGS = {"Gather_Index_Forward": [], "Gather_Forward_Local": []} -class GatherFunction(Function): - @staticmethod - def forward( - ctx, - local_send_tensor: torch.Tensor, - indices: torch.LongTensor, - # vertex_ranks: torch.Tensor, - edge_rank_loc: torch.Tensor, - edge_dest_ranks: torch.Tensor, - rank: int, - world_size: int, - cache: Optional[NCCLGatherCache] = None, - ): - num_local_input_rows = local_send_tensor.shape[1] - - if cache is not None: - # We have a cache, use it, don't need to save anything - ctx.has_cache = True - ctx.cache = cache - # TODO: Should we cash the indices as well? - S.Z - else: - ctx.has_cache = False - - ctx.save_for_backward( - indices, - edge_rank_loc, - edge_dest_ranks, - torch.tensor(num_local_input_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - - # Since NCCL is two-sided, we need to push from local rank and pull from - # remote rank to get the global gather - - # TODO: One possible optmization is cache all these calculations - # and only do the gather when the cache is invalidated. Essentially - # if we are working with static graphs, the indices and distribution pattern - # will not change and we can cache the communication pattern. - S.Z - - # We can also pre-compute this on the data ingestion side. Might - # be worth looking to some kind of cached communication pattern store - # that can be passed to the communicator. - S.Z - - batch_size = 1 - num_features = local_send_tensor.shape[2] - - if cache is not None: - local_indices = cache.gather_local_indices % local_send_tensor.shape[1] - local_gather_mask = cache.gather_local_comm_mask - needs_comm = cache.gather_needs_comm - local_output_rows = cache.gather_num_output_rows - local_rank_mapping = cache.gather_local_remapped_ranks - recv_tensor = torch.zeros(batch_size, local_output_rows, num_features).to( - local_send_tensor.device - ) - local_recv_tensor = cache.gather_local_recv_mapping - else: - # Get the edges that are local to the rank - - local_slice_mask = edge_rank_loc == rank - - num_local_output_rows = int(local_slice_mask.sum().item()) - - recv_tensor = torch.zeros( - batch_size, num_local_output_rows, num_features - ).to(local_send_tensor.device) - - local_indices_slice = indices[local_slice_mask.unsqueeze(0)] - local_rank_mapping = edge_rank_loc[local_slice_mask] - local_recv_tensor = edge_dest_ranks[local_slice_mask] - - # assert torch.all(local_recv_tensor == rank), local_recv_tensor - - local_indices = local_indices_slice % local_send_tensor.shape[1] - - needs_comm = (local_recv_tensor != rank).any() - - recv_tensor = OptimizedRankLocalMaskedGather( - local_send_tensor, - local_indices, - local_rank_mapping, - recv_tensor, - rank, - ) - - if needs_comm: - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=local_send_tensor, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_recv_tensor, - edge_rank_loc=edge_rank_loc, - src_rank_loc=edge_dest_ranks, - rank=rank, - world_size=world_size, - cache=cache, - ) - - return recv_tensor - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - ( - indices, - recv_ranks, - send_ranks, - # vertices_per_rank, - num_local_input_rows, - rank, - world_size, - ) = ctx.saved_tensors - - if ctx.has_cache: - cache: Optional[NCCLGatherCache] = ctx.cache - else: - cache = None - - num_local_output_rows = num_local_input_rows.item() - rank = rank.item() - world_size = world_size.item() - send_tensor = grad_output - - # Now it's a scatter operation - num_features = send_tensor.shape[-1] - device = send_tensor.device - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - local_slice_mask = recv_ranks == rank - local_indices_slice = indices[local_slice_mask] - local_dest_ranks = send_ranks[local_slice_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if cache is not None: - local_comm_mask = cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - send_buffer_dict = {} - if torch.any(local_comm_mask): - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - - if cache is not None: - num_remote_rows = cache.scatter_num_remote_rows - remapped_ranks = cache.scatter_local_remapped_ranks - renumbered_indices = cache.scatter_renumbered_indices - receiving_ranks = cache.scatter_remote_send_to_ranks - - else: - - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - num_remote_rows = len(unique_indices) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receiving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - # Now we need to receive the data from the remote ranks - - recv_buffer_dict = {} - - recv_placement = {} - - if cache is not None: - recv_placement = cache.scatter_recv_local_placement - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - else: - send_to_rank = send_ranks # Pedantic variable name change - all_comm_mask = send_to_rank != recv_ranks - reciever_mask = send_to_rank == rank - receive_from_remote = all_comm_mask & reciever_mask - - if torch.any(receive_from_remote): - receive_from_ranks = recv_ranks[receive_from_remote] - - for _sender in range(world_size): - if _sender == rank: - continue - if torch.any(receive_from_ranks == _sender): - _send_mask = (recv_ranks == _sender) & receive_from_remote - _send_indices = indices[_send_mask] % num_local_output_rows - # TODO: This is brittle, look into a better way to do this - S.Z - - unique_send_indices = torch.unique(_send_indices) - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[_sender] = torch.zeros( - 1, num_elements, num_features - ).cuda() - recv_placement[_sender] = unique_send_indices - - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - - send_tensor_grad = local_rank_output - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - rank_grad = None - world_size_grad = None - cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - rank_grad, - world_size_grad, - cache_grad, - ) - - -class ScatterFunction(Function): - @staticmethod - def forward( - ctx, - send_tensor: torch.Tensor, - indices: torch.Tensor, - edge_src_ranks: torch.Tensor, - edge_dest_ranks: torch.Tensor, - num_local_output_rows: int, - rank: int, - world_size: int, - scatter_cache: Optional[NCCLScatterCache] = None, - ) -> torch.Tensor: - - ctx.save_for_backward( - indices, - edge_src_ranks, - edge_dest_ranks, - torch.tensor(num_local_output_rows), - torch.tensor(rank), - torch.tensor(world_size), - ) - use_cache = scatter_cache is not None - if use_cache: - ctx.scatter_cache = scatter_cache - ctx.has_cache = True - else: - ctx.has_cache = False - - num_features = send_tensor.shape[-1] - device = send_tensor.device - - local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( - device - ) - - indices = indices.view(-1) - - local_edge_mask = edge_src_ranks == rank - - local_indices_slice = indices[local_edge_mask] - local_dest_ranks = edge_dest_ranks[local_edge_mask] - - local_rank_output = RankLocalMaskedScatter( - send_tensor, - local_rank_output, - local_indices_slice, - local_dest_ranks, - rank, - ) - - if use_cache: - local_comm_mask = scatter_cache.scatter_local_comm_mask - else: - local_comm_mask = local_dest_ranks != rank - - all_comm_mask = edge_src_ranks != edge_dest_ranks - reciever_mask = edge_dest_ranks == rank - receive_from_remote_mask = all_comm_mask & reciever_mask - - send_buffer_dict = {} - - if torch.any(local_comm_mask): - - if use_cache: - num_remote_rows = scatter_cache.scatter_num_remote_rows - remapped_ranks = scatter_cache.scatter_local_remapped_ranks - renumbered_indices = scatter_cache.scatter_local_renumbered_indices - receving_ranks = scatter_cache.scatter_remote_send_to_ranks - - else: - # These rows need to be sent to other ranks - # First aggregate these into a single buffer - local_comm_indices = local_indices_slice[local_comm_mask] - local_remote_dest_mappings = local_dest_ranks[local_comm_mask] - # TODO: This is very slow, look into a better way to do this - S.Z - # Uncached is slow, should look into augmenting torch functions - # to speed this up - S.Z - renumbered_indices, unique_indices, remapped_ranks = ( - RankLocalRenumberingWithMapping( - local_comm_indices, local_remote_dest_mappings - ) - ) - num_remote_rows = len(unique_indices) - receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) - - buffer = torch.zeros(1, num_remote_rows, num_features).to(device) - buffer.scatter_add_( - 1, - renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), - send_tensor[:, local_comm_mask, :], - ) - - for _recv_rank in receving_ranks: - _recv_indices = remapped_ranks == _recv_rank - send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] - - recv_buffer_dict = {} - recv_placement = {} - if use_cache: - recv_placement = scatter_cache.scatter_recv_local_placement - else: - recv_placement = _get_local_unique_recv_placement( - indices, - edge_src_ranks, - receive_from_remote_mask, - num_local_output_rows, - rank, - world_size, - ) - - # Allocate the receive buffers for the communication based on the - # size of the recv_placement indices. - for key, unique_send_indices in recv_placement.items(): - num_elements = unique_send_indices.shape[0] - recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( - device - ) - recv_buffer_dict = _nccl_alltoallv_with_dict( - send_buffer_dict, recv_buffer_dict, rank, world_size - ) - for key, recv_buffer in recv_buffer_dict.items(): - local_rank_output.scatter_add_( - 1, - recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), - recv_buffer, - ) - return local_rank_output - - @staticmethod - def backward(ctx, grad_output): - # We need to switch the send and recv ranks - indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( - ctx.saved_tensors - ) - - local_mask = recv_ranks == rank - if ctx.has_cache: - cache: NCCLScatterCache = ctx.scatter_cache - num_local_output_rows = cache.gather_num_output_rows - - else: - rank = int(rank.item()) - world_size = int(world_size.item()) - - indices = indices.view(1, -1) - - # Now it's a gather operation - - num_local_output_rows = int(local_mask.sum().item()) - - batch_size = 1 - num_features = grad_output.shape[2] - - recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( - grad_output.device - ) - - local_indices_slice = indices[0][local_mask] - local_rank_mapping = send_ranks[local_mask] - - local_indices = local_indices_slice % grad_output.shape[1] - - if len(local_indices_slice) > 0: - - recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( - grad_output, local_indices, local_rank_mapping, rank - ) - - recv_tensor = _nccl_alltoall_v( - local_send_tensor=grad_output, - local_recv_tensor=recv_tensor, - indices=indices, - local_rank_mapping=local_rank_mapping, - edge_rank_loc=send_ranks, - src_rank_loc=recv_ranks, - rank=rank, - world_size=world_size, - ) - - # if rank == 0: - # breakpoint() - # dist.barrier() - # NOTE: even if the inputs are non-tensors, the number of backward outputs - # must be the same as the number of inputs. - send_tensor_grad = recv_tensor - indices_grad = None - send_ranks_grad = None - recv_ranks_grad = None - num_local_output_rows_grad = None - rank_grad = None - world_size_grad = None - scatter_cache_grad = None - - return ( - send_tensor_grad, - indices_grad, - send_ranks_grad, - recv_ranks_grad, - num_local_output_rows_grad, - rank_grad, - world_size_grad, - scatter_cache_grad, - ) - - class NCCLBackendEngine(BackendEngine): _is_initialized = False _rank = -1 @@ -559,66 +93,99 @@ def get_local_rank_slice(self, tensor: torch.Tensor, dim: int) -> torch.Tensor: end_index = start_index + local_size return tensor[:, start_index:end_index] + @overload def scatter( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, output_size: int, - cache: Optional[NCCLScatterCache] = None, - *args, - **kwargs, + ) -> torch.Tensor: ... + + @overload + def scatter( + self, + local_send_tensor: torch.Tensor, + *, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: ... + + def scatter( + self, + local_send_tensor: torch.Tensor, + indices: Optional[torch.Tensor] = None, + rank_mappings: Optional[torch.Tensor] = None, + output_size: Optional[int] = None, + comm_plan: Optional[NCCLGraphCommPlan] = None, ) -> torch.Tensor: - send_tensor_shape = local_send_tensor.shape - b_size = send_tensor_shape[0] - world_size = self.get_world_size() - rank = self.get_rank() - assert b_size == 1, "Multi-batch gather disabled for testing" - assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" - assert indices.shape[-1] == rank_mappings.shape[-1], ( - f"Indices shape: {indices.shape} and rank mappings shape: " - + f" {rank_mappings.shape} must match" - ) - assert rank_mappings.shape[0] == 2, ( - "Rank mappings shape[0] expected to be 2, " - + f"but got {rank_mappings.shape[0]}" - ) - assert ( - local_send_tensor.device.type == "cuda" - ), f"Device: {local_send_tensor.device.type} expected cuda" - assert output_size > 0, "Output size must be greater than 0" + if comm_plan is not None: + return CommPlan_ScatterFunction.apply(local_send_tensor, comm_plan) # type: ignore + else: + if indices is None or rank_mappings is None or output_size is None: + raise ValueError( + "Indices, rank mappings, and output size must be provided for NCCL backend" + ) - src_ranks = rank_mappings[0] - dest_ranks = rank_mappings[1] + send_tensor_shape = local_send_tensor.shape + b_size = send_tensor_shape[0] - use_cache = cache is not None + world_size = self.get_world_size() + rank = self.get_rank() + assert b_size == 1, "Multi-batch gather disabled for testing" + assert len(send_tensor_shape) == 3, "Currently only support 3D tensors" + assert indices.shape[-1] == rank_mappings.shape[-1], ( + f"Indices shape: {indices.shape} and rank mappings shape: " + + f" {rank_mappings.shape} must match" + ) + assert rank_mappings.shape[0] == 2, ( + "Rank mappings shape[0] expected to be 2, " + + f"but got {rank_mappings.shape[0]}" + ) + assert ( + local_send_tensor.device.type == "cuda" + ), f"Device: {local_send_tensor.device.type} expected cuda" + assert output_size > 0, "Output size must be greater than 0" - if use_cache: - assert type(cache) == NCCLScatterCache - scatter_cache = cache - else: - scatter_cache = None + src_ranks = rank_mappings[0] + dest_ranks = rank_mappings[1] - output_tensor = ScatterFunction.apply( - local_send_tensor, - indices, - src_ranks, - dest_ranks, - output_size, - rank, - world_size, - scatter_cache, - ) + output_tensor = ScatterFunction.apply( + local_send_tensor, + indices, + src_ranks, + dest_ranks, + output_size, + rank, + world_size, + ) return output_tensor # type: ignore + @overload + def gather( + self, + local_send_tensor: torch.Tensor, + indices: torch.Tensor, + rank_mappings: torch.Tensor, + **kwargs, + ) -> torch.Tensor: ... + + @overload + def gather( + self, + local_send_tensor: torch.Tensor, + *, + comm_plan: NCCLGraphCommPlan, + **kwargs, + ) -> torch.Tensor: ... + def gather( self, local_send_tensor: torch.Tensor, indices: torch.Tensor, rank_mappings: torch.Tensor, - cache: Optional[NCCLGatherCache] = None, + comm_plan: Optional[NCCLGraphCommPlan] = None, **kwargs, ) -> torch.Tensor: """Gather the distributed tensor across all ranks according to the indices @@ -644,6 +211,9 @@ def gather( rank_mappings (torch.Tensor): The rank mappings for the gather operation """ + if comm_plan is not None: + return CommPlan_GatherFunction.apply(local_send_tensor, comm_plan) # type: ignore + send_tensor_shape = local_send_tensor.shape b_size = send_tensor_shape[0] world_size = self.get_world_size() @@ -667,14 +237,6 @@ def gather( send_rank = rank_mappings[0] recv_rank = rank_mappings[1] - use_cache = cache is not None - - if use_cache: - assert type(cache) == NCCLGatherCache, f"Invalid cache type {type(cache)}" - gather_cache = cache - else: - gather_cache = None - output_tensor = GatherFunction.apply( local_send_tensor, indices, @@ -682,7 +244,6 @@ def gather( recv_rank, rank, world_size, - gather_cache, ) dist.barrier() diff --git a/DGraph/distributed/nccl/_NCCLCommPlan.py b/DGraph/distributed/nccl/_NCCLCommPlan.py new file mode 100644 index 0000000..8fb3d48 --- /dev/null +++ b/DGraph/distributed/nccl/_NCCLCommPlan.py @@ -0,0 +1,303 @@ +import torch +from dataclasses import dataclass +from typing import List, Optional +import torch.distributed as dist + + +@dataclass +class NCCLGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter (vector addressing) + + Attributes: + rank (int): Local rank + world_size (int): World size + local_num_vertices (int): Number of local vertices + local_src_idx (torch.Tensor): Local source indices for scatter-sum + local_dst_idx (torch.Tensor): Local destination indices for scatter-sum + send_src_idx (torch.Tensor): Source indices to send to other ranks + send_buffer_idx (torch.Tensor): Buffer indices to store data to send to other ranks + send_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to send to each rank + recv_dst_idx (torch.Tensor): Destination indices to receive from other ranks + recv_comm_vector (torch.Tensor): Communication vector of shape [world_size] of messages to + """ + + rank: int + world_size: int + + # Allocation meta data + num_local_vertices: int + num_local_edges: int + + # Local edge-vertex mapping + # + # Used for: + # 1) Local scatter-sum (edge -> vertex aggregation) + # y[local_vertex_idx] += x[local_edge_idx] + # 2) Local gather (vertex -> edge gathering) + # y[local_edge_idx] = x[local_vertex_idx] + + local_edge_idx: torch.Tensor + local_vertex_idx: torch.Tensor + + # Boundary edges (data must be sent/received to/from other ranks for gather/scatter) + + boundary_edge_idx: torch.Tensor + boundary_edge_buffer_map: torch.Tensor + boundary_edge_splits: List[int] + + # Boundary vertices (vertices that have edges on other ranks) + boundary_vertex_idx: torch.Tensor + boundary_vertex_splits: List[int] + + def to(self, device: torch.device): + self.local_edge_idx = self.local_edge_idx.to(device) + self.local_vertex_idx = self.local_vertex_idx.to(device) + self.boundary_edge_idx = self.boundary_edge_idx.to(device) + self.boundary_edge_buffer_map = self.boundary_edge_buffer_map.to(device) + self.boundary_vertex_idx = self.boundary_vertex_idx.to(device) + return self + + +@dataclass +class NCCLEdgeConditionedGraphCommPlan: + """ + Class to store communication plan for distributed gather-scatter for edge-conditioned + graphs where both source and destination vertices are needed. + + Attributes: + rank (int): Local rank + world_size (int): World size + + source_graph_plan (NCCLGraphCommPlan): Communication plan for source vertices + dest_graph_plan (NCCLGraphCommPlan): Communication plan for destination vertices + """ + + rank: int + world_size: int + + source_graph_plan: NCCLGraphCommPlan + dest_graph_plan: Optional[NCCLGraphCommPlan] = None + + def to(self, device: torch.device): + self.source_graph_plan = self.source_graph_plan.to(device) + if self.dest_graph_plan is not None: + self.dest_graph_plan = self.dest_graph_plan.to(device) + return self + + def reverse(self): + if self.dest_graph_plan is None: + raise ValueError("Destination graph plan is None, cannot reverse.") + return NCCLEdgeConditionedGraphCommPlan( + rank=self.rank, + world_size=self.world_size, + source_graph_plan=self.dest_graph_plan, + dest_graph_plan=self.source_graph_plan, + ) + + +def compute_edge_slices(dest_ranks, rank, my_dst_global, offset): + + is_internal = dest_ranks == rank + internal_dst_global = my_dst_global[is_internal] + internal_node_idx = internal_dst_global - offset[rank] + + internal_edge_indices = torch.nonzero(is_internal, as_tuple=True)[0] + + remote_mask = ~is_internal + + boundary_edge_indices = torch.nonzero(remote_mask, as_tuple=True)[0] + + b_dst_global = my_dst_global[remote_mask] + b_dest_ranks = dest_ranks[remote_mask] + + return ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) + + +def fast_2D_unique(indices_1, indices_2): + packed_keys = indices_1.to(torch.int64) << 32 | indices_2.to(torch.int64) + unique_packed, inverse_indices = torch.unique( + packed_keys, return_inverse=True, sorted=False + ) + unique_1 = unique_packed >> 32 + unique_2 = unique_packed & 0xFFFFFFFF + return unique_1, unique_2, inverse_indices + + +def COO_to_NCCLCommPlan( + rank: int, + world_size: int, + global_edges_dst: torch.Tensor, + local_edge_list: torch.Tensor, + offset: torch.Tensor, +) -> NCCLGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + vertex_rank_placement (torch.Tensor): Rank placement of vertices + local_edge_list (torch.Tensor): List of indices of local edges + offset (torch.Tensor): Offset for each rank. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [offset[rank], offset[rank + 1]) are assigned to the rank. + + """ + + device = local_edge_list.device + my_dst_global = global_edges_dst[local_edge_list].to(device) + + if int(offset[-1].item()) > (2**32): + raise ValueError( + f"{offset[-1]}, Number of vertices exceeding {2**32}, which is not supported" + ) + + my_start = offset[rank].item() + my_end = offset[rank + 1].item() + num_local_vertices = int(my_end - my_start) + num_local_edges = local_edge_list.size(0) + + dest_ranks = torch.bucketize(my_dst_global, offset, right=True) - 1 + + # Seperate this out to reduce memory usage + ( + internal_node_idx, + internal_edge_indices, + b_dst_global, + b_dest_ranks, + boundary_edge_indices, + ) = compute_edge_slices(dest_ranks, rank, my_dst_global, offset) + + unique_ranks, unique_global_ids, inverse_indices = fast_2D_unique( + b_dest_ranks, b_dst_global + ) + + print(f"Rank {rank} has {len(boundary_edge_indices)} edges to send ") + print(f"Rank {rank} has {len(unique_ranks)} unique messages to send ") + + if len(unique_ranks) > 0: + print( + f"Rank {rank} message reduction ratio: {len(boundary_edge_indices)/len(unique_ranks)}" + ) + + boundary_edge_buffer_map = inverse_indices + + boundary_edge_splits = torch.bincount(unique_ranks, minlength=world_size).tolist() + + recv_counts_tensor = torch.zeros(world_size, dtype=torch.long, device=device) + send_counts_tensor = torch.tensor( + boundary_edge_splits, dtype=torch.long, device=device + ) + if recv_counts_tensor.device == torch.device("cpu"): + recv_counts_tensor = recv_counts_tensor.cuda() + + if send_counts_tensor.device == torch.device("cpu"): + send_counts_tensor = send_counts_tensor.cuda() + + dist.all_to_all_single(recv_counts_tensor, send_counts_tensor) + print(f"rank: {rank} recv_counts_tensor: {recv_counts_tensor}") + + boundary_node_splits = recv_counts_tensor.tolist() + + total_recv_nodes = sum(boundary_node_splits) + if total_recv_nodes > 0: + recv_global_ids = torch.empty(total_recv_nodes, dtype=torch.long, device=device) + else: + recv_global_ids = torch.empty(0, dtype=torch.long, device=device) + + if sum(send_counts_tensor) == 0: + unique_global_ids = torch.empty(0, dtype=torch.long, device=device) + + if recv_global_ids.device == torch.device("cpu"): + recv_global_ids = recv_global_ids.cuda() + + if unique_global_ids.device == torch.device("cpu"): + unique_global_ids = unique_global_ids.cuda() + + dist.all_to_all_single( + recv_global_ids, + unique_global_ids, + output_split_sizes=boundary_node_splits, + input_split_sizes=boundary_edge_splits, + ) + + boundary_node_idx = recv_global_ids - my_start + + return NCCLGraphCommPlan( + rank=rank, + world_size=world_size, + num_local_vertices=num_local_vertices, + num_local_edges=num_local_edges, + local_edge_idx=internal_edge_indices, + local_vertex_idx=internal_node_idx, + boundary_edge_idx=boundary_edge_indices, + boundary_edge_buffer_map=boundary_edge_buffer_map, + boundary_edge_splits=boundary_edge_splits, + boundary_vertex_idx=boundary_node_idx, + boundary_vertex_splits=boundary_node_splits, + ) + + +def COO_to_NCCLEdgeConditionedCommPlan( + rank: int, + world_size: int, + global_edges_src: torch.Tensor, + global_edges_dst: torch.Tensor, + local_edge_list: torch.Tensor, + src_offset: torch.Tensor, + dest_offset: Optional[torch.Tensor], +) -> NCCLEdgeConditionedGraphCommPlan: + """ + + Convert COO (Coordinate List) format graph to NCCLEdgeConditionedGraphCommPlan for distributed gather-scatter operations. + + Args: + rank (int): Local rank + world_size (int): World size + global_edges_src (torch.Tensor): Global source indices of edges + global_edges_dst (torch.Tensor): Global destination indices of edges + local_edge_list (torch.Tensor): List of indices of local edges + src_offset (torch.Tensor): Offset for each rank for source vertices. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [src_offset[rank], src_offset[rank + 1]) are assigned to the rank. + dest_offset (Optional[torch.Tensor]): Offset for each rank for destination vertices. + The vertices are partitioned among ranks in a contiguous manner. + All vertices in the range [dest_offset[rank], dest_offset[rank + 1]) are assigned to the rank. + """ + device = local_edge_list.device + + source_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_src, + local_edge_list, + src_offset, + ) + + if dest_offset is None: + dest_offset = src_offset + + dest_plan = COO_to_NCCLCommPlan( + rank, + world_size, + global_edges_dst, + local_edge_list, + dest_offset, + ) + + return NCCLEdgeConditionedGraphCommPlan( + rank=rank, + world_size=world_size, + source_graph_plan=source_plan, + dest_graph_plan=dest_plan, + ) diff --git a/DGraph/distributed/nccl/__init__.py b/DGraph/distributed/nccl/__init__.py index cf28164..aae0291 100644 --- a/DGraph/distributed/nccl/__init__.py +++ b/DGraph/distributed/nccl/__init__.py @@ -12,9 +12,9 @@ # # SPDX-License-Identifier: (Apache-2.0) from DGraph.distributed.nccl.NCCLBackendEngine import NCCLBackendEngine, TIMINGS -from DGraph.distributed.nccl._nccl_cache import ( - NCCLGatherCache, - NCCLScatterCache, - NCCLScatterCacheGenerator, - NCCLGatherCacheGenerator, +from DGraph.distributed.nccl._NCCLCommPlan import ( + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, + COO_to_NCCLCommPlan, + COO_to_NCCLEdgeConditionedCommPlan, ) diff --git a/DGraph/distributed/nccl/_nccl_cache.py b/DGraph/distributed/nccl/_nccl_cache.py index 28a2d01..b378f1c 100644 --- a/DGraph/distributed/nccl/_nccl_cache.py +++ b/DGraph/distributed/nccl/_nccl_cache.py @@ -64,6 +64,9 @@ class NCCLScatterCache: world_size: int +# @dataclass +# class + def all_to_all_cache_helper( indices, edge_placement, edge_vertex_ranks, num_rows, rank, world_size ): @@ -200,7 +203,6 @@ def NCCLScatterCacheGenerator( recv_placement = _get_local_unique_recv_placement( indices, edge_placement, remote_recv_mask, num_output_rows, rank, world_size ) - # Information for the backward pass # It's a gather operation so quite a bit simpler @@ -253,7 +255,6 @@ def NCCLGatherCacheGenerator( indices, edge_placement, edge_dest_ranks, num_input_rows, rank, world_size ) ) - local_slice_mask = edge_placement == rank local_mask = edge_placement[local_slice_mask] diff --git a/DGraph/distributed/nccl/_torch_func_impl.py b/DGraph/distributed/nccl/_torch_func_impl.py new file mode 100644 index 0000000..71880b7 --- /dev/null +++ b/DGraph/distributed/nccl/_torch_func_impl.py @@ -0,0 +1,673 @@ +import torch +from typing import Optional +from torch.autograd import Function +import torch.distributed as dist +from dataclasses import dataclass +from DGraph.distributed.nccl._nccl_cache import NCCLGatherCache, NCCLScatterCache +from DGraph.distributed.RankLocalOps import ( + OptimizedRankLocalMaskedGather, + OptimizedLocalScatterGather, + OptimizedLocalScatterSumGather, +) +from DGraph.distributed.nccl._NCCLCommPlan import NCCLGraphCommPlan + + +class CommPlan_GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: + """ + Forward pass for distributed gather using the common plan to effectively perform: + y[i] = x[indices[i]] + + The process is as follows: + 1) Perform local gather from local vertices to local edges + 2) Gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (GatherCommPlan): Communication plan + """ + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.num_local_edges, num_features + ).to(local_send_tensor.device) + + # Local vertex to edge gather + output_tensor = OptimizedLocalScatterGather( + src=local_send_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + output=output_tensor, + ) + + # To do: Combine this with the local gather above to reduce kernel launches + send_buf = local_send_tensor[:, comm_plan.boundary_edge_idx, :] + + total_recv = sum(comm_plan.boundary_edge_splits) + + recv_buffer = torch.empty(num_batches, total_recv, num_features).to( + local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + + output_tensor = OptimizedLocalScatterGather( + src=recv_buffer, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_edge_idx, + output=output_tensor, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for distributed gather + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor. + Shape: (batch_size, num_local_edges, num_features) + """ + comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device + + grad_input = torch.zeros( + num_batches, comm_plan.num_local_vertices, num_features, device=device + ) + + grad_input = OptimizedLocalScatterSumGather( + src=grad_output, + output=grad_input, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + ) + + send_buf = grad_output[:, comm_plan.boundary_vertex_idx, :] + total_recv = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty(num_batches, total_recv, num_features).to(device) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + grad_input = OptimizedLocalScatterSumGather( + src=recv_buffer, + output=grad_input, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, + ) + + return grad_input, None + + +class CommPlan_ScatterFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + comm_plan: NCCLGraphCommPlan, + ) -> torch.Tensor: + """ + Forward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + local_send_tensor (torch.Tensor): Local send tensor + comm_plan (NCCLGraphCommPlan): Communication plan + """ + assert ( + len(local_send_tensor.shape) == 3 + ), "Local send tensor must be of shape (batch_size, num_rows, num_features)" + ctx.comm_plan = comm_plan + + num_features = local_send_tensor.shape[-1] + num_batches = local_send_tensor.shape[0] + + output_tensor = torch.zeros( + num_batches, comm_plan.num_local_vertices, num_features + ).to(local_send_tensor.device) + + output_tensor = OptimizedLocalScatterSumGather( + src=local_send_tensor, + output=output_tensor, + src_indices=comm_plan.local_edge_idx, + dst_indices=comm_plan.local_vertex_idx, + ) + + total_send_rows = sum(comm_plan.boundary_edge_splits) + + send_buf = torch.zeros( + num_batches, total_send_rows, num_features, device=local_send_tensor.device + ) + + send_buf = OptimizedLocalScatterSumGather( + src=local_send_tensor, + output=send_buf, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + ) + + total_recv_rows = sum(comm_plan.boundary_vertex_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=local_send_tensor.device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_vertex_splits, + input_split_sizes=comm_plan.boundary_edge_splits, + ) + output_tensor = OptimizedLocalScatterSumGather( + src=recv_buffer, + output=output_tensor, + src_indices=comm_plan.boundary_edge_buffer_map, + dst_indices=comm_plan.boundary_vertex_idx, + ) + + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + """ + Backward pass for distributed scatter + + Args: + ctx (torch.autograd.FunctionContext): Context object + grad_output (torch.Tensor): Gradient of the output tensor + """ + comm_plan = ctx.comm_plan + num_features = grad_output.shape[-1] + num_batches = grad_output.shape[0] + device = grad_output.device + num_output_rows = comm_plan.num_local_edges + + grad_input = torch.zeros( + num_batches, num_output_rows, num_features, device=device + ) + + grad_input = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.local_vertex_idx, + dst_indices=comm_plan.local_edge_idx, + output=grad_input, + ) + + num_send_rows = sum(comm_plan.boundary_vertex_splits) + send_buf_locs = torch.arange(num_send_rows, device=device) + send_buf = torch.zeros(num_batches, num_send_rows, num_features, device=device) + send_buf = OptimizedLocalScatterGather( + src=grad_output, + src_indices=comm_plan.boundary_vertex_idx, + dst_indices=send_buf_locs, + output=send_buf, + ) + total_recv_rows = sum(comm_plan.boundary_edge_splits) + recv_buffer = torch.empty( + num_batches, total_recv_rows, num_features, device=device + ) + dist.all_to_all_single( + recv_buffer, + send_buf, + output_split_sizes=comm_plan.boundary_edge_splits, + input_split_sizes=comm_plan.boundary_vertex_splits, + ) + + grad_input = OptimizedLocalScatterGather( + src=recv_buffer, + src_indices=comm_plan.boundary_edge_idx, + dst_indices=comm_plan.boundary_edge_buffer_map, + output=grad_input, + ) + + return grad_input, None + + +class GatherFunction(Function): + @staticmethod + def forward( + ctx, + local_send_tensor: torch.Tensor, + indices: torch.LongTensor, + # vertex_ranks: torch.Tensor, + edge_rank_loc: torch.Tensor, + edge_dest_ranks: torch.Tensor, + rank: int, + world_size: int, + ): + num_local_input_rows = local_send_tensor.shape[1] + + ctx.save_for_backward( + indices, + edge_rank_loc, + edge_dest_ranks, + torch.tensor(num_local_input_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + + # Since NCCL is two-sided, we need to push from local rank and pull from + # remote rank to get the global gather + + # TODO: One possible optmization is cache all these calculations + # and only do the gather when the cache is invalidated. Essentially + # if we are working with static graphs, the indices and distribution pattern + # will not change and we can cache the communication pattern. - S.Z + + # We can also pre-compute this on the data ingestion side. Might + # be worth looking to some kind of cached communication pattern store + # that can be passed to the communicator. - S.Z + + batch_size = 1 + num_features = local_send_tensor.shape[2] + + local_slice_mask = edge_rank_loc == rank + + num_local_output_rows = int(local_slice_mask.sum().item()) + + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + local_send_tensor.device + ) + + local_indices_slice = indices[local_slice_mask.unsqueeze(0)] + local_rank_mapping = edge_rank_loc[local_slice_mask] + local_recv_tensor = edge_dest_ranks[local_slice_mask] + + # assert torch.all(local_recv_tensor == rank), local_recv_tensor + + local_indices = local_indices_slice % local_send_tensor.shape[1] + + needs_comm = (local_recv_tensor != rank).any() + + recv_tensor = OptimizedRankLocalMaskedGather( + local_send_tensor, + local_indices, + local_rank_mapping, + recv_tensor, + rank, + ) + + if needs_comm: + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=local_send_tensor, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_recv_tensor, + edge_rank_loc=edge_rank_loc, + src_rank_loc=edge_dest_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + return recv_tensor + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + ( + indices, + recv_ranks, + send_ranks, + # vertices_per_rank, + num_local_input_rows, + rank, + world_size, + ) = ctx.saved_tensors + + num_local_output_rows = num_local_input_rows.item() + rank = rank.item() + world_size = world_size.item() + send_tensor = grad_output + + # Now it's a scatter operation + num_features = send_tensor.shape[-1] + device = send_tensor.device + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + local_slice_mask = recv_ranks == rank + local_indices_slice = indices[local_slice_mask] + local_dest_ranks = send_ranks[local_slice_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if cache is not None: + local_comm_mask = cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + send_buffer_dict = {} + if torch.any(local_comm_mask): + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + + if cache is not None: + num_remote_rows = cache.scatter_num_remote_rows + remapped_ranks = cache.scatter_local_remapped_ranks + renumbered_indices = cache.scatter_renumbered_indices + receiving_ranks = cache.scatter_remote_send_to_ranks + + else: + + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + receiving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + num_remote_rows = len(unique_indices) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receiving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + # Now we need to receive the data from the remote ranks + + recv_buffer_dict = {} + + recv_placement = {} + + if cache is not None: + recv_placement = cache.scatter_recv_local_placement + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + else: + send_to_rank = send_ranks # Pedantic variable name change + all_comm_mask = send_to_rank != recv_ranks + reciever_mask = send_to_rank == rank + receive_from_remote = all_comm_mask & reciever_mask + + if torch.any(receive_from_remote): + receive_from_ranks = recv_ranks[receive_from_remote] + + for _sender in range(world_size): + if _sender == rank: + continue + if torch.any(receive_from_ranks == _sender): + _send_mask = (recv_ranks == _sender) & receive_from_remote + _send_indices = indices[_send_mask] % num_local_output_rows + # TODO: This is brittle, look into a better way to do this - S.Z + + unique_send_indices = torch.unique(_send_indices) + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[_sender] = torch.zeros( + 1, num_elements, num_features + ).cuda() + recv_placement[_sender] = unique_send_indices + + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + + send_tensor_grad = local_rank_output + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + rank_grad = None + world_size_grad = None + cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + rank_grad, + world_size_grad, + cache_grad, + ) + + +class ScatterFunction(Function): + @staticmethod + def forward( + ctx, + send_tensor: torch.Tensor, + indices: torch.Tensor, + edge_src_ranks: torch.Tensor, + edge_dest_ranks: torch.Tensor, + num_local_output_rows: int, + rank: int, + world_size: int, + ) -> torch.Tensor: + + ctx.save_for_backward( + indices, + edge_src_ranks, + edge_dest_ranks, + torch.tensor(num_local_output_rows), + torch.tensor(rank), + torch.tensor(world_size), + ) + use_cache = scatter_cache is not None + if use_cache: + ctx.scatter_cache = scatter_cache + ctx.has_cache = True + else: + ctx.has_cache = False + + num_features = send_tensor.shape[-1] + device = send_tensor.device + + local_rank_output = torch.zeros(1, num_local_output_rows, num_features).to( + device + ) + + indices = indices.view(-1) + + local_edge_mask = edge_src_ranks == rank + + local_indices_slice = indices[local_edge_mask] + local_dest_ranks = edge_dest_ranks[local_edge_mask] + + local_rank_output = RankLocalMaskedScatter( + send_tensor, + local_rank_output, + local_indices_slice, + local_dest_ranks, + rank, + ) + + if use_cache: + local_comm_mask = scatter_cache.scatter_local_comm_mask + else: + local_comm_mask = local_dest_ranks != rank + + all_comm_mask = edge_src_ranks != edge_dest_ranks + reciever_mask = edge_dest_ranks == rank + receive_from_remote_mask = all_comm_mask & reciever_mask + + send_buffer_dict = {} + + if torch.any(local_comm_mask): + + if use_cache: + num_remote_rows = scatter_cache.scatter_num_remote_rows + remapped_ranks = scatter_cache.scatter_local_remapped_ranks + renumbered_indices = scatter_cache.scatter_local_renumbered_indices + receving_ranks = scatter_cache.scatter_remote_send_to_ranks + + else: + # These rows need to be sent to other ranks + # First aggregate these into a single buffer + local_comm_indices = local_indices_slice[local_comm_mask] + local_remote_dest_mappings = local_dest_ranks[local_comm_mask] + # TODO: This is very slow, look into a better way to do this - S.Z + # Uncached is slow, should look into augmenting torch functions + # to speed this up - S.Z + renumbered_indices, unique_indices, remapped_ranks = ( + RankLocalRenumberingWithMapping( + local_comm_indices, local_remote_dest_mappings + ) + ) + num_remote_rows = len(unique_indices) + receving_ranks = torch.unique(local_dest_ranks[local_comm_mask]) + + buffer = torch.zeros(1, num_remote_rows, num_features).to(device) + buffer.scatter_add_( + 1, + renumbered_indices.view(1, -1, 1).expand(1, -1, num_features), + send_tensor[:, local_comm_mask, :], + ) + + for _recv_rank in receving_ranks: + _recv_indices = remapped_ranks == _recv_rank + send_buffer_dict[_recv_rank.item()] = buffer[:, _recv_indices, :] + + recv_buffer_dict = {} + recv_placement = {} + if use_cache: + recv_placement = scatter_cache.scatter_recv_local_placement + else: + recv_placement = _get_local_unique_recv_placement( + indices, + edge_src_ranks, + receive_from_remote_mask, + num_local_output_rows, + rank, + world_size, + ) + + # Allocate the receive buffers for the communication based on the + # size of the recv_placement indices. + for key, unique_send_indices in recv_placement.items(): + num_elements = unique_send_indices.shape[0] + recv_buffer_dict[key] = torch.zeros(1, num_elements, num_features).to( + device + ) + recv_buffer_dict = _nccl_alltoallv_with_dict( + send_buffer_dict, recv_buffer_dict, rank, world_size + ) + for key, recv_buffer in recv_buffer_dict.items(): + local_rank_output.scatter_add_( + 1, + recv_placement[key].view(1, -1, 1).expand(1, -1, num_features), + recv_buffer, + ) + return local_rank_output + + @staticmethod + def backward(ctx, grad_output): + # We need to switch the send and recv ranks + indices, recv_ranks, send_ranks, num_input_rows, rank, world_size = ( + ctx.saved_tensors + ) + + local_mask = recv_ranks == rank + if ctx.has_cache: + cache: NCCLScatterCache = ctx.scatter_cache + num_local_output_rows = cache.gather_num_output_rows + + else: + rank = int(rank.item()) + world_size = int(world_size.item()) + + indices = indices.view(1, -1) + + # Now it's a gather operation + + num_local_output_rows = int(local_mask.sum().item()) + + batch_size = 1 + num_features = grad_output.shape[2] + + recv_tensor = torch.zeros(batch_size, num_local_output_rows, num_features).to( + grad_output.device + ) + + local_indices_slice = indices[0][local_mask] + local_rank_mapping = send_ranks[local_mask] + + local_indices = local_indices_slice % grad_output.shape[1] + + if len(local_indices_slice) > 0: + + recv_tensor[:, local_rank_mapping == rank, :] = RankLocalMaskedGather( + grad_output, local_indices, local_rank_mapping, rank + ) + + recv_tensor = _nccl_alltoall_v( + local_send_tensor=grad_output, + local_recv_tensor=recv_tensor, + indices=indices, + local_rank_mapping=local_rank_mapping, + edge_rank_loc=send_ranks, + src_rank_loc=recv_ranks, + rank=rank, + world_size=world_size, + cache=cache, + ) + + # NOTE: even if the inputs are non-tensors, the number of backward outputs + # must be the same as the number of inputs. + send_tensor_grad = recv_tensor + indices_grad = None + send_ranks_grad = None + recv_ranks_grad = None + num_local_output_rows_grad = None + rank_grad = None + world_size_grad = None + scatter_cache_grad = None + + return ( + send_tensor_grad, + indices_grad, + send_ranks_grad, + recv_ranks_grad, + num_local_output_rows_grad, + rank_grad, + world_size_grad, + scatter_cache_grad, + ) diff --git a/DGraph/distributed/nccl/alltoallv_impl.py b/DGraph/distributed/nccl/alltoallv_impl.py index 060c390..d549cb9 100644 --- a/DGraph/distributed/nccl/alltoallv_impl.py +++ b/DGraph/distributed/nccl/alltoallv_impl.py @@ -159,3 +159,23 @@ def _nccl_alltoallv_with_dict(send_buffer_dict, recv_buffer_dict, rank, world_si for key, recv_buffer in recv_buffer_dict.items(): recv_buffer_dict[key] = recv_buffer.float() return recv_buffer_dict + + +def torch_alltoallv_with_comm_map(contiguous_send_tensor: torch.Tensor, + contiguous_recv_tensor: torch.Tensor, + send_comm_map: torch.Tensor, + recv_comm_map: torch.Tensor, + rank: int, + world_size: int): + assert len(send_comm_map) == world_size, "Send comm map should be of size world_size" + assert len(recv_comm_map) == world_size, "Recv comm map should be of size world_size" + + send_sizes = send_comm_map.tolist() + recv_sizes = recv_comm_map.tolist() + + send_list = list(torch.split(contiguous_send_tensor, send_sizes, dim=1)) + recv_list = list(torch.split(contiguous_recv_tensor, recv_sizes, dim=1)) + + dist.all_to_all(recv_list, send_list) + return recv_list + diff --git a/experiments/Benchmarks/README.md b/experiments/Benchmarks/README.md index cabbb68..e60f14c 100644 --- a/experiments/Benchmarks/README.md +++ b/experiments/Benchmarks/README.md @@ -34,7 +34,8 @@ class ScatterGraphData: data_rank_mapping: torch.Tensor # Where each data is located edge_rank_placement: torch.Tensor # Where each edge is located edge_dst_rank: torch.Tensor # Rank of the destination vertex of each edge - edge_indices: torch.Tensor # Vertex index of the destination vertex of each num_local_vertices: int # Number of vertices on each rank + edge_indices: torch.Tensor # Vertex index of the destination vertex of each edge + num_local_vertices: int # Number of vertices on each rank ``` *** New communication patterns can be added to the benchmarking code by creating new instances of these dataclasses. *** diff --git a/experiments/OGB-LSC/CacheGenerator.py b/experiments/OGB-LSC/CacheGenerator.py new file mode 100644 index 0000000..e99ad06 --- /dev/null +++ b/experiments/OGB-LSC/CacheGenerator.py @@ -0,0 +1,184 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch + +import os.path as osp +from DGraph.distributed.nccl._nccl_cache import ( + NCCLGatherCacheGenerator, + NCCLScatterCacheGenerator, +) + + +def get_cache( + src_gather_cache, + dest_gather_cache, + dest_scatter_cache, + src_gather_cache_file, + dest_gather_cache_file, + dest_scatter_cache_file, + rank, + world_size, + src_indices, + dest_indices, + edge_location, + src_data_mappings, + dest_data_mappings, + num_src_rows, + num_dest_rows, +): + """ """ + if src_gather_cache is None: + + _src_gather_cache = NCCLGatherCacheGenerator( + indices=src_indices, + edge_placement=edge_location, + edge_dest_ranks=src_data_mappings, + num_input_rows=num_src_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_src_gather_cache, src_gather_cache_file) + else: + _src_gather_cache = src_gather_cache + + if dest_scatter_cache is None: + _dest_scatter_cache = NCCLScatterCacheGenerator( + indices=src_indices, + edge_placement=edge_location, + edge_dest_ranks=src_data_mappings, + num_output_rows=num_src_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_scatter_cache, dest_scatter_cache_file) + else: + _dest_scatter_cache = dest_scatter_cache + + if dest_gather_cache is None: + _dest_gather_cache = NCCLGatherCacheGenerator( + indices=dest_indices, + edge_placement=edge_location, + edge_dest_ranks=dest_data_mappings, + num_input_rows=num_dest_rows, + rank=rank, + world_size=world_size, + ) + + torch.save(_dest_gather_cache, dest_gather_cache_file) + else: + _dest_gather_cache = dest_gather_cache + + # Unit tests + + return _src_gather_cache, _dest_scatter_cache, _dest_gather_cache + + +if __name__ == "__main__": + from fire import Fire + from functools import partial + from config import SyntheticDatasetConfig + + # Use this script to generate the caches prior to running the main training script + # This is useful because cache generation can take a long time and could cause issues + # with timeouts on some systems. + + def main(dataset): + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from synthetic.synthetic_dataset import HeterogeneousDataset as Dataset + + synthetic_config = SyntheticDatasetConfig() + graph_dataset = partial( + Dataset, + num_papers=synthetic_config.num_papers, + num_authors=synthetic_config.num_authors, + num_institutions=synthetic_config.num_institutions, + num_features=synthetic_config.num_features, + num_classes=synthetic_config.num_classes, + ) + elif dataset == "mag240m": + from mag240m.DGraph_MAG240M import DGraph_MAG240M as Dataset + + graph_dataset = partial(Dataset, data_dir="data/MAG240M") + + rank = 0 + world_size = 4 + COMM = type( + "dummy_comm", + (object,), + {"get_rank": lambda self: rank, "get_world_size": lambda self: world_size}, + ) + comm = COMM() + + dataset = graph_dataset( + comm=comm, + ) + + dataset = dataset.add_batch_dimension() + dataset = dataset.to("cpu") + + xs, edge_indices, edge_types, rank_mappings = dataset[0] + + # for simulated_rank in range(world_size): + simulated_rank = 0 + for simulated_rank in [0, 1]: + rel = 0 + + for edge_index, edge_type, rank_mapping in zip( + edge_indices, edge_types, rank_mappings + ): + if rel != 3: + rel += 1 + continue + print(f"Edge index shape: {edge_index.shape}") + print(f"Edge type shape: {edge_type}") + print(f"Rank mapping shape: {rank_mapping[0].shape}") + print(f"Rank mapping shape: {rank_mapping[1].shape}") + + get_cache( + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + src_gather_cache_file=f"test_cache/synthetic_src_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_gather_cache_file=f"test_cache/synthetic_dest_gather_cache_{rel}_{simulated_rank}_{world_size}.pt", + dest_scatter_cache_file=f"test_cache/synthetic_dest_scatter_cache_{rel}_{simulated_rank}_{world_size}.pt", + rank=simulated_rank, + world_size=world_size, + src_indices=edge_index[:, 0], + dest_indices=edge_index[:, 1], + edge_location=rank_mapping[0], + src_data_mappings=rank_mapping[0], + dest_data_mappings=rank_mapping[1], + num_src_rows=xs[edge_type[0]].shape[1], + num_dest_rows=xs[edge_type[1]].shape[1], + ) + + rel += 1 + rel = 3 + synthetic_scatter_cache_1 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_1_{world_size}.pt", + weights_only=False, + ) + synthetic_scatter_cache_0 = torch.load( + f"test_cache/synthetic_dest_scatter_cache_{rel}_0_{world_size}.pt", + weights_only=False, + ) + + print(synthetic_scatter_cache_1.scatter_recv_local_placement) + print(synthetic_scatter_cache_0.scatter_recv_local_placement) + + Fire(main) diff --git a/experiments/OGB-LSC/README.md b/experiments/OGB-LSC/README.md new file mode 100644 index 0000000..0d6ac53 --- /dev/null +++ b/experiments/OGB-LSC/README.md @@ -0,0 +1,45 @@ +# Directed Heterogeneous Graphs on DGraph + +`DGraph` supports arbitrary graph types, GNNs, and structures for distributed training. This example shows how to use `DGraph` to train a Relational Graph Attention Network ([RGAT](https://arxiv.org/abs/1703.06103)) on the [OGB-LSC MAG240M](https://ogb.stanford.edu/docs/lsc/mag240m/) dataset, which is a large-scale heterogeneous graph with three types of nodes (paper, author, institution) and three types of edges (paper->paper, paper->author, author->institution). + +## Requirements + +Make sure you have the following packages installed: +- `torch` +- `torch_geometric` +- `ogb` +- `torch_sparse` +- `numpy` +- `tqdm` +- `fire` + +## Preprocessing the dataset +The MAG240M dataset is a fairly large graph dataset and requires some preprocessing before it can be used with DGraph, and takes a while to process. The following script processes the dataset and saves the processed data in a directory. + +```bash +torchrun-hpc -N -n setup_dataset_comms.py --comm_type nccl --dataset mag240m --data_dir +``` + +Make sure to replace `` with the path where you want to store the processed data. The script will download the dataset if it is not already present in the specified directory. The processed data will be saved in the same directory. + +The processing machine requires at least `128GB` of RAM to process the dataset. + + +## Data preparation +The dataset is fairly large (over 100GB). Please follow the instructions in the `mag240m` folder to download and preprocess the dataset. + +## Training +To train RGAT on a synthetic dataset, run the following command: + +```bash +torchrun-hpc -N -n main.py \ +--dataset synthetic --num_papers \ +--num_authors --num_institutions -n main.py --dataset mag240m \ +--data-path +``` diff --git a/experiments/OGB-LSC/RGAT.py b/experiments/OGB-LSC/RGAT.py new file mode 100644 index 0000000..ebea143 --- /dev/null +++ b/experiments/OGB-LSC/RGAT.py @@ -0,0 +1,362 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch +import torch.nn as nn +import torch.distributed as dist +from distributed_layers import DistributedBatchNorm1D +import os.path as osp +from CacheGenerator import get_cache +import os +from typing import Any, List, Optional, overload +from DGraph.distributed.nccl import ( + NCCLBackendEngine, + NCCLGraphCommPlan, + NCCLEdgeConditionedGraphCommPlan, +) + + +class ConvLayer(nn.Module): + def __init__(self, in_channels, out_channels): + super(ConvLayer, self).__init__() + self.conv = nn.Linear(in_channels, out_channels) + self.act = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.act(x) + return x + + +class CommAwareGAT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + comm, + heads=1, + bias=True, + residual=False, + hetero=False, + ): + super(CommAwareGAT, self).__init__() + self.conv1 = nn.Linear(in_channels, out_channels, bias=False) + self.comm = comm + self.project_message = nn.Linear(2 * out_channels, 1) + self.leaky_relu = nn.LeakyReLU(0.2) + self.residual = residual + self.heads = heads + self.hetero = hetero + if self.residual: + self.res_net = nn.Linear(in_channels, out_channels, bias=False) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + nn.init.zeros_(self.bias) + else: + self.register_parameter("bias", None) + + @overload + def forward( + self, + x: torch.Tensor, + comm_plan: NCCLEdgeConditionedGraphCommPlan, + *, + x_j: Optional[torch.Tensor] = None, + ): ... + + @overload + def forward( + self, + x: torch.Tensor, + *, + edge_index: Any, + rank_mapping: Any, + x_j: Optional[torch.Tensor] = None, + src_gather_cache: Optional[Any] = None, + dest_gather_cache: Optional[Any] = None, + dest_scatter_cache: Optional[Any] = None, + ): ... + + def forward( + self, + x, + comm_plan=None, + *, + edge_index=None, + rank_mapping=None, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + ): + """Forward method that can use either a communication plan or COO format + + Args: + x: Node features tensor + comm_plan: Communication plan object (if available) + edge_index: Edge index tensor in COO format + rank_mapping: Rank mapping tensors + x_j: Optional source node features tensor (for hetero graphs) + src_gather_cache: Optional cache for source gather communication + dest_gather_cache: Optional cache for destination gather communication + dest_scatter_cache: Optional cache for destination scatter communication + + Returns: + out: Output node features tensor + """ + if comm_plan is not None: + return self._forward_comm_plan(x, comm_plan, x_j=x_j) + + return self._forward_coo( + x, + edge_index=edge_index, + rank_mapping=rank_mapping, + x_j=x_j, + src_gather_cache=src_gather_cache, + dest_gather_cache=dest_gather_cache, + dest_scatter_cache=dest_scatter_cache, + ) + + def _process_messages( + self, + h, + h_j, + ): + messages = torch.cat([h, h_j], dim=-1) + edge_scores = self.leaky_relu(self.project_message(messages)) + numerator = torch.exp(edge_scores) + return numerator + + def _calc_attention_messages( + self, + neighbor_features, + numerator, + denominator, + ): + alpha_ij = numerator / (denominator + 1e-16) + attention_messages = neighbor_features * alpha_ij + return attention_messages + + def _apply_res_and_bias(self, out, x): + if self.residual: + out = out + self.res_net(x) + if self.bias is not None: + out = out + self.bias + return out + + def _forward_comm_plan( + self, x, comm_plan: NCCLEdgeConditionedGraphCommPlan, x_j=None + ): + h = self.conv1(x) + + source_graph_plan = comm_plan.source_graph_plan + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + assert comm_plan.dest_graph_plan is not None + dest_graph_plan = comm_plan.dest_graph_plan + else: + h_j = h + dest_graph_plan = source_graph_plan + + assert isinstance(self.comm.__backend_engine, NCCLBackendEngine) + + h_i = self.comm.__backend_engine.gather(h, comm_plan=source_graph_plan) + + h_j = self.comm.__backend_engine.gather(h_j, comm_plan=dest_graph_plan) + + numerator = self._process_messages(h_i, h_j) + + denominator = self.comm.__backend_engine.scatter( + numerator, comm_plan=source_graph_plan + ) + + denominator = self.comm.__backend_engine.gather( + denominator, comm_plan=dest_graph_plan + ) + + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + + out = self.comm.__backend_engine.scatter( + attention_messages, comm_plan=source_graph_plan + ) + out = self._apply_res_and_bias(out, x) + + return out + + def _forward_coo( + self, + x, + edge_index, + rank_mapping, + x_j=None, + src_gather_cache=None, + dest_gather_cache=None, + dest_scatter_cache=None, + ): + h = self.conv1(x) + if self.hetero: + assert x_j is not None + h_j = self.conv1(x_j) + else: + h_j = h + + _src_indices = edge_index[:, 0, :] + _dst_indices = edge_index[:, 1, :] + _src_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0 + ) + _dst_rank_mappings = torch.cat( + [rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0 + ) + + h_i = self.comm.gather( + h, _dst_indices, _dst_rank_mappings, cache=dest_gather_cache + ) + + h_j = self.comm.gather( + h_j, _src_indices, _src_rank_mappings, cache=src_gather_cache + ) + + numerator = self._process_messages(h_i, h_j) + + denominator = self.comm.scatter( + numerator, + _dst_indices, + _dst_rank_mappings, + h.size(1), + cache=dest_scatter_cache, + ) + + denominator = self.comm.gather( + denominator, _src_indices, _src_rank_mappings, cache=dest_gather_cache + ) + + attention_messages = self._calc_attention_messages(h_j, numerator, denominator) + + out = self.comm.scatter( + attention_messages, + _dst_indices, + _dst_rank_mappings, + h.size(1), + cache=dest_scatter_cache, + ) + + out = self._apply_res_and_bias(out, x) + + return out + + +class CommAwareRGAT(nn.Module): + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + num_relations, + num_layers, + heads, + comm, + dropout=0.5, + ): + super(CommAwareRGAT, self).__init__() + self.layers = nn.ModuleList() + self.bn_layers = nn.ModuleList() + self.skip_layers = nn.ModuleList() + self.num_layers = num_layers + self.dropout = dropout + self.comm = comm + relation_specific_convs = [] + + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + in_channels, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + hetero=True, + ) + ) + self.layers.append(nn.ModuleList(relation_specific_convs)) + + for _ in range(num_layers - 1): + relation_specific_convs = [] + for _ in range(num_relations): + relation_specific_convs.append( + CommAwareGAT( + hidden_channels, + hidden_channels, + heads=heads, + bias=True, + residual=True, + comm=comm, + hetero=True, + ) + ) + self.layers.append(nn.ModuleList(relation_specific_convs)) + + for _ in range(num_layers): + self.bn_layers.append(DistributedBatchNorm1D(hidden_channels)) + + self.skip_layers.append(nn.Linear(in_channels, hidden_channels)) + for _ in range(num_layers - 1): + self.skip_layers.append(nn.Linear(hidden_channels, hidden_channels)) + + self.mlp = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels), + DistributedBatchNorm1D(hidden_channels), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Linear(hidden_channels, out_channels), + ) + self.num_relations = num_relations + + def forward(self, xs, edge_types, comm_plans: List[NCCLGraphCommPlan]): + + assert len(edge_types) == len(comm_plans) + outs = xs + + for i in range(self.num_layers): + temp_outs = [self.skip_layers[i](outs[feat]) for feat in range(len(outs))] + + for j, (edge_type, comm_plan) in enumerate(zip(edge_types, comm_plans)): + + src_edge_type, dst_edge_type = edge_type + + temp_outs[dst_edge_type] += self.layers[i][j]( # type: ignore + outs[dst_edge_type], x_j=outs[src_edge_type], comm_plan=comm_plan + ) + outs = [ + self.bn_layers[i](temp_outs[feat]) for feat in range(len(temp_outs)) + ] + outs = [torch.relu(outs[feat]) for feat in range(len(outs))] + outs = [ + torch.dropout(outs[feat], p=self.dropout, train=self.training) + for feat in range(len(outs)) + ] + + dummy_prameters_use = bool(int(os.getenv("RGAT_DUMMY_ALL_PARAMS_USE", "0"))) + if dummy_prameters_use: + # Dummy operation to touch all outs to avoid DDP's 'unused parameters' + dummy = torch.zeros(1, device=outs[0].device, dtype=outs[0].dtype) + for t in outs: + dummy = dummy + ( + t[0].sum() * 0.0 + ) # zero-valued scalar that depends on t + outs[0][0] = outs[0][0] + dummy + + return self.mlp(outs[0]) diff --git a/experiments/OGB-LSC/Trainer.py b/experiments/OGB-LSC/Trainer.py new file mode 100644 index 0000000..97bc66c --- /dev/null +++ b/experiments/OGB-LSC/Trainer.py @@ -0,0 +1,134 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import torch +from RGAT import CommAwareRGAT +from config import ModelConfig, TrainingConfig +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from distributed_layers import GetGlobalVal +from lsc_datasets import DistributedHeteroGraphDataset + +import os +from typing import Union + + +class Trainer: + def __init__(self, dataset, comm): + self.dataset: DistributedHeteroGraphDataset = dataset + self.comm = comm + self.model_config = ModelConfig() + self.training_config = TrainingConfig() + # TODO: We need some better way to set the device but + # difficult to do that since systems have different bindings. + # self.device = torch.device(f"cuda:{comm.get_local_rank()}") + rank = comm.get_rank() + print(f"Rank {rank} using GPU {rank % torch.cuda.device_count()}") + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + self.device = torch.device("cuda") + self.model = CommAwareRGAT( + in_channels=self.dataset.num_features, + out_channels=self.dataset.num_classes, + num_relations=self.dataset.num_relations, + hidden_channels=self.model_config.hidden_channels, + num_layers=self.model_config.num_layers, + heads=self.model_config.heads, + comm=comm, + dropout=self.model_config.dropout, + ).to(self.device) + # Enable unused-parameter detection only if requested (reduces sync errors with moderate overhead) + ddp_find_unused = bool(int(os.getenv("RGAT_DDP_FIND_UNUSED", "0"))) + self.model = DDP( + self.model, + device_ids=[rank % num_gpus], + find_unused_parameters=ddp_find_unused, + ) + self.optimizer = torch.optim.Adam( + self.model.parameters(), lr=self.training_config.lr, weight_decay=5e-4 + ) + + def prepare_data(self): + self.dataset = self.dataset.add_batch_dimension() + self.dataset = self.dataset.to(self.device) + + def train(self): + self.model.train() + + xs, _, _, _ = self.dataset[0] + comm_plans = self.dataset.get_NCCL_comm_plans() + + # Fetch once; masks/targets are static across epochs + train_mask = self.dataset.get_mask("train") + target = self.dataset.get_target("train") + + for epoch in range(1, self.training_config.epochs + 1): + # zero grads before forward to avoid dangling reduction state + self.optimizer.zero_grad(set_to_none=True) + + out = self.model(xs, comm_plans) + local_train_vertices = out[:, train_mask, :].squeeze(0) + + loss = torch.nn.functional.cross_entropy( + local_train_vertices, target, reduction="sum" + ) + local_num_targets = target.size(0) + global_num_targets = GetGlobalVal(local_num_targets) + loss = loss / global_num_targets # Average the loss + + loss.backward() + self.optimizer.step() + if self.comm.get_rank() == 0: + print(f"Epoch {epoch:03d} | loss {loss.item():.4f}") + return loss.item() + + @torch.no_grad() + def evaluate(self): + self.model.eval() + + xs, edge_index, edge_type, rank_mapping = self.dataset[0] + out = self.model(xs, edge_index, edge_type, rank_mapping) + + y_pred = out.argmax(dim=-1, keepdim=True).cpu().numpy() + train_mask = self.dataset.get_mask("train").cpu().numpy() + val_mask = self.dataset.get_mask("val").cpu().numpy() + test_mask = self.dataset.get_mask("test").cpu().numpy() + y_true_train = self.dataset.get_target("train").cpu().numpy() + y_pred_val = self.dataset.get_target("val").cpu().numpy() + y_pred_test = self.dataset.get_target("test").cpu().numpy() + + train_acc = (y_pred[train_mask] == y_true_train).sum() / int(train_mask.sum()) + # Not guaranteed to have validation or test samples on every rank + num_local_val_samples = int(val_mask.sum()) + num_local_test_samples = int(test_mask.sum()) + if num_local_val_samples == 0: + val_acc = 0.0 + else: + val_acc = (y_pred[val_mask] == y_pred_val).sum().item() + val_acc = GetGlobalVal(val_acc) + + num_global_val_samples = GetGlobalVal(num_local_val_samples) + val_acc = val_acc / int(num_global_val_samples) + + if num_local_test_samples == 0: + test_acc = 0.0 + else: + test_acc = (y_pred[test_mask] == y_pred_test).sum().item() + + test_acc = GetGlobalVal(test_acc) + num_global_test_samples = GetGlobalVal(num_local_test_samples) + test_acc = test_acc / int(num_global_test_samples) + + # All ranks should have the same accuracy values + + return train_acc, val_acc, test_acc diff --git a/experiments/OGB-LSC/config.py b/experiments/OGB-LSC/config.py new file mode 100644 index 0000000..49d8164 --- /dev/null +++ b/experiments/OGB-LSC/config.py @@ -0,0 +1,45 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +from dataclasses import dataclass + + +@dataclass +class ModelConfig: + hidden_channels: int = 1024 + dropout: float = 0.5 + num_layers: int = 2 + heads: int = 4 + use_cache: bool = True + # Those numbers are available in the dataset classes (synthetic or mag240m) + # num_features: int = 768 + # num_relations: int = 5 + # num_classes: int = 153 + + +@dataclass +class TrainingConfig: + epochs: int = 100 + lr: float = 0.0001 + lr_step_size: int = 25 + lr_gamma: float = 0.25 + + +@dataclass +class SyntheticDatasetConfig: + num_papers: int = 2048 + num_authors: int = 512 + num_institutions: int = 16 + num_features: int = 16 + num_classes: int = 153 diff --git a/experiments/OGB-LSC/distributed_layers.py b/experiments/OGB-LSC/distributed_layers.py new file mode 100644 index 0000000..54408b1 --- /dev/null +++ b/experiments/OGB-LSC/distributed_layers.py @@ -0,0 +1,214 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import torch +from torch import nn +import torch.distributed as dist +from torch.autograd import Function +from typing import Callable + + +def _compute_bn_forward(input, learned_gamma=None, learned_beta=None): + local_sum = torch.mean(input, dim=0) + global_sum = local_sum.clone() + num_rows = torch.tensor([input.size(0)], dtype=torch.float32, device=input.device) + + global_num_rows = num_rows.clone() + + dist.all_reduce(global_num_rows, op=dist.ReduceOp.SUM) + global_mean = global_sum / global_num_rows + local_var = ((input - global_mean) ** 2).sum(dim=0) + global_var = local_var.clone() + dist.all_reduce(global_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(global_var, op=dist.ReduceOp.SUM) + global_var = global_var / global_num_rows + + x_hat = (input - global_mean) / torch.sqrt(global_var + 1e-5) + if learned_gamma is not None and learned_beta is not None: + output = x_hat * learned_gamma + learned_beta + + return output, x_hat, global_mean, global_var, global_num_rows + + +def _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma=None, learned_beta=None +): + if learned_gamma is not None and learned_beta is not None: + local_dbeta = torch.sum(grad_output, dim=0) + global_dbeta = local_dbeta.clone().unsqueeze(0) + dist.all_reduce(global_dbeta, op=dist.ReduceOp.SUM) + local_dgamma = torch.sum(grad_output * x_hat, dim=0) + global_dgamma = local_dgamma.clone().unsqueeze(0) + dist.all_reduce(global_dgamma, op=dist.ReduceOp.SUM) + dx_hat = grad_output * learned_gamma + else: + dx_hat = grad_output + global_dgamma = None + global_dbeta = None + + local_dvar = torch.sum(dx_hat * (x - mean) * -0.5 * (var + 1e-5) ** 2, dim=0) + global_dvar = local_dvar.clone() + dist.all_reduce(global_dvar, op=dist.ReduceOp.SUM) + + local_dmean = torch.sum( + dx_hat * -1 / torch.sqrt(var + 1e-5), dim=0 + ) + global_dvar * torch.mean(-2 * (x - mean), dim=0) + global_dmean = local_dmean.clone() + dist.all_reduce(global_dmean, op=dist.ReduceOp.SUM) + dx = ( + (dx_hat / torch.sqrt(var + 1e-5)) + + (global_dvar * 2 * (x - mean) / num_rows) + + (global_dmean / num_rows) + ) + return dx, global_dgamma, global_dbeta + + +class DistributedBN_with_Recompute(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + ctx.save_for_backward(input) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + output, _, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + ctx.mean = global_mean + ctx.var = global_var + ctx.input = input + ctx.num_rows = global_num_rows + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output, grad_mean, grad_var): + x = ctx.input + mean = ctx.mean + var = ctx.var + # recompute x_hat to save memory + x_hat = (x - mean) / torch.sqrt(var + 1e-5) + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + num_rows = ctx.num_rows + + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBN_Impl(Function): + @staticmethod + def forward(ctx, input, learned_gamma=None, learned_beta=None): + output, x_hat, global_mean, global_var, global_num_rows = _compute_bn_forward( + input, learned_gamma, learned_beta + ) + + ctx.save_for_backward(x_hat) + ctx.learned_gamma = learned_gamma + ctx.learned_beta = learned_beta + ctx.mean = global_mean + ctx.var = global_var + ctx.num_rows = global_num_rows + ctx.input = input + ctx.x_hat = x_hat + return output, global_mean, global_var + + @staticmethod + def backward(ctx, grad_output, grad_mean, grad_var): + + learned_gamma = ctx.learned_gamma + learned_beta = ctx.learned_beta + mean = ctx.mean + var = ctx.var + x_hat = ctx.x_hat + num_rows = ctx.num_rows + x = ctx.input + dx, global_dgamma, global_dbeta = _compute_bn_backward( + grad_output, x, x_hat, mean, var, num_rows, learned_gamma, learned_beta + ) + + return dx, global_dgamma, global_dbeta + + +class DistributedBatchNorm1D(nn.Module): + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + recompute=False, + ): + super(DistributedBatchNorm1D, self).__init__() + if affine: + self.gamma = nn.Parameter(torch.ones(1, num_features)) + self.beta = nn.Parameter(torch.zeros(1, num_features)) + else: + self.register_parameter("gamma", None) + self.register_parameter("beta", None) + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(1, num_features)) + self.register_buffer("running_var", torch.ones(1, num_features)) + self.register_buffer( + "num_batches_tracked", torch.tensor(0, dtype=torch.long) + ) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.recompute = recompute + if recompute: + self.bn: Callable = DistributedBN_with_Recompute.apply + else: + self.bn: Callable = DistributedBN_Impl.apply + + def forward(self, x): + if x.dim() == 3: + assert x.size(0) == 1, "only mini-batch size 1 is supported" + x = x.squeeze(0) + elif x.dim() != 2: + raise ValueError("Expected 2D or 3D input (got {}D input)".format(x.dim())) + + if self.training: + if self.track_running_stats: + self.num_batches_tracked += 1 + y, mean, var = self.bn(x, self.gamma, self.beta) + + if self.track_running_stats: + with torch.no_grad(): + self.running_mean = ( + 1 - self.momentum + ) * self.running_mean + self.momentum * mean + self.running_var = ( + 1 - self.momentum + ) * self.running_var + self.momentum * var + else: + y = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps) + if self.gamma is not None and self.beta is not None: + y = y * self.gamma + self.beta + + if y.dim() == 2: + y = y.unsqueeze(0) + return y + + +def GetGlobalVal(local_val): + """Get the global sum of a local value across all ranks.""" + global_val = torch.tensor([local_val]).cuda() + dist.all_reduce(global_val, op=dist.ReduceOp.SUM) + return global_val.item() diff --git a/experiments/OGB-LSC/lsc_datasets/MAG240M_dataset.py b/experiments/OGB-LSC/lsc_datasets/MAG240M_dataset.py new file mode 100644 index 0000000..2584bf3 --- /dev/null +++ b/experiments/OGB-LSC/lsc_datasets/MAG240M_dataset.py @@ -0,0 +1,331 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from ogb.lsc import MAG240MDataset +import torch +from typing import Optional, Tuple +from torch_sparse import SparseTensor +import numpy as np +from tqdm import tqdm +import os.path as osp +from DGraph.Communicator import Communicator +from DGraph.distributed.nccl import NCCLGraphCommPlan, COO_to_NCCLCommPlan +from distributed_graph_dataset import ( + get_rank_mappings, + get_vertex_offsets, + DistributedHeteroGraphDataset, +) + + +def get_col_slice(x, start_row_idx, end_row_idx, start_col_idx, end_col_idx): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + outs = [] + chunk = 100000 + for i in tqdm(range(start_row_idx, end_row_idx, chunk)): + j = min(i + chunk, end_row_idx) + outs.append(x[i:j, start_col_idx:end_col_idx].copy()) + return np.concatenate(outs, axis=0) + + +def save_col_slice( + x_src, x_dst, start_row_idx, end_row_idx, start_col_idx, end_col_idx +): + """Obtained from: + https://github.com/snap-stanford/ogb/blob/master/examples/lsc/mag240m/rgnn.py + """ + assert x_src.shape[0] == end_row_idx - start_row_idx + assert x_src.shape[1] == end_col_idx - start_col_idx + chunk, offset = 100000, start_row_idx + for i in tqdm(range(0, end_row_idx - start_row_idx, chunk)): + j = min(i + chunk, end_row_idx - start_row_idx) + x_dst[offset + i : offset + j, start_col_idx:end_col_idx] = x_src[i:j] + + +def get_edge_mappings(src_indices, dst_indices, rank_mappings): + edge_mappings = torch.zeros_like(src_indices) + # The edges are mapped to the rank of the destination node + # Because that is the accumulation rank + edge_mappings = rank_mappings[dst_indices] + return edge_mappings + + +def _generate_features_from_paper_features( + out: np.memmap, + num_nodes: int, + num_papers: int, + paper_feat: np.ndarray, + edge_index: np.ndarray, + num_features: int, +): + + row, col = torch.from_numpy(edge_index) + adj = SparseTensor( + row=row, col=col, sparse_sizes=(num_nodes, num_papers), is_sorted=True + ) + + dim_chunk_size = 64 + + for i in tqdm(range(0, num_features, dim_chunk_size)): + j = min(i + dim_chunk_size, num_features) + inputs = get_col_slice( + paper_feat, + start_row_idx=0, + end_row_idx=num_papers, + start_col_idx=i, + end_col_idx=j, + ) + inputs = torch.from_numpy(inputs) + out_ = adj.matmul(inputs, reduce="mean").numpy() # type: ignore + del inputs + save_col_slice( + x_src=out_, + x_dst=out, + start_row_idx=0, + end_row_idx=num_nodes, + start_col_idx=i, + end_col_idx=j, + ) + del out_ + out.flush() + + +def load_or_generate_vertex_rank_mask( + rank_mapping: Optional[torch.Tensor], num_vertices: int, world_size: int, rank: int +): + if rank_mapping is None: + rank_mapping, vertices_cur_rank = get_rank_mappings( + num_vertices, world_size, rank + ) + rank_mask = rank_mapping == rank + return rank_mask, vertices_cur_rank + + +class DGraph_MAG240M_Dataset(DistributedHeteroGraphDataset): + + # data_dir must be the location where all ranks can access + def __init__( + self, + comm: Communicator, + data_dir: str = "lsc_datasets/data/MAG240M", + comm_plan_only: bool = True, + paper_rank_mappings: Optional[torch.Tensor] = None, + author_rank_mappings: Optional[torch.Tensor] = None, + institution_rank_mappings: Optional[torch.Tensor] = None, + cached_comm_plans: Optional[str] = None, + ): + rank = comm.get_rank() + world_size = comm.get_world_size() + self.comm = comm + self.dataset = MAG240MDataset(root=data_dir) + self.num_papers = self.dataset.num_papers + self.num_authors = self.dataset.num_authors + self.num_institutions = self.dataset.num_institutions + + self.train_mask = self.dataset.get_idx_split("train") + self.val_mask = self.dataset.get_idx_split("valid") + self.test_mask = self.dataset.get_idx_split("test-dev") + + local_papers_mask, num_local_papers = load_or_generate_vertex_rank_mask( + paper_rank_mappings, self.num_papers, world_size, rank + ) + + local_authors_mask, num_local_authors = load_or_generate_vertex_rank_mask( + author_rank_mappings, self.num_authors, world_size, rank + ) + + local_institutions_mask, num_local_institutions = ( + load_or_generate_vertex_rank_mask( + institution_rank_mappings, self.num_institutions, world_size, rank + ) + ) + + self.num_local_papers = num_local_papers + self.num_local_authors = num_local_authors + self.num_local_institutions = num_local_institutions + + self.generate_feature_data() + + paper_features = torch.from_numpy(self.dataset.paper_feat[local_papers_mask]) + + path = self.dataset.dir + + author_features = torch.from_numpy( + np.memmap( + filename=path + "/author_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_authors, self.num_features), + )[local_authors_mask] + ) + institution_features = torch.from_numpy( + np.memmap( + filename=path + "/institution_feat.npy", + mode="r", + dtype=np.float16, + shape=(self.num_institutions, self.num_features), + )[local_institutions_mask] + ) + labels = torch.from_numpy(self.dataset.paper_label) + + paper_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("paper", "cites", "paper") + ) + author_2_paper_edges = torch.from_numpy( + self.dataset.edge_index("author", "writes", "paper") + ) + author_2_institution_edges = torch.from_numpy( + self.dataset.edge_index("author", "institution") + ) + + num_features = self.dataset.num_paper_features + num_classes = self.dataset.num_classes + + num_authors = self.dataset.num_authors + num_papers = self.dataset.num_papers + num_institutions = self.dataset.num_institutions + + paper_vertex_offsets = get_vertex_offsets( + num_vertices=num_papers, world_size=world_size + ) + author_vertex_offsets = get_vertex_offsets( + num_vertices=num_authors, world_size=world_size + ) + institution_vertex_offsets = get_vertex_offsets( + num_vertices=num_institutions, world_size=world_size + ) + + super().__init__( + rank=rank, + world_size=world_size, + num_features=num_features, + num_classes=num_classes, + num_relations=5, + paper_features=paper_features, + author_features=author_features, + institution_features=institution_features, + paper_vertex_offset=paper_vertex_offsets, + author_vertex_offset=author_vertex_offsets, + institution_vertex_offset=institution_vertex_offsets, + paper_labels=labels, + paper_2_paper_edges=paper_2_paper_edges, + author_2_paper_edges=author_2_paper_edges, + author_2_institution_edges=author_2_institution_edges, + comm_plan_only=comm_plan_only, + paper_vertex_rank_mapping=paper_rank_mappings, + author_vertex_rank_mapping=author_rank_mappings, + institution_vertex_rank_mapping=institution_rank_mappings, + ) + + if cached_comm_plans is not None: + comm_plans = torch.load(cached_comm_plans) + self.paper_2_paper_comm_plan = comm_plans["paper_2_paper_comm_plan"] + self.paper_2_author_comm_plan = comm_plans["paper_2_author_comm_plan"] + self.author_2_institution_comm_plan = comm_plans[ + "author_2_institution_comm_plan" + ] + self.institution_2_author_comm_plan = comm_plans[ + "institution_2_author_comm_plan" + ] + self.author_2_paper_comm_plan = comm_plans["author_2_paper_comm_plan"] + + else: + f_name = ( + f"synthetic_dataset_rank_{self.rank}_of_{self.world_size}_comm_plans.pt" + ) + f_name = osp.join(data_dir, f_name) + if osp.exists(f_name): + self._load_comm_plans(f_name) + else: + self._generate_comm_plans(f_name) + + def generate_feature_data(self): + dataset = self.dataset + # This function emulates the author and institute features generation steps here + # https://github.com/snap-stanford/ogb/blob/61e9784ca76edeaa6e259ba0f836099608ff0586/examples/lsc/mag240m/rgnn.py#L82 + + # Generate author features + # Mag240M author features are generated from paper features + num_authors = dataset.num_authors + num_papers = dataset.num_papers + path = dataset.dir + paper_feat = dataset.paper_feat + rank = self.comm.get_rank() + # Only one rank must do this work + if rank == 0: + if not osp.exists(path + "/author_feat.npy"): + print("Generating author features") + author_feat = np.memmap( + filename=path + "/author_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_authors, self.num_features), + ) + _generate_features_from_paper_features( + out=author_feat, + num_nodes=num_authors, + num_papers=num_papers, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "paper"), + num_features=self.num_features, + ) + + if not osp.exists(path + "/institution_feat.npy"): + print("Generating institution features") + # Generate institution features + num_institutions = dataset.num_institutions + institution_feat = np.memmap( + filename=path + "/institution_feat.npy", + mode="w+", + dtype=np.float16, + shape=(num_institutions, self.num_features), + ) + _generate_features_from_paper_features( + out=institution_feat, + num_nodes=num_authors, + num_papers=num_institutions, + paper_feat=paper_feat, + edge_index=dataset.edge_index("author", "institution"), + num_features=self.num_features, + ) + self.comm.barrier() + + # Make sure all ranks can see the generated files + if not osp.exists(path + "/author_feat.npy"): + raise FileNotFoundError("author_feat.npy not found") + if not osp.exists(path + "/institution_feat.npy"): + raise FileNotFoundError("institution_feat.npy not found") + self.comm.barrier() + + print("Data processing complete") + + +if __name__ == "__main__": + import DGraph.Communicator as Comm + + class DummyCommunicator(Comm): + def __init__(self, rank, world_size): + self._rank = rank + self._world_size = world_size + + def get_rank(self): + return self._rank + + def get_world_size(self): + return self._world_size + + def barrier(self): + pass + + dataset = DGraph_MAG240M_Dataset(comm=DummyCommunicator(0, 1)) diff --git a/experiments/OGB-LSC/lsc_datasets/__init__.py b/experiments/OGB-LSC/lsc_datasets/__init__.py new file mode 100644 index 0000000..8fd220f --- /dev/null +++ b/experiments/OGB-LSC/lsc_datasets/__init__.py @@ -0,0 +1 @@ +from .distributed_graph_dataset import DistributedHeteroGraphDataset diff --git a/experiments/OGB-LSC/lsc_datasets/synthetic_dataset.py b/experiments/OGB-LSC/lsc_datasets/synthetic_dataset.py new file mode 100644 index 0000000..1455d76 --- /dev/null +++ b/experiments/OGB-LSC/lsc_datasets/synthetic_dataset.py @@ -0,0 +1,196 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +from DGraph.Communicator import Communicator +import torch +from typing import Optional + +from .distributed_graph_dataset import ( + get_vertex_offsets, + DistributedHeteroGraphDataset, +) + +import os.path as osp + + +import hashlib + + +def generate_config_hash(numbers): + # Convert tuple to string, encode to bytes, hash + payload = str(tuple(numbers)).encode("utf-8") + return hashlib.md5(payload).hexdigest() + + +torch.random.manual_seed(0) + + +def _generate_paper_2_paper_edges(num_papers): + # Average degree of a paper is ~11 + num_edges = num_papers * 11 + coo_list = torch.randint( + low=0, high=num_papers, size=(2, num_edges), dtype=torch.long + ) + coo_list = torch.unique(coo_list, dim=1) + transpose = coo_list.flip(0) + coo_list = torch.cat([coo_list, transpose], dim=1) + coo_list = torch.sort(coo_list, dim=1).values + return coo_list + + +def _generate_author_2_paper_edges(num_authors, num_papers): + # Average number of authors per paper is ~3.5 + num_edges = int(num_authors * 3.5) + dest_papers = torch.randint( + low=0, high=num_papers, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_papers], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +def _generate_author_2_institution_edges(num_authors, num_institutions): + # Average number of institutions per author is ~0.35 + num_edges = int(num_authors * 0.35) + dest_num_institutions = torch.randint( + low=0, high=num_institutions, size=(1, num_edges), dtype=torch.long + ) + src_authors = torch.randint( + low=0, high=num_authors, size=(1, num_edges), dtype=torch.long + ) + coo_list = torch.cat([src_authors, dest_num_institutions], dim=0) + coo_list = torch.unique(coo_list, dim=1) + return coo_list + + +class SyntheticHeterogeneousDataset(DistributedHeteroGraphDataset): + def __init__( + self, + synthetic_config, + comm: Communicator, + cached_comm_plans: Optional[str] = None, + ): + """Synthetic heterogeneous graph dataset for OGB-LSC experiments built to + mimic the MAG240M dataset. + + Args: + synthetic_config: Configuration object for synthetic dataset. Must have `num_papers`, + `num_authors`, `num_institutions`, `num_features`, and `num_classes` attributes. + comm: DGraph communicator object. + cached_comm_plans: Optional path to cached communication plans. + + """ + num_papers = synthetic_config.num_papers + num_authors = synthetic_config.num_authors + num_institutions = synthetic_config.num_institutions + num_features = synthetic_config.num_features + num_classes = synthetic_config.num_classes + self._num_relations = 5 + self.comm = comm + rank = comm.get_rank() + world_size = comm.get_world_size() + + # Set up synthetic data for papers, authors, and institutions and call superclass init + + _vertices = torch.randperm(num_papers) + + self.train_mask = _vertices[: int(0.7 * num_papers)] + self.val_mask = _vertices[int(0.7 * num_papers) : int(0.85 * num_papers)] + self.test_mask = _vertices[int(0.85 * num_papers) :] + + labels = torch.randint( + low=0, high=num_classes, size=(num_papers,), dtype=torch.long + ) + + # Generate edges + paper_2_paper_edges = _generate_paper_2_paper_edges(num_papers) + author_2_paper_edges = _generate_author_2_paper_edges(num_authors, num_papers) + author_2_institution_edges = _generate_author_2_institution_edges( + num_authors, num_institutions + ) + + paper_vertex_offsets = get_vertex_offsets( + num_vertices=num_papers, world_size=comm.get_world_size() + ) + author_vertex_offsets = get_vertex_offsets( + num_vertices=num_authors, world_size=comm.get_world_size() + ) + institution_vertex_offsets = get_vertex_offsets( + num_vertices=num_institutions, world_size=comm.get_world_size() + ) + + num_paper_vertices_cur_rank = int( + paper_vertex_offsets[rank + 1] - paper_vertex_offsets[rank] + ) + num_author_vertices_cur_rank = int( + author_vertex_offsets[rank + 1] - author_vertex_offsets[rank] + ) + num_institution_vertices_cur_rank = int( + institution_vertex_offsets[rank + 1] - institution_vertex_offsets[rank] + ) + + # Generate random feature data for vertices + paper_features = torch.randn( + (num_paper_vertices_cur_rank, num_features), dtype=torch.float32 + ) + author_features = torch.randn( + (num_author_vertices_cur_rank, num_features), dtype=torch.float32 + ) + institution_features = torch.randn( + (num_institution_vertices_cur_rank, num_features), dtype=torch.float32 + ) + + super().__init__( + rank=rank, + world_size=comm.get_world_size(), + num_features=num_features, + num_classes=num_classes, + num_relations=5, + paper_features=paper_features, + author_features=author_features, + institution_features=institution_features, + paper_vertex_offset=paper_vertex_offsets, + author_vertex_offset=author_vertex_offsets, + institution_vertex_offset=institution_vertex_offsets, + paper_labels=labels, + paper_2_paper_edges=paper_2_paper_edges, + author_2_paper_edges=author_2_paper_edges, + author_2_institution_edges=author_2_institution_edges, + comm_plan_only=True, + ) + + if cached_comm_plans is not None: + comm_plans = torch.load(cached_comm_plans) + self.paper_2_paper_comm_plan = comm_plans["paper_2_paper_comm_plan"] + self.paper_2_author_comm_plan = comm_plans["paper_2_author_comm_plan"] + self.author_2_institution_comm_plan = comm_plans[ + "author_2_institution_comm_plan" + ] + self.institution_2_author_comm_plan = comm_plans[ + "institution_2_author_comm_plan" + ] + self.author_2_paper_comm_plan = comm_plans["author_2_paper_comm_plan"] + + else: + dataset_hash = generate_config_hash( + [num_papers, num_authors, num_institutions, num_features, num_classes] + ) + f_name = f"synthetic_dataset_{dataset_hash}_rank_{self.rank}_of_{self.world_size}_comm_plans.pt" + if osp.exists(f_name): + print(f"Loading comm plans from {f_name}") + self._load_comm_plans(f_name) + else: + self._generate_comm_plans(f_name) diff --git a/experiments/OGB-LSC/main.py b/experiments/OGB-LSC/main.py new file mode 100644 index 0000000..698dcb2 --- /dev/null +++ b/experiments/OGB-LSC/main.py @@ -0,0 +1,116 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import fire +import torch +from functools import partial +import os.path as osp +import DGraph.Communicator as Comm +from Trainer import Trainer +from config import SyntheticDatasetConfig +import torch.distributed as dist + + +def _load_optional_file(file_path: str): + if osp.exists(file_path): + return torch.load(file_path, weights_only=False) + return None + + +def main( + comm_type: str = "nccl", + dataset: str = "synthetic", + num_papers: int = 2048, + num_authors: int = 512, + num_institutions: int = 16, + optimized_graph_structure: bool = True, + paper_rank_mapping_file: str = "", + author_rank_mapping_file: str = "", + institution_rank_mapping_file: str = "", + data_dir: str = "datasets/data/MAG240M", +): + """Main function to run DGraph experiments on OGB-LSC datasets. + + Args: + comm_type (str): Type of communicator to use. Options are 'nccl' and + 'nvshmem'. Default is 'nccl'. + dataset (str): Dataset to use. Options are 'synthetic' and 'mag240m'. + Default is 'synthetic'. + num_papers (int): Number of paper nodes to use in the synthetic dataset. + Default is 2048. + num_authors (int): Number of author nodes to use in the synthetic dataset. + Default is 512. + num_institutions (int): Number of institution nodes to use in the synthetic + dataset. Default is 16. + paper_rank_mapping_file (str): Path to the paper rank mapping file for + mag240m dataset. Default is ''. + author_rank_mapping_file (str): Path to the author rank mapping file for + mag240m dataset. Default is not set. + institution_rank_mapping_file (str): Path to the institution rank mapping + file for mag240m dataset. Default is not set. + data_dir (str): Path to the mag240m dataset directory. Default is + 'mag240m/data/MAG240M'. + """ + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from lsc_datasets.synthetic_dataset import ( + SyntheticHeterogeneousDataset as Dataset, + ) + + synthetic_config = SyntheticDatasetConfig( + num_papers=num_papers, + num_authors=num_authors, + num_institutions=num_institutions, + ) + graph_dataset = partial( + Dataset, + synthetic_config=synthetic_config, + ) + + elif dataset == "mag240m": + from lsc_datasets import DGraph_MAG240M_Dataset as Dataset + + graph_dataset = partial( + Dataset, + paper_rank_mappings=_load_optional_file(paper_rank_mapping_file), + author_rank_mappings=_load_optional_file(author_rank_mapping_file), + institution_rank_mappings=_load_optional_file( + institution_rank_mapping_file + ), + data_dir=data_dir, + comm_plan_only=optimized_graph_structure, + ) + else: + raise ValueError(f"Invalid dataset: {dataset}") + + assert comm_type in ["nccl", "nvshmem"] + comm = Comm.Communicator.init_process_group(comm_type) + + comm.barrier() + print(f"Running with {comm.get_world_size()} ranks. Rank: {comm.get_rank()}") + + graph_dataset = graph_dataset(comm=comm) + + trainer = Trainer(graph_dataset, comm) + trainer.prepare_data() + trainer.train() + comm.destroy() + + if dist.is_initialized(): + dist.destroy_process_group() + + return 0 + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/experiments/OGB-LSC/setup_dataset_comms.py b/experiments/OGB-LSC/setup_dataset_comms.py new file mode 100644 index 0000000..556c989 --- /dev/null +++ b/experiments/OGB-LSC/setup_dataset_comms.py @@ -0,0 +1,95 @@ +# Copyright (c) 2014-2025, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) +import fire +import torch +import torch.distributed as dist +from typing import Literal +from DGraph import Communicator +from config import SyntheticDatasetConfig +from functools import partial + + +def main( + comm_type: Literal["nccl", "nvshmem"] = "nccl", + dataset: Literal["synthetic", "mag240m"] = "mag240m", + num_papers: int = 2048, + num_authors: int = 512, + num_institutions: int = 16, + data_dir: str = "lsc_datasets/data/MAG240M", +): + assert comm_type in ["nccl", "nvshmem"] + comm = Communicator.init_process_group(comm_type) + + device_id = comm.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(device_id) + + def comm_print(*args, **kwargs): + comm.barrier() + if comm.get_rank() == 0: + print(*args, **kwargs) + comm.barrier() + + world_size = comm.get_world_size() + comm_print(f"Communicator initialized with World Size: {world_size}") + + assert dataset in ["synthetic", "mag240m"] + if dataset == "synthetic": + from lsc_datasets.synthetic_dataset import ( + SyntheticHeterogeneousDataset as Dataset, + ) + + synthetic_config = SyntheticDatasetConfig( + num_papers=num_papers, + num_authors=num_authors, + num_institutions=num_institutions, + ) + + comm_print( + f"Setting up synthetic dataset with configuration {synthetic_config}" + ) + graph_dataset = partial( + Dataset, + synthetic_config=synthetic_config, + ) + comm_print(f"Finished setting up synthetic dataset") + + elif dataset == "mag240m": + from lsc_datasets.MAG240M_dataset import DGraph_MAG240M_Dataset as Dataset + + comm_print(f"Setting up MAG240M dataset") + graph_dataset = partial( + Dataset, + data_dir=data_dir, + comm_plan_only=True, + ) + comm_print(f"Finished setting up MAG240M dataset") + + else: + raise ValueError(f"Invalid dataset: {dataset}") + + graph_dataset = graph_dataset(comm=comm) + + comm_plans = graph_dataset.get_NCCL_comm_plans() + + for i, comm_plan in enumerate(comm_plans): + comm_plan = comm_plan.source_graph_plan + comm_print(f"Comm Plan # {i}") + comm_print(f"Num Local Vertices: {comm_plan.num_local_vertices}") + comm_print(f"Num Boundary Vertices: {comm_plan.boundary_vertex_splits}") + comm_print(f"Num Local Edges: {comm_plan.num_local_edges}") + comm_print(f"Num Boundary Edges: {comm_plan.boundary_edge_splits}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/experiments/OGB/GenerateCache.py b/experiments/OGB/GenerateCache.py index a16e795..6fc3e98 100644 --- a/experiments/OGB/GenerateCache.py +++ b/experiments/OGB/GenerateCache.py @@ -21,6 +21,9 @@ NCCLGatherCacheGenerator, NCCLScatterCacheGenerator, ) + +from DGraph.distributed.nccl import COO_to_NCCLEdgeConditionedCommPlan + from time import perf_counter from tqdm import tqdm from multiprocessing import get_context @@ -33,67 +36,37 @@ } -def generate_cache_file( - dist_graph, - src_indices, - dst_indices, - edge_placement, - edge_src_placement, - edge_dest_placement, - cache_prefix_str: str, - rank: int, - world_size: int, +def generate_comm_plan( + coo_list, + offsets, + rank, + world_size, + dest_offsets=None, ): - print(f"Generating cache for rank {rank}...") - local_node_features = dist_graph.get_local_node_features(rank).unsqueeze(0) - num_input_rows = local_node_features.size(1) - - print( - f"Rank {rank} has {num_input_rows} input rows with shape {local_node_features.shape}" - ) - gather_cache = NCCLGatherCacheGenerator( - dst_indices, - edge_placement, - edge_dest_placement, - num_input_rows, - rank, - world_size, - ) - - nodes_per_rank = dist_graph.get_nodes_per_rank() - nodes_per_rank = int(nodes_per_rank[rank].item()) - - scatter_cache = NCCLScatterCacheGenerator( - src_indices, - edge_placement, - edge_src_placement, - nodes_per_rank, + # Source edges belonging to this rank should be where the source + # vertex falls within the rank's offset range. + src_start = offsets[rank].item() + src_end = offsets[rank + 1].item() + local_edges = torch.nonzero( + (coo_list[0] >= src_start) & (coo_list[0] < src_end), as_tuple=True + )[0] + + comm_plan = COO_to_NCCLEdgeConditionedCommPlan( rank, world_size, + coo_list[0], + coo_list[1], + local_edges, + offsets, + dest_offset=dest_offsets, ) - print(f"Rank {rank} completed cache generation") - with open( - f"{cache_prefix_str}_gather_cache_rank_{world_size}_{rank}.pt", "wb" - ) as f: - torch.save(gather_cache, f) - - with open( - f"{cache_prefix_str}_scatter_cache_rank_{world_size}_{rank}.pt", "wb" - ) as f: - torch.save(scatter_cache, f) - return 0 + return comm_plan def main(dset: str, world_size: int, node_rank_placement_file: str): assert dset in ["ogbn-arxiv", "ogbn-products", "ogbn-papers100M"] assert world_size > 0 - assert os.path.exists( - node_rank_placement_file - ), "Node rank placement file does not exist." - - node_rank_placement = torch.load(node_rank_placement_file) - dataset = NodePropPredDataset( dset, ) @@ -102,62 +75,9 @@ def main(dset: str, world_size: int, node_rank_placement_file: str): assert split_index is not None, "Split index is None." graph, labels = dataset[0] - num_edges = graph["edge_index"].shape print(num_edges) - - dist_graph = process_homogenous_data( - graph_data=graph, - labels=labels, - world_Size=world_size, - split_idx=split_index, - node_rank_placement=node_rank_placement, - rank=0, - ) - - edge_indices = dist_graph.get_global_edge_indices() - rank_mappings = dist_graph.get_global_rank_mappings() - - print("Edge indices shape:", edge_indices.shape) - print("Rank mappings shape:", rank_mappings.shape) - - edge_indices = edge_indices.unsqueeze(0) - src_indices = edge_indices[:, 0, :] - dst_indices = edge_indices[:, 1, :] - - edge_placement = rank_mappings[0] - edge_src_placement = rank_mappings[0] - edge_dest_placement = rank_mappings[1] - - start_time = perf_counter() - cache_prefix_str = f"cache/{cache_prefix[dset]}" - with get_context("spawn").Pool(min(world_size, 8)) as pool: - args = [ - ( - dist_graph, - src_indices, - dst_indices, - edge_placement, - edge_src_placement, - edge_dest_placement, - cache_prefix_str, - rank, - world_size, - ) - for rank in range(world_size) - ] - - out = pool.starmap(generate_cache_file, args) - - end_time = perf_counter() - print(f"Cache generation time: {end_time - start_time:.4f} seconds") - print("Cache files generated successfully.") - print( - f"Gather cache file: {cache_prefix_str}_gather_cache_rank_{world_size}_.pt" - ) - print( - f"Scatter cache file: {cache_prefix_str}_scatter_cache_rank_{world_size}_.pt" - ) + num_nodes = graph["num_nodes"] if __name__ == "__main__": diff --git a/tests/test_NCCLCommPlan.py b/tests/test_NCCLCommPlan.py new file mode 100644 index 0000000..45740b4 --- /dev/null +++ b/tests/test_NCCLCommPlan.py @@ -0,0 +1,166 @@ +# Copyright (c) 2014-2024, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. See the top-level LICENSE file for details. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LBANN and https://github.com/LLNL/LBANN. +# +# SPDX-License-Identifier: (Apache-2.0) + +import pytest +from DGraph.distributed.nccl import ( + NCCLGraphCommPlan, + COO_to_NCCLCommPlan, + COO_to_NCCLEdgeConditionedCommPlan, +) +from DGraph.Communicator import Communicator +import torch.distributed as dist +import torch + + +@pytest.fixture(scope="module") +def init_nccl_backend_communicator(): + dist.init_process_group(backend="nccl") + + comm = Communicator.init_process_group("nccl") + + return comm + + +def setup_coo_matrix(world_size): + torch.manual_seed(0) + num_nodes = 32 * world_size + + # generate num_nodes x num_nodes adjacency matrix + adj_matrix = torch.rand(num_nodes, num_nodes) + adj_matrix = (adj_matrix + adj_matrix.t()) / 2 + adj_matrix[adj_matrix < 0.8] = 0.0 # sparsify + adj_matrix[adj_matrix >= 0.8] = 1.0 + adj_matrix.fill_diagonal_(0) + coo_matrix = adj_matrix.nonzero(as_tuple=False).t().contiguous() + return num_nodes, coo_matrix + + +def test_coo_to_nccl_comm_plan(init_nccl_backend_communicator): + comm = init_nccl_backend_communicator + + rank = comm.get_rank() + world_size = comm.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + num_nodes, coo_matrix = setup_coo_matrix(world_size) + coo_matrix = coo_matrix.to(device) + + nodes_per_rank = num_nodes // world_size + offset = torch.arange(world_size + 1, device=device) * nodes_per_rank + + my_start = offset[rank] + my_end = offset[rank + 1] + + src = coo_matrix[0] + dst = coo_matrix[1] + + is_local_edge = (src >= my_start) & (src < my_end) + local_edge_indices = torch.nonzero(is_local_edge, as_tuple=True)[0] + + plan = COO_to_NCCLCommPlan( + rank=rank, + world_size=world_size, + global_edges_dst=dst, + local_edge_list=local_edge_indices, + offset=offset, + ) + + # 1. Check internal vs boundary edges + my_dst = dst[local_edge_indices] + is_internal_gt = (my_dst >= my_start) & (my_dst < my_end) + internal_indices_gt = torch.nonzero(is_internal_gt, as_tuple=True)[0] + + assert torch.equal( + plan.local_edge_idx.sort()[0], internal_indices_gt.sort()[0] + ), f"Rank {rank}: Local edge indices mismatch" + + internal_dst_gt = my_dst[internal_indices_gt] + local_vertex_idx_gt = internal_dst_gt - my_start + + assert torch.equal( + plan.local_vertex_idx.sort()[0], local_vertex_idx_gt.sort()[0] + ), f"Rank {rank}: Local vertex indices mismatch" + + # 2. Check boundary edges + boundary_indices_gt = torch.nonzero(~is_internal_gt, as_tuple=True)[0] + assert torch.equal( + plan.boundary_edge_idx.sort()[0], boundary_indices_gt.sort()[0] + ), f"Rank {rank}: Boundary edge indices mismatch" + + # 3. Check boundary vertices (received from other ranks) + expected_recv_vertices_unique_per_rank = [] + for r in range(world_size): + if r == rank: + continue + r_start = offset[r] + r_end = offset[r + 1] + is_r_edge = (src >= r_start) & (src < r_end) + r_dst = dst[is_r_edge] + is_to_me = (r_dst >= my_start) & (r_dst < my_end) + dst_to_me = r_dst[is_to_me] + unique_dst_to_me = torch.unique(dst_to_me) + expected_recv_vertices_unique_per_rank.append(unique_dst_to_me) + + if len(expected_recv_vertices_unique_per_rank) > 0: + expected_recv_stream = torch.cat(expected_recv_vertices_unique_per_rank) + else: + expected_recv_stream = torch.tensor([], device=device, dtype=torch.long) + + expected_local_stream = expected_recv_stream - my_start + + assert torch.equal( + plan.boundary_vertex_idx.sort()[0], expected_local_stream.sort()[0] + ), f"Rank {rank}: Boundary vertex indices mismatch" + + +def test_edge_conditioned_comm_plan(init_nccl_backend_communicator): + comm = init_nccl_backend_communicator + rank = comm.get_rank() + world_size = comm.get_world_size() + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + + num_nodes, coo_matrix = setup_coo_matrix(world_size) + coo_matrix = coo_matrix.to(device) + + nodes_per_rank = num_nodes // world_size + offset = torch.arange(world_size + 1, device=device) * nodes_per_rank + + my_start = offset[rank] + my_end = offset[rank + 1] + + src = coo_matrix[0] + dst = coo_matrix[1] + is_local_edge = (src >= my_start) & (src < my_end) + local_edge_indices = torch.nonzero(is_local_edge, as_tuple=True)[0] + + ec_plan = COO_to_NCCLEdgeConditionedCommPlan( + rank=rank, + world_size=world_size, + global_edges_src=src, + global_edges_dst=dst, + local_edge_list=local_edge_indices, + src_offset=offset, + dest_offset=offset, + ) + + assert ec_plan.source_graph_plan is not None + assert ec_plan.dest_graph_plan is not None + + assert ec_plan.source_graph_plan.boundary_edge_idx.numel() == 0 + assert ( + ec_plan.source_graph_plan.local_edge_idx.numel() == local_edge_indices.numel() + ) + assert ec_plan.dest_graph_plan.num_local_edges == local_edge_indices.numel() diff --git a/tests/test_local_kernels.py b/tests/test_local_kernels.py index 1544644..f3a44ac 100644 --- a/tests/test_local_kernels.py +++ b/tests/test_local_kernels.py @@ -67,3 +67,84 @@ def test_optimized_local_gather(): assert torch.allclose( out_tensor.cpu(), out_tensor_gt ), "Optimized local gather failed" + + +def test_optimized_scatter_gaher(): + try: + from torch_local import local_masked_scatter_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1]) + dst_indices = torch.tensor([1, 3, 5, 7]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] = src_tensor[i, src_indices[j]] + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-gather failed" + + +def test_optimized_scatter_add_gather(): + try: + from torch_local import local_masked_scatter_add_gather + except ImportError as e: + pytest.fail(f"Failed to import local_masked_scatter_add_gather: {e}") + + num_src_rows = 8 + num_out_rows = 8 + bs = 1 + num_features = 4 + src_tensor = torch.randn(bs, num_src_rows, num_features) + src_indices = torch.tensor([0, 3, 2, 1, 3]) + dst_indices = torch.tensor([1, 3, 5, 7, 3]) + + out_tensor_gt = torch.zeros(bs, num_out_rows, num_features) + + for i in range(bs): + for j in range(len(src_indices)): + out_tensor_gt[i, dst_indices[j]] += src_tensor[i, src_indices[j]] + + out_tensor_gt = out_tensor_gt.view(bs, num_out_rows, num_features) + out_tensor = torch.zeros_like(out_tensor_gt) + out_tensor = out_tensor.cuda() + src_tensor = src_tensor.cuda() + src_indices = src_indices.cuda().long() + dst_indices = dst_indices.cuda().long() + local_masked_scatter_add_gather( + src_tensor, + src_indices, + dst_indices, + out_tensor, + bs, + num_src_rows, + num_features, + num_out_rows, + ) + assert torch.allclose( + out_tensor.cpu(), out_tensor_gt + ), "Optimized local scatter-add-gather failed" diff --git a/tests/test_nccl_backend.py b/tests/test_nccl_backend.py index 08a83e9..8788b55 100644 --- a/tests/test_nccl_backend.py +++ b/tests/test_nccl_backend.py @@ -232,6 +232,59 @@ def setup_unbalanced_scatter_data(init_nccl_backend): ) +@pytest.fixture(scope="module") +def setup_comm_plan(init_nccl_backend): + comm = init_nccl_backend + torch.manual_seed(0) + torch.cuda.set_device(comm.get_rank()) + rank = comm.get_rank() + world_size = comm.get_world_size() + + num_nodes = 32 * world_size + num_features = 64 + + # generate num_nodes x num_nodes adjacency matrix + adj_matrix = torch.rand(num_nodes, num_nodes) + adj_matrix = (adj_matrix + adj_matrix.t()) / 2 + adj_matrix[adj_matrix < 0.8] = 0.0 # sparsify + adj_matrix[adj_matrix >= 0.8] = 1.0 + adj_matrix.fill_diagonal_(0) + coo_matrix = adj_matrix.nonzero(as_tuple=False).t().contiguous() + from DGraph.distributed.nccl import ( + COO_to_NCCLCommPlan, + ) + + nodes_per_rank = num_nodes // world_size + offset = torch.arange(world_size + 1) * nodes_per_rank + + my_start = offset[rank] + my_end = offset[rank + 1] + + src = coo_matrix[0] + dst = coo_matrix[1] + + num_edges = src.shape[0] + + is_local_edge = (src >= my_start) & (src < my_end) + local_edge_indices = torch.nonzero(is_local_edge, as_tuple=True)[0] + + plan = COO_to_NCCLCommPlan( + rank=rank, + world_size=world_size, + global_edges_dst=dst, + local_edge_list=local_edge_indices, + offset=offset, + ) + + global_input = torch.randn(1, num_nodes, num_features) + global_output = torch.zeros(1, num_edges, num_features) + + for i in range(num_edges): + global_output[:, i] = global_input[:, dst[i]] + + return plan + + def test_nccl_backend_init(init_nccl_backend): comm = init_nccl_backend rank = comm.get_rank() @@ -350,3 +403,10 @@ def test_nccl_backend_scatter(init_nccl_backend, setup_scatter_data): assert local_output_gt.shape == (1, 2, 4) assert dgraph_output_tensor.shape == (1, 2, 4) assert torch.allclose(dgraph_output_tensor.cpu(), local_output_gt) + + +def test_nccl_backend_gather_comm_plan(init_nccl_backend, setup_comm_plan): + comm: Comm.Communicator = init_nccl_backend + plan = setup_comm_plan + rank = comm.get_rank() + world_size = comm.get_world_size()