1 //===----------------- Client.cpp - Client Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // gRPC (Client) for the remote plugin.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <cmath>
14 
15 #include "Client.h"
16 #include "omptarget.h"
17 #include "openmp.pb.h"
18 
19 using namespace std::chrono;
20 
21 using grpc::ClientContext;
22 using grpc::ClientReader;
23 using grpc::ClientWriter;
24 using grpc::Status;
25 
26 template <typename Fn1, typename Fn2, typename TReturn>
remoteCall(Fn1 Preprocessor,Fn2 Postprocessor,TReturn ErrorValue,bool CanTimeOut)27 auto RemoteOffloadClient::remoteCall(Fn1 Preprocessor, Fn2 Postprocessor,
28                                      TReturn ErrorValue, bool CanTimeOut) {
29   ArenaAllocatorLock->lock();
30   if (Arena->SpaceAllocated() >= MaxSize)
31     Arena->Reset();
32   ArenaAllocatorLock->unlock();
33 
34   ClientContext Context;
35   if (CanTimeOut) {
36     auto Deadline =
37         std::chrono::system_clock::now() + std::chrono::seconds(Timeout);
38     Context.set_deadline(Deadline);
39   }
40 
41   Status RPCStatus;
42   auto Reply = Preprocessor(RPCStatus, Context);
43 
44   if (!RPCStatus.ok()) {
45     CLIENT_DBG("%s", RPCStatus.error_message().c_str())
46   } else {
47     return Postprocessor(Reply);
48   }
49 
50   CLIENT_DBG("Failed")
51   return ErrorValue;
52 }
53 
shutdown(void)54 int32_t RemoteOffloadClient::shutdown(void) {
55   ClientContext Context;
56   Null Request;
57   I32 Reply;
58   CLIENT_DBG("Shutting down server.")
59   auto Status = Stub->Shutdown(&Context, Request, &Reply);
60   if (Status.ok())
61     return Reply.number();
62   return 1;
63 }
64 
registerLib(__tgt_bin_desc * Desc)65 int32_t RemoteOffloadClient::registerLib(__tgt_bin_desc *Desc) {
66   return remoteCall(
67       /* Preprocessor */
68       [&](auto &RPCStatus, auto &Context) {
69         auto *Request = protobuf::Arena::CreateMessage<TargetBinaryDescription>(
70             Arena.get());
71         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
72         loadTargetBinaryDescription(Desc, *Request);
73         Request->set_bin_ptr((uint64_t)Desc);
74 
75         RPCStatus = Stub->RegisterLib(&Context, *Request, Reply);
76         return Reply;
77       },
78       /* Postprocessor */
79       [&](const auto &Reply) {
80         if (Reply->number() == 0) {
81           CLIENT_DBG("Registered library")
82           return 0;
83         }
84         return 1;
85       },
86       /* Error Value */ 1);
87 }
88 
unregisterLib(__tgt_bin_desc * Desc)89 int32_t RemoteOffloadClient::unregisterLib(__tgt_bin_desc *Desc) {
90   return remoteCall(
91       /* Preprocessor */
92       [&](auto &RPCStatus, auto &Context) {
93         auto *Request = protobuf::Arena::CreateMessage<Pointer>(Arena.get());
94         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
95 
96         Request->set_number((uint64_t)Desc);
97 
98         RPCStatus = Stub->UnregisterLib(&Context, *Request, Reply);
99         return Reply;
100       },
101       /* Postprocessor */
102       [&](const auto &Reply) {
103         if (Reply->number() == 0) {
104           CLIENT_DBG("Unregistered library")
105           return 0;
106         }
107         CLIENT_DBG("Failed to unregister library")
108         return 1;
109       },
110       /* Error Value */ 1);
111 }
112 
isValidBinary(__tgt_device_image * Image)113 int32_t RemoteOffloadClient::isValidBinary(__tgt_device_image *Image) {
114   return remoteCall(
115       /* Preprocessor */
116       [&](auto &RPCStatus, auto &Context) {
117         auto *Request =
118             protobuf::Arena::CreateMessage<TargetDeviceImagePtr>(Arena.get());
119         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
120 
121         Request->set_image_ptr((uint64_t)Image->ImageStart);
122 
123         auto *EntryItr = Image->EntriesBegin;
124         while (EntryItr != Image->EntriesEnd)
125           Request->add_entry_ptrs((uint64_t)EntryItr++);
126 
127         RPCStatus = Stub->IsValidBinary(&Context, *Request, Reply);
128         return Reply;
129       },
130       /* Postprocessor */
131       [&](const auto &Reply) {
132         if (Reply->number()) {
133           CLIENT_DBG("Validated binary")
134         } else {
135           CLIENT_DBG("Could not validate binary")
136         }
137         return Reply->number();
138       },
139       /* Error Value */ 0);
140 }
141 
getNumberOfDevices()142 int32_t RemoteOffloadClient::getNumberOfDevices() {
143   return remoteCall(
144       /* Preprocessor */
145       [&](Status &RPCStatus, ClientContext &Context) {
146         auto *Request = protobuf::Arena::CreateMessage<Null>(Arena.get());
147         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
148 
149         RPCStatus = Stub->GetNumberOfDevices(&Context, *Request, Reply);
150 
151         return Reply;
152       },
153       /* Postprocessor */
154       [&](const auto &Reply) {
155         if (Reply->number()) {
156           CLIENT_DBG("Found %d devices", Reply->number())
157         } else {
158           CLIENT_DBG("Could not get the number of devices")
159         }
160         return Reply->number();
161       },
162       /*Error Value*/ -1);
163 }
164 
initDevice(int32_t DeviceId)165 int32_t RemoteOffloadClient::initDevice(int32_t DeviceId) {
166   return remoteCall(
167       /* Preprocessor */
168       [&](auto &RPCStatus, auto &Context) {
169         auto *Request = protobuf::Arena::CreateMessage<I32>(Arena.get());
170         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
171 
172         Request->set_number(DeviceId);
173 
174         RPCStatus = Stub->InitDevice(&Context, *Request, Reply);
175 
176         return Reply;
177       },
178       /* Postprocessor */
179       [&](const auto &Reply) {
180         if (!Reply->number()) {
181           CLIENT_DBG("Initialized device %d", DeviceId)
182         } else {
183           CLIENT_DBG("Could not initialize device %d", DeviceId)
184         }
185         return Reply->number();
186       },
187       /* Error Value */ -1);
188 }
189 
initRequires(int64_t RequiresFlags)190 int32_t RemoteOffloadClient::initRequires(int64_t RequiresFlags) {
191   return remoteCall(
192       /* Preprocessor */
193       [&](auto &RPCStatus, auto &Context) {
194         auto *Request = protobuf::Arena::CreateMessage<I64>(Arena.get());
195         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
196         Request->set_number(RequiresFlags);
197         RPCStatus = Stub->InitRequires(&Context, *Request, Reply);
198         return Reply;
199       },
200       /* Postprocessor */
201       [&](const auto &Reply) {
202         if (Reply->number()) {
203           CLIENT_DBG("Initialized requires")
204         } else {
205           CLIENT_DBG("Could not initialize requires")
206         }
207         return Reply->number();
208       },
209       /* Error Value */ -1);
210 }
211 
loadBinary(int32_t DeviceId,__tgt_device_image * Image)212 __tgt_target_table *RemoteOffloadClient::loadBinary(int32_t DeviceId,
213                                                     __tgt_device_image *Image) {
214   return remoteCall(
215       /* Preprocessor */
216       [&](auto &RPCStatus, auto &Context) {
217         auto *ImageMessage =
218             protobuf::Arena::CreateMessage<Binary>(Arena.get());
219         auto *Reply = protobuf::Arena::CreateMessage<TargetTable>(Arena.get());
220         ImageMessage->set_image_ptr((uint64_t)Image->ImageStart);
221         ImageMessage->set_device_id(DeviceId);
222 
223         RPCStatus = Stub->LoadBinary(&Context, *ImageMessage, Reply);
224         return Reply;
225       },
226       /* Postprocessor */
227       [&](auto &Reply) {
228         if (Reply->entries_size() == 0) {
229           CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId)
230           return (__tgt_target_table *)nullptr;
231         }
232         DevicesToTables[DeviceId] = std::make_unique<__tgt_target_table>();
233         unloadTargetTable(*Reply, DevicesToTables[DeviceId].get(),
234                           RemoteEntries[DeviceId]);
235 
236         CLIENT_DBG("Loaded Image %p to device %d with %d entries", Image,
237                    DeviceId, Reply->entries_size())
238 
239         return DevicesToTables[DeviceId].get();
240       },
241       /* Error Value */ (__tgt_target_table *)nullptr,
242       /* CanTimeOut */ false);
243 }
244 
isDataExchangeable(int32_t SrcDevId,int32_t DstDevId)245 int32_t RemoteOffloadClient::isDataExchangeable(int32_t SrcDevId,
246                                                 int32_t DstDevId) {
247   return remoteCall(
248       /* Preprocessor */
249       [&](auto &RPCStatus, auto &Context) {
250         auto *Request = protobuf::Arena::CreateMessage<DevicePair>(Arena.get());
251         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
252 
253         Request->set_src_dev_id(SrcDevId);
254         Request->set_dst_dev_id(DstDevId);
255 
256         RPCStatus = Stub->IsDataExchangeable(&Context, *Request, Reply);
257         return Reply;
258       },
259       /* Postprocessor */
260       [&](auto &Reply) {
261         if (Reply->number()) {
262           CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId)
263         } else {
264           CLIENT_DBG("Data is not exchangeable between %d, %d", SrcDevId,
265                      DstDevId)
266         }
267         return Reply->number();
268       },
269       /* Error Value */ -1);
270 }
271 
dataAlloc(int32_t DeviceId,int64_t Size,void * HstPtr)272 void *RemoteOffloadClient::dataAlloc(int32_t DeviceId, int64_t Size,
273                                      void *HstPtr) {
274   return remoteCall(
275       /* Preprocessor */
276       [&](auto &RPCStatus, auto &Context) {
277         auto *Reply = protobuf::Arena::CreateMessage<Pointer>(Arena.get());
278         auto *Request = protobuf::Arena::CreateMessage<AllocData>(Arena.get());
279 
280         Request->set_device_id(DeviceId);
281         Request->set_size(Size);
282         Request->set_hst_ptr((uint64_t)HstPtr);
283 
284         RPCStatus = Stub->DataAlloc(&Context, *Request, Reply);
285         return Reply;
286       },
287       /* Postprocessor */
288       [&](auto &Reply) {
289         if (Reply->number()) {
290           CLIENT_DBG("Allocated %ld bytes on device %d at %p", Size, DeviceId,
291                      (void *)Reply->number())
292         } else {
293           CLIENT_DBG("Could not allocate %ld bytes on device %d at %p", Size,
294                      DeviceId, (void *)Reply->number())
295         }
296         return (void *)Reply->number();
297       },
298       /* Error Value */ (void *)nullptr);
299 }
300 
dataSubmit(int32_t DeviceId,void * TgtPtr,void * HstPtr,int64_t Size)301 int32_t RemoteOffloadClient::dataSubmit(int32_t DeviceId, void *TgtPtr,
302                                         void *HstPtr, int64_t Size) {
303 
304   return remoteCall(
305       /* Preprocessor */
306       [&](auto &RPCStatus, auto &Context) {
307         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
308         std::unique_ptr<ClientWriter<SubmitData>> Writer(
309             Stub->DataSubmit(&Context, Reply));
310 
311         if (Size > BlockSize) {
312           int64_t Start = 0, End = BlockSize;
313           for (auto I = 0; I < ceil((float)Size / BlockSize); I++) {
314             auto *Request =
315                 protobuf::Arena::CreateMessage<SubmitData>(Arena.get());
316 
317             Request->set_device_id(DeviceId);
318             Request->set_data((char *)HstPtr + Start, End - Start);
319             Request->set_hst_ptr((uint64_t)HstPtr);
320             Request->set_tgt_ptr((uint64_t)TgtPtr);
321             Request->set_start(Start);
322             Request->set_size(Size);
323 
324             if (!Writer->Write(*Request)) {
325               CLIENT_DBG("Broken stream when submitting data")
326               Reply->set_number(0);
327               return Reply;
328             }
329 
330             Start += BlockSize;
331             End += BlockSize;
332             if (End >= Size)
333               End = Size;
334           }
335         } else {
336           auto *Request =
337               protobuf::Arena::CreateMessage<SubmitData>(Arena.get());
338 
339           Request->set_device_id(DeviceId);
340           Request->set_data(HstPtr, Size);
341           Request->set_hst_ptr((uint64_t)HstPtr);
342           Request->set_tgt_ptr((uint64_t)TgtPtr);
343           Request->set_start(0);
344           Request->set_size(Size);
345 
346           if (!Writer->Write(*Request)) {
347             CLIENT_DBG("Broken stream when submitting data")
348             Reply->set_number(0);
349             return Reply;
350           }
351         }
352 
353         Writer->WritesDone();
354         RPCStatus = Writer->Finish();
355 
356         return Reply;
357       },
358       /* Postprocessor */
359       [&](auto &Reply) {
360         if (!Reply->number()) {
361           CLIENT_DBG(" submitted %ld bytes on device %d at %p", Size, DeviceId,
362                      TgtPtr)
363         } else {
364           CLIENT_DBG("Could not async submit %ld bytes on device %d at %p",
365                      Size, DeviceId, TgtPtr)
366         }
367         return Reply->number();
368       },
369       /* Error Value */ -1,
370       /* CanTimeOut */ false);
371 }
372 
dataRetrieve(int32_t DeviceId,void * HstPtr,void * TgtPtr,int64_t Size)373 int32_t RemoteOffloadClient::dataRetrieve(int32_t DeviceId, void *HstPtr,
374                                           void *TgtPtr, int64_t Size) {
375   return remoteCall(
376       /* Preprocessor */
377       [&](auto &RPCStatus, auto &Context) {
378         auto *Request =
379             protobuf::Arena::CreateMessage<RetrieveData>(Arena.get());
380 
381         Request->set_device_id(DeviceId);
382         Request->set_size(Size);
383         Request->set_hst_ptr((int64_t)HstPtr);
384         Request->set_tgt_ptr((int64_t)TgtPtr);
385 
386         auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
387         std::unique_ptr<ClientReader<Data>> Reader(
388             Stub->DataRetrieve(&Context, *Request));
389         Reader->WaitForInitialMetadata();
390         while (Reader->Read(Reply)) {
391           if (Reply->ret()) {
392             CLIENT_DBG("Could not async retrieve %ld bytes on device %d at %p "
393                        "for %p",
394                        Size, DeviceId, TgtPtr, HstPtr)
395             return Reply;
396           }
397 
398           if (Reply->start() == 0 && Reply->size() == Reply->data().size()) {
399             memcpy(HstPtr, Reply->data().data(), Reply->data().size());
400 
401             return Reply;
402           }
403 
404           memcpy((void *)((char *)HstPtr + Reply->start()),
405                  Reply->data().data(), Reply->data().size());
406         }
407         RPCStatus = Reader->Finish();
408 
409         return Reply;
410       },
411       /* Postprocessor */
412       [&](auto &Reply) {
413         if (!Reply->ret()) {
414           CLIENT_DBG("Retrieved %ld bytes on Device %d", Size, DeviceId)
415         } else {
416           CLIENT_DBG("Could not async retrieve %ld bytes on Device %d", Size,
417                      DeviceId)
418         }
419         return Reply->ret();
420       },
421       /* Error Value */ -1,
422       /* CanTimeOut */ false);
423 }
424 
dataExchange(int32_t SrcDevId,void * SrcPtr,int32_t DstDevId,void * DstPtr,int64_t Size)425 int32_t RemoteOffloadClient::dataExchange(int32_t SrcDevId, void *SrcPtr,
426                                           int32_t DstDevId, void *DstPtr,
427                                           int64_t Size) {
428   return remoteCall(
429       /* Preprocessor */
430       [&](auto &RPCStatus, auto &Context) {
431         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
432         auto *Request =
433             protobuf::Arena::CreateMessage<ExchangeData>(Arena.get());
434 
435         Request->set_src_dev_id(SrcDevId);
436         Request->set_src_ptr((uint64_t)SrcPtr);
437         Request->set_dst_dev_id(DstDevId);
438         Request->set_dst_ptr((uint64_t)DstPtr);
439         Request->set_size(Size);
440 
441         RPCStatus = Stub->DataExchange(&Context, *Request, Reply);
442         return Reply;
443       },
444       /* Postprocessor */
445       [&](auto &Reply) {
446         if (Reply->number()) {
447           CLIENT_DBG(
448               "Exchanged %ld bytes on device %d at %p for %p on device %d",
449               Size, SrcDevId, SrcPtr, DstPtr, DstDevId)
450         } else {
451           CLIENT_DBG("Could not exchange %ld bytes on device %d at %p for %p "
452                      "on device %d",
453                      Size, SrcDevId, SrcPtr, DstPtr, DstDevId)
454         }
455         return Reply->number();
456       },
457       /* Error Value */ -1);
458 }
459 
dataDelete(int32_t DeviceId,void * TgtPtr)460 int32_t RemoteOffloadClient::dataDelete(int32_t DeviceId, void *TgtPtr) {
461   return remoteCall(
462       /* Preprocessor */
463       [&](auto &RPCStatus, auto &Context) {
464         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
465         auto *Request = protobuf::Arena::CreateMessage<DeleteData>(Arena.get());
466 
467         Request->set_device_id(DeviceId);
468         Request->set_tgt_ptr((uint64_t)TgtPtr);
469 
470         RPCStatus = Stub->DataDelete(&Context, *Request, Reply);
471         return Reply;
472       },
473       /* Postprocessor */
474       [&](auto &Reply) {
475         if (!Reply->number()) {
476           CLIENT_DBG("Deleted data at %p on device %d", TgtPtr, DeviceId)
477         } else {
478           CLIENT_DBG("Could not delete data at %p on device %d", TgtPtr,
479                      DeviceId)
480         }
481         return Reply->number();
482       },
483       /* Error Value */ -1);
484 }
485 
runTargetRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum)486 int32_t RemoteOffloadClient::runTargetRegion(int32_t DeviceId,
487                                              void *TgtEntryPtr, void **TgtArgs,
488                                              ptrdiff_t *TgtOffsets,
489                                              int32_t ArgNum) {
490   return remoteCall(
491       /* Preprocessor */
492       [&](auto &RPCStatus, auto &Context) {
493         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
494         auto *Request =
495             protobuf::Arena::CreateMessage<TargetRegion>(Arena.get());
496 
497         Request->set_device_id(DeviceId);
498 
499         Request->set_tgt_entry_ptr(
500             (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]);
501 
502         char **ArgPtr = (char **)TgtArgs;
503         for (auto I = 0; I < ArgNum; I++, ArgPtr++)
504           Request->add_tgt_args((uint64_t)*ArgPtr);
505 
506         char *OffsetPtr = (char *)TgtOffsets;
507         for (auto I = 0; I < ArgNum; I++, OffsetPtr++)
508           Request->add_tgt_offsets((uint64_t)*OffsetPtr);
509 
510         Request->set_arg_num(ArgNum);
511 
512         RPCStatus = Stub->RunTargetRegion(&Context, *Request, Reply);
513         return Reply;
514       },
515       /* Postprocessor */
516       [&](auto &Reply) {
517         if (!Reply->number()) {
518           CLIENT_DBG("Ran target region async on device %d", DeviceId)
519         } else {
520           CLIENT_DBG("Could not run target region async on device %d", DeviceId)
521         }
522         return Reply->number();
523       },
524       /* Error Value */ -1,
525       /* CanTimeOut */ false);
526 }
527 
runTargetTeamRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum,int32_t TeamNum,int32_t ThreadLimit,uint64_t LoopTripcount)528 int32_t RemoteOffloadClient::runTargetTeamRegion(
529     int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets,
530     int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit,
531     uint64_t LoopTripcount) {
532   return remoteCall(
533       /* Preprocessor */
534       [&](auto &RPCStatus, auto &Context) {
535         auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
536         auto *Request =
537             protobuf::Arena::CreateMessage<TargetTeamRegion>(Arena.get());
538 
539         Request->set_device_id(DeviceId);
540 
541         Request->set_tgt_entry_ptr(
542             (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]);
543 
544         char **ArgPtr = (char **)TgtArgs;
545         for (auto I = 0; I < ArgNum; I++, ArgPtr++) {
546           Request->add_tgt_args((uint64_t)*ArgPtr);
547         }
548 
549         char *OffsetPtr = (char *)TgtOffsets;
550         for (auto I = 0; I < ArgNum; I++, OffsetPtr++)
551           Request->add_tgt_offsets((uint64_t)*OffsetPtr);
552 
553         Request->set_arg_num(ArgNum);
554         Request->set_team_num(TeamNum);
555         Request->set_thread_limit(ThreadLimit);
556         Request->set_loop_tripcount(LoopTripcount);
557 
558         RPCStatus = Stub->RunTargetTeamRegion(&Context, *Request, Reply);
559         return Reply;
560       },
561       /* Postprocessor */
562       [&](auto &Reply) {
563         if (!Reply->number()) {
564           CLIENT_DBG("Ran target team region async on device %d", DeviceId)
565         } else {
566           CLIENT_DBG("Could not run target team region async on device %d",
567                      DeviceId)
568         }
569         return Reply->number();
570       },
571       /* Error Value */ -1,
572       /* CanTimeOut */ false);
573 }
574 
shutdown(void)575 int32_t RemoteClientManager::shutdown(void) {
576   int32_t Ret = 0;
577   for (auto &Client : Clients)
578     Ret &= Client.shutdown();
579   return Ret;
580 }
581 
registerLib(__tgt_bin_desc * Desc)582 int32_t RemoteClientManager::registerLib(__tgt_bin_desc *Desc) {
583   int32_t Ret = 0;
584   for (auto &Client : Clients)
585     Ret &= Client.registerLib(Desc);
586   return Ret;
587 }
588 
unregisterLib(__tgt_bin_desc * Desc)589 int32_t RemoteClientManager::unregisterLib(__tgt_bin_desc *Desc) {
590   int32_t Ret = 0;
591   for (auto &Client : Clients)
592     Ret &= Client.unregisterLib(Desc);
593   return Ret;
594 }
595 
isValidBinary(__tgt_device_image * Image)596 int32_t RemoteClientManager::isValidBinary(__tgt_device_image *Image) {
597   int32_t ClientIdx = 0;
598   for (auto &Client : Clients) {
599     if (auto Ret = Client.isValidBinary(Image))
600       return Ret;
601     ClientIdx++;
602   }
603   return 0;
604 }
605 
getNumberOfDevices()606 int32_t RemoteClientManager::getNumberOfDevices() {
607   auto ClientIdx = 0;
608   for (auto &Client : Clients) {
609     if (auto NumDevices = Client.getNumberOfDevices()) {
610       Devices.push_back(NumDevices);
611     }
612     ClientIdx++;
613   }
614 
615   return std::accumulate(Devices.begin(), Devices.end(), 0);
616 }
617 
mapDeviceId(int32_t DeviceId)618 std::pair<int32_t, int32_t> RemoteClientManager::mapDeviceId(int32_t DeviceId) {
619   for (size_t ClientIdx = 0; ClientIdx < Devices.size(); ClientIdx++) {
620     if (DeviceId < Devices[ClientIdx])
621       return {ClientIdx, DeviceId};
622     DeviceId -= Devices[ClientIdx];
623   }
624   return {-1, -1};
625 }
626 
initDevice(int32_t DeviceId)627 int32_t RemoteClientManager::initDevice(int32_t DeviceId) {
628   int32_t ClientIdx, DeviceIdx;
629   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
630   return Clients[ClientIdx].initDevice(DeviceIdx);
631 }
632 
initRequires(int64_t RequiresFlags)633 int32_t RemoteClientManager::initRequires(int64_t RequiresFlags) {
634   for (auto &Client : Clients)
635     Client.initRequires(RequiresFlags);
636 
637   return RequiresFlags;
638 }
639 
loadBinary(int32_t DeviceId,__tgt_device_image * Image)640 __tgt_target_table *RemoteClientManager::loadBinary(int32_t DeviceId,
641                                                     __tgt_device_image *Image) {
642   int32_t ClientIdx, DeviceIdx;
643   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
644   return Clients[ClientIdx].loadBinary(DeviceIdx, Image);
645 }
646 
isDataExchangeable(int32_t SrcDevId,int32_t DstDevId)647 int32_t RemoteClientManager::isDataExchangeable(int32_t SrcDevId,
648                                                 int32_t DstDevId) {
649   int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx;
650   std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId);
651   std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId);
652   return Clients[SrcClientIdx].isDataExchangeable(SrcDeviceIdx, DstDeviceIdx);
653 }
654 
dataAlloc(int32_t DeviceId,int64_t Size,void * HstPtr)655 void *RemoteClientManager::dataAlloc(int32_t DeviceId, int64_t Size,
656                                      void *HstPtr) {
657   int32_t ClientIdx, DeviceIdx;
658   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
659   return Clients[ClientIdx].dataAlloc(DeviceIdx, Size, HstPtr);
660 }
661 
dataDelete(int32_t DeviceId,void * TgtPtr)662 int32_t RemoteClientManager::dataDelete(int32_t DeviceId, void *TgtPtr) {
663   int32_t ClientIdx, DeviceIdx;
664   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
665   return Clients[ClientIdx].dataDelete(DeviceIdx, TgtPtr);
666 }
667 
dataSubmit(int32_t DeviceId,void * TgtPtr,void * HstPtr,int64_t Size)668 int32_t RemoteClientManager::dataSubmit(int32_t DeviceId, void *TgtPtr,
669                                         void *HstPtr, int64_t Size) {
670   int32_t ClientIdx, DeviceIdx;
671   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
672   return Clients[ClientIdx].dataSubmit(DeviceIdx, TgtPtr, HstPtr, Size);
673 }
674 
dataRetrieve(int32_t DeviceId,void * HstPtr,void * TgtPtr,int64_t Size)675 int32_t RemoteClientManager::dataRetrieve(int32_t DeviceId, void *HstPtr,
676                                           void *TgtPtr, int64_t Size) {
677   int32_t ClientIdx, DeviceIdx;
678   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
679   return Clients[ClientIdx].dataRetrieve(DeviceIdx, HstPtr, TgtPtr, Size);
680 }
681 
dataExchange(int32_t SrcDevId,void * SrcPtr,int32_t DstDevId,void * DstPtr,int64_t Size)682 int32_t RemoteClientManager::dataExchange(int32_t SrcDevId, void *SrcPtr,
683                                           int32_t DstDevId, void *DstPtr,
684                                           int64_t Size) {
685   int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx;
686   std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId);
687   std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId);
688   return Clients[SrcClientIdx].dataExchange(SrcDeviceIdx, SrcPtr, DstDeviceIdx,
689                                             DstPtr, Size);
690 }
691 
runTargetRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum)692 int32_t RemoteClientManager::runTargetRegion(int32_t DeviceId,
693                                              void *TgtEntryPtr, void **TgtArgs,
694                                              ptrdiff_t *TgtOffsets,
695                                              int32_t ArgNum) {
696   int32_t ClientIdx, DeviceIdx;
697   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
698   return Clients[ClientIdx].runTargetRegion(DeviceIdx, TgtEntryPtr, TgtArgs,
699                                             TgtOffsets, ArgNum);
700 }
701 
runTargetTeamRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum,int32_t TeamNum,int32_t ThreadLimit,uint64_t LoopTripCount)702 int32_t RemoteClientManager::runTargetTeamRegion(
703     int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets,
704     int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit,
705     uint64_t LoopTripCount) {
706   int32_t ClientIdx, DeviceIdx;
707   std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
708   return Clients[ClientIdx].runTargetTeamRegion(DeviceIdx, TgtEntryPtr, TgtArgs,
709                                                 TgtOffsets, ArgNum, TeamNum,
710                                                 ThreadLimit, LoopTripCount);
711 }
712