Skip to content

Commit 18eaf29

Browse files
authored
rpc : prevent crashes on invalid input (#9040)
Add more checks which prevent RPC server from crashing if invalid input is received from client
1 parent 554b049 commit 18eaf29

File tree

1 file changed

+47
-34
lines changed

1 file changed

+47
-34
lines changed

ggml/src/ggml-rpc.cpp

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
8282

8383
// RPC commands
8484
enum rpc_cmd {
85-
ALLOC_BUFFER = 0,
86-
GET_ALIGNMENT,
87-
GET_MAX_SIZE,
88-
BUFFER_GET_BASE,
89-
FREE_BUFFER,
90-
BUFFER_CLEAR,
91-
SET_TENSOR,
92-
GET_TENSOR,
93-
COPY_TENSOR,
94-
GRAPH_COMPUTE,
95-
GET_DEVICE_MEMORY,
85+
RPC_CMD_ALLOC_BUFFER = 0,
86+
RPC_CMD_GET_ALIGNMENT,
87+
RPC_CMD_GET_MAX_SIZE,
88+
RPC_CMD_BUFFER_GET_BASE,
89+
RPC_CMD_FREE_BUFFER,
90+
RPC_CMD_BUFFER_CLEAR,
91+
RPC_CMD_SET_TENSOR,
92+
RPC_CMD_GET_TENSOR,
93+
RPC_CMD_COPY_TENSOR,
94+
RPC_CMD_GRAPH_COMPUTE,
95+
RPC_CMD_GET_DEVICE_MEMORY,
96+
RPC_CMD_COUNT,
9697
};
9798

9899
// RPC data structures
@@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
330331
uint64_t remote_ptr = ctx->remote_ptr;
331332
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
332333
std::vector<uint8_t> output;
333-
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
334+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
334335
GGML_ASSERT(status);
335336
GGML_ASSERT(output.empty());
336337
delete ctx;
@@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
346347
uint64_t remote_ptr = ctx->remote_ptr;
347348
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
348349
std::vector<uint8_t> output;
349-
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
350+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
350351
GGML_ASSERT(status);
351352
GGML_ASSERT(output.size() == sizeof(uint64_t));
352353
// output serialization format: | base_ptr (8 bytes) |
@@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
405406
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
406407
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
407408
std::vector<uint8_t> output;
408-
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
409+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
409410
GGML_ASSERT(status);
410411
}
411412

@@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
419420
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
420421
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
421422
std::vector<uint8_t> output;
422-
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
423+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
423424
GGML_ASSERT(status);
424425
GGML_ASSERT(output.size() == size);
425426
// output serialization format: | data (size bytes) |
@@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
444445
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
445446
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
446447
std::vector<uint8_t> output;
447-
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
448+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
448449
GGML_ASSERT(status);
449450
// output serialization format: | result (1 byte) |
450451
GGML_ASSERT(output.size() == 1);
@@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
459460
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
460461
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
461462
std::vector<uint8_t> output;
462-
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
463+
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
463464
GGML_ASSERT(status);
464465
}
465466

@@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
488489
memcpy(input.data(), &size, sizeof(size));
489490
std::vector<uint8_t> output;
490491
auto sock = get_socket(buft_ctx->endpoint);
491-
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
492+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
492493
GGML_ASSERT(status);
493494
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
494495
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
511512
// input serialization format: | 0 bytes |
512513
std::vector<uint8_t> input;
513514
std::vector<uint8_t> output;
514-
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
515+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
515516
GGML_ASSERT(status);
516517
GGML_ASSERT(output.size() == sizeof(uint64_t));
517518
// output serialization format: | alignment (8 bytes) |
@@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
529530
// input serialization format: | 0 bytes |
530531
std::vector<uint8_t> input;
531532
std::vector<uint8_t> output;
532-
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
533+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
533534
GGML_ASSERT(status);
534535
GGML_ASSERT(output.size() == sizeof(uint64_t));
535536
// output serialization format: | max_size (8 bytes) |
@@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
622623
serialize_graph(cgraph, input);
623624
std::vector<uint8_t> output;
624625
auto sock = get_socket(rpc_ctx->endpoint);
625-
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
626+
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
626627
GGML_ASSERT(status);
627628
GGML_ASSERT(output.size() == 1);
628629
return (enum ggml_status)output[0];
@@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
719720
// input serialization format: | 0 bytes |
720721
std::vector<uint8_t> input;
721722
std::vector<uint8_t> output;
722-
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
723+
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
723724
GGML_ASSERT(status);
724725
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
725726
// output serialization format: | free (8 bytes) | total (8 bytes) |
@@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
10981099
if (!recv_data(sockfd, &cmd, 1)) {
10991100
break;
11001101
}
1102+
if (cmd >= RPC_CMD_COUNT) {
1103+
// fail fast if the command is invalid
1104+
fprintf(stderr, "Unknown command: %d\n", cmd);
1105+
break;
1106+
}
11011107
std::vector<uint8_t> input;
11021108
std::vector<uint8_t> output;
11031109
uint64_t input_size;
11041110
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
11051111
break;
11061112
}
1107-
input.resize(input_size);
1113+
try {
1114+
input.resize(input_size);
1115+
} catch (const std::bad_alloc & e) {
1116+
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
1117+
break;
1118+
}
11081119
if (!recv_data(sockfd, input.data(), input_size)) {
11091120
break;
11101121
}
11111122
bool ok = true;
11121123
switch (cmd) {
1113-
case ALLOC_BUFFER: {
1124+
case RPC_CMD_ALLOC_BUFFER: {
11141125
ok = server.alloc_buffer(input, output);
11151126
break;
11161127
}
1117-
case GET_ALIGNMENT: {
1128+
case RPC_CMD_GET_ALIGNMENT: {
11181129
server.get_alignment(output);
11191130
break;
11201131
}
1121-
case GET_MAX_SIZE: {
1132+
case RPC_CMD_GET_MAX_SIZE: {
11221133
server.get_max_size(output);
11231134
break;
11241135
}
1125-
case BUFFER_GET_BASE: {
1136+
case RPC_CMD_BUFFER_GET_BASE: {
11261137
ok = server.buffer_get_base(input, output);
11271138
break;
11281139
}
1129-
case FREE_BUFFER: {
1140+
case RPC_CMD_FREE_BUFFER: {
11301141
ok = server.free_buffer(input);
11311142
break;
11321143
}
1133-
case BUFFER_CLEAR: {
1144+
case RPC_CMD_BUFFER_CLEAR: {
11341145
ok = server.buffer_clear(input);
11351146
break;
11361147
}
1137-
case SET_TENSOR: {
1148+
case RPC_CMD_SET_TENSOR: {
11381149
ok = server.set_tensor(input);
11391150
break;
11401151
}
1141-
case GET_TENSOR: {
1152+
case RPC_CMD_GET_TENSOR: {
11421153
ok = server.get_tensor(input, output);
11431154
break;
11441155
}
1145-
case COPY_TENSOR: {
1156+
case RPC_CMD_COPY_TENSOR: {
11461157
ok = server.copy_tensor(input, output);
11471158
break;
11481159
}
1149-
case GRAPH_COMPUTE: {
1160+
case RPC_CMD_GRAPH_COMPUTE: {
11501161
ok = server.graph_compute(input, output);
11511162
break;
11521163
}
1153-
case GET_DEVICE_MEMORY: {
1164+
case RPC_CMD_GET_DEVICE_MEMORY: {
11541165
// output serialization format: | free (8 bytes) | total (8 bytes) |
11551166
output.resize(2*sizeof(uint64_t), 0);
11561167
memcpy(output.data(), &free_mem, sizeof(free_mem));
@@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
12031214
return;
12041215
}
12051216
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
1217+
fflush(stdout);
12061218
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
12071219
printf("Client connection closed\n");
1220+
fflush(stdout);
12081221
}
12091222
#ifdef _WIN32
12101223
WSACleanup();

0 commit comments

Comments
 (0)