@@ -92,12 +92,19 @@ enum rpc_cmd {
92
92
RPC_CMD_GET_DEVICE_MEMORY,
93
93
RPC_CMD_INIT_TENSOR,
94
94
RPC_CMD_GET_ALLOC_SIZE,
95
+ RPC_CMD_HELLO,
95
96
RPC_CMD_COUNT,
96
97
};
97
98
98
99
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99
100
const size_t HASH_THRESHOLD = 10 * 1024 * 1024 ;
100
101
102
+ struct rpc_msg_hello_rsp {
103
+ uint8_t major;
104
+ uint8_t minor;
105
+ uint8_t patch;
106
+ };
107
+
101
108
struct rpc_msg_get_alloc_size_req {
102
109
rpc_tensor tensor;
103
110
};
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
400
407
401
408
// RPC client-side implementation
402
409
410
+ static bool check_server_version (const std::shared_ptr<socket_t > & sock) {
411
+ rpc_msg_hello_rsp response;
412
+ bool status = send_rpc_cmd (sock, RPC_CMD_HELLO, nullptr , 0 , &response, sizeof (response));
413
+ GGML_ASSERT (status);
414
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
415
+ fprintf (stderr, " RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
416
+ return false ;
417
+ }
418
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
419
+ fprintf (stderr, " WARNING: RPC server version mismatch: %d.%d.%d\n " , response.major , response.minor , response.patch );
420
+ }
421
+ return true ;
422
+ }
423
+
403
424
static std::shared_ptr<socket_t > get_socket (const std::string & endpoint) {
404
425
static std::mutex mutex;
405
426
std::lock_guard<std::mutex> lock (mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
433
454
if (sock == nullptr ) {
434
455
return nullptr ;
435
456
}
457
+ if (!check_server_version (sock)) {
458
+ return nullptr ;
459
+ }
436
460
GGML_PRINT_DEBUG (" [%s] connected to %s, sockfd=%d\n " , __func__, endpoint.c_str (), sock->fd );
437
461
sockets[endpoint] = sock;
438
462
return sock;
@@ -818,6 +842,7 @@ class rpc_server {
818
842
}
819
843
~rpc_server ();
820
844
845
+ void hello (rpc_msg_hello_rsp & response);
821
846
void alloc_buffer (const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822
847
void get_alignment (rpc_msg_get_alignment_rsp & response);
823
848
void get_max_size (rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ class rpc_server {
846
871
std::unordered_set<ggml_backend_buffer_t > buffers;
847
872
};
848
873
874
+ void rpc_server::hello (rpc_msg_hello_rsp & response) {
875
+ response.major = RPC_PROTO_MAJOR_VERSION;
876
+ response.minor = RPC_PROTO_MINOR_VERSION;
877
+ response.patch = RPC_PROTO_PATCH_VERSION;
878
+ GGML_PRINT_DEBUG (" [%s] version: %d.%d.%d\n " , __func__, response.major , response.minor , response.patch );
879
+ }
880
+
849
881
bool rpc_server::get_alloc_size (const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850
882
ggml_backend_buffer_type_t buft;
851
883
struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
1271
1303
static void rpc_serve_client (ggml_backend_t backend, const char * cache_dir,
1272
1304
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1273
1305
rpc_server server (backend, cache_dir);
1306
+ uint8_t cmd;
1307
+ if (!recv_data (sockfd, &cmd, 1 )) {
1308
+ return ;
1309
+ }
1310
+ // the first command sent by the client must be HELLO
1311
+ if (cmd != RPC_CMD_HELLO) {
1312
+ fprintf (stderr, " Expected HELLO command, update client\n " );
1313
+ return ;
1314
+ }
1315
+ if (!recv_msg (sockfd, nullptr , 0 )) {
1316
+ return ;
1317
+ }
1318
+ rpc_msg_hello_rsp response;
1319
+ server.hello (response);
1320
+ if (!send_msg (sockfd, &response, sizeof (response))) {
1321
+ return ;
1322
+ }
1274
1323
while (true ) {
1275
- uint8_t cmd;
1276
1324
if (!recv_data (sockfd, &cmd, 1 )) {
1277
1325
break ;
1278
1326
}
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1282
1330
break ;
1283
1331
}
1284
1332
switch (cmd) {
1333
+ case RPC_CMD_HELLO: {
1334
+ // HELLO command is handled above
1335
+ return ;
1336
+ }
1285
1337
case RPC_CMD_ALLOC_BUFFER: {
1286
1338
rpc_msg_alloc_buffer_req request;
1287
1339
if (!recv_msg (sockfd, &request, sizeof (request))) {
0 commit comments