1 //===----------------- Server.cpp - Server Implementation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Offloading gRPC server for remote host.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <cmath>
14 #include <future>
15 
16 #include "Server.h"
17 #include "omptarget.h"
18 #include "openmp.grpc.pb.h"
19 #include "openmp.pb.h"
20 
21 using grpc::WriteOptions;
22 
23 extern std::promise<void> ShutdownPromise;
24 
Shutdown(ServerContext * Context,const Null * Request,I32 * Reply)25 Status RemoteOffloadImpl::Shutdown(ServerContext *Context, const Null *Request,
26                                    I32 *Reply) {
27   SERVER_DBG("Shutting down the server")
28 
29   Reply->set_number(0);
30   ShutdownPromise.set_value();
31   return Status::OK;
32 }
33 
34 Status
RegisterLib(ServerContext * Context,const TargetBinaryDescription * Description,I32 * Reply)35 RemoteOffloadImpl::RegisterLib(ServerContext *Context,
36                                const TargetBinaryDescription *Description,
37                                I32 *Reply) {
38   auto Desc = std::make_unique<__tgt_bin_desc>();
39 
40   unloadTargetBinaryDescription(Description, Desc.get(),
41                                 HostToRemoteDeviceImage);
42   PM->RTLs.RegisterLib(Desc.get());
43 
44   if (Descriptions.find((void *)Description->bin_ptr()) != Descriptions.end())
45     freeTargetBinaryDescription(
46         Descriptions[(void *)Description->bin_ptr()].get());
47   else
48     Descriptions[(void *)Description->bin_ptr()] = std::move(Desc);
49 
50   SERVER_DBG("Registered library")
51   Reply->set_number(0);
52   return Status::OK;
53 }
54 
UnregisterLib(ServerContext * Context,const Pointer * Request,I32 * Reply)55 Status RemoteOffloadImpl::UnregisterLib(ServerContext *Context,
56                                         const Pointer *Request, I32 *Reply) {
57   if (Descriptions.find((void *)Request->number()) == Descriptions.end()) {
58     Reply->set_number(1);
59     return Status::OK;
60   }
61 
62   PM->RTLs.UnregisterLib(Descriptions[(void *)Request->number()].get());
63   freeTargetBinaryDescription(Descriptions[(void *)Request->number()].get());
64   Descriptions.erase((void *)Request->number());
65 
66   SERVER_DBG("Unregistered library")
67   Reply->set_number(0);
68   return Status::OK;
69 }
70 
IsValidBinary(ServerContext * Context,const TargetDeviceImagePtr * DeviceImage,I32 * IsValid)71 Status RemoteOffloadImpl::IsValidBinary(ServerContext *Context,
72                                         const TargetDeviceImagePtr *DeviceImage,
73                                         I32 *IsValid) {
74   __tgt_device_image *Image =
75       HostToRemoteDeviceImage[(void *)DeviceImage->image_ptr()];
76 
77   IsValid->set_number(0);
78 
79   for (auto &RTL : PM->RTLs.AllRTLs)
80     if (auto Ret = RTL.is_valid_binary(Image)) {
81       IsValid->set_number(Ret);
82       break;
83     }
84 
85   SERVER_DBG("Checked if binary (%p) is valid",
86              (void *)(DeviceImage->image_ptr()))
87   return Status::OK;
88 }
89 
GetNumberOfDevices(ServerContext * Context,const Null * Null,I32 * NumberOfDevices)90 Status RemoteOffloadImpl::GetNumberOfDevices(ServerContext *Context,
91                                              const Null *Null,
92                                              I32 *NumberOfDevices) {
93   std::call_once(PM->RTLs.initFlag, &RTLsTy::LoadRTLs, &PM->RTLs);
94 
95   int32_t Devices = 0;
96   PM->RTLsMtx.lock();
97   for (auto &RTL : PM->RTLs.AllRTLs)
98     Devices += RTL.NumberOfDevices;
99   PM->RTLsMtx.unlock();
100 
101   NumberOfDevices->set_number(Devices);
102 
103   SERVER_DBG("Got number of devices")
104   return Status::OK;
105 }
106 
InitDevice(ServerContext * Context,const I32 * DeviceNum,I32 * Reply)107 Status RemoteOffloadImpl::InitDevice(ServerContext *Context,
108                                      const I32 *DeviceNum, I32 *Reply) {
109   Reply->set_number(PM->Devices[DeviceNum->number()].RTL->init_device(
110       mapHostRTLDeviceId(DeviceNum->number())));
111 
112   SERVER_DBG("Initialized device %d", DeviceNum->number())
113   return Status::OK;
114 }
115 
InitRequires(ServerContext * Context,const I64 * RequiresFlag,I32 * Reply)116 Status RemoteOffloadImpl::InitRequires(ServerContext *Context,
117                                        const I64 *RequiresFlag, I32 *Reply) {
118   for (auto &Device : PM->Devices)
119     if (Device.RTL->init_requires)
120       Device.RTL->init_requires(RequiresFlag->number());
121   Reply->set_number(RequiresFlag->number());
122 
123   SERVER_DBG("Initialized requires for devices")
124   return Status::OK;
125 }
126 
LoadBinary(ServerContext * Context,const Binary * Binary,TargetTable * Reply)127 Status RemoteOffloadImpl::LoadBinary(ServerContext *Context,
128                                      const Binary *Binary, TargetTable *Reply) {
129   __tgt_device_image *Image =
130       HostToRemoteDeviceImage[(void *)Binary->image_ptr()];
131 
132   Table = PM->Devices[Binary->device_id()].RTL->load_binary(
133       mapHostRTLDeviceId(Binary->device_id()), Image);
134   if (Table)
135     loadTargetTable(Table, *Reply, Image);
136 
137   SERVER_DBG("Loaded binary (%p) to device %d", (void *)Binary->image_ptr(),
138              Binary->device_id())
139   return Status::OK;
140 }
141 
IsDataExchangeable(ServerContext * Context,const DevicePair * Request,I32 * Reply)142 Status RemoteOffloadImpl::IsDataExchangeable(ServerContext *Context,
143                                              const DevicePair *Request,
144                                              I32 *Reply) {
145   Reply->set_number(-1);
146   if (PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
147           .RTL->is_data_exchangable)
148     Reply->set_number(PM->Devices[mapHostRTLDeviceId(Request->src_dev_id())]
149                           .RTL->is_data_exchangable(Request->src_dev_id(),
150                                                     Request->dst_dev_id()));
151 
152   SERVER_DBG("Checked if data exchangeable between device %d and device %d",
153              Request->src_dev_id(), Request->dst_dev_id())
154   return Status::OK;
155 }
156 
DataAlloc(ServerContext * Context,const AllocData * Request,Pointer * Reply)157 Status RemoteOffloadImpl::DataAlloc(ServerContext *Context,
158                                     const AllocData *Request, Pointer *Reply) {
159   uint64_t TgtPtr = (uint64_t)PM->Devices[Request->device_id()].RTL->data_alloc(
160       mapHostRTLDeviceId(Request->device_id()), Request->size(),
161       (void *)Request->hst_ptr(), TARGET_ALLOC_DEFAULT);
162   Reply->set_number(TgtPtr);
163 
164   SERVER_DBG("Allocated at " DPxMOD "", DPxPTR((void *)TgtPtr))
165 
166   return Status::OK;
167 }
168 
DataSubmit(ServerContext * Context,ServerReader<SubmitData> * Reader,I32 * Reply)169 Status RemoteOffloadImpl::DataSubmit(ServerContext *Context,
170                                      ServerReader<SubmitData> *Reader,
171                                      I32 *Reply) {
172   SubmitData Request;
173   uint8_t *HostCopy = nullptr;
174   while (Reader->Read(&Request)) {
175     if (Request.start() == 0 && Request.size() == Request.data().size()) {
176       Reader->SendInitialMetadata();
177 
178       Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit(
179           mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
180           (void *)Request.data().data(), Request.data().size()));
181 
182       SERVER_DBG("Submitted %lu bytes async to (%p) on device %d",
183                  Request.data().size(), (void *)Request.tgt_ptr(),
184                  Request.device_id())
185 
186       return Status::OK;
187     }
188     if (!HostCopy) {
189       HostCopy = new uint8_t[Request.size()];
190       Reader->SendInitialMetadata();
191     }
192 
193     memcpy((void *)((char *)HostCopy + Request.start()), Request.data().data(),
194            Request.data().size());
195   }
196 
197   Reply->set_number(PM->Devices[Request.device_id()].RTL->data_submit(
198       mapHostRTLDeviceId(Request.device_id()), (void *)Request.tgt_ptr(),
199       HostCopy, Request.size()));
200 
201   delete[] HostCopy;
202 
203   SERVER_DBG("Submitted %lu bytes to (%p) on device %d", Request.data().size(),
204              (void *)Request.tgt_ptr(), Request.device_id())
205 
206   return Status::OK;
207 }
208 
DataRetrieve(ServerContext * Context,const RetrieveData * Request,ServerWriter<Data> * Writer)209 Status RemoteOffloadImpl::DataRetrieve(ServerContext *Context,
210                                        const RetrieveData *Request,
211                                        ServerWriter<Data> *Writer) {
212   auto HstPtr = std::make_unique<char[]>(Request->size());
213 
214   auto Ret = PM->Devices[Request->device_id()].RTL->data_retrieve(
215       mapHostRTLDeviceId(Request->device_id()), HstPtr.get(),
216       (void *)Request->tgt_ptr(), Request->size());
217 
218   if (Arena->SpaceAllocated() >= MaxSize)
219     Arena->Reset();
220 
221   if (Request->size() > BlockSize) {
222     uint64_t Start = 0, End = BlockSize;
223     for (auto I = 0; I < ceil((float)Request->size() / BlockSize); I++) {
224       auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
225 
226       Reply->set_start(Start);
227       Reply->set_size(Request->size());
228       Reply->set_data((char *)HstPtr.get() + Start, End - Start);
229       Reply->set_ret(Ret);
230 
231       if (!Writer->Write(*Reply)) {
232         CLIENT_DBG("Broken stream when submitting data")
233       }
234 
235       SERVER_DBG("Retrieved %lu-%lu/%lu bytes from (%p) on device %d", Start,
236                  End, Request->size(), (void *)Request->tgt_ptr(),
237                  mapHostRTLDeviceId(Request->device_id()))
238 
239       Start += BlockSize;
240       End += BlockSize;
241       if (End >= Request->size())
242         End = Request->size();
243     }
244   } else {
245     auto *Reply = protobuf::Arena::CreateMessage<Data>(Arena.get());
246 
247     Reply->set_start(0);
248     Reply->set_size(Request->size());
249     Reply->set_data((char *)HstPtr.get(), Request->size());
250     Reply->set_ret(Ret);
251 
252     SERVER_DBG("Retrieved %lu bytes from (%p) on device %d", Request->size(),
253                (void *)Request->tgt_ptr(),
254                mapHostRTLDeviceId(Request->device_id()))
255 
256     Writer->WriteLast(*Reply, WriteOptions());
257   }
258 
259   return Status::OK;
260 }
261 
DataExchange(ServerContext * Context,const ExchangeData * Request,I32 * Reply)262 Status RemoteOffloadImpl::DataExchange(ServerContext *Context,
263                                        const ExchangeData *Request,
264                                        I32 *Reply) {
265   if (PM->Devices[Request->src_dev_id()].RTL->data_exchange) {
266     int32_t Ret = PM->Devices[Request->src_dev_id()].RTL->data_exchange(
267         mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
268         mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
269         Request->size());
270     Reply->set_number(Ret);
271   } else
272     Reply->set_number(-1);
273 
274   SERVER_DBG(
275       "Exchanged data asynchronously from device %d (%p) to device %d (%p) of "
276       "size %lu",
277       mapHostRTLDeviceId(Request->src_dev_id()), (void *)Request->src_ptr(),
278       mapHostRTLDeviceId(Request->dst_dev_id()), (void *)Request->dst_ptr(),
279       Request->size())
280   return Status::OK;
281 }
282 
DataDelete(ServerContext * Context,const DeleteData * Request,I32 * Reply)283 Status RemoteOffloadImpl::DataDelete(ServerContext *Context,
284                                      const DeleteData *Request, I32 *Reply) {
285   auto Ret = PM->Devices[Request->device_id()].RTL->data_delete(
286       mapHostRTLDeviceId(Request->device_id()), (void *)Request->tgt_ptr());
287   Reply->set_number(Ret);
288 
289   SERVER_DBG("Deleted data from (%p) on device %d", (void *)Request->tgt_ptr(),
290              mapHostRTLDeviceId(Request->device_id()))
291   return Status::OK;
292 }
293 
RunTargetRegion(ServerContext * Context,const TargetRegion * Request,I32 * Reply)294 Status RemoteOffloadImpl::RunTargetRegion(ServerContext *Context,
295                                           const TargetRegion *Request,
296                                           I32 *Reply) {
297   std::vector<uint8_t> TgtArgs(Request->arg_num());
298   for (auto I = 0; I < Request->arg_num(); I++)
299     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
300 
301   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
302   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
303   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
304     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
305 
306   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
307 
308   int32_t Ret = PM->Devices[Request->device_id()].RTL->run_region(
309       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
310       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num());
311 
312   Reply->set_number(Ret);
313 
314   SERVER_DBG("Ran TargetRegion on device %d with %d args",
315              mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
316   return Status::OK;
317 }
318 
RunTargetTeamRegion(ServerContext * Context,const TargetTeamRegion * Request,I32 * Reply)319 Status RemoteOffloadImpl::RunTargetTeamRegion(ServerContext *Context,
320                                               const TargetTeamRegion *Request,
321                                               I32 *Reply) {
322   std::vector<uint64_t> TgtArgs(Request->arg_num());
323   for (auto I = 0; I < Request->arg_num(); I++)
324     TgtArgs[I] = (uint64_t)Request->tgt_args()[I];
325 
326   std::vector<ptrdiff_t> TgtOffsets(Request->arg_num());
327   const auto *TgtOffsetItr = Request->tgt_offsets().begin();
328   for (auto I = 0; I < Request->arg_num(); I++, TgtOffsetItr++)
329     TgtOffsets[I] = (ptrdiff_t)*TgtOffsetItr;
330 
331   void *TgtEntryPtr = ((__tgt_offload_entry *)Request->tgt_entry_ptr())->addr;
332 
333   int32_t Ret = PM->Devices[Request->device_id()].RTL->run_team_region(
334       mapHostRTLDeviceId(Request->device_id()), TgtEntryPtr,
335       (void **)TgtArgs.data(), TgtOffsets.data(), Request->arg_num(),
336       Request->team_num(), Request->thread_limit(), Request->loop_tripcount());
337 
338   Reply->set_number(Ret);
339 
340   SERVER_DBG("Ran TargetTeamRegion on device %d with %d args",
341              mapHostRTLDeviceId(Request->device_id()), Request->arg_num())
342   return Status::OK;
343 }
344 
mapHostRTLDeviceId(int32_t RTLDeviceID)345 int32_t RemoteOffloadImpl::mapHostRTLDeviceId(int32_t RTLDeviceID) {
346   for (auto &RTL : PM->RTLs.UsedRTLs) {
347     if (RTLDeviceID - RTL->NumberOfDevices >= 0)
348       RTLDeviceID -= RTL->NumberOfDevices;
349     else
350       break;
351   }
352   return RTLDeviceID;
353 }
354