Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 64 additions & 18 deletions tensorflow_networking/verbs/BUILD
Original file line number Diff line number Diff line change
@@ -1,13 +1,41 @@
# Description:
# Verbs RDMA communication interfaces and implementations for TensorFlow.

package(default_visibility = [
"//tensorflow_networking:__subpackages__",
])
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_library")

licenses(["notice"]) # Apache 2.0
# For platform specific build config
load(
"@org_tensorflow//tensorflow/core/platform:default/build_config.bzl",
"tf_proto_library_cc",
)

load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_library")
load(
"@org_tensorflow//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
"tf_cuda_library",
)

load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")

# For platform specific build config
load(
"@org_tensorflow//tensorflow/core/platform:default/build_config.bzl",
"tf_kernel_tests_linkstatic",
)

load(
"@org_tensorflow//tensorflow/core/platform:default/build_config_root.bzl",
"tf_cuda_tests_tags",
)

package(
default_visibility = [
"//tensorflow_networking:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)

exports_files(["LICENSE"])

Expand All @@ -19,12 +47,6 @@ filegroup(
]),
)

# For platform specific build config
load(
"@org_tensorflow//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library_cc",
)

tf_proto_library_cc(
name = "verbs_service_proto",
srcs = ["verbs_service.proto"],
Expand All @@ -43,6 +65,10 @@ cc_library(
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
],
linkopts = select({
"@org_tensorflow//tensorflow:with_verbs_support": ["-libverbs"],
"//conditions:default": [],
}),
)

cc_library(
Expand All @@ -52,9 +78,10 @@ cc_library(
deps = [
":grpc_verbs_service_impl",
":rdma_mgr",
":rdma",
":verbs_service_proto_cc",
"@org_tensorflow//tensorflow:grpc++",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:async_service_interface",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_call",
Expand All @@ -77,6 +104,7 @@ cc_library(
name = "grpc_verbs_client",
srcs = ["grpc_verbs_client.cc"],
hdrs = ["grpc_verbs_client.h"],
copts = ["-Og", "-g3"],
deps = [
":grpc_verbs_service_impl",
":verbs_service_proto_cc",
Expand All @@ -90,49 +118,66 @@ cc_library(
cc_library(
name = "rdma_rendezvous_mgr",
srcs = ["rdma_rendezvous_mgr.cc"],
hdrs = ["rdma_rendezvous_mgr.h"],
hdrs = ["rdma_rendezvous_mgr.h", "rdma.h"],
copts = ["-Og", "-g3"],
deps = [
":rdma_mgr",
":verbs_util",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
#"@org_tensorflow//tensorflow/core:gpu_runtime",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
#"@org_tensorflow//tensorflow/core/distributed_runtime:worker_cache_partial",
],
)

tf_cuda_library(
name = "rdma_mgr",
srcs = ["rdma_mgr.cc"],
hdrs = ["rdma_mgr.h"],
hdrs = ["rdma_mgr.h", "rdma.h"],
copts = ["-Og", "-g3"],
deps = [
":grpc_verbs_client",
":rdma",
#":rdma",
":verbs_util",
":verbs_service_proto_cc",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_channel",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
],
)


tf_cuda_library(
name = "rdma",
srcs = ["rdma.cc"],
hdrs = ["rdma.h"],
linkopts = ["-libverbs"],
linkopts = select({
"@org_tensorflow//tensorflow:with_verbs_support": ["-libverbs"],
"//conditions:default": [],
}),
copts = ["-Og", "-g3"],
deps = [
":rdma_mgr",
":grpc_verbs_client",
":verbs_service_proto_cc",
":verbs_util",
"@org_tensorflow//tensorflow/core",
#"@org_tensorflow//tensorflow/core:core_cpu_internal",
"@org_tensorflow//tensorflow/core:framework",
#"@org_tensorflow//tensorflow/core:framework_internal",
#"@org_tensorflow//tensorflow/core:gpu_runtime",
"@org_tensorflow//tensorflow/core:lib",
#"@org_tensorflow//tensorflow/core:lib_internal",
"@org_tensorflow//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
"@org_tensorflow//tensorflow/core/distributed_runtime:session_mgr",
"@org_tensorflow//tensorflow/core/distributed_runtime:worker_env",
"@org_tensorflow//tensorflow/core/distributed_runtime/rpc:grpc_channel",
],
)

Expand All @@ -151,3 +196,4 @@ cc_library(
],
alwayslink = 1,
)

82 changes: 0 additions & 82 deletions tensorflow_networking/verbs/Dockerfile

This file was deleted.

13 changes: 0 additions & 13 deletions tensorflow_networking/verbs/docker_howto.txt

This file was deleted.

33 changes: 33 additions & 0 deletions tensorflow_networking/verbs/grpc_verbs_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,39 @@ Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
return GetRemoteAddress(&call_options, request, response);
}


Status GrpcVerbsClient::ReqDriverMessage(CallOptions* call_options,
const DriverMessageReq* request,
DriverMessageResp* response) {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ReqDriverMessage(&ctx, *request, response));
}

Status GrpcVerbsClient::ReqDriverMessage(const DriverMessageReq* request,
DriverMessageResp* response) {
CallOptions call_options;
call_options.SetTimeout(-1); // no time out
return ReqDriverMessage(&call_options, request, response);
}

Status GrpcVerbsClient::ReqPleSendOrCheck(CallOptions* call_options,
const PleSendOrCheckReq* request,
PleSendOrCheckResp* response) {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
return FromGrpcStatus(stub_->ReqPleSendOrCheck(&ctx, *request, response));
}

Status GrpcVerbsClient::ReqPleSendOrCheck(const PleSendOrCheckReq* request,
PleSendOrCheckResp* response) {
CallOptions call_options;
call_options.SetTimeout(-1); // no time out
return ReqPleSendOrCheck(&call_options, request, response);
}

void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
int64 time_in_ms) {
if (time_in_ms > 0) {
Expand Down
17 changes: 15 additions & 2 deletions tensorflow_networking/verbs/grpc_verbs_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_
#define TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_CLIENT_H_

#include "tensorflow_networking/verbs/grpc_verbs_service_impl.h"
#include "tensorflow_networking/verbs/verbs_service.pb.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow_networking/verbs/grpc_verbs_service_impl.h"
#include "tensorflow_networking/verbs/verbs_service.pb.h"

namespace tensorflow {

Expand All @@ -37,6 +37,19 @@ class GrpcVerbsClient {
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);

Status ReqDriverMessage(CallOptions* call_options,
const DriverMessageReq* request,
DriverMessageResp* response);
Status ReqDriverMessage(const DriverMessageReq* request,
DriverMessageResp* response);

Status ReqPleSendOrCheck(CallOptions* call_options,
const PleSendOrCheckReq* request,
PleSendOrCheckResp* response);

Status ReqPleSendOrCheck(const PleSendOrCheckReq* request,
PleSendOrCheckResp* response);

private:
std::unique_ptr<grpc::VerbsService::Stub> stub_;

Expand Down
Loading