Spaces:
Running
Running
rpc : add RPC_CMD_HELLO (llama/12955)
Browse filesAdd RPC_CMD_HELLO for getting the version of the protocol implemend by
the server. Follow the semantic versioning rules at https://semver.org
Hopefully this bring better user experience when we make breaking
changes at the protocol level and avoid issues like #12465
- ggml/include/ggml-rpc.h +3 -0
- ggml/src/ggml-rpc/ggml-rpc.cpp +53 -1
ggml/include/ggml-rpc.h
CHANGED
|
@@ -7,6 +7,9 @@
|
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
#define GGML_RPC_MAX_SERVERS 16
|
| 11 |
|
| 12 |
// backend API
|
|
|
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
| 9 |
|
| 10 |
+
#define RPC_PROTO_MAJOR_VERSION 1
|
| 11 |
+
#define RPC_PROTO_MINOR_VERSION 0
|
| 12 |
+
#define RPC_PROTO_PATCH_VERSION 0
|
| 13 |
#define GGML_RPC_MAX_SERVERS 16
|
| 14 |
|
| 15 |
// backend API
|
ggml/src/ggml-rpc/ggml-rpc.cpp
CHANGED
|
@@ -92,12 +92,19 @@ enum rpc_cmd {
|
|
| 92 |
RPC_CMD_GET_DEVICE_MEMORY,
|
| 93 |
RPC_CMD_INIT_TENSOR,
|
| 94 |
RPC_CMD_GET_ALLOC_SIZE,
|
|
|
|
| 95 |
RPC_CMD_COUNT,
|
| 96 |
};
|
| 97 |
|
| 98 |
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
| 99 |
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
struct rpc_msg_get_alloc_size_req {
|
| 102 |
rpc_tensor tensor;
|
| 103 |
};
|
|
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
| 400 |
|
| 401 |
// RPC client-side implementation
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
| 404 |
static std::mutex mutex;
|
| 405 |
std::lock_guard<std::mutex> lock(mutex);
|
|
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
| 433 |
if (sock == nullptr) {
|
| 434 |
return nullptr;
|
| 435 |
}
|
|
|
|
|
|
|
|
|
|
| 436 |
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
| 437 |
sockets[endpoint] = sock;
|
| 438 |
return sock;
|
|
@@ -818,6 +842,7 @@ public:
|
|
| 818 |
}
|
| 819 |
~rpc_server();
|
| 820 |
|
|
|
|
| 821 |
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
| 822 |
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
| 823 |
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
@@ -846,6 +871,13 @@ private:
|
|
| 846 |
std::unordered_set<ggml_backend_buffer_t> buffers;
|
| 847 |
};
|
| 848 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 849 |
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
| 850 |
ggml_backend_buffer_type_t buft;
|
| 851 |
struct ggml_init_params params {
|
|
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
|
|
| 1271 |
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
| 1272 |
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
| 1273 |
rpc_server server(backend, cache_dir);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1274 |
while (true) {
|
| 1275 |
-
uint8_t cmd;
|
| 1276 |
if (!recv_data(sockfd, &cmd, 1)) {
|
| 1277 |
break;
|
| 1278 |
}
|
|
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
| 1282 |
break;
|
| 1283 |
}
|
| 1284 |
switch (cmd) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1285 |
case RPC_CMD_ALLOC_BUFFER: {
|
| 1286 |
rpc_msg_alloc_buffer_req request;
|
| 1287 |
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
|
|
|
| 92 |
RPC_CMD_GET_DEVICE_MEMORY,
|
| 93 |
RPC_CMD_INIT_TENSOR,
|
| 94 |
RPC_CMD_GET_ALLOC_SIZE,
|
| 95 |
+
RPC_CMD_HELLO,
|
| 96 |
RPC_CMD_COUNT,
|
| 97 |
};
|
| 98 |
|
| 99 |
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
| 100 |
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
| 101 |
|
| 102 |
+
struct rpc_msg_hello_rsp {
|
| 103 |
+
uint8_t major;
|
| 104 |
+
uint8_t minor;
|
| 105 |
+
uint8_t patch;
|
| 106 |
+
};
|
| 107 |
+
|
| 108 |
struct rpc_msg_get_alloc_size_req {
|
| 109 |
rpc_tensor tensor;
|
| 110 |
};
|
|
|
|
| 407 |
|
| 408 |
// RPC client-side implementation
|
| 409 |
|
| 410 |
+
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
| 411 |
+
rpc_msg_hello_rsp response;
|
| 412 |
+
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
| 413 |
+
GGML_ASSERT(status);
|
| 414 |
+
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
| 415 |
+
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
| 416 |
+
return false;
|
| 417 |
+
}
|
| 418 |
+
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
| 419 |
+
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
| 420 |
+
}
|
| 421 |
+
return true;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
| 425 |
static std::mutex mutex;
|
| 426 |
std::lock_guard<std::mutex> lock(mutex);
|
|
|
|
| 454 |
if (sock == nullptr) {
|
| 455 |
return nullptr;
|
| 456 |
}
|
| 457 |
+
if (!check_server_version(sock)) {
|
| 458 |
+
return nullptr;
|
| 459 |
+
}
|
| 460 |
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
| 461 |
sockets[endpoint] = sock;
|
| 462 |
return sock;
|
|
|
|
| 842 |
}
|
| 843 |
~rpc_server();
|
| 844 |
|
| 845 |
+
void hello(rpc_msg_hello_rsp & response);
|
| 846 |
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
| 847 |
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
| 848 |
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
|
|
| 871 |
std::unordered_set<ggml_backend_buffer_t> buffers;
|
| 872 |
};
|
| 873 |
|
| 874 |
+
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
| 875 |
+
response.major = RPC_PROTO_MAJOR_VERSION;
|
| 876 |
+
response.minor = RPC_PROTO_MINOR_VERSION;
|
| 877 |
+
response.patch = RPC_PROTO_PATCH_VERSION;
|
| 878 |
+
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
| 882 |
ggml_backend_buffer_type_t buft;
|
| 883 |
struct ggml_init_params params {
|
|
|
|
| 1303 |
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
| 1304 |
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
| 1305 |
rpc_server server(backend, cache_dir);
|
| 1306 |
+
uint8_t cmd;
|
| 1307 |
+
if (!recv_data(sockfd, &cmd, 1)) {
|
| 1308 |
+
return;
|
| 1309 |
+
}
|
| 1310 |
+
// the first command sent by the client must be HELLO
|
| 1311 |
+
if (cmd != RPC_CMD_HELLO) {
|
| 1312 |
+
fprintf(stderr, "Expected HELLO command, update client\n");
|
| 1313 |
+
return;
|
| 1314 |
+
}
|
| 1315 |
+
if (!recv_msg(sockfd, nullptr, 0)) {
|
| 1316 |
+
return;
|
| 1317 |
+
}
|
| 1318 |
+
rpc_msg_hello_rsp response;
|
| 1319 |
+
server.hello(response);
|
| 1320 |
+
if (!send_msg(sockfd, &response, sizeof(response))) {
|
| 1321 |
+
return;
|
| 1322 |
+
}
|
| 1323 |
while (true) {
|
|
|
|
| 1324 |
if (!recv_data(sockfd, &cmd, 1)) {
|
| 1325 |
break;
|
| 1326 |
}
|
|
|
|
| 1330 |
break;
|
| 1331 |
}
|
| 1332 |
switch (cmd) {
|
| 1333 |
+
case RPC_CMD_HELLO: {
|
| 1334 |
+
// HELLO command is handled above
|
| 1335 |
+
return;
|
| 1336 |
+
}
|
| 1337 |
case RPC_CMD_ALLOC_BUFFER: {
|
| 1338 |
rpc_msg_alloc_buffer_req request;
|
| 1339 |
if (!recv_msg(sockfd, &request, sizeof(request))) {
|