Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions fserver/csrc/public.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ int push_pull(std::vector<torch::Tensor>& push_tensors,
}

void wait(int handler, uint64_t timeout_ms = 1000) {
fworker_->Wait(handler, timeout_ms);
fworker_->Wait(handler, timeout_ms);
}

void barrier(bool include_server, bool include_worker, bool instrance_barrier=true) {
Expand Down Expand Up @@ -254,7 +254,7 @@ void pybind_public(py::module &m){
m.def("wait", &wait,
py::arg("handler"),
py::arg("timeout_ms") = 10000,
py::call_guard<py::none>());
py::call_guard<py::gil_scoped_release>());

// APIs for FFN Instances
m.def("get_batch", &get_batch, py::call_guard<py::none>());
Expand Down
32 changes: 19 additions & 13 deletions include/ps/af_tensor_app.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
#include <iostream>
#include <sstream>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>

#include "ps/base.h"
#include "ps/hash_table8.hpp"
#include "ps/internal/backend.h"
#include "ps/internal/utils.h"
#include "ps/kv_app.h"
Expand Down Expand Up @@ -138,7 +138,7 @@ class AFTensorWorker {
pushpull_queue_.Push(std::move(req));

// std::unique_lock<std::mutex> timestamp_lock(timestamp_mu_);
batch_timestamps_[start_ts] = std::move(timestamps);
batch_timestamps_.emplace_unique(start_ts, std::move(timestamps));
return start_ts;
}

Expand All @@ -149,12 +149,12 @@ class AFTensorWorker {
void Wait(int timestamp, uint64_t timeout_ms = 10000) {
kv_.Wait(timestamp, timeout_ms);
// std::unique_lock<std::mutex> lock(timestamp_mu_);
auto itr = batch_timestamps_.find(timestamp);
if (itr != batch_timestamps_.end()) {
for (auto ts : itr->second) {
auto v = batch_timestamps_.try_get(timestamp);
if (v) {
for (auto ts : *v) {
kv_.Wait(ts, timeout_ms);
}
batch_timestamps_.erase(itr);
batch_timestamps_.erase(timestamp);
}
}

Expand All @@ -166,12 +166,18 @@ class AFTensorWorker {
std::vector<int> handlers;
handlers.push_back(timestamp);
std::unique_lock<std::mutex> lock(timestamp_mu_);
auto itr = batch_timestamps_.find(timestamp);
if (itr != batch_timestamps_.end()) {
for (auto ts : itr->second) {
auto v = batch_timestamps_.try_get(timestamp);
if (v) {
for (auto ts : *v) {
handlers.push_back(ts);
}
}
// auto itr = batch_timestamps_.find(timestamp);
// if (itr != batch_timestamps_.end()) {
// for (auto ts : itr->second) {
// handlers.push_back(ts);
// }
// }
return handlers;
}

Expand Down Expand Up @@ -201,7 +207,7 @@ class AFTensorWorker {
}

void PushPullWorker() {
BindCpuCore(3, 1);
BindCpuCore(4, 1);
Backend::Get()->SetDevice(gpu_);
while (true) {
PS_VLOG(4) << "pushpull_queue_ Loop wait ";
Expand Down Expand Up @@ -288,7 +294,7 @@ class AFTensorWorker {
msg.meta.dtype = static_cast<int>(tensor.scalar_type());
msg.meta.shape.clear();
for (int64_t s = 0; s < tensor.dim(); s++) {
msg.meta.shape.push_back(tensor.size(i));
msg.meta.shape.push_back(tensor.size(s));
}
msg.data.clear();
msg.AddData(key);
Expand Down Expand Up @@ -360,7 +366,7 @@ class AFTensorWorker {
/** \brief API mutex */
mutable std::mutex mu_;
/** \brief record timestamps for each batch */
std::unordered_map<int, std::vector<int>> batch_timestamps_;
emhash8::HashMap<int, std::vector<int>> batch_timestamps_;
/** \brief mutex for record timestamps */
std::mutex timestamp_mu_;
/** \brief tensor events */
Expand Down Expand Up @@ -590,7 +596,7 @@ class AFTensorServer {
}

void ResponseWorker() {
BindCpuCore(3, 1);
BindCpuCore(1, 1);
Backend::Get()->SetDevice(gpu_);
PS_LOG(INFO) << "Start ResponseWorker " << gpu_;
while (!response_stop_.load()) {
Expand Down
Loading