1 //===----------- MemoryManager.h - Target independent memory manager ------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Target independent memory manager. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H 14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H 15 16 #include <cassert> 17 #include <functional> 18 #include <list> 19 #include <mutex> 20 #include <set> 21 #include <unordered_map> 22 #include <vector> 23 24 #include "Debug.h" 25 #include "omptargetplugin.h" 26 27 /// Base class of per-device allocator. 28 class DeviceAllocatorTy { 29 public: 30 virtual ~DeviceAllocatorTy() = default; 31 32 /// Allocate a memory of size \p Size . \p HstPtr is used to assist the 33 /// allocation. 34 virtual void *allocate(size_t Size, void *HstPtr, TargetAllocTy Kind) = 0; 35 36 /// Delete the pointer \p TgtPtr on the device 37 virtual int free(void *TgtPtr) = 0; 38 }; 39 40 /// Class of memory manager. The memory manager is per-device by using 41 /// per-device allocator. Therefore, each plugin using memory manager should 42 /// have an allocator for each device. 43 class MemoryManagerTy { 44 static constexpr const size_t BucketSize[] = { 45 0, 1U << 2, 1U << 3, 1U << 4, 1U << 5, 1U << 6, 1U << 7, 46 1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13}; 47 48 static constexpr const int NumBuckets = 49 sizeof(BucketSize) / sizeof(BucketSize[0]); 50 51 /// Find the previous number that is power of 2 given a number that is not 52 /// power of 2. floorToPowerOfTwo(size_t Num)53 static size_t floorToPowerOfTwo(size_t Num) { 54 Num |= Num >> 1; 55 Num |= Num >> 2; 56 Num |= Num >> 4; 57 Num |= Num >> 8; 58 Num |= Num >> 16; 59 #if INTPTR_MAX == INT64_MAX 60 Num |= Num >> 32; 61 #elif INTPTR_MAX == INT32_MAX 62 // Do nothing with 32-bit 63 #else 64 #error Unsupported architecture 65 #endif 66 Num += 1; 67 return Num >> 1; 68 } 69 70 /// Find a suitable bucket findBucket(size_t Size)71 static int findBucket(size_t Size) { 72 const size_t F = floorToPowerOfTwo(Size); 73 74 DP("findBucket: Size %zu is floored to %zu.\n", Size, F); 75 76 int L = 0, H = NumBuckets - 1; 77 while (H - L > 1) { 78 int M = (L + H) >> 1; 79 if (BucketSize[M] == F) 80 return M; 81 if (BucketSize[M] > F) 82 H = M - 1; 83 else 84 L = M; 85 } 86 87 assert(L >= 0 && L < NumBuckets && "L is out of range"); 88 89 DP("findBucket: Size %zu goes to bucket %d\n", Size, L); 90 91 return L; 92 } 93 94 /// A structure stores the meta data of a target pointer 95 struct NodeTy { 96 /// Memory size 97 const size_t Size; 98 /// Target pointer 99 void *Ptr; 100 101 /// Constructor NodeTyNodeTy102 NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {} 103 }; 104 105 /// To make \p NodePtrTy ordered when they're put into \p std::multiset. 106 struct NodeCmpTy { operatorNodeCmpTy107 bool operator()(const NodeTy &LHS, const NodeTy &RHS) const { 108 return LHS.Size < RHS.Size; 109 } 110 }; 111 112 /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make 113 /// the look up procedure more efficient. 114 using FreeListTy = std::multiset<std::reference_wrapper<NodeTy>, NodeCmpTy>; 115 116 /// A list of \p FreeListTy entries, each of which is a \p std::multiset of 117 /// Nodes whose size is less or equal to a specific bucket size. 118 std::vector<FreeListTy> FreeLists; 119 /// A list of mutex for each \p FreeListTy entry 120 std::vector<std::mutex> FreeListLocks; 121 /// A table to map from a target pointer to its node 122 std::unordered_map<void *, NodeTy> PtrToNodeTable; 123 /// The mutex for the table \p PtrToNodeTable 124 std::mutex MapTableLock; 125 126 /// The reference to a device allocator 127 DeviceAllocatorTy &DeviceAllocator; 128 129 /// The threshold to manage memory using memory manager. If the request size 130 /// is larger than \p SizeThreshold, the allocation will not be managed by the 131 /// memory manager. 132 size_t SizeThreshold = 1U << 13; 133 134 /// Request memory from target device allocateOnDevice(size_t Size,void * HstPtr)135 void *allocateOnDevice(size_t Size, void *HstPtr) const { 136 return DeviceAllocator.allocate(Size, HstPtr, TARGET_ALLOC_DEVICE); 137 } 138 139 /// Deallocate data on device deleteOnDevice(void * Ptr)140 int deleteOnDevice(void *Ptr) const { return DeviceAllocator.free(Ptr); } 141 142 /// This function is called when it tries to allocate memory on device but the 143 /// device returns out of memory. It will first free all memory in the 144 /// FreeList and try to allocate again. freeAndAllocate(size_t Size,void * HstPtr)145 void *freeAndAllocate(size_t Size, void *HstPtr) { 146 std::vector<void *> RemoveList; 147 148 // Deallocate all memory in FreeList 149 for (int I = 0; I < NumBuckets; ++I) { 150 FreeListTy &List = FreeLists[I]; 151 std::lock_guard<std::mutex> Lock(FreeListLocks[I]); 152 if (List.empty()) 153 continue; 154 for (const NodeTy &N : List) { 155 deleteOnDevice(N.Ptr); 156 RemoveList.push_back(N.Ptr); 157 } 158 FreeLists[I].clear(); 159 } 160 161 // Remove all nodes in the map table which have been released 162 if (!RemoveList.empty()) { 163 std::lock_guard<std::mutex> LG(MapTableLock); 164 for (void *P : RemoveList) 165 PtrToNodeTable.erase(P); 166 } 167 168 // Try allocate memory again 169 return allocateOnDevice(Size, HstPtr); 170 } 171 172 /// The goal is to allocate memory on the device. It first tries to 173 /// allocate directly on the device. If a \p nullptr is returned, it might 174 /// be because the device is OOM. In that case, it will free all unused 175 /// memory and then try again. allocateOrFreeAndAllocateOnDevice(size_t Size,void * HstPtr)176 void *allocateOrFreeAndAllocateOnDevice(size_t Size, void *HstPtr) { 177 void *TgtPtr = allocateOnDevice(Size, HstPtr); 178 // We cannot get memory from the device. It might be due to OOM. Let's 179 // free all memory in FreeLists and try again. 180 if (TgtPtr == nullptr) { 181 DP("Failed to get memory on device. Free all memory in FreeLists and " 182 "try again.\n"); 183 TgtPtr = freeAndAllocate(Size, HstPtr); 184 } 185 186 if (TgtPtr == nullptr) 187 DP("Still cannot get memory on device probably because the device is " 188 "OOM.\n"); 189 190 return TgtPtr; 191 } 192 193 public: 194 /// Constructor. If \p Threshold is non-zero, then the default threshold will 195 /// be overwritten by \p Threshold. 196 MemoryManagerTy(DeviceAllocatorTy &DeviceAllocator, size_t Threshold = 0) FreeLists(NumBuckets)197 : FreeLists(NumBuckets), FreeListLocks(NumBuckets), 198 DeviceAllocator(DeviceAllocator) { 199 if (Threshold) 200 SizeThreshold = Threshold; 201 } 202 203 /// Destructor ~MemoryManagerTy()204 ~MemoryManagerTy() { 205 for (auto Itr = PtrToNodeTable.begin(); Itr != PtrToNodeTable.end(); 206 ++Itr) { 207 assert(Itr->second.Ptr && "nullptr in map table"); 208 deleteOnDevice(Itr->second.Ptr); 209 } 210 } 211 212 /// Allocate memory of size \p Size from target device. \p HstPtr is used to 213 /// assist the allocation. allocate(size_t Size,void * HstPtr)214 void *allocate(size_t Size, void *HstPtr) { 215 // If the size is zero, we will not bother the target device. Just return 216 // nullptr directly. 217 if (Size == 0) 218 return nullptr; 219 220 DP("MemoryManagerTy::allocate: size %zu with host pointer " DPxMOD ".\n", 221 Size, DPxPTR(HstPtr)); 222 223 // If the size is greater than the threshold, allocate it directly from 224 // device. 225 if (Size > SizeThreshold) { 226 DP("%zu is greater than the threshold %zu. Allocate it directly from " 227 "device\n", 228 Size, SizeThreshold); 229 void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr); 230 231 DP("Got target pointer " DPxMOD ". Return directly.\n", DPxPTR(TgtPtr)); 232 233 return TgtPtr; 234 } 235 236 NodeTy *NodePtr = nullptr; 237 238 // Try to get a node from FreeList 239 { 240 const int B = findBucket(Size); 241 FreeListTy &List = FreeLists[B]; 242 243 NodeTy TempNode(Size, nullptr); 244 std::lock_guard<std::mutex> LG(FreeListLocks[B]); 245 const auto Itr = List.find(TempNode); 246 247 if (Itr != List.end()) { 248 NodePtr = &Itr->get(); 249 List.erase(Itr); 250 } 251 } 252 253 if (NodePtr != nullptr) 254 DP("Find one node " DPxMOD " in the bucket.\n", DPxPTR(NodePtr)); 255 256 // We cannot find a valid node in FreeLists. Let's allocate on device and 257 // create a node for it. 258 if (NodePtr == nullptr) { 259 DP("Cannot find a node in the FreeLists. Allocate on device.\n"); 260 // Allocate one on device 261 void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr); 262 263 if (TgtPtr == nullptr) 264 return nullptr; 265 266 // Create a new node and add it into the map table 267 { 268 std::lock_guard<std::mutex> Guard(MapTableLock); 269 auto Itr = PtrToNodeTable.emplace(TgtPtr, NodeTy(Size, TgtPtr)); 270 NodePtr = &Itr.first->second; 271 } 272 273 DP("Node address " DPxMOD ", target pointer " DPxMOD ", size %zu\n", 274 DPxPTR(NodePtr), DPxPTR(TgtPtr), Size); 275 } 276 277 assert(NodePtr && "NodePtr should not be nullptr at this point"); 278 279 return NodePtr->Ptr; 280 } 281 282 /// Deallocate memory pointed by \p TgtPtr free(void * TgtPtr)283 int free(void *TgtPtr) { 284 DP("MemoryManagerTy::free: target memory " DPxMOD ".\n", DPxPTR(TgtPtr)); 285 286 NodeTy *P = nullptr; 287 288 // Look it up into the table 289 { 290 std::lock_guard<std::mutex> G(MapTableLock); 291 auto Itr = PtrToNodeTable.find(TgtPtr); 292 293 // We don't remove the node from the map table because the map does not 294 // change. 295 if (Itr != PtrToNodeTable.end()) 296 P = &Itr->second; 297 } 298 299 // The memory is not managed by the manager 300 if (P == nullptr) { 301 DP("Cannot find its node. Delete it on device directly.\n"); 302 return deleteOnDevice(TgtPtr); 303 } 304 305 // Insert the node to the free list 306 const int B = findBucket(P->Size); 307 308 DP("Found its node " DPxMOD ". Insert it to bucket %d.\n", DPxPTR(P), B); 309 310 { 311 std::lock_guard<std::mutex> G(FreeListLocks[B]); 312 FreeLists[B].insert(*P); 313 } 314 315 return OFFLOAD_SUCCESS; 316 } 317 318 /// Get the size threshold from the environment variable 319 /// \p LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD . Returns a <tt> 320 /// std::pair<size_t, bool> </tt> where the first element represents the 321 /// threshold and the second element represents whether user disables memory 322 /// manager explicitly by setting the var to 0. If user doesn't specify 323 /// anything, returns <0, true>. getSizeThresholdFromEnv()324 static std::pair<size_t, bool> getSizeThresholdFromEnv() { 325 size_t Threshold = 0; 326 327 if (const char *Env = 328 std::getenv("LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD")) { 329 Threshold = std::stoul(Env); 330 if (Threshold == 0) { 331 DP("Disabled memory manager as user set " 332 "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0.\n"); 333 return std::make_pair(0, false); 334 } 335 } 336 337 return std::make_pair(Threshold, true); 338 } 339 }; 340 341 // GCC still cannot handle the static data member like Clang so we still need 342 // this part. 343 constexpr const size_t MemoryManagerTy::BucketSize[]; 344 constexpr const int MemoryManagerTy::NumBuckets; 345 346 #endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H 347