1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file rpc_session.cc
22  * \brief RPC session for remote function call.
23  */
24 #include <tvm/runtime/packed_func.h>
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/runtime/registry.h>
27 #include <tvm/runtime/serializer.h>
28 #include <memory>
29 #include <array>
30 #include <string>
31 #include <chrono>
32 #include <vector>
33 #include <utility>
34 #include <cmath>
35 #include <algorithm>
36 #include "rpc_session.h"
37 #include "../object_internal.h"
38 #include "../../common/ring_buffer.h"
39 #include "../../common/socket.h"
40 
41 namespace tvm {
42 namespace runtime {
43 // Temp buffer for data array
44 struct RPCByteArrayBuffer {
45   TVMByteArray arr;
46   std::string data;
47 };
48 // Temp buffer for data array
49 struct RPCDataArrayBuffer {
50   DLTensor tensor;
51   std::vector<int64_t> shape;
52 };
53 /*!
54  * \brief Temporal argument buffer.
55  */
56 struct RPCArgBuffer {
57   // The argument values
58   std::vector<TVMValue> value;
59   // The type codes.
60   std::vector<int> tcode;
61   // Temporal resources.
62   std::vector<std::unique_ptr<RPCByteArrayBuffer> > temp_bytes;
63   // Temporal array
64   std::vector<std::unique_ptr<RPCDataArrayBuffer> > temp_array;
65   // convert buffer as TVMArgs
AsTVMArgstvm::runtime::RPCArgBuffer66   TVMArgs AsTVMArgs() const {
67     return TVMArgs(value.data(), tcode.data(), static_cast<int>(value.size()));
68   }
69 };
70 
71 // Event handler for RPC events.
72 class RPCSession::EventHandler : public dmlc::Stream {
73  public:
EventHandler(common::RingBuffer * reader,common::RingBuffer * writer,int rpc_sess_table_index,std::string name,std::string * remote_key)74   EventHandler(common::RingBuffer* reader,
75                common::RingBuffer* writer,
76                int rpc_sess_table_index,
77                std::string name,
78                std::string* remote_key)
79       : reader_(reader),
80         writer_(writer),
81         rpc_sess_table_index_(rpc_sess_table_index),
82         name_(name),
83         remote_key_(remote_key) {
84     this->Clear();
85     if (*remote_key == "%toinit") {
86       state_ = kInitHeader;
87       remote_key_->resize(0);
88       pending_request_bytes_ = sizeof(int32_t);
89     }
90   }
91   // Bytes needed to fulfill current request
BytesNeeded()92   size_t BytesNeeded() {
93     if (reader_->bytes_available() < pending_request_bytes_) {
94       return pending_request_bytes_ - reader_->bytes_available();
95     } else {
96       return 0;
97     }
98   }
99   // Request number of bytes from reader.
RequestBytes(size_t nbytes)100   void RequestBytes(size_t nbytes) {
101     pending_request_bytes_ += nbytes;
102     reader_->Reserve(pending_request_bytes_);
103   }
104   // Whether we are ready to handle next request.
Ready()105   bool Ready() {
106     return reader_->bytes_available() >= pending_request_bytes_;
107   }
CanCleanShutdown() const108   bool CanCleanShutdown() const {
109     return state_ == kRecvCode;
110   }
FinishCopyAck()111   void FinishCopyAck() {
112     this->SwitchToState(kRecvCode);
113   }
HandleNextEvent(TVMRetValue * rv,bool client_mode,const PackedFunc * fwrap)114   RPCCode HandleNextEvent(TVMRetValue* rv,
115                           bool client_mode,
116                           const PackedFunc* fwrap) {
117     std::swap(client_mode_, client_mode);
118     while (this->Ready()) {
119       switch (state_) {
120         case kInitHeader: HandleInitHeader(); break;
121         case kRecvCode: HandleRecvCode(); break;
122         case kRecvCallHandle: {
123           CHECK(this->Read(&call_handle_));
124           this->SwitchToState(kRecvPackedSeqNumArgs);
125           break;
126         }
127         case kRecvPackedSeqNumArgs: {
128           CHECK(this->Read(&num_packed_args_));
129           arg_buf_.reset(new RPCArgBuffer());
130           arg_buf_->value.resize(num_packed_args_);
131           arg_buf_->tcode.resize(num_packed_args_);
132           this->SwitchToState(kRecvPackedSeqTypeCode);
133           break;
134         }
135         case kRecvPackedSeqTypeCode: {
136           if (num_packed_args_ != 0) {
137             this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
138           }
139           arg_index_ = 0;
140           arg_recv_stage_ = 0;
141           this->SwitchToState(kRecvPackedSeqArg);
142           break;
143         }
144         case kRecvPackedSeqArg: {
145           this->HandleRecvPackedSeqArg();
146           break;
147         }
148         case kDoCopyFromRemote: {
149           this->HandleCopyFromRemote();
150           break;
151         }
152         case kDoCopyToRemote: {
153           this->HandleCopyToRemote();
154           break;
155         }
156         case kReturnReceived: {
157           CHECK_GE(arg_buf_->value.size(), 1U);
158 
159           TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
160           if (argv.type_code() == kFuncHandle ||
161               argv.type_code() == kModuleHandle ||
162               argv.type_code() == kArrayHandle) {
163             CHECK(fwrap != nullptr) << "function/module wrapper not available";
164             fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
165           } else {
166             CHECK_EQ(arg_buf_->value.size(), 1U);
167             *rv = argv;
168           }
169           arg_buf_.reset();
170           this->SwitchToState(kRecvCode);
171           std::swap(client_mode_, client_mode);
172           return RPCCode::kReturn;
173         }
174         case kCopyAckReceived: {
175           std::swap(client_mode_, client_mode);
176           return RPCCode::kCopyAck;
177         }
178         case kShutdownReceived: {
179           std::swap(client_mode_, client_mode);
180           return RPCCode::kShutdown;
181         }
182       }
183     }
184     std::swap(client_mode_, client_mode);
185     return RPCCode::kNone;
186   }
187   // Reset and clear all states.
Clear()188   void Clear() {
189     state_ = kRecvCode;
190     pending_request_bytes_ = sizeof(RPCCode);
191     arg_recv_stage_ = 0;
192     arg_buf_.reset();
193   }
194   // strip session on mask
StripSessMask(TVMContext ctx)195   TVMContext StripSessMask(TVMContext ctx) {
196     int dev_type = ctx.device_type;
197     CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
198         << "Can not pass in local context or context with a different remote session";
199     ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
200     return ctx;
201   }
202   // Send Packed sequence to writer.
203   // return_ndarray is a special flag to handle returning of ndarray
204   //    In this case, we return the shape, context and data of the array,
205   //    as well as a customized PackedFunc that handles deletion of
206   //    the array in the remote.
SendPackedSeq(const TVMValue * arg_values,const int * type_codes,int n,bool return_ndarray=false)207   void SendPackedSeq(const TVMValue* arg_values,
208                      const int* type_codes,
209                      int n,
210                      bool return_ndarray = false) {
211     this->Write(n);
212     for (int i = 0; i < n; ++i) {
213       int tcode = type_codes[i];
214       if (tcode == kNDArrayContainer) tcode = kArrayHandle;
215       this->Write(tcode);
216     }
217 
218     // Argument packing.
219     for (int i = 0; i < n; ++i) {
220       int tcode = type_codes[i];
221       TVMValue value = arg_values[i];
222       switch (tcode) {
223         case kDLInt:
224         case kDLUInt:
225         case kDLFloat: {
226           this->Write<int64_t>(value.v_int64);
227           break;
228         }
229         case kTVMType: {
230           this->Write(value.v_type);
231           // padding
232           int32_t padding = 0;
233           this->Write<int32_t>(padding);
234           break;
235         }
236         case kTVMContext: {
237           value.v_ctx = StripSessMask(value.v_ctx);
238           this->Write(value.v_ctx);
239           break;
240         }
241         case kFuncHandle:
242         case kModuleHandle:
243         case kHandle: {
244           // always send handle in 64 bit.
245           uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
246           this->Write(handle);
247           break;
248         }
249         case kNDArrayContainer:
250         case kArrayHandle: {
251           DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
252           TVMContext ctx;
253           uint64_t data;
254           if (!return_ndarray) {
255             // in the client mode
256             // ctx contains the remote table index
257             // the space is wrapped by an RemoteSpace
258             // that holds reference to the session.
259             ctx = StripSessMask(arr->ctx);
260             data = reinterpret_cast<uint64_t>(
261                 static_cast<RemoteSpace*>(arr->data)->data);
262           } else {
263             // When we return NDArray, we directly return
264             // the space and the context
265             // The client will be further wrapping
266             ctx = arr->ctx;
267             data = reinterpret_cast<uint64_t>(arr->data);
268           }
269           this->Write(data);
270           this->Write(ctx);
271           this->Write(arr->ndim);
272           this->Write(arr->dtype);
273           this->WriteArray(arr->shape, arr->ndim);
274           CHECK(arr->strides == nullptr)
275               << "Do not support strided remote array";
276           CHECK_EQ(arr->byte_offset, 0)
277               << "Do not support send byte offset";
278           break;
279         }
280         case kNull: break;
281         case kStr: {
282           const char* s = value.v_str;
283           uint64_t len = strlen(s);
284           this->Write(len);
285           this->WriteArray(s, len);
286           break;
287         }
288         case kBytes: {
289           TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
290           uint64_t len = bytes->size;
291           this->Write(len);
292           this->WriteArray(bytes->data, len);
293           break;
294         }
295         default: {
296           LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
297           break;
298         }
299       }
300     }
301   }
302 
303   // Endian aware IO handling
304   using Stream::Read;
305   using Stream::Write;
306   using Stream::ReadArray;
307   using Stream::WriteArray;
308 
Read(RPCCode * code)309   inline bool Read(RPCCode* code) {
310     int cdata;
311     if (!this->Read(&cdata)) return false;
312     *code = static_cast<RPCCode>(cdata);
313     return true;
314   }
Write(RPCCode code)315   inline void Write(RPCCode code) {
316     int cdata = static_cast<int>(code);
317     this->Write(cdata);
318   }
319 
320  protected:
321   enum State {
322     kInitHeader,
323     kRecvCode,
324     kRecvCallHandle,
325     kRecvPackedSeqNumArgs,
326     kRecvPackedSeqTypeCode,
327     kRecvPackedSeqArg,
328     kDoCopyFromRemote,
329     kDoCopyToRemote,
330     kReturnReceived,
331     kCopyAckReceived,
332     kShutdownReceived
333   };
334   // Current state;
335   State state_;
336   // The RPCCode to be read.
337   RPCCode code_;
338   // Handle for the remote function call.
339   uint64_t call_handle_;
340   // Initialize remote header
341   bool init_header_step_{0};
342   // Number of packed arguments.
343   int num_packed_args_;
344   // Current argument index.
345   int arg_index_;
346   // The stage of each argument receiver.
347   int arg_recv_stage_;
348   // Whether current handler is client or server mode.
349   bool client_mode_{false};
350   // Argument buffer
351   std::unique_ptr<RPCArgBuffer> arg_buf_;
352   // Temp byte buffer.
353   std::unique_ptr<RPCByteArrayBuffer> temp_bytes_;
354   // Temp array buffer.
355   std::unique_ptr<RPCDataArrayBuffer> temp_array_;
356   // Internal temporal data space.
357   std::string temp_data_;
358   // Temp variables for copy request state.
359   TVMContext copy_ctx_;
360   TVMType copy_dtype_;
361   uint64_t copy_handle_, copy_offset_, copy_size_;
362   // State switcher
SwitchToState(State state)363   void SwitchToState(State state) {
364     // invariant
365     CHECK_EQ(pending_request_bytes_, 0U)
366         << "state=" << state;
367     state_ = state;
368     switch (state) {
369       case kInitHeader: {
370         LOG(FATAL) << "cannot switch to init header";
371         break;
372       }
373       case kRecvCode: {
374         this->RequestBytes(sizeof(RPCCode));
375         break;
376       }
377       case kRecvCallHandle: {
378         this->RequestBytes(sizeof(call_handle_));
379         break;
380       }
381       case kRecvPackedSeqNumArgs: {
382         this->RequestBytes(sizeof(num_packed_args_));
383         break;
384       }
385       case kRecvPackedSeqTypeCode: {
386         this->RequestBytes(sizeof(int) * num_packed_args_);
387         break;
388       }
389       case kRecvPackedSeqArg: {
390         CHECK_LE(arg_index_, num_packed_args_);
391         if (arg_index_ == num_packed_args_) {
392           // The function can change state_ again.
393           HandlePackedCall();
394         } else {
395           RequestRecvPackedSeqArg();
396         }
397         break;
398       }
399       case kDoCopyFromRemote: {
400         this->RequestBytes(sizeof(uint64_t) * 3);
401         this->RequestBytes(sizeof(TVMContext));
402         this->RequestBytes(sizeof(TVMType));
403         break;
404       }
405       case kDoCopyToRemote: {
406         this->RequestBytes(sizeof(uint64_t) * 3);
407         this->RequestBytes(sizeof(TVMContext));
408         this->RequestBytes(sizeof(TVMType));
409         break;
410       }
411       case kCopyAckReceived:
412       case kReturnReceived:
413       case kShutdownReceived: {
414         break;
415       }
416     }
417   }
418   // Requets bytes needed for next computation.
RequestRecvPackedSeqArg()419   void RequestRecvPackedSeqArg() {
420     CHECK_EQ(arg_recv_stage_, 0);
421     int tcode = arg_buf_->tcode[arg_index_];
422     static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
423     switch (tcode) {
424       case kDLInt:
425       case kDLUInt:
426       case kDLFloat:
427       case kTVMType:
428       case kHandle:
429       case kStr:
430       case kBytes:
431       case kTVMContext: {
432         this->RequestBytes(sizeof(TVMValue)); break;
433       }
434       case kFuncHandle:
435       case kModuleHandle: {
436         CHECK(client_mode_)
437             << "Only client can receive remote functions";
438         this->RequestBytes(sizeof(TVMValue)); break;
439       }
440       case kNull: break;
441       case kArrayHandle: {
442         this->RequestBytes(sizeof(uint64_t));
443         this->RequestBytes(sizeof(TVMContext));
444         this->RequestBytes(sizeof(int));
445         this->RequestBytes(sizeof(DLDataType));
446         break;
447       }
448       default: {
449         LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
450         break;
451       }
452     }
453   }
454   // Handler for packed sequence argument receive.
HandleRecvPackedSeqArg()455   void HandleRecvPackedSeqArg() {
456     CHECK_LT(arg_index_, num_packed_args_);
457     int tcode = arg_buf_->tcode[arg_index_];
458     TVMValue& value = arg_buf_->value[arg_index_];
459     if (arg_recv_stage_ == 0) {
460       switch (tcode) {
461         case kDLInt:
462         case kDLUInt:
463         case kDLFloat: {
464           this->Read<int64_t>(&(value.v_int64));
465           ++arg_index_;
466           this->SwitchToState(kRecvPackedSeqArg);
467           break;
468         }
469         case kTVMType: {
470           this->Read(&(value.v_type));
471           int32_t padding = 0;
472           this->Read<int32_t>(&padding);
473           ++arg_index_;
474           this->SwitchToState(kRecvPackedSeqArg);
475           break;
476         }
477         case kTVMContext: {
478           this->Read(&(value.v_ctx));
479           ++arg_index_;
480           this->SwitchToState(kRecvPackedSeqArg);
481           break;
482         }
483         case kFuncHandle:
484         case kModuleHandle:
485         case kHandle: {
486           // always send handle in 64 bit.
487           uint64_t handle;
488           this->Read(&handle);
489           value.v_handle = reinterpret_cast<void*>(handle);
490           ++arg_index_;
491           this->SwitchToState(kRecvPackedSeqArg);
492           break;
493         }
494         case kNull: {
495           value.v_handle = nullptr;
496           ++arg_index_;
497           this->SwitchToState(kRecvPackedSeqArg);
498           break;
499         }
500         case kStr:
501         case kBytes: {
502           uint64_t len;
503           this->Read(&len);
504           temp_bytes_.reset( new RPCByteArrayBuffer());
505           temp_bytes_->data.resize(len);
506           arg_recv_stage_ = 1;
507           this->RequestBytes(len);
508           break;
509         }
510         case kArrayHandle: {
511           temp_array_.reset(new RPCDataArrayBuffer());
512           uint64_t handle;
513           this->Read(&handle);
514           DLTensor& tensor = temp_array_->tensor;
515           tensor.data = reinterpret_cast<void*>(handle);
516           this->Read(&(tensor.ctx));
517           this->Read(&(tensor.ndim));
518           this->Read(&(tensor.dtype));
519           temp_array_->shape.resize(tensor.ndim);
520           tensor.shape = temp_array_->shape.data();
521           arg_recv_stage_ = 1;
522           tensor.strides = nullptr;
523           tensor.byte_offset = 0;
524           this->RequestBytes(sizeof(int64_t) * tensor.ndim);
525           break;
526         }
527         default: {
528           LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
529           break;
530         }
531       }
532     } else {
533       CHECK_EQ(arg_recv_stage_, 1);
534       if (tcode == kStr || tcode == kBytes) {
535         if (temp_bytes_->data.size() != 0) {
536           this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
537         }
538         if (tcode == kStr) {
539           value.v_str = temp_bytes_->data.c_str();
540         } else {
541           temp_bytes_->arr.size = static_cast<size_t>(temp_bytes_->data.size());
542           temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data);
543           value.v_handle = &(temp_bytes_->arr);
544         }
545         arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_));
546       } else {
547         CHECK_EQ(tcode, kArrayHandle);
548         DLTensor& tensor = temp_array_->tensor;
549         this->ReadArray(tensor.shape, tensor.ndim);
550         value.v_handle = &tensor;
551         arg_buf_->temp_array.emplace_back(std::move(temp_array_));
552       }
553       ++arg_index_;
554       arg_recv_stage_ = 0;
555       this->SwitchToState(kRecvPackedSeqArg);
556     }
557   }
558   // handler for initial header read
HandleInitHeader()559   void HandleInitHeader() {
560     if (init_header_step_ == 0) {
561       int32_t len;
562       this->Read(&len);
563       remote_key_->resize(len);
564       init_header_step_ = 1;
565       this->RequestBytes(len);
566       return;
567     } else {
568       CHECK_EQ(init_header_step_, 1);
569       this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
570       this->SwitchToState(kRecvCode);
571     }
572   }
573   // Handler for read code.
HandleRecvCode()574   void HandleRecvCode() {
575     this->Read(&code_);
576     if (code_ > RPCCode::kSystemFuncStart) {
577       SwitchToState(kRecvPackedSeqNumArgs);
578       return;
579     }
580     // invariant.
581     CHECK_EQ(arg_recv_stage_, 0);
582     switch (code_) {
583       case RPCCode::kCallFunc: {
584         SwitchToState(kRecvCallHandle);
585         break;
586       }
587       case RPCCode::kException:
588       case RPCCode::kReturn: {
589         SwitchToState(kRecvPackedSeqNumArgs);
590         break;
591       }
592       case RPCCode::kCopyFromRemote: {
593         SwitchToState(kDoCopyFromRemote);
594         break;
595       }
596       case RPCCode::kCopyToRemote: {
597         SwitchToState(kDoCopyToRemote);
598         break;
599       }
600       case RPCCode::kShutdown: {
601         SwitchToState(kShutdownReceived);
602         break;
603       }
604       case RPCCode::kCopyAck: {
605         SwitchToState(kCopyAckReceived);
606         break;
607       }
608       default: LOG(FATAL) << "Unknown event "  << static_cast<int>(code_);
609     }
610   }
611 
HandleCopyFromRemote()612   void HandleCopyFromRemote() {
613     uint64_t handle, offset, num_bytes;
614     TVMContext ctx;
615     TVMType type_hint;
616     this->Read(&handle);
617     this->Read(&offset);
618     this->Read(&num_bytes);
619     this->Read(&ctx);
620     this->Read(&type_hint);
621     size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
622 
623     if (ctx.device_type == kDLCPU) {
624       RPCCode code = RPCCode::kCopyAck;
625       this->Write(code);
626       char* dptr = reinterpret_cast<char*>(handle) + offset;
627       if (!DMLC_IO_NO_ENDIAN_SWAP) {
628         temp_data_.resize(0);
629         temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes);
630         dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
631         this->WriteArray(temp_data_.data(), num_bytes);
632       } else {
633         this->WriteArray(dptr, num_bytes);
634       }
635     } else {
636       temp_data_.resize(num_bytes + 1);
637       try {
638         TVMContext cpu_ctx;
639         cpu_ctx.device_type = kDLCPU;
640         cpu_ctx.device_id = 0;
641         DeviceAPI::Get(ctx)->CopyDataFromTo(
642             reinterpret_cast<void*>(handle), offset,
643             dmlc::BeginPtr(temp_data_), 0,
644             num_bytes, ctx, cpu_ctx, type_hint, nullptr);
645         RPCCode code = RPCCode::kCopyAck;
646         this->Write(code);
647         if (!DMLC_IO_NO_ENDIAN_SWAP) {
648           dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
649         }
650         this->WriteArray(&temp_data_[0], num_bytes);
651       } catch (const std::runtime_error &e) {
652         RPCCode code = RPCCode::kException;
653         this->Write(code);
654         TVMValue ret_value;
655         ret_value.v_str = e.what();
656         int ret_tcode = kStr;
657         SendPackedSeq(&ret_value, &ret_tcode, 1);
658       }
659     }
660     this->SwitchToState(kRecvCode);
661   }
662 
HandleCopyToRemote()663   void HandleCopyToRemote() {
664     // use static variable to persist state.
665     // This only works if next stage is immediately after this.
666     if (arg_recv_stage_ == 0) {
667       CHECK(this->Read(&copy_handle_));
668       CHECK(this->Read(&copy_offset_));
669       CHECK(this->Read(&copy_size_));
670       CHECK(this->Read(&copy_ctx_));
671       CHECK(this->Read(&copy_dtype_));
672       arg_recv_stage_ = 1;
673       CHECK_EQ(pending_request_bytes_, 0U);
674       this->RequestBytes(copy_size_);
675     } else {
676       CHECK_EQ(arg_recv_stage_, 1);
677       TVMValue ret_value;
678       ret_value.v_handle = nullptr;
679       int ret_tcode = kNull;
680       RPCCode code = RPCCode::kReturn;
681       std::string errmsg;
682 
683       size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8;
684       if (copy_ctx_.device_type == kDLCPU) {
685         char* dptr = reinterpret_cast<char*>(copy_handle_) + copy_offset_;
686         this->ReadArray(dptr, copy_size_);
687         if (!DMLC_IO_NO_ENDIAN_SWAP) {
688           dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes);
689         }
690       } else {
691         temp_data_.resize(copy_size_ + 1);
692         this->ReadArray(&temp_data_[0], copy_size_);
693         if (!DMLC_IO_NO_ENDIAN_SWAP) {
694           dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes);
695         }
696         try {
697           TVMContext cpu_ctx;
698           cpu_ctx.device_type = kDLCPU;
699           cpu_ctx.device_id = 0;
700           DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
701               temp_data_.data(), 0,
702               reinterpret_cast<void*>(copy_handle_), copy_offset_,
703               copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr);
704         } catch (const std::runtime_error &e) {
705           code = RPCCode::kException;
706           errmsg = e.what();
707           ret_value.v_str = errmsg.c_str();
708           ret_tcode = kStr;
709         }
710       }
711       this->Write(code);
712       SendPackedSeq(&ret_value, &ret_tcode, 1);
713       arg_recv_stage_ = 0;
714       this->SwitchToState(kRecvCode);
715     }
716   }
717   // Handle for packed call.
718   void HandlePackedCall();
719 
720   template<typename F>
CallHandler(F f)721   void CallHandler(F f) {
722     TVMRetValue rv;
723     TVMValue ret_value;
724     int ret_tcode;
725     try {
726       // Need to move out, in case f itself need to call RecvPackedSeq
727       // Which will override argbuf again.
728       std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
729       f(args->AsTVMArgs(), &rv);
730       RPCCode code = RPCCode::kReturn;
731       this->Write(code);
732       if (rv.type_code() == kStr) {
733         ret_value.v_str = rv.ptr<std::string>()->c_str();
734         ret_tcode = kStr;
735         SendPackedSeq(&ret_value, &ret_tcode, 1);
736       } else if (rv.type_code() == kBytes) {
737         std::string* bytes = rv.ptr<std::string>();
738         TVMByteArray arr;
739         arr.data = bytes->c_str();
740         arr.size = bytes->length();
741         ret_value.v_handle = &arr;
742         ret_tcode = kBytes;
743         SendPackedSeq(&ret_value, &ret_tcode, 1);
744       } else if (rv.type_code() == kFuncHandle ||
745                  rv.type_code() == kModuleHandle) {
746         // always send handle in 64 bit.
747         CHECK(!client_mode_)
748               << "Only server can send function and module handle back.";
749         rv.MoveToCHost(&ret_value, &ret_tcode);
750         SendPackedSeq(&ret_value, &ret_tcode, 1);
751       } else if (rv.type_code() == kNDArrayContainer) {
752         // always send handle in 64 bit.
753         CHECK(!client_mode_)
754             << "Only server can send NDArray back";
755         // We follow a special protocol to return NDArray to client side
756         // The first pack value is the NDArray handle as DLTensor
757         // The second pack value is a customized deleter that deletes the NDArray.
758         TVMValue ret_value_pack[2];
759         int ret_tcode_pack[2];
760         rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);
761 
762         NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
763         ret_value_pack[1].v_handle = nd;
764         ret_tcode_pack[1] = kHandle;
765         SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
766       } else {
767         ret_value = rv.value();
768         ret_tcode = rv.type_code();
769         SendPackedSeq(&ret_value, &ret_tcode, 1);
770       }
771     } catch (const std::runtime_error& e) {
772       RPCCode code = RPCCode::kException;
773       this->Write(code);
774       ret_value.v_str = e.what();
775       ret_tcode = kStr;
776       SendPackedSeq(&ret_value, &ret_tcode, 1);
777     }
778   }
779 
780  private:
781   // Utility functions
782   // Internal read function, update pending_request_bytes_
Read(void * data,size_t size)783   size_t Read(void* data, size_t size) final {
784     CHECK_LE(size, pending_request_bytes_);
785     reader_->Read(data, size);
786     pending_request_bytes_ -= size;
787     return size;
788   }
Write(const void * data,size_t size)789   void Write(const void* data, size_t size) final {
790     writer_->Write(data, size);
791   }
792   // Number of pending bytes requests
793   size_t pending_request_bytes_;
794   // The ring buffer to read data from.
795   common::RingBuffer* reader_;
796   // The ringr buffer to write reply to.
797   common::RingBuffer* writer_;
798   // Session table index.
799   int rpc_sess_table_index_;
800   // Name of session.
801   std::string name_;
802   // remote key
803   std::string* remote_key_;
804 };
805 
806 struct RPCSessTable {
807  public:
808   static constexpr int kMaxRPCSession = 32;
809   // Get global singleton
Globaltvm::runtime::RPCSessTable810   static RPCSessTable* Global() {
811     static RPCSessTable inst;
812     return &inst;
813   }
814   // Get session from table
Gettvm::runtime::RPCSessTable815   std::shared_ptr<RPCSession> Get(int index) {
816     CHECK(index >= 0 && index < kMaxRPCSession);
817     return tbl_[index].lock();
818   }
819   // Insert session into table.
Inserttvm::runtime::RPCSessTable820   int Insert(std::shared_ptr<RPCSession> ptr) {
821     std::lock_guard<std::mutex> lock(mutex_);
822     for (int i = 0; i < kMaxRPCSession; ++i) {
823       if (tbl_[i].lock() == nullptr) {
824         tbl_[i] = ptr; return i;
825       }
826     }
827     LOG(FATAL) << "maximum number of RPC session reached";
828     return 0;
829   }
830 
831  private:
832   // The mutex
833   std::mutex mutex_;
834   // Use weak_ptr intentionally
835   // If the RPCSession get released, the pointer session will be released
836   std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
837 };
838 
HandleUntilReturnEvent(TVMRetValue * rv,bool client_mode,const PackedFunc * fwrap)839 RPCCode RPCSession::HandleUntilReturnEvent(
840     TVMRetValue* rv,  bool client_mode, const PackedFunc* fwrap) {
841   RPCCode code = RPCCode::kCallFunc;
842   while (code != RPCCode::kReturn &&
843          code != RPCCode::kShutdown &&
844          code != RPCCode::kCopyAck) {
845     while (writer_.bytes_available() != 0) {
846       writer_.ReadWithCallback([this](const void *data, size_t size) {
847           return channel_->Send(data, size);
848         }, writer_.bytes_available());
849     }
850     size_t bytes_needed = handler_->BytesNeeded();
851     if (bytes_needed != 0) {
852       size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
853           return channel_->Recv(data, size);
854         }, bytes_needed);
855       if (n == 0) {
856         if (handler_->CanCleanShutdown()) {
857           return RPCCode::kShutdown;
858         } else {
859           LOG(FATAL) << "Channel closes before we get neded bytes";
860         }
861       }
862     }
863     code = handler_->HandleNextEvent(rv, client_mode, fwrap);
864   }
865   return code;
866 }
867 
Init()868 void RPCSession::Init() {
869   // Event handler
870   handler_ = std::make_shared<EventHandler>(
871       &reader_, &writer_, table_index_, name_, &remote_key_);
872   // Quick function to call remote.
873   call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
874       handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
875       RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
876       CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
877     });
878 }
879 
Create(std::unique_ptr<RPCChannel> channel,std::string name,std::string remote_key)880 std::shared_ptr<RPCSession> RPCSession::Create(
881     std::unique_ptr<RPCChannel> channel,
882     std::string name,
883     std::string remote_key) {
884   std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
885   sess->channel_ = std::move(channel);
886   sess->name_ = std::move(name);
887   sess->remote_key_ = std::move(remote_key);
888   sess->table_index_ = RPCSessTable::Global()->Insert(sess);
889   sess->Init();
890   return sess;
891 }
892 
Get(int table_index)893 std::shared_ptr<RPCSession> RPCSession::Get(int table_index) {
894   return RPCSessTable::Global()->Get(table_index);
895 }
896 
~RPCSession()897 RPCSession::~RPCSession() {
898   this->Shutdown();
899 }
900 
Shutdown()901 void RPCSession::Shutdown() {
902   if (channel_ != nullptr) {
903     RPCCode code = RPCCode::kShutdown;
904     handler_->Write(code);
905     // flush all writing buffer to output channel.
906     try {
907       while (writer_.bytes_available() != 0) {
908         size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
909             return channel_->Send(data, size);
910           }, writer_.bytes_available());
911         if (n == 0) break;
912       }
913     } catch (const dmlc::Error& e) {
914     }
915     channel_.reset(nullptr);
916   }
917 }
918 
ServerLoop()919 void RPCSession::ServerLoop() {
920   std::lock_guard<std::recursive_mutex> lock(mutex_);
921   if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
922     (*f)();
923   }
924   TVMRetValue rv;
925   CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
926   if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
927     (*f)();
928   }
929   channel_.reset(nullptr);
930 }
931 
ServerEventHandler(const std::string & bytes,int event_flag)932 int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
933   std::lock_guard<std::recursive_mutex> lock(mutex_);
934   RPCCode code = RPCCode::kNone;
935   if (bytes.length() != 0) {
936     reader_.Write(bytes.c_str(), bytes.length());
937     TVMRetValue rv;
938     code = handler_->HandleNextEvent(&rv, false, nullptr);
939   }
940   if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
941     writer_.ReadWithCallback([this](const void *data, size_t size) {
942         return channel_->Send(data, size);
943       }, writer_.bytes_available());
944   }
945   CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
946   if (code == RPCCode::kShutdown) return 0;
947   if (writer_.bytes_available() != 0) return 2;
948   return 1;
949 }
950 
951 // Get remote function with name
CallFunc(void * h,TVMArgs args,TVMRetValue * rv,const PackedFunc * fwrap)952 void RPCSession::CallFunc(void* h,
953                           TVMArgs args,
954                           TVMRetValue* rv,
955                           const PackedFunc* fwrap) {
956   std::lock_guard<std::recursive_mutex> lock(mutex_);
957   RPCCode code = RPCCode::kCallFunc;
958   handler_->Write(code);
959   uint64_t handle = reinterpret_cast<uint64_t>(h);
960   handler_->Write(handle);
961   handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
962   code = HandleUntilReturnEvent(rv, true, fwrap);
963   CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
964 }
965 
CopyToRemote(void * from,size_t from_offset,void * to,size_t to_offset,size_t data_size,TVMContext ctx_to,TVMType type_hint)966 void RPCSession::CopyToRemote(void* from,
967                               size_t from_offset,
968                               void* to,
969                               size_t to_offset,
970                               size_t data_size,
971                               TVMContext ctx_to,
972                               TVMType type_hint) {
973   std::lock_guard<std::recursive_mutex> lock(mutex_);
974   ctx_to = handler_->StripSessMask(ctx_to);
975   RPCCode code = RPCCode::kCopyToRemote;
976   handler_->Write(code);
977   uint64_t handle = reinterpret_cast<uint64_t>(to);
978   handler_->Write(handle);
979   uint64_t offset = static_cast<uint64_t>(to_offset);
980   handler_->Write(offset);
981   uint64_t size = static_cast<uint64_t>(data_size);
982   handler_->Write(size);
983   handler_->Write(ctx_to);
984   handler_->Write(type_hint);
985   handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
986   TVMRetValue rv;
987   CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
988 }
989 
CopyFromRemote(void * from,size_t from_offset,void * to,size_t to_offset,size_t data_size,TVMContext ctx_from,TVMType type_hint)990 void RPCSession::CopyFromRemote(void* from,
991                                 size_t from_offset,
992                                 void* to,
993                                 size_t to_offset,
994                                 size_t data_size,
995                                 TVMContext ctx_from,
996                                 TVMType type_hint) {
997   std::lock_guard<std::recursive_mutex> lock(mutex_);
998   ctx_from = handler_->StripSessMask(ctx_from);
999   RPCCode code = RPCCode::kCopyFromRemote;
1000   handler_->Write(code);
1001   uint64_t handle = reinterpret_cast<uint64_t>(from);
1002   handler_->Write(handle);
1003   uint64_t offset = static_cast<uint64_t>(from_offset);
1004   handler_->Write(offset);
1005   uint64_t size = static_cast<uint64_t>(data_size);
1006   handler_->Write(size);
1007   handler_->Write(ctx_from);
1008   handler_->Write(type_hint);
1009   TVMRetValue rv;
1010   CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
1011   reader_.Reserve(data_size);
1012   handler_->RequestBytes(data_size);
1013   while (!handler_->Ready()) {
1014     size_t bytes_needed = handler_->BytesNeeded();
1015     reader_.WriteWithCallback([this](void* data, size_t size) {
1016         size_t n = channel_->Recv(data, size);
1017         CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
1018         return n;
1019       }, bytes_needed);
1020   }
1021   handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
1022   handler_->FinishCopyAck();
1023 }
1024 
GetTimeEvaluator(RPCFuncHandle fhandle,TVMContext ctx,int number,int repeat,int min_repeat_ms)1025 RPCFuncHandle RPCSession::GetTimeEvaluator(
1026     RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) {
1027   return this->CallRemote(
1028       RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms);
1029 }
1030 
1031 // Event handler functions
RPCGetGlobalFunc(TVMArgs args,TVMRetValue * rv)1032 void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
1033   std::string name = args[0];
1034   auto *fp = tvm::runtime::Registry::Get(name);
1035   if (fp != nullptr) {
1036     *rv = static_cast<void*>(new tvm::runtime::PackedFunc(*fp));
1037   } else {
1038     *rv = nullptr;
1039   }
1040 }
1041 
RPCFreeFunc(TVMArgs args,TVMRetValue * rv)1042 void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) {
1043   void* handle = args[0];
1044   delete static_cast<PackedFunc*>(handle);
1045 }
1046 
RPCDevSetDevice(TVMArgs args,TVMRetValue * rv)1047 void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) {
1048   TVMContext ctx = args[0];
1049   DeviceAPI::Get(ctx)->SetDevice(ctx);
1050 }
1051 
RPCDevGetAttr(TVMArgs args,TVMRetValue * rv)1052 void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) {
1053   TVMContext ctx = args[0];
1054   DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
1055   if (kind == kExist) {
1056     DeviceAPI* api = DeviceAPI::Get(ctx, true);
1057     if (api != nullptr) {
1058       api->GetAttr(ctx, kind, rv);
1059     } else {
1060       *rv = 0;
1061     }
1062   } else {
1063     DeviceAPI::Get(ctx)->GetAttr(
1064         ctx, static_cast<DeviceAttrKind>(kind), rv);
1065   }
1066 }
1067 
RPCDevAllocData(TVMArgs args,TVMRetValue * rv)1068 void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) {
1069   TVMContext ctx = args[0];
1070   uint64_t nbytes = args[1];
1071   uint64_t alignment = args[2];
1072   TVMType type_hint = args[3];
1073   void* data = DeviceAPI::Get(ctx)->AllocDataSpace(
1074       ctx, nbytes, alignment, type_hint);
1075   *rv = data;
1076 }
1077 
RPCDevFreeData(TVMArgs args,TVMRetValue * rv)1078 void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) {
1079   TVMContext ctx = args[0];
1080   void* ptr = args[1];
1081   DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr);
1082 }
1083 
RPCDevStreamSync(TVMArgs args,TVMRetValue * rv)1084 void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) {
1085   TVMContext ctx = args[0];
1086   TVMStreamHandle handle = args[1];
1087   DeviceAPI::Get(ctx)->StreamSync(ctx, handle);
1088 }
1089 
RPCCopyAmongRemote(TVMArgs args,TVMRetValue * rv)1090 void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
1091   void* from = args[0];
1092   uint64_t from_offset = args[1];
1093   void* to = args[2];
1094   uint64_t to_offset = args[3];
1095   uint64_t size = args[4];
1096   TVMContext ctx_from = args[5];
1097   TVMContext ctx_to = args[6];
1098   TVMType type_hint = args[7];
1099   TVMStreamHandle stream = args[8];
1100   TVMContext ctx = ctx_from;
1101   if (ctx.device_type == kDLCPU) {
1102     ctx = ctx_to;
1103   } else {
1104     CHECK(ctx_to.device_type == kDLCPU ||
1105           ctx_to.device_type == ctx_from.device_type)
1106         << "Can not copy across different ctx types directly";
1107   }
1108   DeviceAPI::Get(ctx)->CopyDataFromTo(
1109       from, from_offset,
1110       to, to_offset,
1111       size, ctx_from, ctx_to, type_hint, stream);
1112 }
1113 
RPCModuleLoad(TVMArgs args,TVMRetValue * rv)1114 void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
1115   static const PackedFunc* fsys_load_ = nullptr;
1116   if (fsys_load_ == nullptr) {
1117     fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
1118     CHECK(fsys_load_ != nullptr);
1119   }
1120   std::string file_name = args[0];
1121   TVMRetValue ret = (*fsys_load_)(file_name);
1122   // pass via void*
1123   TVMValue value;
1124   int rcode;
1125   ret.MoveToCHost(&value, &rcode);
1126   CHECK_EQ(rcode, kModuleHandle);
1127   *rv = static_cast<void*>(value.v_handle);
1128 }
1129 
RPCModuleImport(TVMArgs args,TVMRetValue * rv)1130 void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
1131   void* pmod = args[0];
1132   void* cmod = args[1];
1133   ObjectInternal::GetModuleNode(pmod)->Import(
1134       GetRef<Module>(ObjectInternal::GetModuleNode(cmod)));
1135 }
1136 
RPCModuleFree(TVMArgs args,TVMRetValue * rv)1137 void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
1138   void* mhandle = args[0];
1139   ObjectInternal::ObjectFree(mhandle);
1140 }
1141 
RPCModuleGetFunc(TVMArgs args,TVMRetValue * rv)1142 void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
1143   void* mhandle = args[0];
1144   PackedFunc pf = ObjectInternal::GetModuleNode(mhandle)->GetFunction(
1145       args[1], false);
1146   if (pf != nullptr) {
1147     *rv = static_cast<void*>(new PackedFunc(pf));
1148   } else {
1149     *rv = nullptr;
1150   }
1151 }
1152 
RPCModuleGetSource(TVMArgs args,TVMRetValue * rv)1153 void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
1154   void* mhandle = args[0];
1155   std::string fmt = args[1];
1156   *rv = ObjectInternal::GetModuleNode(mhandle)->GetSource(fmt);
1157 }
1158 
RPCNDArrayFree(TVMArgs args,TVMRetValue * rv)1159 void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
1160   void* handle = args[0];
1161   static_cast<NDArray::Container*>(handle)->DecRef();
1162 }
1163 
RPCGetTimeEvaluator(TVMArgs args,TVMRetValue * rv)1164 void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
1165   PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
1166   void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4]));
1167   delete pf;
1168   *rv = fhandle;
1169 }
1170 
HandlePackedCall()1171 void RPCSession::EventHandler::HandlePackedCall() {
1172   CHECK_EQ(pending_request_bytes_, 0U);
1173   if (code_ == RPCCode::kReturn) {
1174     state_ = kReturnReceived; return;
1175   }
1176   // reset state to clean init state
1177   state_ = kRecvCode;
1178   this->RequestBytes(sizeof(RPCCode));
1179   // Event handler sit at clean state at this point.
1180   switch (code_) {
1181     case RPCCode::kCallFunc: {
1182       PackedFunc* pf = reinterpret_cast<PackedFunc*>(call_handle_);
1183       CallHandler([pf](TVMArgs args, TVMRetValue* rv) {
1184           pf->CallPacked(args, rv);
1185         });
1186       break;
1187     }
1188     case RPCCode::kException: {
1189       CHECK_EQ(arg_buf_->value.size(), 1U);
1190       CHECK_EQ(arg_buf_->tcode[0], kStr);
1191       std::ostringstream os;
1192       os << "Except caught from RPC call: " << arg_buf_->value[0].v_str;
1193       arg_buf_.reset();
1194       throw dmlc::Error(os.str());
1195       break;
1196     }
1197     // system functions
1198     case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
1199     case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
1200     case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
1201     case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
1202     case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break;
1203     case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break;
1204     case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break;
1205     case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
1206     case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
1207     case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
1208     case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
1209     case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
1210     case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
1211     case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
1212     case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
1213     default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
1214   }
1215   CHECK_EQ(state_, kRecvCode);
1216 }
1217 
WrapTimeEvaluator(PackedFunc pf,TVMContext ctx,int number,int repeat,int min_repeat_ms)1218 PackedFunc WrapTimeEvaluator(PackedFunc pf,
1219                              TVMContext ctx,
1220                              int number,
1221                              int repeat,
1222                              int min_repeat_ms) {
1223   auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable {
1224     TVMRetValue temp;
1225     std::ostringstream os;
1226     // skip first time call, to activate lazy compilation components.
1227     pf.CallPacked(args, &temp);
1228     DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
1229 
1230     for (int i = 0; i < repeat; ++i) {
1231       std::chrono::time_point<
1232         std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
1233       double duration_ms = 0.0;
1234 
1235       do {
1236         if (duration_ms > 0.0) {
1237           number = static_cast<int>(
1238               std::max((min_repeat_ms / (duration_ms / number) + 1),
1239                        number * 1.618));   // 1.618 is chosen by random
1240         }
1241 
1242         tbegin = std::chrono::high_resolution_clock::now();
1243         // start timing
1244         for (int i = 0; i < number; ++i) {
1245           pf.CallPacked(args, &temp);
1246         }
1247         DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
1248         tend = std::chrono::high_resolution_clock::now();
1249 
1250         duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
1251             (tend - tbegin).count() * 1000;
1252       } while (duration_ms < min_repeat_ms);
1253 
1254       double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
1255           tend - tbegin).count() / number;
1256       os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
1257     }
1258     std::string blob = os.str();
1259     TVMByteArray arr;
1260     arr.size = blob.length();
1261     arr.data = blob.data();
1262     // return the time.
1263     *rv = arr;
1264   };
1265   return PackedFunc(ftimer);
1266 }
1267 
Send(const void * data,size_t size)1268 size_t CallbackChannel::Send(const void* data, size_t size) {
1269   TVMByteArray bytes;
1270   bytes.data = static_cast<const char*>(data);
1271   bytes.size = size;
1272   int64_t n = fsend_(bytes);
1273   if (n == -1) {
1274     common::Socket::Error("CallbackChannel::Send");
1275   }
1276   return static_cast<size_t>(n);
1277 }
1278 
Recv(void * data,size_t size)1279 size_t CallbackChannel::Recv(void* data, size_t size) {
1280   TVMRetValue ret = frecv_(size);
1281   if (ret.type_code() != kBytes) {
1282     common::Socket::Error("CallbackChannel::Recv");
1283   }
1284   std::string* bytes = ret.ptr<std::string>();
1285   memcpy(static_cast<char*>(data), bytes->c_str(), bytes->length());
1286   return bytes->length();
1287 }
1288 
1289 }  // namespace runtime
1290 }  // namespace tvm
1291