1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file rpc_session.cc
22 * \brief RPC session for remote function call.
23 */
24 #include <tvm/runtime/packed_func.h>
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/runtime/registry.h>
27 #include <tvm/runtime/serializer.h>
28 #include <memory>
29 #include <array>
30 #include <string>
31 #include <chrono>
32 #include <vector>
33 #include <utility>
34 #include <cmath>
35 #include <algorithm>
36 #include "rpc_session.h"
37 #include "../object_internal.h"
38 #include "../../common/ring_buffer.h"
39 #include "../../common/socket.h"
40
41 namespace tvm {
42 namespace runtime {
43 // Temp buffer for data array
44 struct RPCByteArrayBuffer {
45 TVMByteArray arr;
46 std::string data;
47 };
48 // Temp buffer for data array
49 struct RPCDataArrayBuffer {
50 DLTensor tensor;
51 std::vector<int64_t> shape;
52 };
53 /*!
54 * \brief Temporal argument buffer.
55 */
56 struct RPCArgBuffer {
57 // The argument values
58 std::vector<TVMValue> value;
59 // The type codes.
60 std::vector<int> tcode;
61 // Temporal resources.
62 std::vector<std::unique_ptr<RPCByteArrayBuffer> > temp_bytes;
63 // Temporal array
64 std::vector<std::unique_ptr<RPCDataArrayBuffer> > temp_array;
65 // convert buffer as TVMArgs
AsTVMArgstvm::runtime::RPCArgBuffer66 TVMArgs AsTVMArgs() const {
67 return TVMArgs(value.data(), tcode.data(), static_cast<int>(value.size()));
68 }
69 };
70
71 // Event handler for RPC events.
72 class RPCSession::EventHandler : public dmlc::Stream {
73 public:
EventHandler(common::RingBuffer * reader,common::RingBuffer * writer,int rpc_sess_table_index,std::string name,std::string * remote_key)74 EventHandler(common::RingBuffer* reader,
75 common::RingBuffer* writer,
76 int rpc_sess_table_index,
77 std::string name,
78 std::string* remote_key)
79 : reader_(reader),
80 writer_(writer),
81 rpc_sess_table_index_(rpc_sess_table_index),
82 name_(name),
83 remote_key_(remote_key) {
84 this->Clear();
85 if (*remote_key == "%toinit") {
86 state_ = kInitHeader;
87 remote_key_->resize(0);
88 pending_request_bytes_ = sizeof(int32_t);
89 }
90 }
91 // Bytes needed to fulfill current request
BytesNeeded()92 size_t BytesNeeded() {
93 if (reader_->bytes_available() < pending_request_bytes_) {
94 return pending_request_bytes_ - reader_->bytes_available();
95 } else {
96 return 0;
97 }
98 }
99 // Request number of bytes from reader.
RequestBytes(size_t nbytes)100 void RequestBytes(size_t nbytes) {
101 pending_request_bytes_ += nbytes;
102 reader_->Reserve(pending_request_bytes_);
103 }
104 // Whether we are ready to handle next request.
Ready()105 bool Ready() {
106 return reader_->bytes_available() >= pending_request_bytes_;
107 }
CanCleanShutdown() const108 bool CanCleanShutdown() const {
109 return state_ == kRecvCode;
110 }
FinishCopyAck()111 void FinishCopyAck() {
112 this->SwitchToState(kRecvCode);
113 }
HandleNextEvent(TVMRetValue * rv,bool client_mode,const PackedFunc * fwrap)114 RPCCode HandleNextEvent(TVMRetValue* rv,
115 bool client_mode,
116 const PackedFunc* fwrap) {
117 std::swap(client_mode_, client_mode);
118 while (this->Ready()) {
119 switch (state_) {
120 case kInitHeader: HandleInitHeader(); break;
121 case kRecvCode: HandleRecvCode(); break;
122 case kRecvCallHandle: {
123 CHECK(this->Read(&call_handle_));
124 this->SwitchToState(kRecvPackedSeqNumArgs);
125 break;
126 }
127 case kRecvPackedSeqNumArgs: {
128 CHECK(this->Read(&num_packed_args_));
129 arg_buf_.reset(new RPCArgBuffer());
130 arg_buf_->value.resize(num_packed_args_);
131 arg_buf_->tcode.resize(num_packed_args_);
132 this->SwitchToState(kRecvPackedSeqTypeCode);
133 break;
134 }
135 case kRecvPackedSeqTypeCode: {
136 if (num_packed_args_ != 0) {
137 this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
138 }
139 arg_index_ = 0;
140 arg_recv_stage_ = 0;
141 this->SwitchToState(kRecvPackedSeqArg);
142 break;
143 }
144 case kRecvPackedSeqArg: {
145 this->HandleRecvPackedSeqArg();
146 break;
147 }
148 case kDoCopyFromRemote: {
149 this->HandleCopyFromRemote();
150 break;
151 }
152 case kDoCopyToRemote: {
153 this->HandleCopyToRemote();
154 break;
155 }
156 case kReturnReceived: {
157 CHECK_GE(arg_buf_->value.size(), 1U);
158
159 TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
160 if (argv.type_code() == kFuncHandle ||
161 argv.type_code() == kModuleHandle ||
162 argv.type_code() == kArrayHandle) {
163 CHECK(fwrap != nullptr) << "function/module wrapper not available";
164 fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
165 } else {
166 CHECK_EQ(arg_buf_->value.size(), 1U);
167 *rv = argv;
168 }
169 arg_buf_.reset();
170 this->SwitchToState(kRecvCode);
171 std::swap(client_mode_, client_mode);
172 return RPCCode::kReturn;
173 }
174 case kCopyAckReceived: {
175 std::swap(client_mode_, client_mode);
176 return RPCCode::kCopyAck;
177 }
178 case kShutdownReceived: {
179 std::swap(client_mode_, client_mode);
180 return RPCCode::kShutdown;
181 }
182 }
183 }
184 std::swap(client_mode_, client_mode);
185 return RPCCode::kNone;
186 }
187 // Reset and clear all states.
Clear()188 void Clear() {
189 state_ = kRecvCode;
190 pending_request_bytes_ = sizeof(RPCCode);
191 arg_recv_stage_ = 0;
192 arg_buf_.reset();
193 }
194 // strip session on mask
StripSessMask(TVMContext ctx)195 TVMContext StripSessMask(TVMContext ctx) {
196 int dev_type = ctx.device_type;
197 CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
198 << "Can not pass in local context or context with a different remote session";
199 ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
200 return ctx;
201 }
202 // Send Packed sequence to writer.
203 // return_ndarray is a special flag to handle returning of ndarray
204 // In this case, we return the shape, context and data of the array,
205 // as well as a customized PackedFunc that handles deletion of
206 // the array in the remote.
SendPackedSeq(const TVMValue * arg_values,const int * type_codes,int n,bool return_ndarray=false)207 void SendPackedSeq(const TVMValue* arg_values,
208 const int* type_codes,
209 int n,
210 bool return_ndarray = false) {
211 this->Write(n);
212 for (int i = 0; i < n; ++i) {
213 int tcode = type_codes[i];
214 if (tcode == kNDArrayContainer) tcode = kArrayHandle;
215 this->Write(tcode);
216 }
217
218 // Argument packing.
219 for (int i = 0; i < n; ++i) {
220 int tcode = type_codes[i];
221 TVMValue value = arg_values[i];
222 switch (tcode) {
223 case kDLInt:
224 case kDLUInt:
225 case kDLFloat: {
226 this->Write<int64_t>(value.v_int64);
227 break;
228 }
229 case kTVMType: {
230 this->Write(value.v_type);
231 // padding
232 int32_t padding = 0;
233 this->Write<int32_t>(padding);
234 break;
235 }
236 case kTVMContext: {
237 value.v_ctx = StripSessMask(value.v_ctx);
238 this->Write(value.v_ctx);
239 break;
240 }
241 case kFuncHandle:
242 case kModuleHandle:
243 case kHandle: {
244 // always send handle in 64 bit.
245 uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
246 this->Write(handle);
247 break;
248 }
249 case kNDArrayContainer:
250 case kArrayHandle: {
251 DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
252 TVMContext ctx;
253 uint64_t data;
254 if (!return_ndarray) {
255 // in the client mode
256 // ctx contains the remote table index
257 // the space is wrapped by an RemoteSpace
258 // that holds reference to the session.
259 ctx = StripSessMask(arr->ctx);
260 data = reinterpret_cast<uint64_t>(
261 static_cast<RemoteSpace*>(arr->data)->data);
262 } else {
263 // When we return NDArray, we directly return
264 // the space and the context
265 // The client will be further wrapping
266 ctx = arr->ctx;
267 data = reinterpret_cast<uint64_t>(arr->data);
268 }
269 this->Write(data);
270 this->Write(ctx);
271 this->Write(arr->ndim);
272 this->Write(arr->dtype);
273 this->WriteArray(arr->shape, arr->ndim);
274 CHECK(arr->strides == nullptr)
275 << "Do not support strided remote array";
276 CHECK_EQ(arr->byte_offset, 0)
277 << "Do not support send byte offset";
278 break;
279 }
280 case kNull: break;
281 case kStr: {
282 const char* s = value.v_str;
283 uint64_t len = strlen(s);
284 this->Write(len);
285 this->WriteArray(s, len);
286 break;
287 }
288 case kBytes: {
289 TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
290 uint64_t len = bytes->size;
291 this->Write(len);
292 this->WriteArray(bytes->data, len);
293 break;
294 }
295 default: {
296 LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
297 break;
298 }
299 }
300 }
301 }
302
303 // Endian aware IO handling
304 using Stream::Read;
305 using Stream::Write;
306 using Stream::ReadArray;
307 using Stream::WriteArray;
308
Read(RPCCode * code)309 inline bool Read(RPCCode* code) {
310 int cdata;
311 if (!this->Read(&cdata)) return false;
312 *code = static_cast<RPCCode>(cdata);
313 return true;
314 }
Write(RPCCode code)315 inline void Write(RPCCode code) {
316 int cdata = static_cast<int>(code);
317 this->Write(cdata);
318 }
319
320 protected:
321 enum State {
322 kInitHeader,
323 kRecvCode,
324 kRecvCallHandle,
325 kRecvPackedSeqNumArgs,
326 kRecvPackedSeqTypeCode,
327 kRecvPackedSeqArg,
328 kDoCopyFromRemote,
329 kDoCopyToRemote,
330 kReturnReceived,
331 kCopyAckReceived,
332 kShutdownReceived
333 };
334 // Current state;
335 State state_;
336 // The RPCCode to be read.
337 RPCCode code_;
338 // Handle for the remote function call.
339 uint64_t call_handle_;
340 // Initialize remote header
341 bool init_header_step_{0};
342 // Number of packed arguments.
343 int num_packed_args_;
344 // Current argument index.
345 int arg_index_;
346 // The stage of each argument receiver.
347 int arg_recv_stage_;
348 // Whether current handler is client or server mode.
349 bool client_mode_{false};
350 // Argument buffer
351 std::unique_ptr<RPCArgBuffer> arg_buf_;
352 // Temp byte buffer.
353 std::unique_ptr<RPCByteArrayBuffer> temp_bytes_;
354 // Temp array buffer.
355 std::unique_ptr<RPCDataArrayBuffer> temp_array_;
356 // Internal temporal data space.
357 std::string temp_data_;
358 // Temp variables for copy request state.
359 TVMContext copy_ctx_;
360 TVMType copy_dtype_;
361 uint64_t copy_handle_, copy_offset_, copy_size_;
362 // State switcher
SwitchToState(State state)363 void SwitchToState(State state) {
364 // invariant
365 CHECK_EQ(pending_request_bytes_, 0U)
366 << "state=" << state;
367 state_ = state;
368 switch (state) {
369 case kInitHeader: {
370 LOG(FATAL) << "cannot switch to init header";
371 break;
372 }
373 case kRecvCode: {
374 this->RequestBytes(sizeof(RPCCode));
375 break;
376 }
377 case kRecvCallHandle: {
378 this->RequestBytes(sizeof(call_handle_));
379 break;
380 }
381 case kRecvPackedSeqNumArgs: {
382 this->RequestBytes(sizeof(num_packed_args_));
383 break;
384 }
385 case kRecvPackedSeqTypeCode: {
386 this->RequestBytes(sizeof(int) * num_packed_args_);
387 break;
388 }
389 case kRecvPackedSeqArg: {
390 CHECK_LE(arg_index_, num_packed_args_);
391 if (arg_index_ == num_packed_args_) {
392 // The function can change state_ again.
393 HandlePackedCall();
394 } else {
395 RequestRecvPackedSeqArg();
396 }
397 break;
398 }
399 case kDoCopyFromRemote: {
400 this->RequestBytes(sizeof(uint64_t) * 3);
401 this->RequestBytes(sizeof(TVMContext));
402 this->RequestBytes(sizeof(TVMType));
403 break;
404 }
405 case kDoCopyToRemote: {
406 this->RequestBytes(sizeof(uint64_t) * 3);
407 this->RequestBytes(sizeof(TVMContext));
408 this->RequestBytes(sizeof(TVMType));
409 break;
410 }
411 case kCopyAckReceived:
412 case kReturnReceived:
413 case kShutdownReceived: {
414 break;
415 }
416 }
417 }
418 // Requets bytes needed for next computation.
RequestRecvPackedSeqArg()419 void RequestRecvPackedSeqArg() {
420 CHECK_EQ(arg_recv_stage_, 0);
421 int tcode = arg_buf_->tcode[arg_index_];
422 static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
423 switch (tcode) {
424 case kDLInt:
425 case kDLUInt:
426 case kDLFloat:
427 case kTVMType:
428 case kHandle:
429 case kStr:
430 case kBytes:
431 case kTVMContext: {
432 this->RequestBytes(sizeof(TVMValue)); break;
433 }
434 case kFuncHandle:
435 case kModuleHandle: {
436 CHECK(client_mode_)
437 << "Only client can receive remote functions";
438 this->RequestBytes(sizeof(TVMValue)); break;
439 }
440 case kNull: break;
441 case kArrayHandle: {
442 this->RequestBytes(sizeof(uint64_t));
443 this->RequestBytes(sizeof(TVMContext));
444 this->RequestBytes(sizeof(int));
445 this->RequestBytes(sizeof(DLDataType));
446 break;
447 }
448 default: {
449 LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
450 break;
451 }
452 }
453 }
454 // Handler for packed sequence argument receive.
HandleRecvPackedSeqArg()455 void HandleRecvPackedSeqArg() {
456 CHECK_LT(arg_index_, num_packed_args_);
457 int tcode = arg_buf_->tcode[arg_index_];
458 TVMValue& value = arg_buf_->value[arg_index_];
459 if (arg_recv_stage_ == 0) {
460 switch (tcode) {
461 case kDLInt:
462 case kDLUInt:
463 case kDLFloat: {
464 this->Read<int64_t>(&(value.v_int64));
465 ++arg_index_;
466 this->SwitchToState(kRecvPackedSeqArg);
467 break;
468 }
469 case kTVMType: {
470 this->Read(&(value.v_type));
471 int32_t padding = 0;
472 this->Read<int32_t>(&padding);
473 ++arg_index_;
474 this->SwitchToState(kRecvPackedSeqArg);
475 break;
476 }
477 case kTVMContext: {
478 this->Read(&(value.v_ctx));
479 ++arg_index_;
480 this->SwitchToState(kRecvPackedSeqArg);
481 break;
482 }
483 case kFuncHandle:
484 case kModuleHandle:
485 case kHandle: {
486 // always send handle in 64 bit.
487 uint64_t handle;
488 this->Read(&handle);
489 value.v_handle = reinterpret_cast<void*>(handle);
490 ++arg_index_;
491 this->SwitchToState(kRecvPackedSeqArg);
492 break;
493 }
494 case kNull: {
495 value.v_handle = nullptr;
496 ++arg_index_;
497 this->SwitchToState(kRecvPackedSeqArg);
498 break;
499 }
500 case kStr:
501 case kBytes: {
502 uint64_t len;
503 this->Read(&len);
504 temp_bytes_.reset( new RPCByteArrayBuffer());
505 temp_bytes_->data.resize(len);
506 arg_recv_stage_ = 1;
507 this->RequestBytes(len);
508 break;
509 }
510 case kArrayHandle: {
511 temp_array_.reset(new RPCDataArrayBuffer());
512 uint64_t handle;
513 this->Read(&handle);
514 DLTensor& tensor = temp_array_->tensor;
515 tensor.data = reinterpret_cast<void*>(handle);
516 this->Read(&(tensor.ctx));
517 this->Read(&(tensor.ndim));
518 this->Read(&(tensor.dtype));
519 temp_array_->shape.resize(tensor.ndim);
520 tensor.shape = temp_array_->shape.data();
521 arg_recv_stage_ = 1;
522 tensor.strides = nullptr;
523 tensor.byte_offset = 0;
524 this->RequestBytes(sizeof(int64_t) * tensor.ndim);
525 break;
526 }
527 default: {
528 LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
529 break;
530 }
531 }
532 } else {
533 CHECK_EQ(arg_recv_stage_, 1);
534 if (tcode == kStr || tcode == kBytes) {
535 if (temp_bytes_->data.size() != 0) {
536 this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
537 }
538 if (tcode == kStr) {
539 value.v_str = temp_bytes_->data.c_str();
540 } else {
541 temp_bytes_->arr.size = static_cast<size_t>(temp_bytes_->data.size());
542 temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data);
543 value.v_handle = &(temp_bytes_->arr);
544 }
545 arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_));
546 } else {
547 CHECK_EQ(tcode, kArrayHandle);
548 DLTensor& tensor = temp_array_->tensor;
549 this->ReadArray(tensor.shape, tensor.ndim);
550 value.v_handle = &tensor;
551 arg_buf_->temp_array.emplace_back(std::move(temp_array_));
552 }
553 ++arg_index_;
554 arg_recv_stage_ = 0;
555 this->SwitchToState(kRecvPackedSeqArg);
556 }
557 }
558 // handler for initial header read
HandleInitHeader()559 void HandleInitHeader() {
560 if (init_header_step_ == 0) {
561 int32_t len;
562 this->Read(&len);
563 remote_key_->resize(len);
564 init_header_step_ = 1;
565 this->RequestBytes(len);
566 return;
567 } else {
568 CHECK_EQ(init_header_step_, 1);
569 this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
570 this->SwitchToState(kRecvCode);
571 }
572 }
573 // Handler for read code.
HandleRecvCode()574 void HandleRecvCode() {
575 this->Read(&code_);
576 if (code_ > RPCCode::kSystemFuncStart) {
577 SwitchToState(kRecvPackedSeqNumArgs);
578 return;
579 }
580 // invariant.
581 CHECK_EQ(arg_recv_stage_, 0);
582 switch (code_) {
583 case RPCCode::kCallFunc: {
584 SwitchToState(kRecvCallHandle);
585 break;
586 }
587 case RPCCode::kException:
588 case RPCCode::kReturn: {
589 SwitchToState(kRecvPackedSeqNumArgs);
590 break;
591 }
592 case RPCCode::kCopyFromRemote: {
593 SwitchToState(kDoCopyFromRemote);
594 break;
595 }
596 case RPCCode::kCopyToRemote: {
597 SwitchToState(kDoCopyToRemote);
598 break;
599 }
600 case RPCCode::kShutdown: {
601 SwitchToState(kShutdownReceived);
602 break;
603 }
604 case RPCCode::kCopyAck: {
605 SwitchToState(kCopyAckReceived);
606 break;
607 }
608 default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
609 }
610 }
611
HandleCopyFromRemote()612 void HandleCopyFromRemote() {
613 uint64_t handle, offset, num_bytes;
614 TVMContext ctx;
615 TVMType type_hint;
616 this->Read(&handle);
617 this->Read(&offset);
618 this->Read(&num_bytes);
619 this->Read(&ctx);
620 this->Read(&type_hint);
621 size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
622
623 if (ctx.device_type == kDLCPU) {
624 RPCCode code = RPCCode::kCopyAck;
625 this->Write(code);
626 char* dptr = reinterpret_cast<char*>(handle) + offset;
627 if (!DMLC_IO_NO_ENDIAN_SWAP) {
628 temp_data_.resize(0);
629 temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes);
630 dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
631 this->WriteArray(temp_data_.data(), num_bytes);
632 } else {
633 this->WriteArray(dptr, num_bytes);
634 }
635 } else {
636 temp_data_.resize(num_bytes + 1);
637 try {
638 TVMContext cpu_ctx;
639 cpu_ctx.device_type = kDLCPU;
640 cpu_ctx.device_id = 0;
641 DeviceAPI::Get(ctx)->CopyDataFromTo(
642 reinterpret_cast<void*>(handle), offset,
643 dmlc::BeginPtr(temp_data_), 0,
644 num_bytes, ctx, cpu_ctx, type_hint, nullptr);
645 RPCCode code = RPCCode::kCopyAck;
646 this->Write(code);
647 if (!DMLC_IO_NO_ENDIAN_SWAP) {
648 dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
649 }
650 this->WriteArray(&temp_data_[0], num_bytes);
651 } catch (const std::runtime_error &e) {
652 RPCCode code = RPCCode::kException;
653 this->Write(code);
654 TVMValue ret_value;
655 ret_value.v_str = e.what();
656 int ret_tcode = kStr;
657 SendPackedSeq(&ret_value, &ret_tcode, 1);
658 }
659 }
660 this->SwitchToState(kRecvCode);
661 }
662
HandleCopyToRemote()663 void HandleCopyToRemote() {
664 // use static variable to persist state.
665 // This only works if next stage is immediately after this.
666 if (arg_recv_stage_ == 0) {
667 CHECK(this->Read(©_handle_));
668 CHECK(this->Read(©_offset_));
669 CHECK(this->Read(©_size_));
670 CHECK(this->Read(©_ctx_));
671 CHECK(this->Read(©_dtype_));
672 arg_recv_stage_ = 1;
673 CHECK_EQ(pending_request_bytes_, 0U);
674 this->RequestBytes(copy_size_);
675 } else {
676 CHECK_EQ(arg_recv_stage_, 1);
677 TVMValue ret_value;
678 ret_value.v_handle = nullptr;
679 int ret_tcode = kNull;
680 RPCCode code = RPCCode::kReturn;
681 std::string errmsg;
682
683 size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8;
684 if (copy_ctx_.device_type == kDLCPU) {
685 char* dptr = reinterpret_cast<char*>(copy_handle_) + copy_offset_;
686 this->ReadArray(dptr, copy_size_);
687 if (!DMLC_IO_NO_ENDIAN_SWAP) {
688 dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes);
689 }
690 } else {
691 temp_data_.resize(copy_size_ + 1);
692 this->ReadArray(&temp_data_[0], copy_size_);
693 if (!DMLC_IO_NO_ENDIAN_SWAP) {
694 dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes);
695 }
696 try {
697 TVMContext cpu_ctx;
698 cpu_ctx.device_type = kDLCPU;
699 cpu_ctx.device_id = 0;
700 DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
701 temp_data_.data(), 0,
702 reinterpret_cast<void*>(copy_handle_), copy_offset_,
703 copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr);
704 } catch (const std::runtime_error &e) {
705 code = RPCCode::kException;
706 errmsg = e.what();
707 ret_value.v_str = errmsg.c_str();
708 ret_tcode = kStr;
709 }
710 }
711 this->Write(code);
712 SendPackedSeq(&ret_value, &ret_tcode, 1);
713 arg_recv_stage_ = 0;
714 this->SwitchToState(kRecvCode);
715 }
716 }
717 // Handle for packed call.
718 void HandlePackedCall();
719
720 template<typename F>
CallHandler(F f)721 void CallHandler(F f) {
722 TVMRetValue rv;
723 TVMValue ret_value;
724 int ret_tcode;
725 try {
726 // Need to move out, in case f itself need to call RecvPackedSeq
727 // Which will override argbuf again.
728 std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
729 f(args->AsTVMArgs(), &rv);
730 RPCCode code = RPCCode::kReturn;
731 this->Write(code);
732 if (rv.type_code() == kStr) {
733 ret_value.v_str = rv.ptr<std::string>()->c_str();
734 ret_tcode = kStr;
735 SendPackedSeq(&ret_value, &ret_tcode, 1);
736 } else if (rv.type_code() == kBytes) {
737 std::string* bytes = rv.ptr<std::string>();
738 TVMByteArray arr;
739 arr.data = bytes->c_str();
740 arr.size = bytes->length();
741 ret_value.v_handle = &arr;
742 ret_tcode = kBytes;
743 SendPackedSeq(&ret_value, &ret_tcode, 1);
744 } else if (rv.type_code() == kFuncHandle ||
745 rv.type_code() == kModuleHandle) {
746 // always send handle in 64 bit.
747 CHECK(!client_mode_)
748 << "Only server can send function and module handle back.";
749 rv.MoveToCHost(&ret_value, &ret_tcode);
750 SendPackedSeq(&ret_value, &ret_tcode, 1);
751 } else if (rv.type_code() == kNDArrayContainer) {
752 // always send handle in 64 bit.
753 CHECK(!client_mode_)
754 << "Only server can send NDArray back";
755 // We follow a special protocol to return NDArray to client side
756 // The first pack value is the NDArray handle as DLTensor
757 // The second pack value is a customized deleter that deletes the NDArray.
758 TVMValue ret_value_pack[2];
759 int ret_tcode_pack[2];
760 rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
761
762 NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
763 ret_value_pack[1].v_handle = nd;
764 ret_tcode_pack[1] = kHandle;
765 SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
766 } else {
767 ret_value = rv.value();
768 ret_tcode = rv.type_code();
769 SendPackedSeq(&ret_value, &ret_tcode, 1);
770 }
771 } catch (const std::runtime_error& e) {
772 RPCCode code = RPCCode::kException;
773 this->Write(code);
774 ret_value.v_str = e.what();
775 ret_tcode = kStr;
776 SendPackedSeq(&ret_value, &ret_tcode, 1);
777 }
778 }
779
780 private:
781 // Utility functions
782 // Internal read function, update pending_request_bytes_
Read(void * data,size_t size)783 size_t Read(void* data, size_t size) final {
784 CHECK_LE(size, pending_request_bytes_);
785 reader_->Read(data, size);
786 pending_request_bytes_ -= size;
787 return size;
788 }
Write(const void * data,size_t size)789 void Write(const void* data, size_t size) final {
790 writer_->Write(data, size);
791 }
792 // Number of pending bytes requests
793 size_t pending_request_bytes_;
794 // The ring buffer to read data from.
795 common::RingBuffer* reader_;
796 // The ringr buffer to write reply to.
797 common::RingBuffer* writer_;
798 // Session table index.
799 int rpc_sess_table_index_;
800 // Name of session.
801 std::string name_;
802 // remote key
803 std::string* remote_key_;
804 };
805
806 struct RPCSessTable {
807 public:
808 static constexpr int kMaxRPCSession = 32;
809 // Get global singleton
Globaltvm::runtime::RPCSessTable810 static RPCSessTable* Global() {
811 static RPCSessTable inst;
812 return &inst;
813 }
814 // Get session from table
Gettvm::runtime::RPCSessTable815 std::shared_ptr<RPCSession> Get(int index) {
816 CHECK(index >= 0 && index < kMaxRPCSession);
817 return tbl_[index].lock();
818 }
819 // Insert session into table.
Inserttvm::runtime::RPCSessTable820 int Insert(std::shared_ptr<RPCSession> ptr) {
821 std::lock_guard<std::mutex> lock(mutex_);
822 for (int i = 0; i < kMaxRPCSession; ++i) {
823 if (tbl_[i].lock() == nullptr) {
824 tbl_[i] = ptr; return i;
825 }
826 }
827 LOG(FATAL) << "maximum number of RPC session reached";
828 return 0;
829 }
830
831 private:
832 // The mutex
833 std::mutex mutex_;
834 // Use weak_ptr intentionally
835 // If the RPCSession get released, the pointer session will be released
836 std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
837 };
838
HandleUntilReturnEvent(TVMRetValue * rv,bool client_mode,const PackedFunc * fwrap)839 RPCCode RPCSession::HandleUntilReturnEvent(
840 TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap) {
841 RPCCode code = RPCCode::kCallFunc;
842 while (code != RPCCode::kReturn &&
843 code != RPCCode::kShutdown &&
844 code != RPCCode::kCopyAck) {
845 while (writer_.bytes_available() != 0) {
846 writer_.ReadWithCallback([this](const void *data, size_t size) {
847 return channel_->Send(data, size);
848 }, writer_.bytes_available());
849 }
850 size_t bytes_needed = handler_->BytesNeeded();
851 if (bytes_needed != 0) {
852 size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
853 return channel_->Recv(data, size);
854 }, bytes_needed);
855 if (n == 0) {
856 if (handler_->CanCleanShutdown()) {
857 return RPCCode::kShutdown;
858 } else {
859 LOG(FATAL) << "Channel closes before we get neded bytes";
860 }
861 }
862 }
863 code = handler_->HandleNextEvent(rv, client_mode, fwrap);
864 }
865 return code;
866 }
867
Init()868 void RPCSession::Init() {
869 // Event handler
870 handler_ = std::make_shared<EventHandler>(
871 &reader_, &writer_, table_index_, name_, &remote_key_);
872 // Quick function to call remote.
873 call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
874 handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
875 RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
876 CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
877 });
878 }
879
Create(std::unique_ptr<RPCChannel> channel,std::string name,std::string remote_key)880 std::shared_ptr<RPCSession> RPCSession::Create(
881 std::unique_ptr<RPCChannel> channel,
882 std::string name,
883 std::string remote_key) {
884 std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
885 sess->channel_ = std::move(channel);
886 sess->name_ = std::move(name);
887 sess->remote_key_ = std::move(remote_key);
888 sess->table_index_ = RPCSessTable::Global()->Insert(sess);
889 sess->Init();
890 return sess;
891 }
892
Get(int table_index)893 std::shared_ptr<RPCSession> RPCSession::Get(int table_index) {
894 return RPCSessTable::Global()->Get(table_index);
895 }
896
~RPCSession()897 RPCSession::~RPCSession() {
898 this->Shutdown();
899 }
900
Shutdown()901 void RPCSession::Shutdown() {
902 if (channel_ != nullptr) {
903 RPCCode code = RPCCode::kShutdown;
904 handler_->Write(code);
905 // flush all writing buffer to output channel.
906 try {
907 while (writer_.bytes_available() != 0) {
908 size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
909 return channel_->Send(data, size);
910 }, writer_.bytes_available());
911 if (n == 0) break;
912 }
913 } catch (const dmlc::Error& e) {
914 }
915 channel_.reset(nullptr);
916 }
917 }
918
ServerLoop()919 void RPCSession::ServerLoop() {
920 std::lock_guard<std::recursive_mutex> lock(mutex_);
921 if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
922 (*f)();
923 }
924 TVMRetValue rv;
925 CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
926 if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
927 (*f)();
928 }
929 channel_.reset(nullptr);
930 }
931
ServerEventHandler(const std::string & bytes,int event_flag)932 int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
933 std::lock_guard<std::recursive_mutex> lock(mutex_);
934 RPCCode code = RPCCode::kNone;
935 if (bytes.length() != 0) {
936 reader_.Write(bytes.c_str(), bytes.length());
937 TVMRetValue rv;
938 code = handler_->HandleNextEvent(&rv, false, nullptr);
939 }
940 if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
941 writer_.ReadWithCallback([this](const void *data, size_t size) {
942 return channel_->Send(data, size);
943 }, writer_.bytes_available());
944 }
945 CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
946 if (code == RPCCode::kShutdown) return 0;
947 if (writer_.bytes_available() != 0) return 2;
948 return 1;
949 }
950
951 // Get remote function with name
CallFunc(void * h,TVMArgs args,TVMRetValue * rv,const PackedFunc * fwrap)952 void RPCSession::CallFunc(void* h,
953 TVMArgs args,
954 TVMRetValue* rv,
955 const PackedFunc* fwrap) {
956 std::lock_guard<std::recursive_mutex> lock(mutex_);
957 RPCCode code = RPCCode::kCallFunc;
958 handler_->Write(code);
959 uint64_t handle = reinterpret_cast<uint64_t>(h);
960 handler_->Write(handle);
961 handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
962 code = HandleUntilReturnEvent(rv, true, fwrap);
963 CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
964 }
965
CopyToRemote(void * from,size_t from_offset,void * to,size_t to_offset,size_t data_size,TVMContext ctx_to,TVMType type_hint)966 void RPCSession::CopyToRemote(void* from,
967 size_t from_offset,
968 void* to,
969 size_t to_offset,
970 size_t data_size,
971 TVMContext ctx_to,
972 TVMType type_hint) {
973 std::lock_guard<std::recursive_mutex> lock(mutex_);
974 ctx_to = handler_->StripSessMask(ctx_to);
975 RPCCode code = RPCCode::kCopyToRemote;
976 handler_->Write(code);
977 uint64_t handle = reinterpret_cast<uint64_t>(to);
978 handler_->Write(handle);
979 uint64_t offset = static_cast<uint64_t>(to_offset);
980 handler_->Write(offset);
981 uint64_t size = static_cast<uint64_t>(data_size);
982 handler_->Write(size);
983 handler_->Write(ctx_to);
984 handler_->Write(type_hint);
985 handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
986 TVMRetValue rv;
987 CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
988 }
989
CopyFromRemote(void * from,size_t from_offset,void * to,size_t to_offset,size_t data_size,TVMContext ctx_from,TVMType type_hint)990 void RPCSession::CopyFromRemote(void* from,
991 size_t from_offset,
992 void* to,
993 size_t to_offset,
994 size_t data_size,
995 TVMContext ctx_from,
996 TVMType type_hint) {
997 std::lock_guard<std::recursive_mutex> lock(mutex_);
998 ctx_from = handler_->StripSessMask(ctx_from);
999 RPCCode code = RPCCode::kCopyFromRemote;
1000 handler_->Write(code);
1001 uint64_t handle = reinterpret_cast<uint64_t>(from);
1002 handler_->Write(handle);
1003 uint64_t offset = static_cast<uint64_t>(from_offset);
1004 handler_->Write(offset);
1005 uint64_t size = static_cast<uint64_t>(data_size);
1006 handler_->Write(size);
1007 handler_->Write(ctx_from);
1008 handler_->Write(type_hint);
1009 TVMRetValue rv;
1010 CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
1011 reader_.Reserve(data_size);
1012 handler_->RequestBytes(data_size);
1013 while (!handler_->Ready()) {
1014 size_t bytes_needed = handler_->BytesNeeded();
1015 reader_.WriteWithCallback([this](void* data, size_t size) {
1016 size_t n = channel_->Recv(data, size);
1017 CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
1018 return n;
1019 }, bytes_needed);
1020 }
1021 handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
1022 handler_->FinishCopyAck();
1023 }
1024
GetTimeEvaluator(RPCFuncHandle fhandle,TVMContext ctx,int number,int repeat,int min_repeat_ms)1025 RPCFuncHandle RPCSession::GetTimeEvaluator(
1026 RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) {
1027 return this->CallRemote(
1028 RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms);
1029 }
1030
1031 // Event handler functions
RPCGetGlobalFunc(TVMArgs args,TVMRetValue * rv)1032 void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
1033 std::string name = args[0];
1034 auto *fp = tvm::runtime::Registry::Get(name);
1035 if (fp != nullptr) {
1036 *rv = static_cast<void*>(new tvm::runtime::PackedFunc(*fp));
1037 } else {
1038 *rv = nullptr;
1039 }
1040 }
1041
RPCFreeFunc(TVMArgs args,TVMRetValue * rv)1042 void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) {
1043 void* handle = args[0];
1044 delete static_cast<PackedFunc*>(handle);
1045 }
1046
RPCDevSetDevice(TVMArgs args,TVMRetValue * rv)1047 void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) {
1048 TVMContext ctx = args[0];
1049 DeviceAPI::Get(ctx)->SetDevice(ctx);
1050 }
1051
RPCDevGetAttr(TVMArgs args,TVMRetValue * rv)1052 void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) {
1053 TVMContext ctx = args[0];
1054 DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
1055 if (kind == kExist) {
1056 DeviceAPI* api = DeviceAPI::Get(ctx, true);
1057 if (api != nullptr) {
1058 api->GetAttr(ctx, kind, rv);
1059 } else {
1060 *rv = 0;
1061 }
1062 } else {
1063 DeviceAPI::Get(ctx)->GetAttr(
1064 ctx, static_cast<DeviceAttrKind>(kind), rv);
1065 }
1066 }
1067
RPCDevAllocData(TVMArgs args,TVMRetValue * rv)1068 void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) {
1069 TVMContext ctx = args[0];
1070 uint64_t nbytes = args[1];
1071 uint64_t alignment = args[2];
1072 TVMType type_hint = args[3];
1073 void* data = DeviceAPI::Get(ctx)->AllocDataSpace(
1074 ctx, nbytes, alignment, type_hint);
1075 *rv = data;
1076 }
1077
RPCDevFreeData(TVMArgs args,TVMRetValue * rv)1078 void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) {
1079 TVMContext ctx = args[0];
1080 void* ptr = args[1];
1081 DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr);
1082 }
1083
RPCDevStreamSync(TVMArgs args,TVMRetValue * rv)1084 void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) {
1085 TVMContext ctx = args[0];
1086 TVMStreamHandle handle = args[1];
1087 DeviceAPI::Get(ctx)->StreamSync(ctx, handle);
1088 }
1089
RPCCopyAmongRemote(TVMArgs args,TVMRetValue * rv)1090 void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
1091 void* from = args[0];
1092 uint64_t from_offset = args[1];
1093 void* to = args[2];
1094 uint64_t to_offset = args[3];
1095 uint64_t size = args[4];
1096 TVMContext ctx_from = args[5];
1097 TVMContext ctx_to = args[6];
1098 TVMType type_hint = args[7];
1099 TVMStreamHandle stream = args[8];
1100 TVMContext ctx = ctx_from;
1101 if (ctx.device_type == kDLCPU) {
1102 ctx = ctx_to;
1103 } else {
1104 CHECK(ctx_to.device_type == kDLCPU ||
1105 ctx_to.device_type == ctx_from.device_type)
1106 << "Can not copy across different ctx types directly";
1107 }
1108 DeviceAPI::Get(ctx)->CopyDataFromTo(
1109 from, from_offset,
1110 to, to_offset,
1111 size, ctx_from, ctx_to, type_hint, stream);
1112 }
1113
RPCModuleLoad(TVMArgs args,TVMRetValue * rv)1114 void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
1115 static const PackedFunc* fsys_load_ = nullptr;
1116 if (fsys_load_ == nullptr) {
1117 fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
1118 CHECK(fsys_load_ != nullptr);
1119 }
1120 std::string file_name = args[0];
1121 TVMRetValue ret = (*fsys_load_)(file_name);
1122 // pass via void*
1123 TVMValue value;
1124 int rcode;
1125 ret.MoveToCHost(&value, &rcode);
1126 CHECK_EQ(rcode, kModuleHandle);
1127 *rv = static_cast<void*>(value.v_handle);
1128 }
1129
RPCModuleImport(TVMArgs args,TVMRetValue * rv)1130 void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
1131 void* pmod = args[0];
1132 void* cmod = args[1];
1133 ObjectInternal::GetModuleNode(pmod)->Import(
1134 GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
1135 }
1136
RPCModuleFree(TVMArgs args,TVMRetValue * rv)1137 void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
1138 void* mhandle = args[0];
1139 ObjectInternal::ObjectFree(mhandle);
1140 }
1141
RPCModuleGetFunc(TVMArgs args,TVMRetValue * rv)1142 void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
1143 void* mhandle = args[0];
1144 PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
1145 args[1], false);
1146 if (pf != nullptr) {
1147 *rv = static_cast<void*>(new PackedFunc(pf));
1148 } else {
1149 *rv = nullptr;
1150 }
1151 }
1152
RPCModuleGetSource(TVMArgs args,TVMRetValue * rv)1153 void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
1154 void* mhandle = args[0];
1155 std::string fmt = args[1];
1156 *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
1157 }
1158
RPCNDArrayFree(TVMArgs args,TVMRetValue * rv)1159 void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
1160 void* handle = args[0];
1161 static_cast<NDArray::Container*>(handle)->DecRef();
1162 }
1163
RPCGetTimeEvaluator(TVMArgs args,TVMRetValue * rv)1164 void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
1165 PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
1166 void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4]));
1167 delete pf;
1168 *rv = fhandle;
1169 }
1170
HandlePackedCall()1171 void RPCSession::EventHandler::HandlePackedCall() {
1172 CHECK_EQ(pending_request_bytes_, 0U);
1173 if (code_ == RPCCode::kReturn) {
1174 state_ = kReturnReceived; return;
1175 }
1176 // reset state to clean init state
1177 state_ = kRecvCode;
1178 this->RequestBytes(sizeof(RPCCode));
1179 // Event handler sit at clean state at this point.
1180 switch (code_) {
1181 case RPCCode::kCallFunc: {
1182 PackedFunc* pf = reinterpret_cast<PackedFunc*>(call_handle_);
1183 CallHandler([pf](TVMArgs args, TVMRetValue* rv) {
1184 pf->CallPacked(args, rv);
1185 });
1186 break;
1187 }
1188 case RPCCode::kException: {
1189 CHECK_EQ(arg_buf_->value.size(), 1U);
1190 CHECK_EQ(arg_buf_->tcode[0], kStr);
1191 std::ostringstream os;
1192 os << "Except caught from RPC call: " << arg_buf_->value[0].v_str;
1193 arg_buf_.reset();
1194 throw dmlc::Error(os.str());
1195 break;
1196 }
1197 // system functions
1198 case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
1199 case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
1200 case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
1201 case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
1202 case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break;
1203 case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break;
1204 case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break;
1205 case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
1206 case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
1207 case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
1208 case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
1209 case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
1210 case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
1211 case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
1212 case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
1213 default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
1214 }
1215 CHECK_EQ(state_, kRecvCode);
1216 }
1217
WrapTimeEvaluator(PackedFunc pf,TVMContext ctx,int number,int repeat,int min_repeat_ms)1218 PackedFunc WrapTimeEvaluator(PackedFunc pf,
1219 TVMContext ctx,
1220 int number,
1221 int repeat,
1222 int min_repeat_ms) {
1223 auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable {
1224 TVMRetValue temp;
1225 std::ostringstream os;
1226 // skip first time call, to activate lazy compilation components.
1227 pf.CallPacked(args, &temp);
1228 DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
1229
1230 for (int i = 0; i < repeat; ++i) {
1231 std::chrono::time_point<
1232 std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
1233 double duration_ms = 0.0;
1234
1235 do {
1236 if (duration_ms > 0.0) {
1237 number = static_cast<int>(
1238 std::max((min_repeat_ms / (duration_ms / number) + 1),
1239 number * 1.618)); // 1.618 is chosen by random
1240 }
1241
1242 tbegin = std::chrono::high_resolution_clock::now();
1243 // start timing
1244 for (int i = 0; i < number; ++i) {
1245 pf.CallPacked(args, &temp);
1246 }
1247 DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
1248 tend = std::chrono::high_resolution_clock::now();
1249
1250 duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
1251 (tend - tbegin).count() * 1000;
1252 } while (duration_ms < min_repeat_ms);
1253
1254 double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
1255 tend - tbegin).count() / number;
1256 os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
1257 }
1258 std::string blob = os.str();
1259 TVMByteArray arr;
1260 arr.size = blob.length();
1261 arr.data = blob.data();
1262 // return the time.
1263 *rv = arr;
1264 };
1265 return PackedFunc(ftimer);
1266 }
1267
Send(const void * data,size_t size)1268 size_t CallbackChannel::Send(const void* data, size_t size) {
1269 TVMByteArray bytes;
1270 bytes.data = static_cast<const char*>(data);
1271 bytes.size = size;
1272 int64_t n = fsend_(bytes);
1273 if (n == -1) {
1274 common::Socket::Error("CallbackChannel::Send");
1275 }
1276 return static_cast<size_t>(n);
1277 }
1278
Recv(void * data,size_t size)1279 size_t CallbackChannel::Recv(void* data, size_t size) {
1280 TVMRetValue ret = frecv_(size);
1281 if (ret.type_code() != kBytes) {
1282 common::Socket::Error("CallbackChannel::Recv");
1283 }
1284 std::string* bytes = ret.ptr<std::string>();
1285 memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
1286 return bytes->length();
1287 }
1288
1289 } // namespace runtime
1290 } // namespace tvm
1291