1 //===----------- MemoryManager.h - Target independent memory manager ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Target independent memory manager.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H
14 #define LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H
15 
16 #include <cassert>
17 #include <functional>
18 #include <list>
19 #include <mutex>
20 #include <set>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "Debug.h"
25 #include "omptargetplugin.h"
26 
27 /// Base class of per-device allocator.
28 class DeviceAllocatorTy {
29 public:
30   virtual ~DeviceAllocatorTy() = default;
31 
32   /// Allocate a memory of size \p Size . \p HstPtr is used to assist the
33   /// allocation.
34   virtual void *allocate(size_t Size, void *HstPtr, TargetAllocTy Kind) = 0;
35 
36   /// Delete the pointer \p TgtPtr on the device
37   virtual int free(void *TgtPtr) = 0;
38 };
39 
40 /// Class of memory manager. The memory manager is per-device by using
41 /// per-device allocator. Therefore, each plugin using memory manager should
42 /// have an allocator for each device.
43 class MemoryManagerTy {
44   static constexpr const size_t BucketSize[] = {
45       0,       1U << 2, 1U << 3,  1U << 4,  1U << 5,  1U << 6, 1U << 7,
46       1U << 8, 1U << 9, 1U << 10, 1U << 11, 1U << 12, 1U << 13};
47 
48   static constexpr const int NumBuckets =
49       sizeof(BucketSize) / sizeof(BucketSize[0]);
50 
51   /// Find the previous number that is power of 2 given a number that is not
52   /// power of 2.
floorToPowerOfTwo(size_t Num)53   static size_t floorToPowerOfTwo(size_t Num) {
54     Num |= Num >> 1;
55     Num |= Num >> 2;
56     Num |= Num >> 4;
57     Num |= Num >> 8;
58     Num |= Num >> 16;
59 #if INTPTR_MAX == INT64_MAX
60     Num |= Num >> 32;
61 #elif INTPTR_MAX == INT32_MAX
62     // Do nothing with 32-bit
63 #else
64 #error Unsupported architecture
65 #endif
66     Num += 1;
67     return Num >> 1;
68   }
69 
70   /// Find a suitable bucket
findBucket(size_t Size)71   static int findBucket(size_t Size) {
72     const size_t F = floorToPowerOfTwo(Size);
73 
74     DP("findBucket: Size %zu is floored to %zu.\n", Size, F);
75 
76     int L = 0, H = NumBuckets - 1;
77     while (H - L > 1) {
78       int M = (L + H) >> 1;
79       if (BucketSize[M] == F)
80         return M;
81       if (BucketSize[M] > F)
82         H = M - 1;
83       else
84         L = M;
85     }
86 
87     assert(L >= 0 && L < NumBuckets && "L is out of range");
88 
89     DP("findBucket: Size %zu goes to bucket %d\n", Size, L);
90 
91     return L;
92   }
93 
94   /// A structure stores the meta data of a target pointer
95   struct NodeTy {
96     /// Memory size
97     const size_t Size;
98     /// Target pointer
99     void *Ptr;
100 
101     /// Constructor
NodeTyNodeTy102     NodeTy(size_t Size, void *Ptr) : Size(Size), Ptr(Ptr) {}
103   };
104 
105   /// To make \p NodePtrTy ordered when they're put into \p std::multiset.
106   struct NodeCmpTy {
operatorNodeCmpTy107     bool operator()(const NodeTy &LHS, const NodeTy &RHS) const {
108       return LHS.Size < RHS.Size;
109     }
110   };
111 
112   /// A \p FreeList is a set of Nodes. We're using \p std::multiset here to make
113   /// the look up procedure more efficient.
114   using FreeListTy = std::multiset<std::reference_wrapper<NodeTy>, NodeCmpTy>;
115 
116   /// A list of \p FreeListTy entries, each of which is a \p std::multiset of
117   /// Nodes whose size is less or equal to a specific bucket size.
118   std::vector<FreeListTy> FreeLists;
119   /// A list of mutex for each \p FreeListTy entry
120   std::vector<std::mutex> FreeListLocks;
121   /// A table to map from a target pointer to its node
122   std::unordered_map<void *, NodeTy> PtrToNodeTable;
123   /// The mutex for the table \p PtrToNodeTable
124   std::mutex MapTableLock;
125 
126   /// The reference to a device allocator
127   DeviceAllocatorTy &DeviceAllocator;
128 
129   /// The threshold to manage memory using memory manager. If the request size
130   /// is larger than \p SizeThreshold, the allocation will not be managed by the
131   /// memory manager.
132   size_t SizeThreshold = 1U << 13;
133 
134   /// Request memory from target device
allocateOnDevice(size_t Size,void * HstPtr)135   void *allocateOnDevice(size_t Size, void *HstPtr) const {
136     return DeviceAllocator.allocate(Size, HstPtr, TARGET_ALLOC_DEVICE);
137   }
138 
139   /// Deallocate data on device
deleteOnDevice(void * Ptr)140   int deleteOnDevice(void *Ptr) const { return DeviceAllocator.free(Ptr); }
141 
142   /// This function is called when it tries to allocate memory on device but the
143   /// device returns out of memory. It will first free all memory in the
144   /// FreeList and try to allocate again.
freeAndAllocate(size_t Size,void * HstPtr)145   void *freeAndAllocate(size_t Size, void *HstPtr) {
146     std::vector<void *> RemoveList;
147 
148     // Deallocate all memory in FreeList
149     for (int I = 0; I < NumBuckets; ++I) {
150       FreeListTy &List = FreeLists[I];
151       std::lock_guard<std::mutex> Lock(FreeListLocks[I]);
152       if (List.empty())
153         continue;
154       for (const NodeTy &N : List) {
155         deleteOnDevice(N.Ptr);
156         RemoveList.push_back(N.Ptr);
157       }
158       FreeLists[I].clear();
159     }
160 
161     // Remove all nodes in the map table which have been released
162     if (!RemoveList.empty()) {
163       std::lock_guard<std::mutex> LG(MapTableLock);
164       for (void *P : RemoveList)
165         PtrToNodeTable.erase(P);
166     }
167 
168     // Try allocate memory again
169     return allocateOnDevice(Size, HstPtr);
170   }
171 
172   /// The goal is to allocate memory on the device. It first tries to
173   /// allocate directly on the device. If a \p nullptr is returned, it might
174   /// be because the device is OOM. In that case, it will free all unused
175   /// memory and then try again.
allocateOrFreeAndAllocateOnDevice(size_t Size,void * HstPtr)176   void *allocateOrFreeAndAllocateOnDevice(size_t Size, void *HstPtr) {
177     void *TgtPtr = allocateOnDevice(Size, HstPtr);
178     // We cannot get memory from the device. It might be due to OOM. Let's
179     // free all memory in FreeLists and try again.
180     if (TgtPtr == nullptr) {
181       DP("Failed to get memory on device. Free all memory in FreeLists and "
182          "try again.\n");
183       TgtPtr = freeAndAllocate(Size, HstPtr);
184     }
185 
186     if (TgtPtr == nullptr)
187       DP("Still cannot get memory on device probably because the device is "
188          "OOM.\n");
189 
190     return TgtPtr;
191   }
192 
193 public:
194   /// Constructor. If \p Threshold is non-zero, then the default threshold will
195   /// be overwritten by \p Threshold.
196   MemoryManagerTy(DeviceAllocatorTy &DeviceAllocator, size_t Threshold = 0)
FreeLists(NumBuckets)197       : FreeLists(NumBuckets), FreeListLocks(NumBuckets),
198         DeviceAllocator(DeviceAllocator) {
199     if (Threshold)
200       SizeThreshold = Threshold;
201   }
202 
203   /// Destructor
~MemoryManagerTy()204   ~MemoryManagerTy() {
205     for (auto Itr = PtrToNodeTable.begin(); Itr != PtrToNodeTable.end();
206          ++Itr) {
207       assert(Itr->second.Ptr && "nullptr in map table");
208       deleteOnDevice(Itr->second.Ptr);
209     }
210   }
211 
212   /// Allocate memory of size \p Size from target device. \p HstPtr is used to
213   /// assist the allocation.
allocate(size_t Size,void * HstPtr)214   void *allocate(size_t Size, void *HstPtr) {
215     // If the size is zero, we will not bother the target device. Just return
216     // nullptr directly.
217     if (Size == 0)
218       return nullptr;
219 
220     DP("MemoryManagerTy::allocate: size %zu with host pointer " DPxMOD ".\n",
221        Size, DPxPTR(HstPtr));
222 
223     // If the size is greater than the threshold, allocate it directly from
224     // device.
225     if (Size > SizeThreshold) {
226       DP("%zu is greater than the threshold %zu. Allocate it directly from "
227          "device\n",
228          Size, SizeThreshold);
229       void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
230 
231       DP("Got target pointer " DPxMOD ". Return directly.\n", DPxPTR(TgtPtr));
232 
233       return TgtPtr;
234     }
235 
236     NodeTy *NodePtr = nullptr;
237 
238     // Try to get a node from FreeList
239     {
240       const int B = findBucket(Size);
241       FreeListTy &List = FreeLists[B];
242 
243       NodeTy TempNode(Size, nullptr);
244       std::lock_guard<std::mutex> LG(FreeListLocks[B]);
245       const auto Itr = List.find(TempNode);
246 
247       if (Itr != List.end()) {
248         NodePtr = &Itr->get();
249         List.erase(Itr);
250       }
251     }
252 
253     if (NodePtr != nullptr)
254       DP("Find one node " DPxMOD " in the bucket.\n", DPxPTR(NodePtr));
255 
256     // We cannot find a valid node in FreeLists. Let's allocate on device and
257     // create a node for it.
258     if (NodePtr == nullptr) {
259       DP("Cannot find a node in the FreeLists. Allocate on device.\n");
260       // Allocate one on device
261       void *TgtPtr = allocateOrFreeAndAllocateOnDevice(Size, HstPtr);
262 
263       if (TgtPtr == nullptr)
264         return nullptr;
265 
266       // Create a new node and add it into the map table
267       {
268         std::lock_guard<std::mutex> Guard(MapTableLock);
269         auto Itr = PtrToNodeTable.emplace(TgtPtr, NodeTy(Size, TgtPtr));
270         NodePtr = &Itr.first->second;
271       }
272 
273       DP("Node address " DPxMOD ", target pointer " DPxMOD ", size %zu\n",
274          DPxPTR(NodePtr), DPxPTR(TgtPtr), Size);
275     }
276 
277     assert(NodePtr && "NodePtr should not be nullptr at this point");
278 
279     return NodePtr->Ptr;
280   }
281 
282   /// Deallocate memory pointed by \p TgtPtr
free(void * TgtPtr)283   int free(void *TgtPtr) {
284     DP("MemoryManagerTy::free: target memory " DPxMOD ".\n", DPxPTR(TgtPtr));
285 
286     NodeTy *P = nullptr;
287 
288     // Look it up into the table
289     {
290       std::lock_guard<std::mutex> G(MapTableLock);
291       auto Itr = PtrToNodeTable.find(TgtPtr);
292 
293       // We don't remove the node from the map table because the map does not
294       // change.
295       if (Itr != PtrToNodeTable.end())
296         P = &Itr->second;
297     }
298 
299     // The memory is not managed by the manager
300     if (P == nullptr) {
301       DP("Cannot find its node. Delete it on device directly.\n");
302       return deleteOnDevice(TgtPtr);
303     }
304 
305     // Insert the node to the free list
306     const int B = findBucket(P->Size);
307 
308     DP("Found its node " DPxMOD ". Insert it to bucket %d.\n", DPxPTR(P), B);
309 
310     {
311       std::lock_guard<std::mutex> G(FreeListLocks[B]);
312       FreeLists[B].insert(*P);
313     }
314 
315     return OFFLOAD_SUCCESS;
316   }
317 
318   /// Get the size threshold from the environment variable
319   /// \p LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD . Returns a <tt>
320   /// std::pair<size_t, bool> </tt> where the first element represents the
321   /// threshold and the second element represents whether user disables memory
322   /// manager explicitly by setting the var to 0. If user doesn't specify
323   /// anything, returns <0, true>.
getSizeThresholdFromEnv()324   static std::pair<size_t, bool> getSizeThresholdFromEnv() {
325     size_t Threshold = 0;
326 
327     if (const char *Env =
328             std::getenv("LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD")) {
329       Threshold = std::stoul(Env);
330       if (Threshold == 0) {
331         DP("Disabled memory manager as user set "
332            "LIBOMPTARGET_MEMORY_MANAGER_THRESHOLD=0.\n");
333         return std::make_pair(0, false);
334       }
335     }
336 
337     return std::make_pair(Threshold, true);
338   }
339 };
340 
341 // GCC still cannot handle the static data member like Clang so we still need
342 // this part.
343 constexpr const size_t MemoryManagerTy::BucketSize[];
344 constexpr const int MemoryManagerTy::NumBuckets;
345 
346 #endif // LLVM_OPENMP_LIBOMPTARGET_PLUGINS_COMMON_MEMORYMANAGER_MEMORYMANAGER_H
347