@@ -58,7 +58,7 @@ struct socket_t {
58
58
};
59
59
60
60
// ggml_tensor is serialized into rpc_tensor
61
- #pragma pack(push, 1)
61
+ #pragma pack(1)
62
62
struct rpc_tensor {
63
63
uint64_t id;
64
64
uint32_t type;
@@ -76,7 +76,6 @@ struct rpc_tensor {
76
76
77
77
char padding[4 ];
78
78
};
79
- #pragma pack(pop)
80
79
81
80
static_assert (sizeof (rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
82
81
@@ -96,6 +95,77 @@ enum rpc_cmd {
96
95
RPC_CMD_COUNT,
97
96
};
98
97
98
+ #pragma pack(1)
99
+ struct rpc_msg_alloc_buffer_req {
100
+ uint64_t size;
101
+ };
102
+
103
+ #pragma pack(1)
104
+ struct rpc_msg_alloc_buffer_rsp {
105
+ uint64_t remote_ptr;
106
+ uint64_t remote_size;
107
+ };
108
+
109
+ #pragma pack(1)
110
+ struct rpc_msg_get_alignment_rsp {
111
+ uint64_t alignment;
112
+ };
113
+
114
+ #pragma pack(1)
115
+ struct rpc_msg_get_max_size_rsp {
116
+ uint64_t max_size;
117
+ };
118
+
119
+ #pragma pack(1)
120
+ struct rpc_msg_buffer_get_base_req {
121
+ uint64_t remote_ptr;
122
+ };
123
+
124
+ #pragma pack(1)
125
+ struct rpc_msg_buffer_get_base_rsp {
126
+ uint64_t base_ptr;
127
+ };
128
+
129
+ #pragma pack(1)
130
+ struct rpc_msg_free_buffer_req {
131
+ uint64_t remote_ptr;
132
+ };
133
+
134
+ #pragma pack(1)
135
+ struct rpc_msg_buffer_clear_req {
136
+ uint64_t remote_ptr;
137
+ uint8_t value;
138
+ };
139
+
140
+ #pragma pack(1)
141
+ struct rpc_msg_get_tensor_req {
142
+ rpc_tensor tensor;
143
+ uint64_t offset;
144
+ uint64_t size;
145
+ };
146
+
147
+ #pragma pack(1)
148
+ struct rpc_msg_copy_tensor_req {
149
+ rpc_tensor src;
150
+ rpc_tensor dst;
151
+ };
152
+
153
+ #pragma pack(1)
154
+ struct rpc_msg_copy_tensor_rsp {
155
+ uint8_t result;
156
+ };
157
+
158
+ #pragma pack(1)
159
+ struct rpc_msg_graph_compute_rsp {
160
+ uint8_t result;
161
+ };
162
+
163
+ #pragma pack(1)
164
+ struct rpc_msg_get_device_memory_rsp {
165
+ uint64_t free_mem;
166
+ uint64_t total_mem;
167
+ };
168
+
99
169
// RPC data structures
100
170
101
171
static ggml_guid_t ggml_backend_rpc_guid () {
@@ -252,28 +322,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
252
322
253
323
// RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
254
324
// RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
255
- static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const std::vector< uint8_t > & input, std::vector< uint8_t > & output) {
325
+ static bool send_rpc_cmd (const std::shared_ptr<socket_t > & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size ) {
256
326
uint8_t cmd_byte = cmd;
257
327
if (!send_data (sock->fd , &cmd_byte, sizeof (cmd_byte))) {
258
328
return false ;
259
329
}
260
- uint64_t input_size = input.size ();
261
330
if (!send_data (sock->fd , &input_size, sizeof (input_size))) {
262
331
return false ;
263
332
}
264
- if (!send_data (sock->fd , input. data (), input. size () )) {
333
+ if (!send_data (sock->fd , input, input_size )) {
265
334
return false ;
266
335
}
267
- uint64_t output_size;
268
- if (!recv_data (sock->fd , &output_size, sizeof (output_size))) {
336
+ // TODO: currently the output_size is always known, do we need support for commands with variable output size?
337
+ // even if we do, we can skip sending output_size from the server for commands with known output size
338
+ uint64_t out_size;
339
+ if (!recv_data (sock->fd , &out_size, sizeof (out_size))) {
269
340
return false ;
270
341
}
271
- if (output_size == 0 ) {
272
- output.clear ();
273
- return true ;
342
+ if (out_size != output_size) {
343
+ return false ;
274
344
}
275
- output.resize (output_size);
276
- if (!recv_data (sock->fd , output.data (), output_size)) {
345
+ if (!recv_data (sock->fd , output, output_size)) {
277
346
return false ;
278
347
}
279
348
return true ;
@@ -326,14 +395,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
326
395
327
396
static void ggml_backend_rpc_buffer_free_buffer (ggml_backend_buffer_t buffer) {
328
397
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
329
- // input serialization format: | remote_ptr (8 bytes) |
330
- std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
331
- uint64_t remote_ptr = ctx->remote_ptr ;
332
- memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
333
- std::vector<uint8_t > output;
334
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, input, output);
398
+ rpc_msg_free_buffer_req request = {ctx->remote_ptr };
399
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_FREE_BUFFER, &request, sizeof (request), nullptr , 0 );
335
400
GGML_ASSERT (status);
336
- GGML_ASSERT (output.empty ());
337
401
delete ctx;
338
402
}
339
403
@@ -342,20 +406,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
342
406
if (ctx->base_cache .find (buffer) != ctx->base_cache .end ()) {
343
407
return ctx->base_cache [buffer];
344
408
}
345
- // input serialization format: | remote_ptr (8 bytes) |
346
- std::vector<uint8_t > input (sizeof (uint64_t ), 0 );
347
- uint64_t remote_ptr = ctx->remote_ptr ;
348
- memcpy (input.data (), &remote_ptr, sizeof (remote_ptr));
349
- std::vector<uint8_t > output;
350
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, input, output);
409
+ rpc_msg_buffer_get_base_req request = {ctx->remote_ptr };
410
+ rpc_msg_buffer_get_base_rsp response;
411
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_GET_BASE, &request, sizeof (request), &response, sizeof (response));
351
412
GGML_ASSERT (status);
352
- GGML_ASSERT (output.size () == sizeof (uint64_t ));
353
- // output serialization format: | base_ptr (8 bytes) |
354
- uint64_t base_ptr;
355
- memcpy (&base_ptr, output.data (), sizeof (base_ptr));
356
- void * base = reinterpret_cast <void *>(base_ptr);
357
- ctx->base_cache [buffer] = base;
358
- return base;
413
+ void * base_ptr = reinterpret_cast <void *>(response.base_ptr );
414
+ ctx->base_cache [buffer] = base_ptr;
415
+ return base_ptr;
359
416
}
360
417
361
418
static rpc_tensor serialize_tensor (const ggml_tensor * tensor) {
@@ -405,26 +462,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
405
462
memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
406
463
memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
407
464
memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), data, size);
408
- std::vector<uint8_t > output;
409
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input, output);
465
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_SET_TENSOR, input.data (), input.size (), nullptr , 0 );
410
466
GGML_ASSERT (status);
411
467
}
412
468
413
469
static void ggml_backend_rpc_buffer_get_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
414
470
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
415
- // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
416
- int input_size = sizeof (rpc_tensor) + 2 *sizeof (uint64_t );
417
- std::vector<uint8_t > input (input_size, 0 );
418
- rpc_tensor rpc_tensor = serialize_tensor (tensor);
419
- memcpy (input.data (), &rpc_tensor, sizeof (rpc_tensor));
420
- memcpy (input.data () + sizeof (rpc_tensor), &offset, sizeof (offset));
421
- memcpy (input.data () + sizeof (rpc_tensor) + sizeof (offset), &size, sizeof (size));
422
- std::vector<uint8_t > output;
423
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, input, output);
471
+ rpc_msg_get_tensor_req request;
472
+ request.tensor = serialize_tensor (tensor);
473
+ request.offset = offset;
474
+ request.size = size;
475
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_GET_TENSOR, &request, sizeof (request), data, size);
424
476
GGML_ASSERT (status);
425
- GGML_ASSERT (output.size () == size);
426
- // output serialization format: | data (size bytes) |
427
- memcpy (data, output.data (), size);
428
477
}
429
478
430
479
static bool ggml_backend_rpc_buffer_cpy_tensor (ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -437,30 +486,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
437
486
return false ;
438
487
}
439
488
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
440
- // input serialization format: | rpc_tensor src | rpc_tensor dst |
441
- int input_size = 2 *sizeof (rpc_tensor);
442
- std::vector<uint8_t > input (input_size, 0 );
443
- rpc_tensor rpc_src = serialize_tensor (src);
444
- rpc_tensor rpc_dst = serialize_tensor (dst);
445
- memcpy (input.data (), &rpc_src, sizeof (rpc_src));
446
- memcpy (input.data () + sizeof (rpc_src), &rpc_dst, sizeof (rpc_dst));
447
- std::vector<uint8_t > output;
448
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, input, output);
489
+ rpc_msg_copy_tensor_req request;
490
+ request.src = serialize_tensor (src);
491
+ request.dst = serialize_tensor (dst);
492
+ rpc_msg_copy_tensor_rsp response;
493
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_COPY_TENSOR, &request, sizeof (request), &response, sizeof (response));
449
494
GGML_ASSERT (status);
450
- // output serialization format: | result (1 byte) |
451
- GGML_ASSERT (output.size () == 1 );
452
- return output[0 ];
495
+ return response.result ;
453
496
}
454
497
455
498
static void ggml_backend_rpc_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value) {
456
499
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context ;
457
- // serialization format: | bufptr (8 bytes) | value (1 byte) |
458
- int input_size = sizeof (uint64_t ) + sizeof (uint8_t );
459
- std::vector<uint8_t > input (input_size, 0 );
460
- memcpy (input.data (), &ctx->remote_ptr , sizeof (ctx->remote_ptr ));
461
- memcpy (input.data () + sizeof (ctx->remote_ptr ), &value, sizeof (value));
462
- std::vector<uint8_t > output;
463
- bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, input, output);
500
+ rpc_msg_buffer_clear_req request = {ctx->remote_ptr , value};
501
+ bool status = send_rpc_cmd (ctx->sock , RPC_CMD_BUFFER_CLEAR, &request, sizeof (request), nullptr , 0 );
464
502
GGML_ASSERT (status);
465
503
}
466
504
@@ -484,42 +522,27 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
484
522
485
523
static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size) {
486
524
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context ;
487
- // input serialization format: | size (8 bytes) |
488
- int input_size = sizeof (uint64_t );
489
- std::vector<uint8_t > input (input_size, 0 );
490
- memcpy (input.data (), &size, sizeof (size));
491
- std::vector<uint8_t > output;
525
+ rpc_msg_alloc_buffer_req request = {size};
526
+ rpc_msg_alloc_buffer_rsp response;
492
527
auto sock = get_socket (buft_ctx->endpoint );
493
- bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, input, output );
528
+ bool status = send_rpc_cmd (sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof (request), &response, sizeof (response) );
494
529
GGML_ASSERT (status);
495
- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
496
- // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
497
- uint64_t remote_ptr;
498
- memcpy (&remote_ptr, output.data (), sizeof (remote_ptr));
499
- size_t remote_size;
500
- memcpy (&remote_size, output.data () + sizeof (uint64_t ), sizeof (remote_size));
501
- if (remote_ptr != 0 ) {
530
+ if (response.remote_ptr != 0 ) {
502
531
ggml_backend_buffer_t buffer = ggml_backend_buffer_init (buft,
503
532
ggml_backend_rpc_buffer_interface,
504
- new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
505
- remote_size);
533
+ new ggml_backend_rpc_buffer_context{sock, {}, response. remote_ptr , " RPC[" + std::string (buft_ctx->endpoint ) + " ]" },
534
+ response. remote_size );
506
535
return buffer;
507
536
} else {
508
537
return nullptr ;
509
538
}
510
539
}
511
540
512
541
static size_t get_alignment (const std::shared_ptr<socket_t > & sock) {
513
- // input serialization format: | 0 bytes |
514
- std::vector<uint8_t > input;
515
- std::vector<uint8_t > output;
516
- bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, input, output);
542
+ rpc_msg_get_alignment_rsp response;
543
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_ALIGNMENT, nullptr , 0 , &response, sizeof (response));
517
544
GGML_ASSERT (status);
518
- GGML_ASSERT (output.size () == sizeof (uint64_t ));
519
- // output serialization format: | alignment (8 bytes) |
520
- uint64_t alignment;
521
- memcpy (&alignment, output.data (), sizeof (alignment));
522
- return alignment;
545
+ return response.alignment ;
523
546
}
524
547
525
548
static size_t ggml_backend_rpc_buffer_type_get_alignment (ggml_backend_buffer_type_t buft) {
@@ -528,16 +551,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
528
551
}
529
552
530
553
static size_t get_max_size (const std::shared_ptr<socket_t > & sock) {
531
- // input serialization format: | 0 bytes |
532
- std::vector<uint8_t > input;
533
- std::vector<uint8_t > output;
534
- bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, input, output);
554
+ rpc_msg_get_max_size_rsp response;
555
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_MAX_SIZE, nullptr , 0 , &response, sizeof (response));
535
556
GGML_ASSERT (status);
536
- GGML_ASSERT (output.size () == sizeof (uint64_t ));
537
- // output serialization format: | max_size (8 bytes) |
538
- uint64_t max_size;
539
- memcpy (&max_size, output.data (), sizeof (max_size));
540
- return max_size;
557
+ return response.max_size ;
541
558
}
542
559
543
560
static size_t ggml_backend_rpc_get_max_size (ggml_backend_buffer_type_t buft) {
@@ -622,12 +639,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
622
639
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
623
640
std::vector<uint8_t > input;
624
641
serialize_graph (cgraph, input);
625
- std::vector< uint8_t > output ;
642
+ rpc_msg_graph_compute_rsp response ;
626
643
auto sock = get_socket (rpc_ctx->endpoint );
627
- bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input, output );
644
+ bool status = send_rpc_cmd (sock, RPC_CMD_GRAPH_COMPUTE, input. data (), input. size (), &response, sizeof (response) );
628
645
GGML_ASSERT (status);
629
- GGML_ASSERT (output.size () == 1 );
630
- return (enum ggml_status)output[0 ];
646
+ return (enum ggml_status)response.result ;
631
647
}
632
648
633
649
static ggml_backend_i ggml_backend_rpc_interface = {
@@ -702,19 +718,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
702
718
}
703
719
704
720
static void get_device_memory (const std::shared_ptr<socket_t > & sock, size_t * free, size_t * total) {
705
- // input serialization format: | 0 bytes |
706
- std::vector<uint8_t > input;
707
- std::vector<uint8_t > output;
708
- bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
721
+ rpc_msg_get_device_memory_rsp response;
722
+ bool status = send_rpc_cmd (sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr , 0 , &response, sizeof (response));
709
723
GGML_ASSERT (status);
710
- GGML_ASSERT (output.size () == 2 *sizeof (uint64_t ));
711
- // output serialization format: | free (8 bytes) | total (8 bytes) |
712
- uint64_t free_mem;
713
- memcpy (&free_mem, output.data (), sizeof (free_mem));
714
- uint64_t total_mem;
715
- memcpy (&total_mem, output.data () + sizeof (uint64_t ), sizeof (total_mem));
716
- *free = free_mem;
717
- *total = total_mem;
724
+ *free = response.free_mem ;
725
+ *total = response.total_mem ;
718
726
}
719
727
720
728
GGML_API void ggml_backend_rpc_get_device_memory (const char * endpoint, size_t * free, size_t * total) {
0 commit comments