From 6c98756db5da91d28cf455625b87a8fb4c02a24c Mon Sep 17 00:00:00 2001 From: niehao Date: Tue, 4 Nov 2025 10:46:38 +0800 Subject: [PATCH 1/3] Add wait kernel impl --- fserver/csrc/kernel.hpp | 18 ++++++++++ fserver/csrc/ops.cc | 10 ++++-- fserver/csrc/private.hpp | 2 +- fserver/csrc/public.hpp | 4 +-- fserver/csrc/{util.h => util.hpp} | 0 fserver/csrc/wait_kernel.cu | 55 +++++++++++++++++++++++++++++++ setup.py | 3 +- 7 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 fserver/csrc/kernel.hpp rename fserver/csrc/{util.h => util.hpp} (100%) create mode 100644 fserver/csrc/wait_kernel.cu diff --git a/fserver/csrc/kernel.hpp b/fserver/csrc/kernel.hpp new file mode 100644 index 0000000..575edb0 --- /dev/null +++ b/fserver/csrc/kernel.hpp @@ -0,0 +1,18 @@ +#include +#include +#include + +#include +#include + + +torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index); +void write_flag(torch::Tensor flag, int64_t seq); +void wait_flag(torch::Tensor flag, int64_t seq); + +void pybind_kernel(py::module &m){ + // StepMesh utils + m.def("map_pinned_tensor", &map_pinned_tensor, py::arg("tensor"), py::arg("device_index")); + m.def("write_flag", &write_flag, py::arg("flag"), py::arg("seq")); + m.def("wait_flag", &wait_flag, py::arg("flag"), py::arg("seq")); +} \ No newline at end of file diff --git a/fserver/csrc/ops.cc b/fserver/csrc/ops.cc index 27c876d..9f04610 100644 --- a/fserver/csrc/ops.cc +++ b/fserver/csrc/ops.cc @@ -1,12 +1,18 @@ /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ -#include "./util.h" +#include "./util.hpp" #include "./public.hpp" -#include "./private.hpp" +#ifdef DMLC_USE_CUDA + #include "./private.hpp" + #include "./kernel.hpp" +#endif PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind_public(m); +#ifdef DMLC_USE_CUDA pybind_private(m); + pybind_kernel(m); +#endif } diff --git a/fserver/csrc/private.hpp b/fserver/csrc/private.hpp index 816563d..05b7cfa 100644 --- a/fserver/csrc/private.hpp +++ b/fserver/csrc/private.hpp @@ -1,6 +1,6 @@ /* Copyright (c) 2025, StepFun Authors. All rights reserved. */ -#include "./util.h" +#include "./util.hpp" #include "./public.hpp" #include #ifdef DMLC_USE_CUDA diff --git a/fserver/csrc/public.hpp b/fserver/csrc/public.hpp index 249af16..a127c6c 100644 --- a/fserver/csrc/public.hpp +++ b/fserver/csrc/public.hpp @@ -15,7 +15,7 @@ #include -#include "./util.h" +#include "./util.hpp" #ifndef PUBLIC_OPS_ #define PUBLIC_OPS_ @@ -237,7 +237,7 @@ uint64_t get_nanosecond() { void pybind_public(py::module &m){ - m.def("init", &init, py::call_guard()); + m.def("init", &init, py::call_guard()); m.def("stop", &stop, py::call_guard()); m.def("register_recv_buffer", diff --git a/fserver/csrc/util.h b/fserver/csrc/util.hpp similarity index 100% rename from fserver/csrc/util.h rename to fserver/csrc/util.hpp diff --git a/fserver/csrc/wait_kernel.cu b/fserver/csrc/wait_kernel.cu new file mode 100644 index 0000000..20e45ae --- /dev/null +++ b/fserver/csrc/wait_kernel.cu @@ -0,0 +1,55 @@ +#include +#include +#include + +#include +#include + +__global__ void write_flag_kernel(int64_t* flag, int64_t seq) { + if (threadIdx.x == 0) { + __threadfence_system(); + flag[0] = seq; + } +} + +__global__ void wait_flag_kernel(int64_t* flag, int64_t seq) { + if (threadIdx.x == 0) { + // Mark pointer volatile so we reload host-written values each iteration. + volatile int64_t* flag_ptr = flag; + int64_t value = flag_ptr[0]; + while (value < seq) { + __nanosleep(128); + value = flag_ptr[0]; + } + } +} + +static void check_cuda(cudaError_t err, const char* msg) { + TORCH_CHECK(err == cudaSuccess, msg, ": ", cudaGetErrorString(err)); +} + +torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index) { + TORCH_CHECK(tensor.is_pinned(), "tensor must be pinned"); + void* host_ptr = tensor.data_ptr(); + void* device_ptr = nullptr; + check_cuda(cudaHostGetDevicePointer(&device_ptr, host_ptr, 0), + "cudaHostGetDevicePointer failed"); + auto options = tensor.options().device(torch::kCUDA, device_index); + auto sizes = tensor.sizes(); + auto strides = tensor.strides(); + return torch::from_blob(device_ptr, sizes, strides, [](void*){}, options); +} + +void write_flag(torch::Tensor flag, int64_t seq) { + TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); + auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); + write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq); + check_cuda(cudaGetLastError(), "write_flag_kernel launch failed"); +} + +void wait_flag(torch::Tensor flag, int64_t seq) { + TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); + auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); + wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq); + check_cuda(cudaGetLastError(), "wait_flag_kernel launch failed"); +} diff --git a/setup.py b/setup.py index 1a95a9a..15c74dd 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ def _get_cuda_bare_metal_version(cuda_dir): if use_cuda: extra_link += ['-lcuda', '-lcudart'] extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',] - extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_70,code=sm_70', + extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_80,code=sm_80', '--use_fast_math'] + cc_flag bare_metal_major, bare_metal_minor = \ _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) @@ -84,6 +84,7 @@ def _get_cuda_bare_metal_version(cuda_dir): 'fserver_lib', [ __SRC_PATH__ + 'ops.cc', + __SRC_PATH__ + 'wait_kernel.cu', ], extra_compile_args=extra_compile_args, extra_link_args=extra_link, From 7388307197b641bdf3252ed3e701c4fa3e81390b Mon Sep 17 00:00:00 2001 From: niehao Date: Wed, 5 Nov 2025 17:03:52 +0800 Subject: [PATCH 2/3] Add push_pull/wait kernel graph support --- CMakeLists.txt | 10 +- fserver/csrc/kernel.hpp | 6 +- fserver/csrc/public.hpp | 6 +- fserver/csrc/wait_kernel.cu | 41 ++++-- include/ps/af_tensor_app.h | 10 +- setup.py | 13 +- tests/fserver/test_kernel_wait.py | 210 ++++++++++++++++++++++++++++++ 7 files changed, 267 insertions(+), 29 deletions(-) create mode 100644 tests/fserver/test_kernel_wait.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 63dfa24..4007994 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,13 @@ cmake_minimum_required(VERSION 3.22 FATAL_ERROR) project(af LANGUAGES C CXX) set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ") +execute_process(COMMAND ${Python_EXECUTABLE} + -c "import torch; print(int(torch.compiled_with_cxx11_abi()))" + OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ") + + +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ") # import pytorch library find_package (Python COMPONENTS Interpreter Development) @@ -19,7 +25,9 @@ find_package(Torch REQUIRED CONFIG) message("MY TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}") message("MY CUDA_INCLUDE_DIRS ${CUDA_INCLUDE_DIRS}") include_directories(${TORCH_INCLUDE_DIRS}) +# Save ABI setting before adding TORCH_CXX_FLAGS (which might override it) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") +# Ensure ABI setting is preserved after TORCH_CXX_FLAGS list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) diff --git a/fserver/csrc/kernel.hpp b/fserver/csrc/kernel.hpp index 575edb0..ee4ab8e 100644 --- a/fserver/csrc/kernel.hpp +++ b/fserver/csrc/kernel.hpp @@ -7,12 +7,14 @@ torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index); -void write_flag(torch::Tensor flag, int64_t seq); -void wait_flag(torch::Tensor flag, int64_t seq); +void write_flag(torch::Tensor flag, torch::Tensor seq); +void wait_flag(torch::Tensor flag, torch::Tensor seq); +void seq_add_one(torch::Tensor seq); void pybind_kernel(py::module &m){ // StepMesh utils m.def("map_pinned_tensor", &map_pinned_tensor, py::arg("tensor"), py::arg("device_index")); m.def("write_flag", &write_flag, py::arg("flag"), py::arg("seq")); m.def("wait_flag", &wait_flag, py::arg("flag"), py::arg("seq")); + m.def("seq_add_one", &seq_add_one, py::arg("seq")); } \ No newline at end of file diff --git a/fserver/csrc/public.hpp b/fserver/csrc/public.hpp index a127c6c..a1143e3 100644 --- a/fserver/csrc/public.hpp +++ b/fserver/csrc/public.hpp @@ -121,7 +121,8 @@ void respond_vec(torch::Tensor& ret_buffer, int push_pull(std::vector& push_tensors, std::vector& push_keys, std::vector& pull_tensors, - std::vector& pull_keys) { + std::vector& pull_keys, + bool need_event = true) { PS_CHECK_EQ(push_tensors.size(), push_keys.size()); PS_CHECK_EQ(pull_tensors.size(), pull_keys.size()); @@ -138,7 +139,7 @@ int push_pull(std::vector& push_tensors, static_cast(pull_keys[i]), std::move(pull_tensors[i].detach()) }; } - return fworker_->ZBatchPushPull(push_batch, pull_batch); + return fworker_->ZBatchPushPull(push_batch, pull_batch, need_event); } void wait(int handler, uint64_t timeout_ms = 1000) { @@ -250,6 +251,7 @@ void pybind_public(py::module &m){ py::arg("push_keys"), py::arg("pull_tensors"), py::arg("pull_keys"), + py::arg("need_event") = true, py::call_guard()); m.def("wait", &wait, py::arg("handler"), diff --git a/fserver/csrc/wait_kernel.cu b/fserver/csrc/wait_kernel.cu index 20e45ae..2665b86 100644 --- a/fserver/csrc/wait_kernel.cu +++ b/fserver/csrc/wait_kernel.cu @@ -5,25 +5,35 @@ #include #include -__global__ void write_flag_kernel(int64_t* flag, int64_t seq) { +__global__ void write_flag_kernel(int64_t* flag, int64_t* seq) { + int64_t seq_value = seq[0]; if (threadIdx.x == 0) { - __threadfence_system(); - flag[0] = seq; + flag[0] = seq_value; + // 写入后执行 system fence,确保写入对所有线程和 CPU 可见 } + __threadfence_system(); } -__global__ void wait_flag_kernel(int64_t* flag, int64_t seq) { +__global__ void wait_flag_kernel(int64_t* flag, int64_t* seq) { if (threadIdx.x == 0) { // Mark pointer volatile so we reload host-written values each iteration. - volatile int64_t* flag_ptr = flag; - int64_t value = flag_ptr[0]; - while (value < seq) { + volatile int64_t* flag_ptr = flag, *seq_ptr = seq; + int64_t flag_value = flag_ptr[0]; + int64_t seq_value = seq_ptr[0]; + while (flag_value < seq_value) { __nanosleep(128); - value = flag_ptr[0]; + flag_value = flag_ptr[0]; } } } +__global__ void seq_add_one_kernel(int64_t* seq) { + if (threadIdx.x == 0) { + seq[0]++; + } + __threadfence_system(); +} + static void check_cuda(cudaError_t err, const char* msg) { TORCH_CHECK(err == cudaSuccess, msg, ": ", cudaGetErrorString(err)); } @@ -40,16 +50,23 @@ torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index) { return torch::from_blob(device_ptr, sizes, strides, [](void*){}, options); } -void write_flag(torch::Tensor flag, int64_t seq) { +void write_flag(torch::Tensor flag, torch::Tensor seq) { TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); - write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq); + write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq.data_ptr()); check_cuda(cudaGetLastError(), "write_flag_kernel launch failed"); } -void wait_flag(torch::Tensor flag, int64_t seq) { +void wait_flag(torch::Tensor flag, torch::Tensor seq) { TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor"); auto stream = at::cuda::getCurrentCUDAStream(flag.device().index()); - wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq); + wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr(), seq.data_ptr()); check_cuda(cudaGetLastError(), "wait_flag_kernel launch failed"); } + +void seq_add_one(torch::Tensor seq) { + TORCH_CHECK(seq.is_cuda(), "seq must be a CUDA tensor"); + auto stream = at::cuda::getCurrentCUDAStream(seq.device().index()); + seq_add_one_kernel<<<1, 1, 0, stream>>>(seq.data_ptr()); + check_cuda(cudaGetLastError(), "seq_add_one_kernel launch failed"); +} \ No newline at end of file diff --git a/include/ps/af_tensor_app.h b/include/ps/af_tensor_app.h index 1c55d71..25d0341 100644 --- a/include/ps/af_tensor_app.h +++ b/include/ps/af_tensor_app.h @@ -94,7 +94,8 @@ class AFTensorWorker { * @return An integer indicating the result of the operation. */ int ZBatchPushPull(KeyTensorBatch& push_tensors, - KeyTensorBatch& pull_tensors) { + KeyTensorBatch& pull_tensors, + bool need_event = true) { Backend::Get()->SetDevice(gpu_); auto server_ranges = Postoffice::GetWorker(instance_id_)->GetServerKeyRanges(); @@ -130,8 +131,11 @@ class AFTensorWorker { req.push = push_tensors; req.pull = pull_tensors; - req.event = GetEvent(); - req.event->Record(); + req.event = nullptr; + if (need_event) { + req.event = GetEvent(); + req.event->Record(); + } PS_VLOG(3) << "ts" << start_ts << " pushpull_queue_ push " << pushpull_queue_.Size(); diff --git a/setup.py b/setup.py index 15c74dd..c27d1e3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ from pathlib import Path def get_version(): - version = '0.0.4.post1' + version = '0.0.5.post1' # with open('stepkv/version.py', 'r') as fd: # version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]', # fd.read(), re.MULTILINE).group(1) @@ -62,16 +62,11 @@ def _get_cuda_bare_metal_version(cuda_dir): if use_cuda: extra_link += ['-lcuda', '-lcudart'] extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',] - extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_80,code=sm_80', - '--use_fast_math'] + cc_flag + extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_90,code=sm_90', '-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_89,code=sm_89','-gencode', 'arch=compute_90a,code=sm_90a', + '--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag bare_metal_major, bare_metal_minor = \ _get_cuda_bare_metal_version(cpp_extension.CUDA_HOME) - if int(bare_metal_major) >= 11: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_80,code=sm_80') - if int(bare_metal_minor) >= 8 or int(bare_metal_major) >= 12: - cc_flag.append('-gencode') - cc_flag.append('arch=compute_90,code=sm_90') + setup( name='FServer', description='A Remote FFN Server Implementation for AF Disaggregation', diff --git a/tests/fserver/test_kernel_wait.py b/tests/fserver/test_kernel_wait.py new file mode 100644 index 0000000..549baa6 --- /dev/null +++ b/tests/fserver/test_kernel_wait.py @@ -0,0 +1,210 @@ + +import threading +from queue import Queue + +import torch, os +import fserver_lib as f +import optimus +import numpy as np + + +def gen_push_key(private_key, microbatch=0, worker_rank=-1): + """ + Generate a key for push tensors based on microbatch, worker rank, and private key. + :param private_key: your own key, ranging from 0-255, can be used for identify different tensors + :param microbatch: microbatch id + :param worker_rank: current worker rank, otherwise retrieving it from environ + :return: the key for fserver + """ + assert 0 <= private_key < 256, f"illegal private key: {private_key}" + if worker_rank == -1: + if "DMLC_NODE_RANK" in os.environ: + worker_rank = int(os.environ["DMLC_NODE_RANK"]) + else: + worker_rank = 0 + return private_key + microbatch * (1 << 8) + worker_rank * (1 << 16) + + +def gen_pull_key(private_key, microbatch=0, worker_rank=-1): + """ + Generate a key for pull tensors based on microbatch, worker rank, and private key. + :param private_key: your own key, ranging from 0-255, can be used for identify different tensors + :param microbatch: microbatch id + :param worker_rank: current worker rank, otherwise retrieving it from environ + :return: the key for fserver + """ + assert 0 <= private_key < 256, f"illegal private key: {private_key}" + if worker_rank == -1: + if "DMLC_NODE_RANK" in os.environ: + worker_rank = int(os.environ["DMLC_NODE_RANK"]) + else: + worker_rank = 0 + return private_key + microbatch * (1 << 8) + worker_rank * (1 << 16) + (1 << 24) + + +def setup_seed(seed=42): + import random + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + torch.backends.cudnn.deterministic = True + + +is_worker = os.environ.get('DMLC_ROLE') == 'worker' +is_server = os.environ.get('DMLC_ROLE') == 'server' +server_count = int(os.environ.get('DMLC_NUM_SERVER','1')) +worker_count = int(os.environ.get('DMLC_NUM_WORKER','1')) +gpu = os.environ.get('STEPMESH_GPU', '0') +local_rank_num = int(os.environ.get('DMLC_GROUP_SIZE', '1')) +node_rank = int(os.environ.get('DMLC_NODE_RANK', '0')) +rank = node_rank * local_rank_num + int(gpu) +bsz, num_token, dim = 1, 8, 8 +num_iters = 24 +setup_seed(42) +torch.cuda.set_device(int(gpu)) + +f.init() + +# prepare initial buffers (each key should have a dedicated buffer) + +inp_tensors_buffers = [] +inp_tensors_keys = [] +out_tensors_buffers = [] +out_tensors_keys = [] + +for mb in range(3): + expert_token_cnt_buffer = torch.tensor([num_token for _ in range(bsz)], dtype=torch.int32, device=f'cuda:{gpu}') + tokens_buffers = [ + torch.rand([num_token, dim], dtype=torch.bfloat16, device=f'cuda:{gpu}') for _ in range(bsz) + ] + tokens_buffers.append(expert_token_cnt_buffer) + inp_tensors_buffers.append(tokens_buffers) + inp_tensors_keys.append([gen_push_key(i, mb) for i in range(len(tokens_buffers))]) + + o_tensors = [] + for _ in range(server_count ): + o_tensors += [torch.rand([num_token, dim], dtype=torch.bfloat16, device=f'cuda:{gpu}') for _ in range(bsz)] + out_tensors_buffers.append(o_tensors) + + out_tensors_keys.append([gen_pull_key(i, mb) for i in range(len(o_tensors))]) + + +print_queue = Queue() + +if is_worker: + f.barrier(False, True) + q = Queue() + time_list = [] + net_cost_list = [[] for _ in range(bsz + 1 + bsz)] + idx = 0 + device = torch.cuda.device(int(gpu)) + signal_flag_host = torch.zeros(1, dtype=torch.int64, pin_memory=True) + ack_flag_host = torch.zeros(1, dtype=torch.int64, pin_memory=True) + + signal_flag_dev = f.map_pinned_tensor(signal_flag_host, int(gpu)) + ack_flag_dev = f.map_pinned_tensor(ack_flag_host, int(gpu)) + + sequence_tensor = torch.zeros(1, dtype=torch.int64, device=f'cuda:{gpu}') + + expected_sequence = 1 + # 创建 sequence tensor 用于 CUDA Graph + # Graph 中不能使用动态的 Python int 值,需要使用 tensor + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=0, warmup=0, active=24), + on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./profiler_logs/rank_{rank}', use_gzip=True), + record_shapes=True, + with_stack=True, + ) + + thread_stop = threading.Event() + def cpu_handle_thread(): + global expected_sequence, idx, signal_flag_dev, ack_flag_dev, stream + while True: + signal_value = signal_flag_host.item() + if signal_value < expected_sequence: + if thread_stop.is_set(): + print("Cpu handle thread stop") + break + continue + handler = f.push_pull( + inp_tensors_buffers[idx % 3], + inp_tensors_keys[idx % 3], + out_tensors_buffers[idx % 3], + out_tensors_keys[idx % 3], + need_event=False, + ) + f.wait(handler) + expected_sequence += 1 + idx += 1 + print(f"wait done signal_value:{signal_value} expected_seq:{expected_sequence} ") + ack_flag_host.fill_(signal_value) + + + th = threading.Thread(target=cpu_handle_thread) + + print(f"start to run {num_iters}") + + graph = torch.cuda.CUDAGraph() + + ack_flag_host.fill_(num_iters) + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + graph.capture_begin() + for i in range(num_iters): + f.seq_add_one(sequence_tensor) # 更新 tensor 值 + f.write_flag(signal_flag_dev, sequence_tensor) + torch.cuda._sleep(1000) + f.wait_flag(ack_flag_dev, sequence_tensor) + graph.capture_end() + torch.cuda.synchronize() + + def reset_control(): + sequence_tensor.fill_(0) + ack_flag_host.fill_(0) + expected_sequence = 1 + torch.cuda.synchronize() + reset_control() + + th.start() + profiler.start() + # big graph replay + graph.replay() + profiler.step() + + # small graph replay + # for itr in range(num_iters): + # graph.replay() + # profiler.step() + torch.cuda.synchronize() + print(f"Worker stop") + profiler.stop() + thread_stop.set() + th.join() + +elif is_server: + ret_buffer = torch.rand([65535, dim], dtype=torch.bfloat16, device='cuda') + count = 0 + f.barrier(True, False) + def server(): + global count + iter_count = 0 + while True: + batches = f.get_batch() + # print(f"Server get batch: {batches}") + if len(batches) != 0: + iter_count += 1 + recv_tensor_list = [batches[i][1][0] for i in range(worker_count)] + comm_id_list = [batches[i][0] for i in range(worker_count)] + # torch.cuda._sleep(10000) + f.respond_vec(ret_buffer, recv_tensor_list, comm_id_list) + print(f"Server iter: {iter_count}/{num_iters}") + if iter_count == num_iters: + break + server() + torch.cuda.synchronize() + print(f"Server stop") + +print("Fserver Stop") +f.stop() From 2b7f86c64b8c3380b7624acccee8d2ca79013676 Mon Sep 17 00:00:00 2001 From: niehao Date: Thu, 6 Nov 2025 18:03:04 +0800 Subject: [PATCH 3/3] Lint Code --- include/ps/af_tensor_app.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/include/ps/af_tensor_app.h b/include/ps/af_tensor_app.h index 25d0341..2f71765 100644 --- a/include/ps/af_tensor_app.h +++ b/include/ps/af_tensor_app.h @@ -93,8 +93,7 @@ class AFTensorWorker { * where the pulled tensors and their associated keys will be stored. * @return An integer indicating the result of the operation. */ - int ZBatchPushPull(KeyTensorBatch& push_tensors, - KeyTensorBatch& pull_tensors, + int ZBatchPushPull(KeyTensorBatch& push_tensors, KeyTensorBatch& pull_tensors, bool need_event = true) { Backend::Get()->SetDevice(gpu_); auto server_ranges = @@ -135,7 +134,7 @@ class AFTensorWorker { if (need_event) { req.event = GetEvent(); req.event->Record(); - } + } PS_VLOG(3) << "ts" << start_ts << " pushpull_queue_ push " << pushpull_queue_.Size();