Skip to content

Commit 2db9ba1

Browse files
authored
rpc : add RPC_CMD_HELLO (#12955)
Add RPC_CMD_HELLO for getting the version of the protocol implemend by the server. Follow the semantic versioning rules at https://semver.org Hopefully this bring better user experience when we make breaking changes at the protocol level and avoid issues like #12465
1 parent 2f74c35 commit 2db9ba1

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

examples/rpc/rpc-server.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,10 @@ int main(int argc, char * argv[]) {
297297
}
298298
cache_dir = cache_dir_str.c_str();
299299
}
300-
printf("Starting RPC server\n");
300+
printf("Starting RPC server v%d.%d.%d\n",
301+
RPC_PROTO_MAJOR_VERSION,
302+
RPC_PROTO_MINOR_VERSION,
303+
RPC_PROTO_PATCH_VERSION);
301304
printf(" endpoint : %s\n", endpoint.c_str());
302305
printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
303306
printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));

ggml/include/ggml-rpc.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
extern "C" {
88
#endif
99

10+
#define RPC_PROTO_MAJOR_VERSION 1
11+
#define RPC_PROTO_MINOR_VERSION 0
12+
#define RPC_PROTO_PATCH_VERSION 0
1013
#define GGML_RPC_MAX_SERVERS 16
1114

1215
// backend API

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,19 @@ enum rpc_cmd {
9292
RPC_CMD_GET_DEVICE_MEMORY,
9393
RPC_CMD_INIT_TENSOR,
9494
RPC_CMD_GET_ALLOC_SIZE,
95+
RPC_CMD_HELLO,
9596
RPC_CMD_COUNT,
9697
};
9798

9899
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99100
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
100101

102+
struct rpc_msg_hello_rsp {
103+
uint8_t major;
104+
uint8_t minor;
105+
uint8_t patch;
106+
};
107+
101108
struct rpc_msg_get_alloc_size_req {
102109
rpc_tensor tensor;
103110
};
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
400407

401408
// RPC client-side implementation
402409

410+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
411+
rpc_msg_hello_rsp response;
412+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
413+
GGML_ASSERT(status);
414+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
415+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
416+
return false;
417+
}
418+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
419+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
420+
}
421+
return true;
422+
}
423+
403424
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
404425
static std::mutex mutex;
405426
std::lock_guard<std::mutex> lock(mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
433454
if (sock == nullptr) {
434455
return nullptr;
435456
}
457+
if (!check_server_version(sock)) {
458+
return nullptr;
459+
}
436460
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
437461
sockets[endpoint] = sock;
438462
return sock;
@@ -818,6 +842,7 @@ class rpc_server {
818842
}
819843
~rpc_server();
820844

845+
void hello(rpc_msg_hello_rsp & response);
821846
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822847
void get_alignment(rpc_msg_get_alignment_rsp & response);
823848
void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846871
std::unordered_set<ggml_backend_buffer_t> buffers;
847872
};
848873

874+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
875+
response.major = RPC_PROTO_MAJOR_VERSION;
876+
response.minor = RPC_PROTO_MINOR_VERSION;
877+
response.patch = RPC_PROTO_PATCH_VERSION;
878+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
879+
}
880+
849881
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850882
ggml_backend_buffer_type_t buft;
851883
struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
12711303
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12721304
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
12731305
rpc_server server(backend, cache_dir);
1306+
uint8_t cmd;
1307+
if (!recv_data(sockfd, &cmd, 1)) {
1308+
return;
1309+
}
1310+
// the first command sent by the client must be HELLO
1311+
if (cmd != RPC_CMD_HELLO) {
1312+
fprintf(stderr, "Expected HELLO command, update client\n");
1313+
return;
1314+
}
1315+
if (!recv_msg(sockfd, nullptr, 0)) {
1316+
return;
1317+
}
1318+
rpc_msg_hello_rsp response;
1319+
server.hello(response);
1320+
if (!send_msg(sockfd, &response, sizeof(response))) {
1321+
return;
1322+
}
12741323
while (true) {
1275-
uint8_t cmd;
12761324
if (!recv_data(sockfd, &cmd, 1)) {
12771325
break;
12781326
}
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
12821330
break;
12831331
}
12841332
switch (cmd) {
1333+
case RPC_CMD_HELLO: {
1334+
// HELLO command is handled above
1335+
return;
1336+
}
12851337
case RPC_CMD_ALLOC_BUFFER: {
12861338
rpc_msg_alloc_buffer_req request;
12871339
if (!recv_msg(sockfd, &request, sizeof(request))) {

0 commit comments

Comments
 (0)