1 #pragma once 2 #include <cstdlib> 3 #include <new> 4 #ifdef _MSC_VER 5 // Ensure _HAS_EXCEPTIONS is defined 6 #include <vcruntime.h> 7 #include <malloc.h> 8 #endif 9 10 #if !((defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS)) 11 #include <cstdlib> 12 #endif 13 14 // Aligned simple vector. 15 16 namespace intgemm { 17 18 template <class T> class AlignedVector { 19 public: AlignedVector()20 AlignedVector() : mem_(nullptr), size_(0) {} 21 22 explicit AlignedVector(std::size_t size, std::size_t alignment = 64 /* CPU cares about this */) size_(size)23 : size_(size) { 24 #ifdef _MSC_VER 25 mem_ = static_cast<T*>(_aligned_malloc(size * sizeof(T), alignment)); 26 if (!mem_) { 27 # if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS) 28 throw std::bad_alloc(); 29 # else 30 std::abort(); 31 # endif 32 } 33 #else 34 if (posix_memalign(reinterpret_cast<void **>(&mem_), alignment, size * sizeof(T))) { 35 # if (defined(_MSC_VER) && !defined(__clang__)) ? (_HAS_EXCEPTIONS) : (__EXCEPTIONS) 36 throw std::bad_alloc(); 37 # else 38 std::abort(); 39 # endif 40 } 41 #endif 42 } 43 AlignedVector(AlignedVector && from)44 AlignedVector(AlignedVector &&from) : mem_(from.mem_), size_(from.size_) { 45 from.mem_ = nullptr; 46 from.size_ = 0; 47 } 48 49 AlignedVector &operator=(AlignedVector &&from) { 50 if (this == &from) return *this; 51 release(); 52 mem_ = from.mem_; 53 size_ = from.size_; 54 from.mem_ = nullptr; 55 from.size_ = 0; 56 return *this; 57 } 58 59 AlignedVector(const AlignedVector&) = delete; 60 AlignedVector& operator=(const AlignedVector&) = delete; 61 ~AlignedVector()62 ~AlignedVector() { release(); } 63 size()64 std::size_t size() const { return size_; } 65 66 T &operator[](std::size_t offset) { return mem_[offset]; } 67 const T &operator[](std::size_t offset) const { return mem_[offset]; } 68 begin()69 T *begin() { return mem_; } begin()70 const T *begin() const { return mem_; } end()71 T *end() { return mem_ + size_; } end()72 const T *end() const { return mem_ + size_; } 73 74 template <typename ReturnType> as()75 ReturnType *as() { return reinterpret_cast<ReturnType*>(mem_); } 76 77 private: 78 T *mem_; 79 std::size_t size_; 80 release()81 void release() { 82 #ifdef _MSC_VER 83 _aligned_free(mem_); 84 #else 85 std::free(mem_); 86 #endif 87 } 88 }; 89 90 } // namespace intgemm 91