1 //===----------------- Client.cpp - Client Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // gRPC (Client) for the remote plugin.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include <cmath>
14
15 #include "Client.h"
16 #include "omptarget.h"
17 #include "openmp.pb.h"
18
19 using namespace std::chrono;
20
21 using grpc::ClientContext;
22 using grpc::ClientReader;
23 using grpc::ClientWriter;
24 using grpc::Status;
25
26 template <typename Fn1, typename Fn2, typename TReturn>
remoteCall(Fn1 Preprocessor,Fn2 Postprocessor,TReturn ErrorValue,bool CanTimeOut)27 auto RemoteOffloadClient::remoteCall(Fn1 Preprocessor, Fn2 Postprocessor,
28 TReturn ErrorValue, bool CanTimeOut) {
29 ArenaAllocatorLock->lock();
30 if (Arena->SpaceAllocated() >= MaxSize)
31 Arena->Reset();
32 ArenaAllocatorLock->unlock();
33
34 ClientContext Context;
35 if (CanTimeOut) {
36 auto Deadline =
37 std::chrono::system_clock::now() + std::chrono::seconds(Timeout);
38 Context.set_deadline(Deadline);
39 }
40
41 Status RPCStatus;
42 auto Reply = Preprocessor(RPCStatus, Context);
43
44 if (!RPCStatus.ok()) {
45 CLIENT_DBG("%s", RPCStatus.error_message().c_str())
46 } else {
47 return Postprocessor(Reply);
48 }
49
50 CLIENT_DBG("Failed")
51 return ErrorValue;
52 }
53
shutdown(void)54 int32_t RemoteOffloadClient::shutdown(void) {
55 ClientContext Context;
56 Null Request;
57 I32 Reply;
58 CLIENT_DBG("Shutting down server.")
59 auto Status = Stub->Shutdown(&Context, Request, &Reply);
60 if (Status.ok())
61 return Reply.number();
62 return 1;
63 }
64
registerLib(__tgt_bin_desc * Desc)65 int32_t RemoteOffloadClient::registerLib(__tgt_bin_desc *Desc) {
66 return remoteCall(
67 /* Preprocessor */
68 [&](auto &RPCStatus, auto &Context) {
69 auto *Request = protobuf::Arena::CreateMessage<TargetBinaryDescription>(
70 Arena.get());
71 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
72 loadTargetBinaryDescription(Desc, *Request);
73 Request->set_bin_ptr((uint64_t)Desc);
74
75 RPCStatus = Stub->RegisterLib(&Context, *Request, Reply);
76 return Reply;
77 },
78 /* Postprocessor */
79 [&](const auto &Reply) {
80 if (Reply->number() == 0) {
81 CLIENT_DBG("Registered library")
82 return 0;
83 }
84 return 1;
85 },
86 /* Error Value */ 1);
87 }
88
unregisterLib(__tgt_bin_desc * Desc)89 int32_t RemoteOffloadClient::unregisterLib(__tgt_bin_desc *Desc) {
90 return remoteCall(
91 /* Preprocessor */
92 [&](auto &RPCStatus, auto &Context) {
93 auto *Request = protobuf::Arena::CreateMessage<Pointer>(Arena.get());
94 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
95
96 Request->set_number((uint64_t)Desc);
97
98 RPCStatus = Stub->UnregisterLib(&Context, *Request, Reply);
99 return Reply;
100 },
101 /* Postprocessor */
102 [&](const auto &Reply) {
103 if (Reply->number() == 0) {
104 CLIENT_DBG("Unregistered library")
105 return 0;
106 }
107 CLIENT_DBG("Failed to unregister library")
108 return 1;
109 },
110 /* Error Value */ 1);
111 }
112
isValidBinary(__tgt_device_image * Image)113 int32_t RemoteOffloadClient::isValidBinary(__tgt_device_image *Image) {
114 return remoteCall(
115 /* Preprocessor */
116 [&](auto &RPCStatus, auto &Context) {
117 auto *Request =
118 protobuf::Arena::CreateMessage<TargetDeviceImagePtr>(Arena.get());
119 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
120
121 Request->set_image_ptr((uint64_t)Image->ImageStart);
122
123 auto *EntryItr = Image->EntriesBegin;
124 while (EntryItr != Image->EntriesEnd)
125 Request->add_entry_ptrs((uint64_t)EntryItr++);
126
127 RPCStatus = Stub->IsValidBinary(&Context, *Request, Reply);
128 return Reply;
129 },
130 /* Postprocessor */
131 [&](const auto &Reply) {
132 if (Reply->number()) {
133 CLIENT_DBG("Validated binary")
134 } else {
135 CLIENT_DBG("Could not validate binary")
136 }
137 return Reply->number();
138 },
139 /* Error Value */ 0);
140 }
141
getNumberOfDevices()142 int32_t RemoteOffloadClient::getNumberOfDevices() {
143 return remoteCall(
144 /* Preprocessor */
145 [&](Status &RPCStatus, ClientContext &Context) {
146 auto *Request = protobuf::Arena::CreateMessage<Null>(Arena.get());
147 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
148
149 RPCStatus = Stub->GetNumberOfDevices(&Context, *Request, Reply);
150
151 return Reply;
152 },
153 /* Postprocessor */
154 [&](const auto &Reply) {
155 if (Reply->number()) {
156 CLIENT_DBG("Found %d devices", Reply->number())
157 } else {
158 CLIENT_DBG("Could not get the number of devices")
159 }
160 return Reply->number();
161 },
162 /*Error Value*/ -1);
163 }
164
initDevice(int32_t DeviceId)165 int32_t RemoteOffloadClient::initDevice(int32_t DeviceId) {
166 return remoteCall(
167 /* Preprocessor */
168 [&](auto &RPCStatus, auto &Context) {
169 auto *Request = protobuf::Arena::CreateMessage<I32>(Arena.get());
170 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
171
172 Request->set_number(DeviceId);
173
174 RPCStatus = Stub->InitDevice(&Context, *Request, Reply);
175
176 return Reply;
177 },
178 /* Postprocessor */
179 [&](const auto &Reply) {
180 if (!Reply->number()) {
181 CLIENT_DBG("Initialized device %d", DeviceId)
182 } else {
183 CLIENT_DBG("Could not initialize device %d", DeviceId)
184 }
185 return Reply->number();
186 },
187 /* Error Value */ -1);
188 }
189
initRequires(int64_t RequiresFlags)190 int32_t RemoteOffloadClient::initRequires(int64_t RequiresFlags) {
191 return remoteCall(
192 /* Preprocessor */
193 [&](auto &RPCStatus, auto &Context) {
194 auto *Request = protobuf::Arena::CreateMessage<I64>(Arena.get());
195 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
196 Request->set_number(RequiresFlags);
197 RPCStatus = Stub->InitRequires(&Context, *Request, Reply);
198 return Reply;
199 },
200 /* Postprocessor */
201 [&](const auto &Reply) {
202 if (Reply->number()) {
203 CLIENT_DBG("Initialized requires")
204 } else {
205 CLIENT_DBG("Could not initialize requires")
206 }
207 return Reply->number();
208 },
209 /* Error Value */ -1);
210 }
211
loadBinary(int32_t DeviceId,__tgt_device_image * Image)212 __tgt_target_table *RemoteOffloadClient::loadBinary(int32_t DeviceId,
213 __tgt_device_image *Image) {
214 return remoteCall(
215 /* Preprocessor */
216 [&](auto &RPCStatus, auto &Context) {
217 auto *ImageMessage =
218 protobuf::Arena::CreateMessage<Binary>(Arena.get());
219 auto *Reply = protobuf::Arena::CreateMessage<TargetTable>(Arena.get());
220 ImageMessage->set_image_ptr((uint64_t)Image->ImageStart);
221 ImageMessage->set_device_id(DeviceId);
222
223 RPCStatus = Stub->LoadBinary(&Context, *ImageMessage, Reply);
224 return Reply;
225 },
226 /* Postprocessor */
227 [&](auto &Reply) {
228 if (Reply->entries_size() == 0) {
229 CLIENT_DBG("Could not load image %p onto device %d", Image, DeviceId)
230 return (__tgt_target_table *)nullptr;
231 }
232 DevicesToTables[DeviceId] = std::make_unique<__tgt_target_table>();
233 unloadTargetTable(*Reply, DevicesToTables[DeviceId].get(),
234 RemoteEntries[DeviceId]);
235
236 CLIENT_DBG("Loaded Image %p to device %d with %d entries", Image,
237 DeviceId, Reply->entries_size())
238
239 return DevicesToTables[DeviceId].get();
240 },
241 /* Error Value */ (__tgt_target_table *)nullptr,
242 /* CanTimeOut */ false);
243 }
244
isDataExchangeable(int32_t SrcDevId,int32_t DstDevId)245 int32_t RemoteOffloadClient::isDataExchangeable(int32_t SrcDevId,
246 int32_t DstDevId) {
247 return remoteCall(
248 /* Preprocessor */
249 [&](auto &RPCStatus, auto &Context) {
250 auto *Request = protobuf::Arena::CreateMessage<DevicePair>(Arena.get());
251 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
252
253 Request->set_src_dev_id(SrcDevId);
254 Request->set_dst_dev_id(DstDevId);
255
256 RPCStatus = Stub->IsDataExchangeable(&Context, *Request, Reply);
257 return Reply;
258 },
259 /* Postprocessor */
260 [&](auto &Reply) {
261 if (Reply->number()) {
262 CLIENT_DBG("Data is exchangeable between %d, %d", SrcDevId, DstDevId)
263 } else {
264 CLIENT_DBG("Data is not exchangeable between %d, %d", SrcDevId,
265 DstDevId)
266 }
267 return Reply->number();
268 },
269 /* Error Value */ -1);
270 }
271
dataAlloc(int32_t DeviceId,int64_t Size,void * HstPtr)272 void *RemoteOffloadClient::dataAlloc(int32_t DeviceId, int64_t Size,
273 void *HstPtr) {
274 return remoteCall(
275 /* Preprocessor */
276 [&](auto &RPCStatus, auto &Context) {
277 auto *Reply = protobuf::Arena::CreateMessage<Pointer>(Arena.get());
278 auto *Request = protobuf::Arena::CreateMessage<AllocData>(Arena.get());
279
280 Request->set_device_id(DeviceId);
281 Request->set_size(Size);
282 Request->set_hst_ptr((uint64_t)HstPtr);
283
284 RPCStatus = Stub->DataAlloc(&Context, *Request, Reply);
285 return Reply;
286 },
287 /* Postprocessor */
288 [&](auto &Reply) {
289 if (Reply->number()) {
290 CLIENT_DBG("Allocated %ld bytes on device %d at %p", Size, DeviceId,
291 (void *)Reply->number())
292 } else {
293 CLIENT_DBG("Could not allocate %ld bytes on device %d at %p", Size,
294 DeviceId, (void *)Reply->number())
295 }
296 return (void *)Reply->number();
297 },
298 /* Error Value */ (void *)nullptr);
299 }
300
dataSubmit(int32_t DeviceId,void * TgtPtr,void * HstPtr,int64_t Size)301 int32_t RemoteOffloadClient::dataSubmit(int32_t DeviceId, void *TgtPtr,
302 void *HstPtr, int64_t Size) {
303
304 return remoteCall(
305 /* Preprocessor */
306 [&](auto &RPCStatus, auto &Context) {
307 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
308 std::unique_ptr<ClientWriter<SubmitData>> Writer(
309 Stub->DataSubmit(&Context, Reply));
310
311 if (Size > BlockSize) {
312 int64_t Start = 0, End = BlockSize;
313 for (auto I = 0; I < ceil((float)Size / BlockSize); I++) {
314 auto *Request =
315 protobuf::Arena::CreateMessage<SubmitData>(Arena.get());
316
317 Request->set_device_id(DeviceId);
318 Request->set_data((char *)HstPtr + Start, End - Start);
319 Request->set_hst_ptr((uint64_t)HstPtr);
320 Request->set_tgt_ptr((uint64_t)TgtPtr);
321 Request->set_start(Start);
322 Request->set_size(Size);
323
324 if (!Writer->Write(*Request)) {
325 CLIENT_DBG("Broken stream when submitting data")
326 Reply->set_number(0);
327 return Reply;
328 }
329
330 Start += BlockSize;
331 End += BlockSize;
332 if (End >= Size)
333 End = Size;
334 }
335 } else {
336 auto *Request =
337 protobuf::Arena::CreateMessage<SubmitData>(Arena.get());
338
339 Request->set_device_id(DeviceId);
340 Request->set_data(HstPtr, Size);
341 Request->set_hst_ptr((uint64_t)HstPtr);
342 Request->set_tgt_ptr((uint64_t)TgtPtr);
343 Request->set_start(0);
344 Request->set_size(Size);
345
346 if (!Writer->Write(*Request)) {
347 CLIENT_DBG("Broken stream when submitting data")
348 Reply->set_number(0);
349 return Reply;
350 }
351 }
352
353 Writer->WritesDone();
354 RPCStatus = Writer->Finish();
355
356 return Reply;
357 },
358 /* Postprocessor */
359 [&](auto &Reply) {
360 if (!Reply->number()) {
361 CLIENT_DBG(" submitted %ld bytes on device %d at %p", Size, DeviceId,
362 TgtPtr)
363 } else {
364 CLIENT_DBG("Could not async submit %ld bytes on device %d at %p",
365 Size, DeviceId, TgtPtr)
366 }
367 return Reply->number();
368 },
369 /* Error Value */ -1,
370 /* CanTimeOut */ false);
371 }
372
dataRetrieve(int32_t DeviceId,void * HstPtr,void * TgtPtr,int64_t Size)373 int32_t RemoteOffloadClient::dataRetrieve(int32_t DeviceId, void *HstPtr,
374 void *TgtPtr, int64_t Size) {
375 return remoteCall(
376 /* Preprocessor */
377 [&](auto &RPCStatus, auto &Context) {
378 auto *Request =
379 protobuf::Arena::CreateMessage<RetrieveData>(Arena.get());
380
381 Request->set_device_id(DeviceId);
382 Request->set_size(Size);
383 Request->set_hst_ptr((int64_t)HstPtr);
384 Request->set_tgt_ptr((int64_t)TgtPtr);
385
386 auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
387 std::unique_ptr<ClientReader<Data>> Reader(
388 Stub->DataRetrieve(&Context, *Request));
389 Reader->WaitForInitialMetadata();
390 while (Reader->Read(Reply)) {
391 if (Reply->ret()) {
392 CLIENT_DBG("Could not async retrieve %ld bytes on device %d at %p "
393 "for %p",
394 Size, DeviceId, TgtPtr, HstPtr)
395 return Reply;
396 }
397
398 if (Reply->start() == 0 && Reply->size() == Reply->data().size()) {
399 memcpy(HstPtr, Reply->data().data(), Reply->data().size());
400
401 return Reply;
402 }
403
404 memcpy((void *)((char *)HstPtr + Reply->start()),
405 Reply->data().data(), Reply->data().size());
406 }
407 RPCStatus = Reader->Finish();
408
409 return Reply;
410 },
411 /* Postprocessor */
412 [&](auto &Reply) {
413 if (!Reply->ret()) {
414 CLIENT_DBG("Retrieved %ld bytes on Device %d", Size, DeviceId)
415 } else {
416 CLIENT_DBG("Could not async retrieve %ld bytes on Device %d", Size,
417 DeviceId)
418 }
419 return Reply->ret();
420 },
421 /* Error Value */ -1,
422 /* CanTimeOut */ false);
423 }
424
dataExchange(int32_t SrcDevId,void * SrcPtr,int32_t DstDevId,void * DstPtr,int64_t Size)425 int32_t RemoteOffloadClient::dataExchange(int32_t SrcDevId, void *SrcPtr,
426 int32_t DstDevId, void *DstPtr,
427 int64_t Size) {
428 return remoteCall(
429 /* Preprocessor */
430 [&](auto &RPCStatus, auto &Context) {
431 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
432 auto *Request =
433 protobuf::Arena::CreateMessage<ExchangeData>(Arena.get());
434
435 Request->set_src_dev_id(SrcDevId);
436 Request->set_src_ptr((uint64_t)SrcPtr);
437 Request->set_dst_dev_id(DstDevId);
438 Request->set_dst_ptr((uint64_t)DstPtr);
439 Request->set_size(Size);
440
441 RPCStatus = Stub->DataExchange(&Context, *Request, Reply);
442 return Reply;
443 },
444 /* Postprocessor */
445 [&](auto &Reply) {
446 if (Reply->number()) {
447 CLIENT_DBG(
448 "Exchanged %ld bytes on device %d at %p for %p on device %d",
449 Size, SrcDevId, SrcPtr, DstPtr, DstDevId)
450 } else {
451 CLIENT_DBG("Could not exchange %ld bytes on device %d at %p for %p "
452 "on device %d",
453 Size, SrcDevId, SrcPtr, DstPtr, DstDevId)
454 }
455 return Reply->number();
456 },
457 /* Error Value */ -1);
458 }
459
dataDelete(int32_t DeviceId,void * TgtPtr)460 int32_t RemoteOffloadClient::dataDelete(int32_t DeviceId, void *TgtPtr) {
461 return remoteCall(
462 /* Preprocessor */
463 [&](auto &RPCStatus, auto &Context) {
464 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
465 auto *Request = protobuf::Arena::CreateMessage<DeleteData>(Arena.get());
466
467 Request->set_device_id(DeviceId);
468 Request->set_tgt_ptr((uint64_t)TgtPtr);
469
470 RPCStatus = Stub->DataDelete(&Context, *Request, Reply);
471 return Reply;
472 },
473 /* Postprocessor */
474 [&](auto &Reply) {
475 if (!Reply->number()) {
476 CLIENT_DBG("Deleted data at %p on device %d", TgtPtr, DeviceId)
477 } else {
478 CLIENT_DBG("Could not delete data at %p on device %d", TgtPtr,
479 DeviceId)
480 }
481 return Reply->number();
482 },
483 /* Error Value */ -1);
484 }
485
runTargetRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum)486 int32_t RemoteOffloadClient::runTargetRegion(int32_t DeviceId,
487 void *TgtEntryPtr, void **TgtArgs,
488 ptrdiff_t *TgtOffsets,
489 int32_t ArgNum) {
490 return remoteCall(
491 /* Preprocessor */
492 [&](auto &RPCStatus, auto &Context) {
493 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
494 auto *Request =
495 protobuf::Arena::CreateMessage<TargetRegion>(Arena.get());
496
497 Request->set_device_id(DeviceId);
498
499 Request->set_tgt_entry_ptr(
500 (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]);
501
502 char **ArgPtr = (char **)TgtArgs;
503 for (auto I = 0; I < ArgNum; I++, ArgPtr++)
504 Request->add_tgt_args((uint64_t)*ArgPtr);
505
506 char *OffsetPtr = (char *)TgtOffsets;
507 for (auto I = 0; I < ArgNum; I++, OffsetPtr++)
508 Request->add_tgt_offsets((uint64_t)*OffsetPtr);
509
510 Request->set_arg_num(ArgNum);
511
512 RPCStatus = Stub->RunTargetRegion(&Context, *Request, Reply);
513 return Reply;
514 },
515 /* Postprocessor */
516 [&](auto &Reply) {
517 if (!Reply->number()) {
518 CLIENT_DBG("Ran target region async on device %d", DeviceId)
519 } else {
520 CLIENT_DBG("Could not run target region async on device %d", DeviceId)
521 }
522 return Reply->number();
523 },
524 /* Error Value */ -1,
525 /* CanTimeOut */ false);
526 }
527
runTargetTeamRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum,int32_t TeamNum,int32_t ThreadLimit,uint64_t LoopTripcount)528 int32_t RemoteOffloadClient::runTargetTeamRegion(
529 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets,
530 int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit,
531 uint64_t LoopTripcount) {
532 return remoteCall(
533 /* Preprocessor */
534 [&](auto &RPCStatus, auto &Context) {
535 auto *Reply = protobuf::Arena::CreateMessage<I32>(Arena.get());
536 auto *Request =
537 protobuf::Arena::CreateMessage<TargetTeamRegion>(Arena.get());
538
539 Request->set_device_id(DeviceId);
540
541 Request->set_tgt_entry_ptr(
542 (uint64_t)RemoteEntries[DeviceId][TgtEntryPtr]);
543
544 char **ArgPtr = (char **)TgtArgs;
545 for (auto I = 0; I < ArgNum; I++, ArgPtr++) {
546 Request->add_tgt_args((uint64_t)*ArgPtr);
547 }
548
549 char *OffsetPtr = (char *)TgtOffsets;
550 for (auto I = 0; I < ArgNum; I++, OffsetPtr++)
551 Request->add_tgt_offsets((uint64_t)*OffsetPtr);
552
553 Request->set_arg_num(ArgNum);
554 Request->set_team_num(TeamNum);
555 Request->set_thread_limit(ThreadLimit);
556 Request->set_loop_tripcount(LoopTripcount);
557
558 RPCStatus = Stub->RunTargetTeamRegion(&Context, *Request, Reply);
559 return Reply;
560 },
561 /* Postprocessor */
562 [&](auto &Reply) {
563 if (!Reply->number()) {
564 CLIENT_DBG("Ran target team region async on device %d", DeviceId)
565 } else {
566 CLIENT_DBG("Could not run target team region async on device %d",
567 DeviceId)
568 }
569 return Reply->number();
570 },
571 /* Error Value */ -1,
572 /* CanTimeOut */ false);
573 }
574
shutdown(void)575 int32_t RemoteClientManager::shutdown(void) {
576 int32_t Ret = 0;
577 for (auto &Client : Clients)
578 Ret &= Client.shutdown();
579 return Ret;
580 }
581
registerLib(__tgt_bin_desc * Desc)582 int32_t RemoteClientManager::registerLib(__tgt_bin_desc *Desc) {
583 int32_t Ret = 0;
584 for (auto &Client : Clients)
585 Ret &= Client.registerLib(Desc);
586 return Ret;
587 }
588
unregisterLib(__tgt_bin_desc * Desc)589 int32_t RemoteClientManager::unregisterLib(__tgt_bin_desc *Desc) {
590 int32_t Ret = 0;
591 for (auto &Client : Clients)
592 Ret &= Client.unregisterLib(Desc);
593 return Ret;
594 }
595
isValidBinary(__tgt_device_image * Image)596 int32_t RemoteClientManager::isValidBinary(__tgt_device_image *Image) {
597 int32_t ClientIdx = 0;
598 for (auto &Client : Clients) {
599 if (auto Ret = Client.isValidBinary(Image))
600 return Ret;
601 ClientIdx++;
602 }
603 return 0;
604 }
605
getNumberOfDevices()606 int32_t RemoteClientManager::getNumberOfDevices() {
607 auto ClientIdx = 0;
608 for (auto &Client : Clients) {
609 if (auto NumDevices = Client.getNumberOfDevices()) {
610 Devices.push_back(NumDevices);
611 }
612 ClientIdx++;
613 }
614
615 return std::accumulate(Devices.begin(), Devices.end(), 0);
616 }
617
mapDeviceId(int32_t DeviceId)618 std::pair<int32_t, int32_t> RemoteClientManager::mapDeviceId(int32_t DeviceId) {
619 for (size_t ClientIdx = 0; ClientIdx < Devices.size(); ClientIdx++) {
620 if (DeviceId < Devices[ClientIdx])
621 return {ClientIdx, DeviceId};
622 DeviceId -= Devices[ClientIdx];
623 }
624 return {-1, -1};
625 }
626
initDevice(int32_t DeviceId)627 int32_t RemoteClientManager::initDevice(int32_t DeviceId) {
628 int32_t ClientIdx, DeviceIdx;
629 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
630 return Clients[ClientIdx].initDevice(DeviceIdx);
631 }
632
initRequires(int64_t RequiresFlags)633 int32_t RemoteClientManager::initRequires(int64_t RequiresFlags) {
634 for (auto &Client : Clients)
635 Client.initRequires(RequiresFlags);
636
637 return RequiresFlags;
638 }
639
loadBinary(int32_t DeviceId,__tgt_device_image * Image)640 __tgt_target_table *RemoteClientManager::loadBinary(int32_t DeviceId,
641 __tgt_device_image *Image) {
642 int32_t ClientIdx, DeviceIdx;
643 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
644 return Clients[ClientIdx].loadBinary(DeviceIdx, Image);
645 }
646
isDataExchangeable(int32_t SrcDevId,int32_t DstDevId)647 int32_t RemoteClientManager::isDataExchangeable(int32_t SrcDevId,
648 int32_t DstDevId) {
649 int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx;
650 std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId);
651 std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId);
652 return Clients[SrcClientIdx].isDataExchangeable(SrcDeviceIdx, DstDeviceIdx);
653 }
654
dataAlloc(int32_t DeviceId,int64_t Size,void * HstPtr)655 void *RemoteClientManager::dataAlloc(int32_t DeviceId, int64_t Size,
656 void *HstPtr) {
657 int32_t ClientIdx, DeviceIdx;
658 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
659 return Clients[ClientIdx].dataAlloc(DeviceIdx, Size, HstPtr);
660 }
661
dataDelete(int32_t DeviceId,void * TgtPtr)662 int32_t RemoteClientManager::dataDelete(int32_t DeviceId, void *TgtPtr) {
663 int32_t ClientIdx, DeviceIdx;
664 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
665 return Clients[ClientIdx].dataDelete(DeviceIdx, TgtPtr);
666 }
667
dataSubmit(int32_t DeviceId,void * TgtPtr,void * HstPtr,int64_t Size)668 int32_t RemoteClientManager::dataSubmit(int32_t DeviceId, void *TgtPtr,
669 void *HstPtr, int64_t Size) {
670 int32_t ClientIdx, DeviceIdx;
671 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
672 return Clients[ClientIdx].dataSubmit(DeviceIdx, TgtPtr, HstPtr, Size);
673 }
674
dataRetrieve(int32_t DeviceId,void * HstPtr,void * TgtPtr,int64_t Size)675 int32_t RemoteClientManager::dataRetrieve(int32_t DeviceId, void *HstPtr,
676 void *TgtPtr, int64_t Size) {
677 int32_t ClientIdx, DeviceIdx;
678 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
679 return Clients[ClientIdx].dataRetrieve(DeviceIdx, HstPtr, TgtPtr, Size);
680 }
681
dataExchange(int32_t SrcDevId,void * SrcPtr,int32_t DstDevId,void * DstPtr,int64_t Size)682 int32_t RemoteClientManager::dataExchange(int32_t SrcDevId, void *SrcPtr,
683 int32_t DstDevId, void *DstPtr,
684 int64_t Size) {
685 int32_t SrcClientIdx, SrcDeviceIdx, DstClientIdx, DstDeviceIdx;
686 std::tie(SrcClientIdx, SrcDeviceIdx) = mapDeviceId(SrcDevId);
687 std::tie(DstClientIdx, DstDeviceIdx) = mapDeviceId(DstDevId);
688 return Clients[SrcClientIdx].dataExchange(SrcDeviceIdx, SrcPtr, DstDeviceIdx,
689 DstPtr, Size);
690 }
691
runTargetRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum)692 int32_t RemoteClientManager::runTargetRegion(int32_t DeviceId,
693 void *TgtEntryPtr, void **TgtArgs,
694 ptrdiff_t *TgtOffsets,
695 int32_t ArgNum) {
696 int32_t ClientIdx, DeviceIdx;
697 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
698 return Clients[ClientIdx].runTargetRegion(DeviceIdx, TgtEntryPtr, TgtArgs,
699 TgtOffsets, ArgNum);
700 }
701
runTargetTeamRegion(int32_t DeviceId,void * TgtEntryPtr,void ** TgtArgs,ptrdiff_t * TgtOffsets,int32_t ArgNum,int32_t TeamNum,int32_t ThreadLimit,uint64_t LoopTripCount)702 int32_t RemoteClientManager::runTargetTeamRegion(
703 int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, ptrdiff_t *TgtOffsets,
704 int32_t ArgNum, int32_t TeamNum, int32_t ThreadLimit,
705 uint64_t LoopTripCount) {
706 int32_t ClientIdx, DeviceIdx;
707 std::tie(ClientIdx, DeviceIdx) = mapDeviceId(DeviceId);
708 return Clients[ClientIdx].runTargetTeamRegion(DeviceIdx, TgtEntryPtr, TgtArgs,
709 TgtOffsets, ArgNum, TeamNum,
710 ThreadLimit, LoopTripCount);
711 }
712