Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ cmake_minimum_required(VERSION 3.22 FATAL_ERROR)
project(af LANGUAGES C CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ")
execute_process(COMMAND ${Python_EXECUTABLE}
-c "import torch; print(int(torch.compiled_with_cxx11_abi()))"
OUTPUT_VARIABLE TORCH_CXX11_ABI OUTPUT_STRIP_TRAILING_WHITESPACE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 -D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI} ")


# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -fPIC -O3 -Wall -finline-functions -msse2 ")

# import pytorch library
find_package (Python COMPONENTS Interpreter Development)
Expand All @@ -19,7 +25,9 @@ find_package(Torch REQUIRED CONFIG)
message("MY TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}")
message("MY CUDA_INCLUDE_DIRS ${CUDA_INCLUDE_DIRS}")
include_directories(${TORCH_INCLUDE_DIRS})
# Save ABI setting before adding TORCH_CXX_FLAGS (which might override it)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
# Ensure ABI setting is preserved after TORCH_CXX_FLAGS

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
Expand Down
20 changes: 20 additions & 0 deletions fserver/csrc/kernel.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>


torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index);
void write_flag(torch::Tensor flag, torch::Tensor seq);
void wait_flag(torch::Tensor flag, torch::Tensor seq);
void seq_add_one(torch::Tensor seq);

void pybind_kernel(py::module &m){
// StepMesh utils
m.def("map_pinned_tensor", &map_pinned_tensor, py::arg("tensor"), py::arg("device_index"));
m.def("write_flag", &write_flag, py::arg("flag"), py::arg("seq"));
m.def("wait_flag", &wait_flag, py::arg("flag"), py::arg("seq"));
m.def("seq_add_one", &seq_add_one, py::arg("seq"));
}
10 changes: 8 additions & 2 deletions fserver/csrc/ops.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */


#include "./util.h"
#include "./util.hpp"

#include "./public.hpp"
#include "./private.hpp"
#ifdef DMLC_USE_CUDA
#include "./private.hpp"
#include "./kernel.hpp"
#endif

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind_public(m);
#ifdef DMLC_USE_CUDA
pybind_private(m);
pybind_kernel(m);
#endif
}
2 changes: 1 addition & 1 deletion fserver/csrc/private.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* Copyright (c) 2025, StepFun Authors. All rights reserved. */

#include "./util.h"
#include "./util.hpp"
#include "./public.hpp"
#include <future>
#ifdef DMLC_USE_CUDA
Expand Down
10 changes: 6 additions & 4 deletions fserver/csrc/public.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <chrono>

#include "./util.h"
#include "./util.hpp"

#ifndef PUBLIC_OPS_
#define PUBLIC_OPS_
Expand Down Expand Up @@ -121,7 +121,8 @@ void respond_vec(torch::Tensor& ret_buffer,
int push_pull(std::vector<torch::Tensor>& push_tensors,
std::vector<uint64_t>& push_keys,
std::vector<torch::Tensor>& pull_tensors,
std::vector<uint64_t>& pull_keys) {
std::vector<uint64_t>& pull_keys,
bool need_event = true) {

PS_CHECK_EQ(push_tensors.size(), push_keys.size());
PS_CHECK_EQ(pull_tensors.size(), pull_keys.size());
Expand All @@ -138,7 +139,7 @@ int push_pull(std::vector<torch::Tensor>& push_tensors,
static_cast<uint64_t>(pull_keys[i]), std::move(pull_tensors[i].detach())
};
}
return fworker_->ZBatchPushPull(push_batch, pull_batch);
return fworker_->ZBatchPushPull(push_batch, pull_batch, need_event);
}

void wait(int handler, uint64_t timeout_ms = 1000) {
Expand Down Expand Up @@ -237,7 +238,7 @@ uint64_t get_nanosecond() {


void pybind_public(py::module &m){
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
m.def("init", &init, py::call_guard<py::gil_scoped_release>());
m.def("stop", &stop, py::call_guard<py::gil_scoped_release>());

m.def("register_recv_buffer",
Expand All @@ -250,6 +251,7 @@ void pybind_public(py::module &m){
py::arg("push_keys"),
py::arg("pull_tensors"),
py::arg("pull_keys"),
py::arg("need_event") = true,
py::call_guard<py::none>());
m.def("wait", &wait,
py::arg("handler"),
Expand Down
File renamed without changes.
72 changes: 72 additions & 0 deletions fserver/csrc/wait_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>

__global__ void write_flag_kernel(int64_t* flag, int64_t* seq) {
int64_t seq_value = seq[0];
if (threadIdx.x == 0) {
flag[0] = seq_value;
// 写入后执行 system fence,确保写入对所有线程和 CPU 可见
}
__threadfence_system();
}

__global__ void wait_flag_kernel(int64_t* flag, int64_t* seq) {
if (threadIdx.x == 0) {
// Mark pointer volatile so we reload host-written values each iteration.
volatile int64_t* flag_ptr = flag, *seq_ptr = seq;
int64_t flag_value = flag_ptr[0];
int64_t seq_value = seq_ptr[0];
while (flag_value < seq_value) {
__nanosleep(128);
flag_value = flag_ptr[0];
}
}
}

__global__ void seq_add_one_kernel(int64_t* seq) {
if (threadIdx.x == 0) {
seq[0]++;
}
__threadfence_system();
}

static void check_cuda(cudaError_t err, const char* msg) {
TORCH_CHECK(err == cudaSuccess, msg, ": ", cudaGetErrorString(err));
}

torch::Tensor map_pinned_tensor(torch::Tensor tensor, int64_t device_index) {
TORCH_CHECK(tensor.is_pinned(), "tensor must be pinned");
void* host_ptr = tensor.data_ptr();
void* device_ptr = nullptr;
check_cuda(cudaHostGetDevicePointer(&device_ptr, host_ptr, 0),
"cudaHostGetDevicePointer failed");
auto options = tensor.options().device(torch::kCUDA, device_index);
auto sizes = tensor.sizes();
auto strides = tensor.strides();
return torch::from_blob(device_ptr, sizes, strides, [](void*){}, options);
}

void write_flag(torch::Tensor flag, torch::Tensor seq) {
TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor");
auto stream = at::cuda::getCurrentCUDAStream(flag.device().index());
write_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr<int64_t>(), seq.data_ptr<int64_t>());
check_cuda(cudaGetLastError(), "write_flag_kernel launch failed");
}

void wait_flag(torch::Tensor flag, torch::Tensor seq) {
TORCH_CHECK(flag.is_cuda(), "flag must be a CUDA tensor");
auto stream = at::cuda::getCurrentCUDAStream(flag.device().index());
wait_flag_kernel<<<1, 1, 0, stream>>>(flag.data_ptr<int64_t>(), seq.data_ptr<int64_t>());
check_cuda(cudaGetLastError(), "wait_flag_kernel launch failed");
}

void seq_add_one(torch::Tensor seq) {
TORCH_CHECK(seq.is_cuda(), "seq must be a CUDA tensor");
auto stream = at::cuda::getCurrentCUDAStream(seq.device().index());
seq_add_one_kernel<<<1, 1, 0, stream>>>(seq.data_ptr<int64_t>());
check_cuda(cudaGetLastError(), "seq_add_one_kernel launch failed");
}
11 changes: 7 additions & 4 deletions include/ps/af_tensor_app.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class AFTensorWorker {
* where the pulled tensors and their associated keys will be stored.
* @return An integer indicating the result of the operation.
*/
int ZBatchPushPull(KeyTensorBatch& push_tensors,
KeyTensorBatch& pull_tensors) {
int ZBatchPushPull(KeyTensorBatch& push_tensors, KeyTensorBatch& pull_tensors,
bool need_event = true) {
Backend::Get()->SetDevice(gpu_);
auto server_ranges =
Postoffice::GetWorker(instance_id_)->GetServerKeyRanges();
Expand Down Expand Up @@ -130,8 +130,11 @@ class AFTensorWorker {

req.push = push_tensors;
req.pull = pull_tensors;
req.event = GetEvent();
req.event->Record();
req.event = nullptr;
if (need_event) {
req.event = GetEvent();
req.event->Record();
}

PS_VLOG(3) << "ts" << start_ts << " pushpull_queue_ push "
<< pushpull_queue_.Size();
Expand Down
14 changes: 5 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pathlib import Path

def get_version():
version = '0.0.4.post1'
version = '0.0.5.post1'
# with open('stepkv/version.py', 'r') as fd:
# version = re.search(r'^__version__\s*=\s*[\'"]([^\'"]*)[\'"]',
# fd.read(), re.MULTILINE).group(1)
Expand Down Expand Up @@ -62,16 +62,11 @@ def _get_cuda_bare_metal_version(cuda_dir):
if use_cuda:
extra_link += ['-lcuda', '-lcudart']
extra_compile_args['cxx'] += ['-DDMLC_USE_CUDA',]
extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_70,code=sm_70',
'--use_fast_math'] + cc_flag
extra_compile_args['nvcc'] = ['-O3', '-gencode', 'arch=compute_90,code=sm_90', '-gencode', 'arch=compute_80,code=sm_80', '-gencode', 'arch=compute_89,code=sm_89','-gencode', 'arch=compute_90a,code=sm_90a',
'--use_fast_math', f'-D_GLIBCXX_USE_CXX11_ABI={str(int(torch_cxx11_abi))}'] + cc_flag
bare_metal_major, bare_metal_minor = \
_get_cuda_bare_metal_version(cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 8 or int(bare_metal_major) >= 12:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')

setup(
name='FServer',
description='A Remote FFN Server Implementation for AF Disaggregation',
Expand All @@ -84,6 +79,7 @@ def _get_cuda_bare_metal_version(cuda_dir):
'fserver_lib',
[
__SRC_PATH__ + 'ops.cc',
__SRC_PATH__ + 'wait_kernel.cu',
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link,
Expand Down
Loading