1 //===------------------ Client.h - Client Implementation ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // gRPC Client for the remote plugin.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H
14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H
15 
16 #include "Utils.h"
17 #include "omptarget.h"
18 #include <google/protobuf/arena.h>
19 #include <grpcpp/grpcpp.h>
20 #include <grpcpp/security/credentials.h>
21 #include <grpcpp/support/channel_arguments.h>
22 #include <memory>
23 #include <mutex>
24 #include <numeric>
25 
26 using grpc::Channel;
27 using openmp::libomptarget::remote::RemoteOffload;
28 using namespace RemoteOffloading;
29 
30 using namespace google;
31 
32 class RemoteOffloadClient {
33   int DebugLevel;
34   const int Timeout;
35   const uint64_t MaxSize;
36   const int64_t BlockSize;
37 
38   std::unique_ptr<RemoteOffload::Stub> Stub;
39   std::unique_ptr<protobuf::Arena> Arena;
40 
41   std::unique_ptr<std::mutex> ArenaAllocatorLock;
42 
43   std::map<int32_t, std::unordered_map<void *, void *>> RemoteEntries;
44   std::map<int32_t, std::unique_ptr<__tgt_target_table>> DevicesToTables;
45 
46   template <typename Fn1, typename Fn2, typename TReturn>
47   auto remoteCall(Fn1 Preprocessor, Fn2 Postprocessor, TReturn ErrorValue,
48                   bool CanTimeOut = true);
49 
50 public:
RemoteOffloadClient(std::shared_ptr<Channel> Channel,int Timeout,uint64_t MaxSize,int64_t BlockSize)51   RemoteOffloadClient(std::shared_ptr<Channel> Channel, int Timeout,
52                       uint64_t MaxSize, int64_t BlockSize)
53       : Timeout(Timeout), MaxSize(MaxSize), BlockSize(BlockSize),
54         Stub(RemoteOffload::NewStub(Channel)) {
55     DebugLevel = getDebugLevel();
56     Arena = std::make_unique<protobuf::Arena>();
57     ArenaAllocatorLock = std::make_unique<std::mutex>();
58   }
59 
60   RemoteOffloadClient(RemoteOffloadClient &&C) = default;
61 
~RemoteOffloadClient()62   ~RemoteOffloadClient() {
63     for (auto &TableIt : DevicesToTables)
64       freeTargetTable(TableIt.second.get());
65   }
66 
67   int32_t shutdown(void);
68 
69   int32_t registerLib(__tgt_bin_desc *Desc);
70   int32_t unregisterLib(__tgt_bin_desc *Desc);
71 
72   int32_t isValidBinary(__tgt_device_image *Image);
73   int32_t getNumberOfDevices();
74 
75   int32_t initDevice(int32_t DeviceId);
76   int32_t initRequires(int64_t RequiresFlags);
77 
78   __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image);
79 
80   void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr);
81   int32_t dataDelete(int32_t DeviceId, void *TgtPtr);
82 
83   int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
84                      int64_t Size);
85   int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
86                        int64_t Size);
87 
88   int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId);
89   int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId,
90                        void *DstPtr, int64_t Size);
91 
92   int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs,
93                           ptrdiff_t *TgtOffsets, int32_t ArgNum);
94   int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr,
95                               void **TgtArgs, ptrdiff_t *TgtOffsets,
96                               int32_t ArgNum, int32_t TeamNum,
97                               int32_t ThreadLimit, uint64_t LoopTripCount);
98 };
99 
100 class RemoteClientManager {
101 private:
102   std::vector<RemoteOffloadClient> Clients;
103   std::vector<int> Devices;
104 
105   std::pair<int32_t, int32_t> mapDeviceId(int32_t DeviceId);
106   int DebugLevel;
107 
108 public:
RemoteClientManager()109   RemoteClientManager() {
110     ClientManagerConfigTy Config;
111 
112     grpc::ChannelArguments ChArgs;
113     ChArgs.SetMaxReceiveMessageSize(-1);
114     DebugLevel = getDebugLevel();
115     for (auto Address : Config.ServerAddresses) {
116       Clients.push_back(RemoteOffloadClient(
117           grpc::CreateChannel(Address, grpc::InsecureChannelCredentials()),
118           Config.Timeout, Config.MaxSize, Config.BlockSize));
119     }
120   }
121 
122   int32_t shutdown(void);
123 
124   int32_t registerLib(__tgt_bin_desc *Desc);
125   int32_t unregisterLib(__tgt_bin_desc *Desc);
126 
127   int32_t isValidBinary(__tgt_device_image *Image);
128   int32_t getNumberOfDevices();
129 
130   int32_t initDevice(int32_t DeviceId);
131   int32_t initRequires(int64_t RequiresFlags);
132 
133   __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image);
134 
135   void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr);
136   int32_t dataDelete(int32_t DeviceId, void *TgtPtr);
137 
138   int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr,
139                      int64_t Size);
140   int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr,
141                        int64_t Size);
142 
143   int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId);
144   int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId,
145                        void *DstPtr, int64_t Size);
146 
147   int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs,
148                           ptrdiff_t *TgtOffsets, int32_t ArgNum);
149   int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr,
150                               void **TgtArgs, ptrdiff_t *TgtOffsets,
151                               int32_t ArgNum, int32_t TeamNum,
152                               int32_t ThreadLimit, uint64_t LoopTripCount);
153 };
154 
155 #endif
156