1 //===- StorageUniquer.cpp - Common Storage Class Uniquer ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Support/StorageUniquer.h"
10 
11 #include "mlir/Support/LLVM.h"
12 #include "mlir/Support/ThreadLocalCache.h"
13 #include "mlir/Support/TypeID.h"
14 #include "llvm/Support/RWMutex.h"
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 namespace {
20 /// This class represents a uniquer for storage instances of a specific type
21 /// that has parametric storage. It contains all of the necessary data to unique
22 /// storage instances in a thread safe way. This allows for the main uniquer to
23 /// bucket each of the individual sub-types removing the need to lock the main
24 /// uniquer itself.
25 class ParametricStorageUniquer {
26 public:
27   using BaseStorage = StorageUniquer::BaseStorage;
28   using StorageAllocator = StorageUniquer::StorageAllocator;
29 
30   /// A lookup key for derived instances of storage objects.
31   struct LookupKey {
32     /// The known hash value of the key.
33     unsigned hashValue;
34 
35     /// An equality function for comparing with an existing storage instance.
36     function_ref<bool(const BaseStorage *)> isEqual;
37   };
38 
39 private:
40   /// A utility wrapper object representing a hashed storage object. This class
41   /// contains a storage object and an existing computed hash value.
42   struct HashedStorage {
HashedStorage__anon59e4b6ac0111::ParametricStorageUniquer::HashedStorage43     HashedStorage(unsigned hashValue = 0, BaseStorage *storage = nullptr)
44         : hashValue(hashValue), storage(storage) {}
45     unsigned hashValue;
46     BaseStorage *storage;
47   };
48 
49   /// Storage info for derived TypeStorage objects.
50   struct StorageKeyInfo : DenseMapInfo<HashedStorage> {
getEmptyKey__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo51     static HashedStorage getEmptyKey() {
52       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getEmptyKey());
53     }
getTombstoneKey__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo54     static HashedStorage getTombstoneKey() {
55       return HashedStorage(0, DenseMapInfo<BaseStorage *>::getTombstoneKey());
56     }
57 
getHashValue__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo58     static unsigned getHashValue(const HashedStorage &key) {
59       return key.hashValue;
60     }
getHashValue__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo61     static unsigned getHashValue(LookupKey key) { return key.hashValue; }
62 
isEqual__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo63     static bool isEqual(const HashedStorage &lhs, const HashedStorage &rhs) {
64       return lhs.storage == rhs.storage;
65     }
isEqual__anon59e4b6ac0111::ParametricStorageUniquer::StorageKeyInfo66     static bool isEqual(const LookupKey &lhs, const HashedStorage &rhs) {
67       if (isEqual(rhs, getEmptyKey()) || isEqual(rhs, getTombstoneKey()))
68         return false;
69       // Invoke the equality function on the lookup key.
70       return lhs.isEqual(rhs.storage);
71     }
72   };
73   using StorageTypeSet = DenseSet<HashedStorage, StorageKeyInfo>;
74 
75   /// This class represents a single shard of the uniquer. The uniquer uses a
76   /// set of shards to allow for multiple threads to create instances with less
77   /// lock contention.
78   struct Shard {
79     /// The set containing the allocated storage instances.
80     StorageTypeSet instances;
81 
82     /// Allocator to use when constructing derived instances.
83     StorageAllocator allocator;
84 
85 #if LLVM_ENABLE_THREADS != 0
86     /// A mutex to keep uniquing thread-safe.
87     llvm::sys::SmartRWMutex<true> mutex;
88 #endif
89   };
90 
91   /// Get or create an instance of a param derived type in an thread-unsafe
92   /// fashion.
93   BaseStorage *
getOrCreateUnsafe(Shard & shard,LookupKey & key,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)94   getOrCreateUnsafe(Shard &shard, LookupKey &key,
95                     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
96     auto existing = shard.instances.insert_as({key.hashValue}, key);
97     BaseStorage *&storage = existing.first->storage;
98     if (existing.second)
99       storage = ctorFn(shard.allocator);
100     return storage;
101   }
102 
103   /// Destroy all of the storage instances within the given shard.
destroyShardInstances(Shard & shard)104   void destroyShardInstances(Shard &shard) {
105     if (!destructorFn)
106       return;
107     for (HashedStorage &instance : shard.instances)
108       destructorFn(instance.storage);
109   }
110 
111 public:
112 #if LLVM_ENABLE_THREADS != 0
113   /// Initialize the storage uniquer with a given number of storage shards to
114   /// use. The provided shard number is required to be a valid power of 2. The
115   /// destructor function is used to destroy any allocated storage instances.
ParametricStorageUniquer(function_ref<void (BaseStorage *)> destructorFn,size_t numShards=8)116   ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
117                            size_t numShards = 8)
118       : shards(new std::atomic<Shard *>[numShards]), numShards(numShards),
119         destructorFn(destructorFn) {
120     assert(llvm::isPowerOf2_64(numShards) &&
121            "the number of shards is required to be a power of 2");
122     for (size_t i = 0; i < numShards; i++)
123       shards[i].store(nullptr, std::memory_order_relaxed);
124   }
~ParametricStorageUniquer()125   ~ParametricStorageUniquer() {
126     // Free all of the allocated shards.
127     for (size_t i = 0; i != numShards; ++i) {
128       if (Shard *shard = shards[i].load()) {
129         destroyShardInstances(*shard);
130         delete shard;
131       }
132     }
133   }
134   /// Get or create an instance of a parametric type.
135   BaseStorage *
getOrCreate(bool threadingIsEnabled,unsigned hashValue,function_ref<bool (const BaseStorage *)> isEqual,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)136   getOrCreate(bool threadingIsEnabled, unsigned hashValue,
137               function_ref<bool(const BaseStorage *)> isEqual,
138               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
139     Shard &shard = getShard(hashValue);
140     ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
141     if (!threadingIsEnabled)
142       return getOrCreateUnsafe(shard, lookupKey, ctorFn);
143 
144     // Check for a instance of this object in the local cache.
145     auto localIt = localCache->insert_as({hashValue}, lookupKey);
146     BaseStorage *&localInst = localIt.first->storage;
147     if (localInst)
148       return localInst;
149 
150     // Check for an existing instance in read-only mode.
151     {
152       llvm::sys::SmartScopedReader<true> typeLock(shard.mutex);
153       auto it = shard.instances.find_as(lookupKey);
154       if (it != shard.instances.end())
155         return localInst = it->storage;
156     }
157 
158     // Acquire a writer-lock so that we can safely create the new storage
159     // instance.
160     llvm::sys::SmartScopedWriter<true> typeLock(shard.mutex);
161     return localInst = getOrCreateUnsafe(shard, lookupKey, ctorFn);
162   }
163   /// Run a mutation function on the provided storage object in a thread-safe
164   /// way.
165   LogicalResult
mutate(bool threadingIsEnabled,BaseStorage * storage,function_ref<LogicalResult (StorageAllocator &)> mutationFn)166   mutate(bool threadingIsEnabled, BaseStorage *storage,
167          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
168     Shard &shard = getShardFor(storage);
169     if (!threadingIsEnabled)
170       return mutationFn(shard.allocator);
171 
172     llvm::sys::SmartScopedWriter<true> lock(shard.mutex);
173     return mutationFn(shard.allocator);
174   }
175 
176 private:
177   /// Return the shard used for the given hash value.
getShard(unsigned hashValue)178   Shard &getShard(unsigned hashValue) {
179     // Get a shard number from the provided hashvalue.
180     unsigned shardNum = hashValue & (numShards - 1);
181 
182     // Try to acquire an already initialized shard.
183     Shard *shard = shards[shardNum].load(std::memory_order_acquire);
184     if (shard)
185       return *shard;
186 
187     // Otherwise, try to allocate a new shard.
188     Shard *newShard = new Shard();
189     if (shards[shardNum].compare_exchange_strong(shard, newShard))
190       return *newShard;
191 
192     // If one was allocated before we can initialize ours, delete ours.
193     delete newShard;
194     return *shard;
195   }
196 
197   /// Return the shard that allocated the provided storage object.
getShardFor(BaseStorage * storage)198   Shard &getShardFor(BaseStorage *storage) {
199     for (size_t i = 0; i != numShards; ++i) {
200       if (Shard *shard = shards[i].load(std::memory_order_acquire)) {
201         llvm::sys::SmartScopedReader<true> lock(shard->mutex);
202         if (shard->allocator.allocated(storage))
203           return *shard;
204       }
205     }
206     llvm_unreachable("expected storage object to have a valid shard");
207   }
208 
209   /// A thread local cache for storage objects. This helps to reduce the lock
210   /// contention when an object already existing in the cache.
211   ThreadLocalCache<StorageTypeSet> localCache;
212 
213   /// A set of uniquer shards to allow for further bucketing accesses for
214   /// instances of this storage type. Each shard is lazily initialized to reduce
215   /// the overhead when only a small amount of shards are in use.
216   std::unique_ptr<std::atomic<Shard *>[]> shards;
217 
218   /// The number of available shards.
219   size_t numShards;
220 
221   /// Function to used to destruct any allocated storage instances.
222   function_ref<void(BaseStorage *)> destructorFn;
223 
224 #else
225   /// If multi-threading is disabled, ignore the shard parameter as we will
226   /// always use one shard. The destructor function is used to destroy any
227   /// allocated storage instances.
228   ParametricStorageUniquer(function_ref<void(BaseStorage *)> destructorFn,
229                            size_t numShards = 0)
230       : destructorFn(destructorFn) {}
231   ~ParametricStorageUniquer() { destroyShardInstances(shard); }
232 
233   /// Get or create an instance of a parametric type.
234   BaseStorage *
235   getOrCreate(bool threadingIsEnabled, unsigned hashValue,
236               function_ref<bool(const BaseStorage *)> isEqual,
237               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
238     ParametricStorageUniquer::LookupKey lookupKey{hashValue, isEqual};
239     return getOrCreateUnsafe(shard, lookupKey, ctorFn);
240   }
241   /// Run a mutation function on the provided storage object in a thread-safe
242   /// way.
243   LogicalResult
244   mutate(bool threadingIsEnabled, BaseStorage *storage,
245          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
246     return mutationFn(shard.allocator);
247   }
248 
249 private:
250   /// The main uniquer shard that is used for allocating storage instances.
251   Shard shard;
252 
253   /// Function to used to destruct any allocated storage instances.
254   function_ref<void(BaseStorage *)> destructorFn;
255 #endif
256 };
257 } // end anonymous namespace
258 
259 namespace mlir {
260 namespace detail {
261 /// This is the implementation of the StorageUniquer class.
262 struct StorageUniquerImpl {
263   using BaseStorage = StorageUniquer::BaseStorage;
264   using StorageAllocator = StorageUniquer::StorageAllocator;
265 
266   //===--------------------------------------------------------------------===//
267   // Parametric Storage
268   //===--------------------------------------------------------------------===//
269 
270   /// Check if an instance of a parametric storage class exists.
hasParametricStoragemlir::detail::StorageUniquerImpl271   bool hasParametricStorage(TypeID id) { return parametricUniquers.count(id); }
272 
273   /// Get or create an instance of a parametric type.
274   BaseStorage *
getOrCreatemlir::detail::StorageUniquerImpl275   getOrCreate(TypeID id, unsigned hashValue,
276               function_ref<bool(const BaseStorage *)> isEqual,
277               function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
278     assert(parametricUniquers.count(id) &&
279            "creating unregistered storage instance");
280     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
281     return storageUniquer.getOrCreate(threadingIsEnabled, hashValue, isEqual,
282                                       ctorFn);
283   }
284 
285   /// Run a mutation function on the provided storage object in a thread-safe
286   /// way.
287   LogicalResult
mutatemlir::detail::StorageUniquerImpl288   mutate(TypeID id, BaseStorage *storage,
289          function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
290     assert(parametricUniquers.count(id) &&
291            "mutating unregistered storage instance");
292     ParametricStorageUniquer &storageUniquer = *parametricUniquers[id];
293     return storageUniquer.mutate(threadingIsEnabled, storage, mutationFn);
294   }
295 
296   //===--------------------------------------------------------------------===//
297   // Singleton Storage
298   //===--------------------------------------------------------------------===//
299 
300   /// Get or create an instance of a singleton storage class.
getSingletonmlir::detail::StorageUniquerImpl301   BaseStorage *getSingleton(TypeID id) {
302     BaseStorage *singletonInstance = singletonInstances[id];
303     assert(singletonInstance && "expected singleton instance to exist");
304     return singletonInstance;
305   }
306 
307   /// Check if an instance of a singleton storage class exists.
hasSingletonmlir::detail::StorageUniquerImpl308   bool hasSingleton(TypeID id) const { return singletonInstances.count(id); }
309 
310   //===--------------------------------------------------------------------===//
311   // Instance Storage
312   //===--------------------------------------------------------------------===//
313 
314   /// Map of type ids to the storage uniquer to use for registered objects.
315   DenseMap<TypeID, std::unique_ptr<ParametricStorageUniquer>>
316       parametricUniquers;
317 
318   /// Map of type ids to a singleton instance when the storage class is a
319   /// singleton.
320   DenseMap<TypeID, BaseStorage *> singletonInstances;
321 
322   /// Allocator used for uniquing singleton instances.
323   StorageAllocator singletonAllocator;
324 
325   /// Flag specifying if multi-threading is enabled within the uniquer.
326   bool threadingIsEnabled = true;
327 };
328 } // end namespace detail
329 } // namespace mlir
330 
StorageUniquer()331 StorageUniquer::StorageUniquer() : impl(new StorageUniquerImpl()) {}
~StorageUniquer()332 StorageUniquer::~StorageUniquer() {}
333 
334 /// Set the flag specifying if multi-threading is disabled within the uniquer.
disableMultithreading(bool disable)335 void StorageUniquer::disableMultithreading(bool disable) {
336   impl->threadingIsEnabled = !disable;
337 }
338 
339 /// Implementation for getting/creating an instance of a derived type with
340 /// parametric storage.
getParametricStorageTypeImpl(TypeID id,unsigned hashValue,function_ref<bool (const BaseStorage *)> isEqual,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)341 auto StorageUniquer::getParametricStorageTypeImpl(
342     TypeID id, unsigned hashValue,
343     function_ref<bool(const BaseStorage *)> isEqual,
344     function_ref<BaseStorage *(StorageAllocator &)> ctorFn) -> BaseStorage * {
345   return impl->getOrCreate(id, hashValue, isEqual, ctorFn);
346 }
347 
348 /// Implementation for registering an instance of a derived type with
349 /// parametric storage.
registerParametricStorageTypeImpl(TypeID id,function_ref<void (BaseStorage *)> destructorFn)350 void StorageUniquer::registerParametricStorageTypeImpl(
351     TypeID id, function_ref<void(BaseStorage *)> destructorFn) {
352   impl->parametricUniquers.try_emplace(
353       id, std::make_unique<ParametricStorageUniquer>(destructorFn));
354 }
355 
356 /// Implementation for getting an instance of a derived type with default
357 /// storage.
getSingletonImpl(TypeID id)358 auto StorageUniquer::getSingletonImpl(TypeID id) -> BaseStorage * {
359   return impl->getSingleton(id);
360 }
361 
362 /// Test is the storage singleton is initialized.
isSingletonStorageInitialized(TypeID id)363 bool StorageUniquer::isSingletonStorageInitialized(TypeID id) {
364   return impl->hasSingleton(id);
365 }
366 
367 /// Test is the parametric storage is initialized.
isParametricStorageInitialized(TypeID id)368 bool StorageUniquer::isParametricStorageInitialized(TypeID id) {
369   return impl->hasParametricStorage(id);
370 }
371 
372 /// Implementation for registering an instance of a derived type with default
373 /// storage.
registerSingletonImpl(TypeID id,function_ref<BaseStorage * (StorageAllocator &)> ctorFn)374 void StorageUniquer::registerSingletonImpl(
375     TypeID id, function_ref<BaseStorage *(StorageAllocator &)> ctorFn) {
376   assert(!impl->singletonInstances.count(id) &&
377          "storage class already registered");
378   impl->singletonInstances.try_emplace(id, ctorFn(impl->singletonAllocator));
379 }
380 
381 /// Implementation for mutating an instance of a derived storage.
mutateImpl(TypeID id,BaseStorage * storage,function_ref<LogicalResult (StorageAllocator &)> mutationFn)382 LogicalResult StorageUniquer::mutateImpl(
383     TypeID id, BaseStorage *storage,
384     function_ref<LogicalResult(StorageAllocator &)> mutationFn) {
385   return impl->mutate(id, storage, mutationFn);
386 }
387