1 //===------------------ Client.h - Client Implementation ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // gRPC Client for the remote plugin. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H 14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_REMOTE_SRC_CLIENT_H 15 16 #include "Utils.h" 17 #include "omptarget.h" 18 #include <google/protobuf/arena.h> 19 #include <grpcpp/grpcpp.h> 20 #include <grpcpp/security/credentials.h> 21 #include <grpcpp/support/channel_arguments.h> 22 #include <memory> 23 #include <mutex> 24 #include <numeric> 25 26 using grpc::Channel; 27 using openmp::libomptarget::remote::RemoteOffload; 28 using namespace RemoteOffloading; 29 30 using namespace google; 31 32 class RemoteOffloadClient { 33 int DebugLevel; 34 const int Timeout; 35 const uint64_t MaxSize; 36 const int64_t BlockSize; 37 38 std::unique_ptr<RemoteOffload::Stub> Stub; 39 std::unique_ptr<protobuf::Arena> Arena; 40 41 std::unique_ptr<std::mutex> ArenaAllocatorLock; 42 43 std::map<int32_t, std::unordered_map<void *, void *>> RemoteEntries; 44 std::map<int32_t, std::unique_ptr<__tgt_target_table>> DevicesToTables; 45 46 template <typename Fn1, typename Fn2, typename TReturn> 47 auto remoteCall(Fn1 Preprocessor, Fn2 Postprocessor, TReturn ErrorValue, 48 bool CanTimeOut = true); 49 50 public: RemoteOffloadClient(std::shared_ptr<Channel> Channel,int Timeout,uint64_t MaxSize,int64_t BlockSize)51 RemoteOffloadClient(std::shared_ptr<Channel> Channel, int Timeout, 52 uint64_t MaxSize, int64_t BlockSize) 53 : Timeout(Timeout), MaxSize(MaxSize), BlockSize(BlockSize), 54 Stub(RemoteOffload::NewStub(Channel)) { 55 DebugLevel = getDebugLevel(); 56 Arena = std::make_unique<protobuf::Arena>(); 57 ArenaAllocatorLock = std::make_unique<std::mutex>(); 58 } 59 60 RemoteOffloadClient(RemoteOffloadClient &&C) = default; 61 ~RemoteOffloadClient()62 ~RemoteOffloadClient() { 63 for (auto &TableIt : DevicesToTables) 64 freeTargetTable(TableIt.second.get()); 65 } 66 67 int32_t shutdown(void); 68 69 int32_t registerLib(__tgt_bin_desc *Desc); 70 int32_t unregisterLib(__tgt_bin_desc *Desc); 71 72 int32_t isValidBinary(__tgt_device_image *Image); 73 int32_t getNumberOfDevices(); 74 75 int32_t initDevice(int32_t DeviceId); 76 int32_t initRequires(int64_t RequiresFlags); 77 78 __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image); 79 80 void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr); 81 int32_t dataDelete(int32_t DeviceId, void *TgtPtr); 82 83 int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr, 84 int64_t Size); 85 int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, 86 int64_t Size); 87 88 int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); 89 int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, 90 void *DstPtr, int64_t Size); 91 92 int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, 93 ptrdiff_t *TgtOffsets, int32_t ArgNum); 94 int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr, 95 void **TgtArgs, ptrdiff_t *TgtOffsets, 96 int32_t ArgNum, int32_t TeamNum, 97 int32_t ThreadLimit, uint64_t LoopTripCount); 98 }; 99 100 class RemoteClientManager { 101 private: 102 std::vector<RemoteOffloadClient> Clients; 103 std::vector<int> Devices; 104 105 std::pair<int32_t, int32_t> mapDeviceId(int32_t DeviceId); 106 int DebugLevel; 107 108 public: RemoteClientManager()109 RemoteClientManager() { 110 ClientManagerConfigTy Config; 111 112 grpc::ChannelArguments ChArgs; 113 ChArgs.SetMaxReceiveMessageSize(-1); 114 DebugLevel = getDebugLevel(); 115 for (auto Address : Config.ServerAddresses) { 116 Clients.push_back(RemoteOffloadClient( 117 grpc::CreateChannel(Address, grpc::InsecureChannelCredentials()), 118 Config.Timeout, Config.MaxSize, Config.BlockSize)); 119 } 120 } 121 122 int32_t shutdown(void); 123 124 int32_t registerLib(__tgt_bin_desc *Desc); 125 int32_t unregisterLib(__tgt_bin_desc *Desc); 126 127 int32_t isValidBinary(__tgt_device_image *Image); 128 int32_t getNumberOfDevices(); 129 130 int32_t initDevice(int32_t DeviceId); 131 int32_t initRequires(int64_t RequiresFlags); 132 133 __tgt_target_table *loadBinary(int32_t DeviceId, __tgt_device_image *Image); 134 135 void *dataAlloc(int32_t DeviceId, int64_t Size, void *HstPtr); 136 int32_t dataDelete(int32_t DeviceId, void *TgtPtr); 137 138 int32_t dataSubmit(int32_t DeviceId, void *TgtPtr, void *HstPtr, 139 int64_t Size); 140 int32_t dataRetrieve(int32_t DeviceId, void *HstPtr, void *TgtPtr, 141 int64_t Size); 142 143 int32_t isDataExchangeable(int32_t SrcDevId, int32_t DstDevId); 144 int32_t dataExchange(int32_t SrcDevId, void *SrcPtr, int32_t DstDevId, 145 void *DstPtr, int64_t Size); 146 147 int32_t runTargetRegion(int32_t DeviceId, void *TgtEntryPtr, void **TgtArgs, 148 ptrdiff_t *TgtOffsets, int32_t ArgNum); 149 int32_t runTargetTeamRegion(int32_t DeviceId, void *TgtEntryPtr, 150 void **TgtArgs, ptrdiff_t *TgtOffsets, 151 int32_t ArgNum, int32_t TeamNum, 152 int32_t ThreadLimit, uint64_t LoopTripCount); 153 }; 154 155 #endif 156