@@ -58,7 +58,7 @@ struct socket_t {
58
58
};
59
59
60
60
// ggml_tensor is serialized into rpc_tensor
61
- #pragma pack(push, 1)
61
+ #pragma pack(1)
62
62
struct rpc_tensor {
63
63
uint64_t id;
64
64
uint32_t type;
@@ -96,6 +96,17 @@ enum rpc_cmd {
96
96
RPC_CMD_COUNT,
97
97
};
98
98
99
+ #pragma pack(1)
100
+ struct request_alloc_buffer {
101
+ uint64_t size;
102
+ };
103
+
104
+ #pragma pack(1)
105
+ struct response_alloc_buffer {
106
+ uint64_t remote_ptr;
107
+ uint64_t remote_size;
108
+ };
109
+
99
110
// RPC data structures
100
111
101
112
static ggml_guid_t ggml_backend_rpc_guid () {
@@ -252,30 +263,31 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252
263
253
264
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254
265
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255
- static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const std::vector< uint8_t > & input, std::vector< uint8_t > & output) {
266
+ static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size ) {
256
267
uint8_t cmd_byte = cmd;
257
268
if (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
258
269
return false ;
259
270
}
260
- uint64_t input_size = input.size ();
261
271
if (!send_data (sock->fd , &input_size, sizeof (input_size))) {
262
272
return false ;
263
273
}
264
- if (!send_data (sock->fd , input. data (), input. size () )) {
274
+ if (!send_data (sock->fd , input, input_size )) {
265
275
return false ;
266
276
}
267
- uint64_t output_size;
268
- if (!recv_data (sock->fd , &output_size, sizeof (output_size))) {
277
+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
278
+ // even if we do, we can skip sending output_size from the server for commands with known output size
279
+ uint64_t out_size;
280
+ if (!recv_data (sock->fd , &out_size, sizeof (out_size))) {
269
281
return false ;
270
282
}
271
- if (output_size == 0 ) {
272
- output.clear ();
273
- return true ;
274
- }
275
- output.resize (output_size);
276
- if (!recv_data (sock->fd , output.data (), output_size)) {
283
+ if (out_size != output_size) {
277
284
return false ;
278
285
}
286
+ if (output_size > 0 ) {
287
+ if (!recv_data (sock->fd , output, output_size)) {
288
+ return false ;
289
+ }
290
+ }
279
291
return true ;
280
292
}
281
293
@@ -484,25 +496,15 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484
496
485
497
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
486
498
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
487
- // input serialization format: | size (8 bytes) |
488
- int input_size = sizeof (uint64_t );
489
- std::vector<uint8_t > input (input_size, 0 );
490
- memcpy (input.data (), &size, sizeof (size));
491
- std::vector<uint8_t > output;
499
+ request_alloc_buffer request = {size};
500
+ response_alloc_buffer response;
492
501
auto sock = get_socket (buft_ctx->endpoint );
493
- bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, input, output);
494
- GGML_ASSERT (status);
495
- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
496
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497
- uint64_t remote_ptr;
498
- memcpy (&remote_ptr, output.data (), sizeof (remote_ptr));
499
- size_t remote_size;
500
- memcpy (&remote_size, output.data () + sizeof (uint64_t ), sizeof (remote_size));
501
- if (remote_ptr != 0 ) {
502
+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response));
503
+ if (response.remote_ptr != 0 ) {
502
504
ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
503
505
ggml_backend_rpc_buffer_interface,
504
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
505
- remote_size);
506
+ new ggml_backend_rpc_buffer_context{sock, {}, response. remote_ptr , " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
507
+ response. remote_size );
506
508
return buffer;
507
509
} else {
508
510
return nullptr ;
0 commit comments