rgerganov commited on
Commit
ff22836
·
1 Parent(s): fb0d243

rpc : add RPC_CMD_HELLO (llama/12955)

Browse files

Add RPC_CMD_HELLO for getting the version of the protocol implemend by
the server. Follow the semantic versioning rules at https://semver.org

Hopefully this bring better user experience when we make breaking
changes at the protocol level and avoid issues like #12465

ggml/include/ggml-rpc.h CHANGED
@@ -7,6 +7,9 @@
7
  extern "C" {
8
  #endif
9
 
 
 
 
10
  #define GGML_RPC_MAX_SERVERS 16
11
 
12
  // backend API
 
7
  extern "C" {
8
  #endif
9
 
10
+ #define RPC_PROTO_MAJOR_VERSION 1
11
+ #define RPC_PROTO_MINOR_VERSION 0
12
+ #define RPC_PROTO_PATCH_VERSION 0
13
  #define GGML_RPC_MAX_SERVERS 16
14
 
15
  // backend API
ggml/src/ggml-rpc/ggml-rpc.cpp CHANGED
@@ -92,12 +92,19 @@ enum rpc_cmd {
92
  RPC_CMD_GET_DEVICE_MEMORY,
93
  RPC_CMD_INIT_TENSOR,
94
  RPC_CMD_GET_ALLOC_SIZE,
 
95
  RPC_CMD_COUNT,
96
  };
97
 
98
  // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
99
  const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
100
 
 
 
 
 
 
 
101
  struct rpc_msg_get_alloc_size_req {
102
  rpc_tensor tensor;
103
  };
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
400
 
401
  // RPC client-side implementation
402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
404
  static std::mutex mutex;
405
  std::lock_guard<std::mutex> lock(mutex);
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
433
  if (sock == nullptr) {
434
  return nullptr;
435
  }
 
 
 
436
  GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
437
  sockets[endpoint] = sock;
438
  return sock;
@@ -818,6 +842,7 @@ public:
818
  }
819
  ~rpc_server();
820
 
 
821
  void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
822
  void get_alignment(rpc_msg_get_alignment_rsp & response);
823
  void get_max_size(rpc_msg_get_max_size_rsp & response);
@@ -846,6 +871,13 @@ private:
846
  std::unordered_set<ggml_backend_buffer_t> buffers;
847
  };
848
 
 
 
 
 
 
 
 
849
  bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
850
  ggml_backend_buffer_type_t buft;
851
  struct ggml_init_params params {
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
1271
  static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1272
  sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1273
  rpc_server server(backend, cache_dir);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1274
  while (true) {
1275
- uint8_t cmd;
1276
  if (!recv_data(sockfd, &cmd, 1)) {
1277
  break;
1278
  }
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1282
  break;
1283
  }
1284
  switch (cmd) {
 
 
 
 
1285
  case RPC_CMD_ALLOC_BUFFER: {
1286
  rpc_msg_alloc_buffer_req request;
1287
  if (!recv_msg(sockfd, &request, sizeof(request))) {
 
92
  RPC_CMD_GET_DEVICE_MEMORY,
93
  RPC_CMD_INIT_TENSOR,
94
  RPC_CMD_GET_ALLOC_SIZE,
95
+ RPC_CMD_HELLO,
96
  RPC_CMD_COUNT,
97
  };
98
 
99
  // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
100
  const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
101
 
102
+ struct rpc_msg_hello_rsp {
103
+ uint8_t major;
104
+ uint8_t minor;
105
+ uint8_t patch;
106
+ };
107
+
108
  struct rpc_msg_get_alloc_size_req {
109
  rpc_tensor tensor;
110
  };
 
407
 
408
  // RPC client-side implementation
409
 
410
+ static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
411
+ rpc_msg_hello_rsp response;
412
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
413
+ GGML_ASSERT(status);
414
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
415
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
416
+ return false;
417
+ }
418
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
419
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
420
+ }
421
+ return true;
422
+ }
423
+
424
  static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
425
  static std::mutex mutex;
426
  std::lock_guard<std::mutex> lock(mutex);
 
454
  if (sock == nullptr) {
455
  return nullptr;
456
  }
457
+ if (!check_server_version(sock)) {
458
+ return nullptr;
459
+ }
460
  GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
461
  sockets[endpoint] = sock;
462
  return sock;
 
842
  }
843
  ~rpc_server();
844
 
845
+ void hello(rpc_msg_hello_rsp & response);
846
  void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
847
  void get_alignment(rpc_msg_get_alignment_rsp & response);
848
  void get_max_size(rpc_msg_get_max_size_rsp & response);
 
871
  std::unordered_set<ggml_backend_buffer_t> buffers;
872
  };
873
 
874
+ void rpc_server::hello(rpc_msg_hello_rsp & response) {
875
+ response.major = RPC_PROTO_MAJOR_VERSION;
876
+ response.minor = RPC_PROTO_MINOR_VERSION;
877
+ response.patch = RPC_PROTO_PATCH_VERSION;
878
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
879
+ }
880
+
881
  bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
882
  ggml_backend_buffer_type_t buft;
883
  struct ggml_init_params params {
 
1303
  static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
1304
  sockfd_t sockfd, size_t free_mem, size_t total_mem) {
1305
  rpc_server server(backend, cache_dir);
1306
+ uint8_t cmd;
1307
+ if (!recv_data(sockfd, &cmd, 1)) {
1308
+ return;
1309
+ }
1310
+ // the first command sent by the client must be HELLO
1311
+ if (cmd != RPC_CMD_HELLO) {
1312
+ fprintf(stderr, "Expected HELLO command, update client\n");
1313
+ return;
1314
+ }
1315
+ if (!recv_msg(sockfd, nullptr, 0)) {
1316
+ return;
1317
+ }
1318
+ rpc_msg_hello_rsp response;
1319
+ server.hello(response);
1320
+ if (!send_msg(sockfd, &response, sizeof(response))) {
1321
+ return;
1322
+ }
1323
  while (true) {
 
1324
  if (!recv_data(sockfd, &cmd, 1)) {
1325
  break;
1326
  }
 
1330
  break;
1331
  }
1332
  switch (cmd) {
1333
+ case RPC_CMD_HELLO: {
1334
+ // HELLO command is handled above
1335
+ return;
1336
+ }
1337
  case RPC_CMD_ALLOC_BUFFER: {
1338
  rpc_msg_alloc_buffer_req request;
1339
  if (!recv_msg(sockfd, &request, sizeof(request))) {