Skip to content

Commit 4631edc

Browse files
committed
rpc : refactor backend
Use structs for RPC request/response messages
1 parent becfd38 commit 4631edc

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

ggml/src/ggml-rpc.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ struct socket_t {
5858
};
5959

6060
// ggml_tensor is serialized into rpc_tensor
61-
#pragma pack(push, 1)
61+
#pragma pack(1)
6262
struct rpc_tensor {
6363
uint64_t id;
6464
uint32_t type;
@@ -96,6 +96,17 @@ enum rpc_cmd {
9696
RPC_CMD_COUNT,
9797
};
9898

99+
#pragma pack(1)
100+
struct request_alloc_buffer {
101+
uint64_t size;
102+
};
103+
104+
#pragma pack(1)
105+
struct response_alloc_buffer {
106+
uint64_t remote_ptr;
107+
uint64_t remote_size;
108+
};
109+
99110
// RPC data structures
100111

101112
static ggml_guid_t ggml_backend_rpc_guid() {
@@ -252,30 +263,31 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252263

253264
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254265
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255-
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
266+
static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
256267
uint8_t cmd_byte = cmd;
257268
if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
258269
return false;
259270
}
260-
uint64_t input_size = input.size();
261271
if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
262272
return false;
263273
}
264-
if (!send_data(sock->fd, input.data(), input.size())) {
274+
if (!send_data(sock->fd, input, input_size)) {
265275
return false;
266276
}
267-
uint64_t output_size;
268-
if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
277+
// TODO: currently the output_size is always known, do we need support for commands with variable output size?
278+
// even if we do, we can skip sending output_size from the server for commands with known output size
279+
uint64_t out_size;
280+
if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
269281
return false;
270282
}
271-
if (output_size == 0) {
272-
output.clear();
273-
return true;
274-
}
275-
output.resize(output_size);
276-
if (!recv_data(sock->fd, output.data(), output_size)) {
283+
if (out_size != output_size) {
277284
return false;
278285
}
286+
if (output_size > 0) {
287+
if (!recv_data(sock->fd, output, output_size)) {
288+
return false;
289+
}
290+
}
279291
return true;
280292
}
281293

@@ -484,25 +496,15 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484496

485497
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
486498
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
487-
// input serialization format: | size (8 bytes) |
488-
int input_size = sizeof(uint64_t);
489-
std::vector<uint8_t> input(input_size, 0);
490-
memcpy(input.data(), &size, sizeof(size));
491-
std::vector<uint8_t> output;
499+
request_alloc_buffer request = {size};
500+
response_alloc_buffer response;
492501
auto sock = get_socket(buft_ctx->endpoint);
493-
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
494-
GGML_ASSERT(status);
495-
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
496-
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497-
uint64_t remote_ptr;
498-
memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
499-
size_t remote_size;
500-
memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
501-
if (remote_ptr != 0) {
502+
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
503+
if (response.remote_ptr != 0) {
502504
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
503505
ggml_backend_rpc_buffer_interface,
504-
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
505-
remote_size);
506+
new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
507+
response.remote_size);
506508
return buffer;
507509
} else {
508510
return nullptr;

0 commit comments

Comments
 (0)