Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions sw/nic/gpuagent/cli/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ import (
)

var (
svcURL string
svcPort string
svcURL string
svcPort string
svcSocketPath string
useUnixSocket bool
)

// RootCmd represents the base command when called without any subcommands
Expand All @@ -59,6 +61,9 @@ func init() {
RootCmd.PersistentFlags().StringVar(&svcPort, "node-svc-port",
utils.GRPCDefaultPort,
"Remote node's service port")
RootCmd.PersistentFlags().StringVar(&svcSocketPath, "node-svc-socket",
utils.GRPCDefaultSocketPath,
"Unix socket path for GPU agent connection")
}

// NewGpuctlCommand exports the RootCmd for bash-completion
Expand All @@ -68,6 +73,19 @@ func NewGpuctlCommand() *cobra.Command {

func initConfig() {
// Note: initialize any config variables if required
utils.GRPCDefaultBaseURL = svcURL
utils.GRPCDefaultPort = svcPort
// Priority: --node-svc-socket > auto-detect socket > TCP/IP

// Check if user explicitly provided a socket path (different from default)
if svcSocketPath != utils.GRPCDefaultSocketPath {
// Explicit socket path specified by user
utils.GRPCDefaultSocketPath = svcSocketPath
}

// Check if default socket file exists
if _, err := os.Stat(utils.GRPCDefaultSocketPath); err != nil {
// Socket file doesn't exist, use TCP/IP connection
utils.GRPCDefaultSocketPath = ""
utils.GRPCDefaultBaseURL = svcURL
utils.GRPCDefaultPort = svcPort
}
}
29 changes: 19 additions & 10 deletions sw/nic/gpuagent/cli/utils/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ import (
)

var (
GRPCDefaultBaseURL = "127.0.0.1"
GRPCDefaultPort = "50061"
GRPCDefaultBaseURL = "127.0.0.1"
GRPCDefaultPort = "50061"
GRPCDefaultSocketPath = "/var/run/gpuagent.sock"
)

const (
Expand Down Expand Up @@ -66,16 +67,24 @@ func getClientReqTimeout() (uint, error) {
return uint(timeout), nil
}

// createNewGRPCClient creates a grpc connection to HAL
// we first check if secure grpc exists and if not fallback
// to regular grpc
// createNewGRPCClient creates a grpc connection to GPU agent
// supports both TCP/IP and Unix socket connections
func createNewGRPCClient() (*grpc.ClientConn, error) {
// unsecure grpc
agaPort := os.Getenv("AGA_GRPC_PORT")
if agaPort == "" {
agaPort = GRPCDefaultPort
var srvURL string

// check if Unix socket path is specified
if GRPCDefaultSocketPath != "" {
// use Unix socket
srvURL = "unix:" + GRPCDefaultSocketPath
} else {
// use TCP/IP
agaPort := os.Getenv("AGA_GRPC_PORT")
if agaPort == "" {
agaPort = GRPCDefaultPort
}
srvURL = GRPCDefaultBaseURL + ":" + agaPort
}
srvURL := GRPCDefaultBaseURL + ":" + agaPort

timeout, err := getClientPortConnTimeout()
if err != nil {
return nil, err
Expand Down
144 changes: 141 additions & 3 deletions sw/nic/gpuagent/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ limitations under the License.
//----------------------------------------------------------------------------

#include <memory>
#include <cerrno>
#include <cstring>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
#include <grpc++/grpc++.h>
#include "nic/sdk/include/sdk/base.hpp"
#include "nic/sdk/lib/logger/logger.h"
Expand Down Expand Up @@ -179,15 +185,112 @@ create_gpus (void)
return SDK_RET_OK;
}

/// \brief prepare Unix socket for gRPC server
/// \param[in] socket_path Unix socket file path
/// \return SDK_RET_OK on success, SDK_RET_ERR on failure
static sdk_ret_t
prepare_unix_socket (const std::string& socket_path)
{
// check if socket file exists
struct stat st;
if (stat(socket_path.c_str(), &st) == 0) {
// socket file exists - check if it's in use
if (S_ISSOCK(st.st_mode)) {
// try to connect to see if another instance is running
int test_sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (test_sock >= 0) {
struct sockaddr_un addr;
memset(&addr, 0, sizeof(addr));
addr.sun_family = AF_UNIX;
strncpy(addr.sun_path, socket_path.c_str(),
sizeof(addr.sun_path) - 1);

if (connect(test_sock, (struct sockaddr*)&addr,
sizeof(addr)) == 0) {
// another instance is running
close(test_sock);
AGA_TRACE_ERR("Another instance is already running on "
"socket {}", socket_path.c_str());
fprintf(stderr, "Error: Another GPU agent instance is "
"already running on socket %s\n",
socket_path.c_str());
return SDK_RET_ERR;
}
close(test_sock);
}

// socket exists but not in use - clean it up
AGA_TRACE_INFO("Cleaning up stale socket file {}",
socket_path.c_str());
if (unlink(socket_path.c_str()) != 0) {
AGA_TRACE_ERR("Failed to remove stale socket file {}: {}",
socket_path.c_str(), strerror(errno));
fprintf(stderr, "Error: Failed to remove stale socket "
"file %s: %s\n", socket_path.c_str(),
strerror(errno));
return SDK_RET_ERR;
}
} else {
// file exists but is not a socket
AGA_TRACE_ERR("Path {} exists but is not a socket",
socket_path.c_str());
fprintf(stderr, "Error: Path %s exists but is not a socket\n",
socket_path.c_str());
return SDK_RET_ERR;
}
}

// ensure parent directory exists
size_t last_slash = socket_path.find_last_of('/');
if (last_slash != std::string::npos) {
std::string dir_path = socket_path.substr(0, last_slash);
struct stat dir_st;
if (stat(dir_path.c_str(), &dir_st) != 0) {
// directory doesn't exist - create it
AGA_TRACE_INFO("Creating directory {} for socket",
dir_path.c_str());
if (mkdir(dir_path.c_str(), 0755) != 0 && errno != EEXIST) {
AGA_TRACE_ERR("Failed to create directory {}: {}",
dir_path.c_str(), strerror(errno));
fprintf(stderr, "Error: Failed to create directory %s: %s\n",
dir_path.c_str(), strerror(errno));
return SDK_RET_ERR;
}
} else if (!S_ISDIR(dir_st.st_mode)) {
AGA_TRACE_ERR("Path {} exists but is not a directory",
dir_path.c_str());
fprintf(stderr, "Error: Path %s exists but is not a directory\n",
dir_path.c_str());
return SDK_RET_ERR;
}
}

return SDK_RET_OK;
}

/// @brief cleanup Unix socket file on exit
/// @param[in] socket_path
static void
clean_unix_socket (const std::string& socket_path)
{
if (unlink(socket_path.c_str()) != 0) {
AGA_TRACE_WARN("Failed to remove socket file {} on exit: {}",
socket_path.c_str(), strerror(errno));
}
}

/// \brief start the gRPC server
/// \param[in] grpc_server gRPC server (IP:port) string
/// \param[in] grpc_server gRPC server (IP:port or unix:socket_path) string
/// \param[in] grpc_server_type gRPC server type (TCP or Unix socket)
static void
grpc_server_start (const std::string& grpc_server)
grpc_server_start (const std::string& grpc_server,
aga_grpc_server_type_t grpc_server_type)
{
GPUSvcImpl gpu_svc;
TopoSvcImpl topo_svc;
DebugSvcImpl debug_svc;
EventSvcImpl event_svc;
std::string socket_path;
ServerBuilder server_builder;
DebugGPUSvcImpl debug_gpu_svc;
GPUWatchSvcImpl gpu_watch_svc;
Expand All @@ -209,6 +312,20 @@ grpc_server_start (const std::string& grpc_server)
1);
// send continuous keepalive messages as long as channel is open
server_builder.AddChannelArgument(GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA, 0);

// handle Unix socket specific setup
if (grpc_server_type == AGA_GRPC_SERVER_TYPE_UNIX) {
// extract socket path from "unix:path" format
socket_path = grpc_server.substr(5); // skip "unix:"
// prepare Unix socket (check for existing instances, cleanup, etc.)
if (prepare_unix_socket(socket_path) != SDK_RET_OK) {
exit(1);
}
} else {
// cleanup any existing Unix socket on TCP initialization
clean_unix_socket(AGA_DEFAULT_UNIX_SOCKET_PATH);
}

server_builder.AddListeningPort(grpc_server,
grpc::InsecureServerCredentials());
// restrict max. no. of gRPC threads that can be spawned & active at any
Expand All @@ -227,7 +344,28 @@ grpc_server_start (const std::string& grpc_server)
AGA_TRACE_DEBUG("gRPC server listening on {} ...",
grpc_server.c_str());
g_grpc_server = server_builder.BuildAndStart();

if (grpc_server_type == AGA_GRPC_SERVER_TYPE_UNIX)
{
// set socket permissions to 0600 (rw-------) so only root can access
if (chmod(socket_path.c_str(), 0600) != 0)
{
AGA_TRACE_ERR("Failed to set permissions on socket {}: {}",
socket_path.c_str(), strerror(errno));
fprintf(stderr, "Error: Failed to set permissions on socket %s: %s\n",
socket_path.c_str(), strerror(errno));
exit(1);
}
AGA_TRACE_DEBUG("Set permissions on socket {} to 0600",
socket_path.c_str());
}

g_grpc_server->Wait();

// cleanup Unix socket on exit
if (grpc_server_type == AGA_GRPC_SERVER_TYPE_UNIX) {
clean_unix_socket(socket_path);
}
}

static int
Expand Down Expand Up @@ -291,6 +429,6 @@ aga_init (aga_init_params_t *init_params)
return ret;
}
// register for all gRPC services and start the gRPC server
grpc_server_start(init_params->grpc_server);
grpc_server_start(init_params->grpc_server, init_params->grpc_server_type);
return SDK_RET_OK;
}
12 changes: 11 additions & 1 deletion sw/nic/gpuagent/init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,21 @@ limitations under the License.
#define AGA_DEFAULT_GRPC_SERVER_PORT 50061
/// gRPC server:port string length
#define AGA_GRPC_SERVER_STR_LEN 64
/// Default Unix socket path
#define AGA_DEFAULT_UNIX_SOCKET_PATH "/var/run/gpuagent.sock"

/// \brief gRPC server type
typedef enum aga_grpc_server_type_e {
AGA_GRPC_SERVER_TYPE_TCP = 0, ///< TCP/IP based gRPC server
AGA_GRPC_SERVER_TYPE_UNIX = 1, ///< Unix socket based gRPC server
} aga_grpc_server_type_t;

/// \brief initialization parameters
typedef struct aga_init_params_s {
// gRPC server (IP:port)
// gRPC server (IP:port or unix socket path)
char grpc_server[AGA_GRPC_SERVER_STR_LEN];
// gRPC server type (TCP or Unix socket)
aga_grpc_server_type_t grpc_server_type;
} aga_init_params_t;

/// \brief initialize the agent state, threads etc.
Expand Down
50 changes: 40 additions & 10 deletions sw/nic/gpuagent/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,17 @@ static void inline
print_usage (char **argv)
{
fprintf(stdout, "Usage : %s [-p <port> | --grpc-server-port <port>] "
"[-i <ip-addr> | --grpc-server-ip <ip-addr>]\n\n",
"[-i <ip-addr> | --grpc-server-ip <ip-addr>] "
"[-s <socket-path> | --grpc-unix-socket <socket-path>]\n\n",
argv[0]);
fprintf(stdout, "Options:\n");
fprintf(stdout, " -p, --grpc-server-port <port> gRPC server port (default: %d)\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%u

AGA_DEFAULT_GRPC_SERVER_PORT);
fprintf(stdout, " -i, --grpc-server-ip <ip-addr> gRPC server IP address (default: 127.0.0.1)\n");
fprintf(stdout, " -s, --grpc-unix-socket <path> Use Unix socket instead of TCP/IP (default: %s)\n",
AGA_DEFAULT_UNIX_SOCKET_PATH);
fprintf(stdout, " -h, --help Display this help message\n\n");
fprintf(stdout, "Note: If -s/--grpc-unix-socket is specified, -p and -i options are ignored.\n\n");
fprintf(stdout, "Use -h | --help for help\n");
}

Expand All @@ -66,17 +75,20 @@ main (int argc, char **argv)
std::string grpc_server;
std::string grpc_server_ip;
std::string grpc_server_port;
std::string grpc_unix_socket;
bool use_unix_socket = false;
aga_init_params_t init_params = {};
// command line options
struct option longopts[] = {
{ "grpc-server-port", required_argument, NULL, 'p' },
{ "grpc-server-ip", required_argument, NULL, 'i' },
{ "grpc-unix-socket", optional_argument, NULL, 's' },
{ "help", no_argument, NULL, 'h' },
{ 0, 0, NULL, 0 }
};

// parse CLI options
while ((oc = getopt_long(argc, argv, ":hp:i:", longopts, NULL)) != -1) {
while ((oc = getopt_long(argc, argv, ":hp:i:s:", longopts, NULL)) != -1) {
switch (oc) {
case 'p':
try {
Expand Down Expand Up @@ -107,6 +119,11 @@ main (int argc, char **argv)
grpc_server_ip = optarg;
break;

case 's':
grpc_unix_socket = optarg;
use_unix_socket = true;
break;

case 'h':
print_usage(argv);
exit(0);
Expand All @@ -117,15 +134,28 @@ main (int argc, char **argv)
break;
}
}
// use default IP for gRPC server if not specified
if (grpc_server_ip.empty()) {
grpc_server_ip = "127.0.0.1";
}
// use default port for gRPC server if not specified
if (grpc_server_port.empty()) {
grpc_server_port = std::to_string(AGA_DEFAULT_GRPC_SERVER_PORT);
// determine gRPC server type and address
if (use_unix_socket) {
// use Unix socket
if (grpc_unix_socket.empty()) {
grpc_unix_socket = AGA_DEFAULT_UNIX_SOCKET_PATH;
}
grpc_server = "unix:" + grpc_unix_socket;
init_params.grpc_server_type = AGA_GRPC_SERVER_TYPE_UNIX;
} else {
// use TCP/IP
// use default IP for gRPC server if not specified
if (grpc_server_ip.empty()) {
grpc_server_ip = "127.0.0.1";
}
// use default port for gRPC server if not specified
if (grpc_server_port.empty()) {
grpc_server_port = std::to_string(AGA_DEFAULT_GRPC_SERVER_PORT);
}
grpc_server = grpc_server_ip + ":" + grpc_server_port;
init_params.grpc_server_type = AGA_GRPC_SERVER_TYPE_TCP;
}
grpc_server = grpc_server_ip + ":" + grpc_server_port;

// initialize the init params
strncpy(init_params.grpc_server, grpc_server.c_str(),
AGA_GRPC_SERVER_STR_LEN);
Expand Down