@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
82
82
83
83
// RPC commands
84
84
enum rpc_cmd {
85
- ALLOC_BUFFER = 0 ,
86
- GET_ALIGNMENT,
87
- GET_MAX_SIZE,
88
- BUFFER_GET_BASE,
89
- FREE_BUFFER,
90
- BUFFER_CLEAR,
91
- SET_TENSOR,
92
- GET_TENSOR,
93
- COPY_TENSOR,
94
- GRAPH_COMPUTE,
95
- GET_DEVICE_MEMORY,
85
+ RPC_CMD_ALLOC_BUFFER = 0 ,
86
+ RPC_CMD_GET_ALIGNMENT,
87
+ RPC_CMD_GET_MAX_SIZE,
88
+ RPC_CMD_BUFFER_GET_BASE,
89
+ RPC_CMD_FREE_BUFFER,
90
+ RPC_CMD_BUFFER_CLEAR,
91
+ RPC_CMD_SET_TENSOR,
92
+ RPC_CMD_GET_TENSOR,
93
+ RPC_CMD_COPY_TENSOR,
94
+ RPC_CMD_GRAPH_COMPUTE,
95
+ RPC_CMD_GET_DEVICE_MEMORY,
96
+ RPC_CMD_COUNT,
96
97
};
97
98
98
99
// RPC data structures
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
330
331
uint64_t remote_ptr = ctx->remote_ptr ;
331
332
memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
332
333
std::vector<uint8_t > output;
333
- bool status = send_rpc_cmd (ctx->sock , FREE_BUFFER , input, output);
334
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER , input, output);
334
335
GGML_ASSERT (status);
335
336
GGML_ASSERT (output.empty ());
336
337
delete ctx;
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
346
347
uint64_t remote_ptr = ctx->remote_ptr ;
347
348
memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
348
349
std::vector<uint8_t > output;
349
- bool status = send_rpc_cmd (ctx->sock , BUFFER_GET_BASE , input, output);
350
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE , input, output);
350
351
GGML_ASSERT (status);
351
352
GGML_ASSERT (output.size () == sizeof (uint64_t ));
352
353
// output serialization format: | base_ptr (8 bytes) |
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
405
406
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
406
407
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
407
408
std::vector<uint8_t > output;
408
- bool status = send_rpc_cmd (ctx->sock , SET_TENSOR , input, output);
409
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR , input, output);
409
410
GGML_ASSERT (status);
410
411
}
411
412
@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
419
420
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
420
421
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
421
422
std::vector<uint8_t > output;
422
- bool status = send_rpc_cmd (ctx->sock , GET_TENSOR , input, output);
423
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR , input, output);
423
424
GGML_ASSERT (status);
424
425
GGML_ASSERT (output.size () == size);
425
426
// output serialization format: | data (size bytes) |
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
444
445
memcpy (input.data (), &rpc_src, sizeof (rpc_src));
445
446
memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
446
447
std::vector<uint8_t > output;
447
- bool status = send_rpc_cmd (ctx->sock , COPY_TENSOR , input, output);
448
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR , input, output);
448
449
GGML_ASSERT (status);
449
450
// output serialization format: | result (1 byte) |
450
451
GGML_ASSERT (output.size () == 1 );
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
459
460
memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
460
461
memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
461
462
std::vector<uint8_t > output;
462
- bool status = send_rpc_cmd (ctx->sock , BUFFER_CLEAR , input, output);
463
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR , input, output);
463
464
GGML_ASSERT (status);
464
465
}
465
466
@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
488
489
memcpy (input.data (), &size, sizeof (size));
489
490
std::vector<uint8_t > output;
490
491
auto sock = get_socket (buft_ctx->endpoint );
491
- bool status = send_rpc_cmd (sock, ALLOC_BUFFER , input, output);
492
+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER , input, output);
492
493
GGML_ASSERT (status);
493
494
GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
494
495
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
511
512
// input serialization format: | 0 bytes |
512
513
std::vector<uint8_t > input;
513
514
std::vector<uint8_t > output;
514
- bool status = send_rpc_cmd (sock, GET_ALIGNMENT , input, output);
515
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT , input, output);
515
516
GGML_ASSERT (status);
516
517
GGML_ASSERT (output.size () == sizeof (uint64_t ));
517
518
// output serialization format: | alignment (8 bytes) |
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
529
530
// input serialization format: | 0 bytes |
530
531
std::vector<uint8_t > input;
531
532
std::vector<uint8_t > output;
532
- bool status = send_rpc_cmd (sock, GET_MAX_SIZE , input, output);
533
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE , input, output);
533
534
GGML_ASSERT (status);
534
535
GGML_ASSERT (output.size () == sizeof (uint64_t ));
535
536
// output serialization format: | max_size (8 bytes) |
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
622
623
serialize_graph (cgraph, input);
623
624
std::vector<uint8_t > output;
624
625
auto sock = get_socket (rpc_ctx->endpoint );
625
- bool status = send_rpc_cmd (sock, GRAPH_COMPUTE , input, output);
626
+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE , input, output);
626
627
GGML_ASSERT (status);
627
628
GGML_ASSERT (output.size () == 1 );
628
629
return (enum ggml_status)output[0 ];
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
719
720
// input serialization format: | 0 bytes |
720
721
std::vector<uint8_t > input;
721
722
std::vector<uint8_t > output;
722
- bool status = send_rpc_cmd (sock, GET_DEVICE_MEMORY , input, output);
723
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY , input, output);
723
724
GGML_ASSERT (status);
724
725
GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
725
726
// output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
1098
1099
if (!recv_data (sockfd, &cmd, 1 )) {
1099
1100
break ;
1100
1101
}
1102
+ if (cmd >= RPC_CMD_COUNT) {
1103
+ // fail fast if the command is invalid
1104
+ fprintf (stderr, " Unknown command: %d\n " , cmd);
1105
+ break ;
1106
+ }
1101
1107
std::vector<uint8_t > input;
1102
1108
std::vector<uint8_t > output;
1103
1109
uint64_t input_size;
1104
1110
if (!recv_data (sockfd, &input_size, sizeof (input_size))) {
1105
1111
break ;
1106
1112
}
1107
- input.resize (input_size);
1113
+ try {
1114
+ input.resize (input_size);
1115
+ } catch (const std::bad_alloc & e) {
1116
+ fprintf (stderr, " Failed to allocate input buffer of size %" PRIu64 " \n " , input_size);
1117
+ break ;
1118
+ }
1108
1119
if (!recv_data (sockfd, input.data (), input_size)) {
1109
1120
break ;
1110
1121
}
1111
1122
bool ok = true ;
1112
1123
switch (cmd) {
1113
- case ALLOC_BUFFER : {
1124
+ case RPC_CMD_ALLOC_BUFFER : {
1114
1125
ok = server.alloc_buffer (input, output);
1115
1126
break ;
1116
1127
}
1117
- case GET_ALIGNMENT : {
1128
+ case RPC_CMD_GET_ALIGNMENT : {
1118
1129
server.get_alignment (output);
1119
1130
break ;
1120
1131
}
1121
- case GET_MAX_SIZE : {
1132
+ case RPC_CMD_GET_MAX_SIZE : {
1122
1133
server.get_max_size (output);
1123
1134
break ;
1124
1135
}
1125
- case BUFFER_GET_BASE : {
1136
+ case RPC_CMD_BUFFER_GET_BASE : {
1126
1137
ok = server.buffer_get_base (input, output);
1127
1138
break ;
1128
1139
}
1129
- case FREE_BUFFER : {
1140
+ case RPC_CMD_FREE_BUFFER : {
1130
1141
ok = server.free_buffer (input);
1131
1142
break ;
1132
1143
}
1133
- case BUFFER_CLEAR : {
1144
+ case RPC_CMD_BUFFER_CLEAR : {
1134
1145
ok = server.buffer_clear (input);
1135
1146
break ;
1136
1147
}
1137
- case SET_TENSOR : {
1148
+ case RPC_CMD_SET_TENSOR : {
1138
1149
ok = server.set_tensor (input);
1139
1150
break ;
1140
1151
}
1141
- case GET_TENSOR : {
1152
+ case RPC_CMD_GET_TENSOR : {
1142
1153
ok = server.get_tensor (input, output);
1143
1154
break ;
1144
1155
}
1145
- case COPY_TENSOR : {
1156
+ case RPC_CMD_COPY_TENSOR : {
1146
1157
ok = server.copy_tensor (input, output);
1147
1158
break ;
1148
1159
}
1149
- case GRAPH_COMPUTE : {
1160
+ case RPC_CMD_GRAPH_COMPUTE : {
1150
1161
ok = server.graph_compute (input, output);
1151
1162
break ;
1152
1163
}
1153
- case GET_DEVICE_MEMORY : {
1164
+ case RPC_CMD_GET_DEVICE_MEMORY : {
1154
1165
// output serialization format: | free (8 bytes) | total (8 bytes) |
1155
1166
output.resize (2 *sizeof (uint64_t ), 0 );
1156
1167
memcpy (output.data (), &free_mem, sizeof (free_mem));
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
1203
1214
return ;
1204
1215
}
1205
1216
printf (" Accepted client connection, free_mem=%zu, total_mem=%zu\n " , free_mem, total_mem);
1217
+ fflush (stdout);
1206
1218
rpc_serve_client (backend, client_socket->fd , free_mem, total_mem);
1207
1219
printf (" Client connection closed\n " );
1220
+ fflush (stdout);
1208
1221
}
1209
1222
#ifdef _WIN32
1210
1223
WSACleanup ();
0 commit comments