Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 104 additions & 4 deletions driver/utils/accl_network_utils/src/accl_network_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,28 @@ bool check_arp(vnx::Networklayer &network_layer,
} // namespace

namespace accl_network_utils {

// Helper function to validate local_rank against ranks vector
inline void validate_local_rank(int local_rank, size_t ranks_size, const std::string &function_name) {
if (ranks_size == 0) {
throw std::invalid_argument(function_name + ": ranks vector is empty");
}
if (local_rank < 0) {
throw std::invalid_argument(function_name + ": local_rank cannot be negative (got " +
std::to_string(local_rank) + ")");
}
if (static_cast<size_t>(local_rank) >= ranks_size) {
throw std::out_of_range(function_name + ": local_rank (" + std::to_string(local_rank) +
") is out of range for ranks vector of size " +
std::to_string(ranks_size));
}
}

void configure_vnx(vnx::CMAC &cmac, vnx::Networklayer &network_layer,
const std::vector<rank_t> &ranks, int local_rank,
bool rsfec) {
validate_local_rank(local_rank, ranks.size(), "configure_vnx");

if (ranks.size() > vnx::max_sockets_size) {
throw std::runtime_error("Too many ranks. VNX supports up to " +
std::to_string(vnx::max_sockets_size) +
Expand Down Expand Up @@ -195,6 +214,8 @@ void configure_vnx(vnx::CMAC &cmac, vnx::Networklayer &network_layer,
void configure_tcp(XRTBuffer<int8_t> &tx_buf_network, XRTBuffer<int8_t> &rx_buf_network,
xrt::kernel &network_krnl, xrt::kernel &session_krnl,
std::vector<rank_t> &ranks, int local_rank) {
validate_local_rank(local_rank, ranks.size(), "configure_tcp");

tx_buf_network.sync_to_device();
rx_buf_network.sync_to_device();

Expand Down Expand Up @@ -270,7 +291,25 @@ void configure_tcp(XRTBuffer<int8_t> &tx_buf_network, XRTBuffer<int8_t> &rx_buf_
}

void exchange_qp(unsigned int master_rank, unsigned int slave_rank, unsigned int local_rank, std::vector<fpga::ibvQpConn*> &ibvQpConn_vec, std::vector<ACCL::rank_t> &ranks){

// Validate all rank indices
if (ranks.empty()) {
throw std::invalid_argument("exchange_qp: ranks vector is empty");
}
if (ibvQpConn_vec.empty()) {
throw std::invalid_argument("exchange_qp: ibvQpConn_vec is empty");
}
if (master_rank >= ranks.size()) {
throw std::out_of_range("exchange_qp: master_rank (" + std::to_string(master_rank) +
") out of range for ranks size " + std::to_string(ranks.size()));
}
if (slave_rank >= ranks.size()) {
throw std::out_of_range("exchange_qp: slave_rank (" + std::to_string(slave_rank) +
") out of range for ranks size " + std::to_string(ranks.size()));
}
if (slave_rank >= ibvQpConn_vec.size() || master_rank >= ibvQpConn_vec.size()) {
throw std::out_of_range("exchange_qp: rank index out of range for ibvQpConn_vec");
}

if (local_rank == master_rank)
{
std::cout<<"Local rank "<<local_rank<<" sending local QP to remote rank "<<slave_rank<<std::endl;
Expand Down Expand Up @@ -333,6 +372,11 @@ void exchange_qp(unsigned int master_rank, unsigned int slave_rank, unsigned int
}

void configure_cyt_rdma(std::vector<ACCL::rank_t> &ranks, int local_rank, ACCL::CoyoteDevice* device){
validate_local_rank(local_rank, ranks.size(), "configure_cyt_rdma");

if (device == nullptr) {
throw std::invalid_argument("configure_cyt_rdma: device pointer is null");
}

std::cout<<"Initializing QP connections..."<<std::endl;
// create queue pair connections
Expand All @@ -357,6 +401,12 @@ void configure_cyt_rdma(std::vector<ACCL::rank_t> &ranks, int local_rank, ACCL::
}

void configure_cyt_tcp(std::vector<ACCL::rank_t> &ranks, int local_rank, ACCL::CoyoteDevice* device){
validate_local_rank(local_rank, ranks.size(), "configure_cyt_tcp");

if (device == nullptr) {
throw std::invalid_argument("configure_cyt_tcp: device pointer is null");
}

std::cout<<"Configuring Coyote TCP..."<<std::endl;
// arp lookup
for(int i=0; i<ranks.size(); i++){
Expand Down Expand Up @@ -410,6 +460,15 @@ std::vector<std::string> get_ips(fs::path config_file) {
}

std::vector<std::string> get_ips(bool local, int world_size) {
if (world_size <= 0) {
throw std::invalid_argument("get_ips: world_size must be positive (got " +
std::to_string(world_size) + ")");
}
if (world_size > 254) {
throw std::invalid_argument("get_ips: world_size cannot exceed 254 for IP generation (got " +
std::to_string(world_size) + ")");
}

std::vector<std::string> ips{};
for (int i = 0; i < world_size; ++i) {
if (local) {
Expand All @@ -424,11 +483,24 @@ std::vector<std::string> get_ips(bool local, int world_size) {
std::vector<rank_t> generate_ranks(fs::path config_file, int local_rank,
std::uint16_t start_port,
unsigned int rxbuf_size) {
if (local_rank < 0) {
throw std::invalid_argument("generate_ranks: local_rank cannot be negative");
}

std::vector<rank_t> ranks{};
std::vector<std::string> ips = get_ips(config_file);

if (ips.empty()) {
throw std::runtime_error("generate_ranks: no IPs found in config file");
}

// Check for port overflow
if (static_cast<size_t>(start_port) + ips.size() - 1 > 65535) {
throw std::overflow_error("generate_ranks: start_port + number of ranks would exceed max port 65535");
}

for (int i = 0; i < static_cast<int>(ips.size()); ++i) {
rank_t new_rank = {ips[i], start_port + i, i, rxbuf_size};
rank_t new_rank = {ips[i], static_cast<std::uint16_t>(start_port + i), i, rxbuf_size};
ranks.emplace_back(new_rank);
}

Expand All @@ -438,10 +510,26 @@ std::vector<rank_t> generate_ranks(fs::path config_file, int local_rank,
std::vector<rank_t> generate_ranks(bool local, int local_rank, int world_size,
std::uint16_t start_port,
unsigned int rxbuf_size) {
if (local_rank < 0) {
throw std::invalid_argument("generate_ranks: local_rank cannot be negative");
}
if (world_size <= 0) {
throw std::invalid_argument("generate_ranks: world_size must be positive");
}
if (local_rank >= world_size) {
throw std::out_of_range("generate_ranks: local_rank (" + std::to_string(local_rank) +
") must be less than world_size (" + std::to_string(world_size) + ")");
}

// Check for port overflow
if (static_cast<size_t>(start_port) + world_size - 1 > 65535) {
throw std::overflow_error("generate_ranks: start_port + world_size would exceed max port 65535");
}

std::vector<rank_t> ranks{};
std::vector<std::string> ips = get_ips(local, world_size);
for (int i = 0; i < static_cast<int>(ips.size()); ++i) {
rank_t new_rank = {ips[i], start_port + i, i, rxbuf_size};
rank_t new_rank = {ips[i], static_cast<std::uint16_t>(start_port + i), i, rxbuf_size};
ranks.emplace_back(new_rank);
}

Expand All @@ -451,8 +539,20 @@ std::vector<rank_t> generate_ranks(bool local, int local_rank, int world_size,
std::unique_ptr<ACCL::ACCL>
initialize_accl(std::vector<rank_t> &ranks, int local_rank,
bool simulator, acclDesign design, xrt::device device,
fs::path xclbin, unsigned int nbufs, unsigned int bufsize,
fs::path xclbin, unsigned int nbufs, unsigned int bufsize,
unsigned int egrsize, bool rsfec) {
validate_local_rank(local_rank, ranks.size(), "initialize_accl");

if (nbufs == 0) {
throw std::invalid_argument("initialize_accl: nbufs must be greater than 0");
}
if (bufsize == 0) {
throw std::invalid_argument("initialize_accl: bufsize must be greater than 0");
}
if (!simulator && xclbin.empty()) {
throw std::invalid_argument("initialize_accl: xclbin path required for non-simulator mode");
}

std::size_t world_size = ranks.size();
std::unique_ptr<ACCL::ACCL> accl;

Expand Down
35 changes: 18 additions & 17 deletions driver/xrt/include/accl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "accl/simdevice.hpp"
#include "accl/coyotebuffer.hpp"
#include "accl/coyotedevice.hpp"
#include <memory>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -778,12 +779,12 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
std::unique_ptr<Buffer<dtype>> create_buffer_host(size_t length, dataType type) {
if (sim_mode) {
return std::unique_ptr<Buffer<dtype>>(new SimBuffer<dtype>(
length, type, static_cast<SimDevice *>(cclo)->get_context(), true));
length, type, static_cast<SimDevice *>(cclo.get())->get_context(), true));
} else if (cclo->get_device_type() == CCLO::xrt_device) {
return std::unique_ptr<Buffer<dtype>>(new XRTBuffer<dtype>(
length, type, *(static_cast<XRTDevice *>(cclo)->get_device()), xrt::bo::flags::host_only, (xrt::memory_group)0));
length, type, *(static_cast<XRTDevice *>(cclo.get())->get_device()), xrt::bo::flags::host_only, (xrt::memory_group)0));
} else {
return std::unique_ptr<Buffer<dtype>>(new CoyoteBuffer<dtype>(length, type, cclo));
return std::unique_ptr<Buffer<dtype>>(new CoyoteBuffer<dtype>(length, type, static_cast<CoyoteDevice *>(cclo.get())));
}
}

Expand All @@ -808,13 +809,13 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
std::unique_ptr<Buffer<dtype>> create_buffer(size_t length, dataType type, unsigned mem_grp) {
if (sim_mode) {
return std::unique_ptr<Buffer<dtype>>(new SimBuffer<dtype>(
length, type, static_cast<SimDevice *>(cclo)->get_context(), false, mem_grp));
length, type, static_cast<SimDevice *>(cclo.get())->get_context(), false, mem_grp));
} else if (cclo->get_device_type() == CCLO::xrt_device) {
return std::unique_ptr<Buffer<dtype>>(new XRTBuffer<dtype>(
length, type, *(static_cast<XRTDevice *>(cclo)->get_device()), (xrt::memory_group)mem_grp));
length, type, *(static_cast<XRTDevice *>(cclo.get())->get_device()), (xrt::memory_group)mem_grp));
} else {
return std::unique_ptr<Buffer<dtype>>(new CoyoteBuffer<dtype>(
length, type, cclo));
length, type, static_cast<CoyoteDevice *>(cclo.get())));
}
}

Expand Down Expand Up @@ -874,10 +875,10 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
if (sim_mode) {
return std::unique_ptr<Buffer<dtype>>(
new SimBuffer<dtype>(host_buffer, length, type,
static_cast<SimDevice *>(cclo)->get_context(), false, mem_grp));
static_cast<SimDevice *>(cclo.get())->get_context(), false, mem_grp));
} else if(cclo->get_device_type() == CCLO::xrt_device ){
return std::unique_ptr<Buffer<dtype>>(new XRTBuffer<dtype>(
host_buffer, length, type, *(static_cast<XRTDevice *>(cclo)->get_device()), (xrt::memory_group)mem_grp));
host_buffer, length, type, *(static_cast<XRTDevice *>(cclo.get())->get_device()), (xrt::memory_group)mem_grp));
}
return std::unique_ptr<Buffer<dtype>>(nullptr);
}
Expand Down Expand Up @@ -905,8 +906,8 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
dataType type) {
if (sim_mode) {
return std::unique_ptr<Buffer<dtype>>(
new SimBuffer<dtype>(bo, *(static_cast<SimDevice *>(cclo)->get_device()), length, type,
static_cast<SimDevice *>(cclo)->get_context()));
new SimBuffer<dtype>(bo, *(static_cast<SimDevice *>(cclo.get())->get_device()), length, type,
static_cast<SimDevice *>(cclo.get())->get_context()));
} else {
return std::unique_ptr<Buffer<dtype>>(
new XRTBuffer<dtype>(bo, length, type));
Expand Down Expand Up @@ -956,18 +957,18 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
unsigned mem_grp) {
if (sim_mode) {
return std::unique_ptr<Buffer<dtype>>(new SimBuffer<dtype>(
length, type, static_cast<SimDevice *>(cclo)->get_context()));
length, type, static_cast<SimDevice *>(cclo.get())->get_context()));
} else if(cclo->get_device_type() == CCLO::xrt_device ){
return std::unique_ptr<Buffer<dtype>>(new XRTBuffer<dtype>(
length, type, *(static_cast<XRTDevice *>(cclo)->get_device()), xrt::bo::flags::p2p, (xrt::memory_group)mem_grp));
length, type, *(static_cast<XRTDevice *>(cclo.get())->get_device()), xrt::bo::flags::p2p, (xrt::memory_group)mem_grp));
} else {
//for Coyote there's no concept of a p2p buffer
throw std::runtime_error("p2p buffers not supported in Coyote");
}
}

/**
* Construct a new coyote buffer object without an existing host buffer
* Construct a new coyote buffer object without an existing host buffer
*
* Coyote buffer object doesn't have a notion of memory banks
*
Expand All @@ -982,7 +983,7 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
if (sim_mode) {
throw std::runtime_error("create_coyotebuffer sim_mode unsupported!!!");
} else {
return std::unique_ptr<Buffer<dtype>>(new CoyoteBuffer<dtype>(length, type, static_cast<CoyoteDevice *>(cclo)));
return std::unique_ptr<Buffer<dtype>>(new CoyoteBuffer<dtype>(length, type, static_cast<CoyoteDevice *>(cclo.get())));
}
}

Expand Down Expand Up @@ -1066,19 +1067,19 @@ ACCLRequest *barrier(communicatorId comm_id = GLOBAL_COMM,
void close_con(communicatorId comm_id = GLOBAL_COMM);

private:
CCLO *cclo{};
std::unique_ptr<CCLO> cclo;
// Supported types and corresponding arithmetic config
arithConfigMap arith_config;
// Address to put new configurations like arithmetic configs
// and communicators
addr_t current_config_address{};
// RX spare buffers for eager mode
addr_t max_eager_msg_size{};
std::vector<Buffer<int8_t> *> eager_rx_buffers;
std::vector<std::unique_ptr<Buffer<int8_t>>> eager_rx_buffers;
addr_t eager_rx_buffer_size{};
// Spare buffers for use in rendezvous reduces
addr_t max_rndzv_msg_size{};
std::vector<Buffer<int8_t> *> utility_spares;
std::vector<std::unique_ptr<Buffer<int8_t>>> utility_spares;
// List of communicators, to which users will add
std::vector<Communicator> communicators;
// safety checks
Expand Down
18 changes: 18 additions & 0 deletions driver/xrt/include/accl/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ const addr_t EXCHANGE_MEM_ADDRESS_RANGE = 0x2000;
/** Global Communicator */
const communicatorId GLOBAL_COMM = 0x0;

/**
* Configuration constants
*/
/** Maximum valid stream ID for stream_put operations (0-246, so max is 246) */
constexpr unsigned int MAX_STREAM_ID = 246;
/** Default timeout in milliseconds for soft reset operations */
constexpr unsigned int DEFAULT_RESET_TIMEOUT_MS = 100;
/** Default timeout value for ACCL operations (in hardware cycles) */
constexpr unsigned int DEFAULT_OPERATION_TIMEOUT = 1000000;
/** Maximum count for flat tree gather/reduce operations (32 KB) */
constexpr unsigned int FLAT_TREE_MAX_COUNT = 32 * 1024;
/** Default max fanin for flat tree gather operations */
constexpr unsigned int DEFAULT_GATHER_FLAT_TREE_MAX_FANIN = 2;
/** Default max ranks for flat tree broadcast operations */
constexpr unsigned int DEFAULT_BCAST_FLAT_TREE_MAX_RANKS = 3;
/** Default max ranks for flat tree reduce operations */
constexpr unsigned int DEFAULT_REDUCE_FLAT_TREE_MAX_RANKS = 4;

/**
* Address offsets inside the HOSTCTRL internal memory
*
Expand Down
Loading