1 //===--- acxxel.h - The Acxxel API ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 /// \mainpage Welcome to Acxxel
10 ///
11 /// \section Introduction
12 ///
13 /// \b Acxxel is a library providing a modern C++ interface for managing
14 /// accelerator devices such as GPUs. Acxxel handles operations such as
15 /// allocating device memory, copying data to and from device memory, creating
16 /// and managing device events, and creating and managing device streams.
17 ///
18 /// \subsection ExampleUsage Example Usage
19 ///
20 /// Below is some example code to show you the basics of Acxxel.
21 ///
22 /// \snippet examples/simple_example.cu Example simple saxpy
23 ///
24 /// The above code could be compiled with either `clang` or `nvcc`. Compare this
25 /// with the standard CUDA runtime library code to perform these same
26 /// operations:
27 ///
28 /// \snippet examples/simple_example.cu Example CUDA simple saxpy
29 ///
30 /// Notice that the CUDA runtime calls are not type safe. For example, if you
31 /// change the type of the inputs from `float` to `double`, you have to remember
32 /// to change the size calculation. If you forget, you will get garbage output
33 /// data. In the Acxxel example, you would instead get a helpful compile-time
34 /// error that wouldn't let you forget to change the types inside the function.
35 ///
36 /// The Acxxel example also automatically uses the right sizes for memory
37 /// copies, so you don't have to worry about computing the sizes yourself.
38 ///
39 /// The CUDA runtime interface makes it easy to get the source and destination
40 /// mixed up in a call to `cudaMemcpy`. If you pass the pointers in the wrong
41 /// order or pass the wrong enum value for the direction parameter, you won't
42 /// find out until runtime (if you remembered to check the error return value of
43 /// `cudaMemcpy`). In Acxxel there is no verbose direction enum because the name
44 /// of the function says which way the copy goes, and mixing up the order of
45 /// source and destination is a compile-time error.
46 ///
47 /// The CUDA runtime interface makes you clean up your device memory by calling
48 /// `cudaFree` for each call to `cudaMalloc`. In Acxxel, you don't have to worry
49 /// about that because the memory cleans itself up when it goes out of scope.
50 ///
51 /// \subsection AcxxelFeatures Acxxel Features
52 ///
53 /// Acxxel provides many nice features compared to the C-like interfaces, such
54 /// as the CUDA runtime API, which are normally used for the host code in
55 /// applications using accelerators.
56 ///
57 /// \subsubsection TypeSafety Type safety
58 ///
59 /// Most errors involving mixing up types, sources and destinations, or host and
60 /// device memory result in helpful compile-time errors.
61 ///
62 /// \subsubsection NoCopySizes No need to specify sizes for memory copies
63 ///
64 /// When the arguments to copy functions such as acxxel::Platform::copyHToD know
65 /// their sizes (e.g std::array, std::vector, and C-style arrays), there is no
66 /// need to specify the amount of memory to copy; Acxxel will just copy the
67 /// whole thing. Of course the copy functions also have overloads that accept an
68 /// element count for those times when you don't want to copy everything.
69 ///
70 /// \subsubsection MemoryCleanup Automatic memory cleanup
71 ///
72 /// Device memory allocated with acxxel::Platform::mallocD is automatically
73 /// freed when it goes out of scope.
74 ///
75 /// \subsubsection NiceErrorHandling Error handling
76 ///
77 /// Operations that would normally return values return acxxel::Expected obects
78 /// in Acxxel. These `Expected` objects contain either a value or an error
79 /// message explaining why the value is not present. This reminds the user to
80 /// check for errors, but also allows them to opt-out easily be calling the
81 /// acxxel::Expected::getValue or acxxel::Expected::takeValue methods. The
82 /// `getValue` method returns a reference to the value, leaving the `Expected`
83 /// instance as the value owner, whereas the `takeValue` method moves the value
84 /// out of the `Expected` object and transfers ownership to the caller.
85 ///
86 /// \subsubsection PlatformIndependence Platform independence
87 ///
88 /// Acxxel code works not only with CUDA, but also with any other platform that
89 /// can support its interface. For example, Acxxel supports OpenCL. The
90 /// acxxel::getCUDAPlatform and acxxel::getOpenCLPlatform functions are provided
91 /// to allow easy access to the built-in CUDA and OpenCL platforms. Other
92 /// platforms can be created by implementing the acxxel::Platform interface, and
93 /// instances of those classes can be created directly.
94 ///
95 /// \subsubsection CUDAInterop Seamless interoperation with CUDA
96 ///
97 /// Acxxel functions as a modern replacement for the standard CUDA runtime
98 /// library and interoperates seamlessly with kernel calls.
99 
100 #ifndef ACXXEL_ACXXEL_H
101 #define ACXXEL_ACXXEL_H
102 
103 #include "span.h"
104 #include "status.h"
105 
106 #include <functional>
107 #include <memory>
108 #include <string>
109 #include <type_traits>
110 
111 #if defined(__clang__) || defined(__GNUC__)
112 #define ACXXEL_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
113 #else
114 #define ACXXEL_WARN_UNUSED_RESULT
115 #endif
116 
117 /// This type is declared here to provide smooth interoperability with the CUDA
118 /// triple-chevron kernel launch syntax.
119 ///
120 /// A acxxel::Stream instance will be implicitly convertible to a CUstream_st*,
121 /// which is the type expected for the stream argument in the triple-chevron
122 /// CUDA kernel launch. This means that a acxxel::Stream can be passed without
123 /// explicit casting as the fourth argument to a triple-chevron CUDA kernel
124 /// launch.
125 struct CUstream_st; // NOLINT
126 
127 namespace acxxel {
128 
129 class Event;
130 class Platform;
131 class Stream;
132 
133 template <typename T> class DeviceMemory;
134 
135 template <typename T> class DeviceMemorySpan;
136 
137 template <typename T> class AsyncHostMemory;
138 
139 template <typename T> class AsyncHostMemorySpan;
140 
141 template <typename T> class OwnedAsyncHostMemory;
142 
143 /// Function type used to destroy opaque handles given out by the platform.
144 using HandleDestructor = void (*)(void *);
145 
146 /// Functor type for enqueuing host callbacks on a stream.
147 using StreamCallback = std::function<void(Stream &, const Status &)>;
148 
149 struct KernelLaunchDimensions {
150   // Intentionally implicit
151   KernelLaunchDimensions(unsigned int BlockX = 1, unsigned int BlockY = 1,
152                          unsigned int BlockZ = 1, unsigned int GridX = 1,
153                          unsigned int GridY = 1, unsigned int GridZ = 1)
BlockXKernelLaunchDimensions154       : BlockX(BlockX), BlockY(BlockY), BlockZ(BlockZ), GridX(GridX),
155         GridY(GridY), GridZ(GridZ) {}
156 
157   unsigned int BlockX;
158   unsigned int BlockY;
159   unsigned int BlockZ;
160   unsigned int GridX;
161   unsigned int GridY;
162   unsigned int GridZ;
163 };
164 
165 /// Logs a warning message.
166 void logWarning(const std::string &Message);
167 
168 /// Gets a pointer to the standard CUDA platform.
169 Expected<Platform *> getCUDAPlatform();
170 
171 /// Gets a pointer to the standard OpenCL platform.
172 Expected<Platform *> getOpenCLPlatform();
173 
174 /// A function that can be executed on the device.
175 ///
176 /// A Kernel is created from a Program by calling Program::createKernel, and a
177 /// kernel is enqueued into a Stream by calling Stream::asyncKernelLaunch.
178 class Kernel {
179 public:
180   Kernel(const Kernel &) = delete;
181   Kernel &operator=(const Kernel &) = delete;
182   Kernel(Kernel &&) noexcept;
183   Kernel &operator=(Kernel &&That) noexcept;
184   ~Kernel() = default;
185 
186 private:
187   // Only a Program can make a kernel.
188   friend class Program;
Kernel(Platform * APlatform,void * AHandle,HandleDestructor Destructor)189   Kernel(Platform *APlatform, void *AHandle, HandleDestructor Destructor)
190       : ThePlatform(APlatform), TheHandle(AHandle, Destructor) {}
191 
192   // Let stream get raw handle for kernel launches.
193   friend class Stream;
194 
195   Platform *ThePlatform;
196   std::unique_ptr<void, HandleDestructor> TheHandle;
197 };
198 
199 /// A program loaded on a device.
200 ///
201 /// A program can be created by calling Platform::createProgramFromSource, and a
202 /// Kernel can be created from a program by running Program::createKernel.
203 ///
204 /// A program can contain any number of kernels, and a program only needs to be
205 /// loaded once in order to use all its kernels.
206 class Program {
207 public:
208   Program(const Program &) = delete;
209   Program &operator=(const Program &) = delete;
210   Program(Program &&) noexcept;
211   Program &operator=(Program &&That) noexcept;
212   ~Program() = default;
213 
214   Expected<Kernel> createKernel(const std::string &Name);
215 
216 private:
217   // Only a platform can make a program.
218   friend class Platform;
Program(Platform * APlatform,void * AHandle,HandleDestructor Destructor)219   Program(Platform *APlatform, void *AHandle, HandleDestructor Destructor)
220       : ThePlatform(APlatform), TheHandle(AHandle, Destructor) {}
221 
222   Platform *ThePlatform;
223   std::unique_ptr<void, HandleDestructor> TheHandle;
224 };
225 
226 /// A stream of computation.
227 ///
228 /// All operations enqueued on a Stream are serialized, but operations enqueued
229 /// on different Streams may run concurrently.
230 ///
231 /// Each Stream is associated with a specific, fixed device.
232 class Stream {
233 public:
234   Stream(const Stream &) = delete;
235   Stream &operator=(const Stream &) = delete;
236   Stream(Stream &&) noexcept;
237   Stream &operator=(Stream &&) noexcept;
238   ~Stream() = default;
239 
240   /// Gets the index of the device on which this Stream operates.
getDeviceIndex()241   int getDeviceIndex() { return TheDeviceIndex; }
242 
243   /// Blocks the host until the Stream is done executing all previously enqueued
244   /// work.
245   ///
246   /// Returns a Status for any errors emitted by the asynchronous work on the
247   /// Stream, or by any error in the synchronization process itself. Clears the
248   /// Status state of the stream.
249   Status sync() ACXXEL_WARN_UNUSED_RESULT;
250 
251   /// Makes all future work submitted to this stream wait until the event
252   /// reports completion.
253   ///
254   /// This is useful because the event argument may be recorded on a different
255   /// stream, so this method allows for synchronization between streams without
256   /// synchronizing all streams.
257   ///
258   /// Returns a Status for any errors emitted by the asynchronous work on the
259   /// Stream, or by any error in the synchronization process itself. Clears the
260   /// Status state of the stream.
261   Status waitOnEvent(Event &Event) ACXXEL_WARN_UNUSED_RESULT;
262 
263   /// Adds a host callback function to the stream.
264   ///
265   /// The callback will be called on the host after all previously enqueued work
266   /// on the stream is complete, and no work enqueued after the callback will
267   /// begin until after the callback has finished.
268   Stream &addCallback(std::function<void(Stream &, const Status &)> Callback);
269 
270   /// \name Asynchronous device memory copies.
271   ///
272   /// These functions enqueue asynchronous memory copy operations into the
273   /// stream. Only async host memory is allowed for host arguments to these
274   /// functions. Async host memory can be created from normal host memory by
275   /// registering it with Platform::registerHostMem. AsyncHostMemory can also be
276   /// allocated directly by calling Platform::newAsyncHostMem.
277   ///
278   /// For all these functions, DeviceSrcTy must be convertible to
279   /// DeviceMemorySpan<const T>, DeviceDstTy must be convertible to
280   /// DeviceMemorySpan<T>, HostSrcTy must be convertible to
281   /// AsyncHostMemorySpan<const T> and HostDstTy must be convertible to
282   /// AsyncHostMemorySpan<T>. Additionally, the T types must match for the
283   /// destination and source.
284   /// \{
285 
286   /// Copies from device memory to device memory.
287   template <typename DeviceSrcTy, typename DeviceDstTy>
288   Stream &asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst);
289 
290   /// Copies from device memory to device memory with a given element count.
291   template <typename DeviceSrcTy, typename DeviceDstTy>
292   Stream &asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
293                         ptrdiff_t ElementCount);
294 
295   /// Copies from device memory to host memory.
296   template <typename DeviceSrcTy, typename HostDstTy>
297   Stream &asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst);
298 
299   /// Copies from device memory to host memory with a given element count.
300   template <typename DeviceSrcTy, typename HostDstTy>
301   Stream &asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
302                         ptrdiff_t ElementCount);
303 
304   /// Copies from host memory to device memory.
305   template <typename HostSrcTy, typename DeviceDstTy>
306   Stream &asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst);
307 
308   /// Copies from host memory to device memory with a given element count.
309   template <typename HostSrcTy, typename DeviceDstTy>
310   Stream &asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
311                         ptrdiff_t ElementCount);
312 
313   /// \}
314 
315   /// \name Stream-synchronous device memory copies
316   ///
317   /// These functions block the host until the copy and all previously-enqueued
318   /// work on the stream has completed.
319   ///
320   /// For all these functions, DeviceSrcTy must be convertible to
321   /// DeviceMemorySpan<const T>, DeviceDstTy must be convertible to
322   /// DeviceMemorySpan<T>, HostSrcTy must be convertible to Span<const T> and
323   /// HostDstTy must be convertible to Span<T>. Additionally, the T types must
324   /// match for the destination and source.
325   /// \{
326 
327   template <typename DeviceSrcTy, typename DeviceDstTy>
328   Stream &syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst);
329 
330   template <typename DeviceSrcTy, typename DeviceDstTy>
331   Stream &syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
332                        ptrdiff_t ElementCount);
333 
334   template <typename DeviceSrcTy, typename HostDstTy>
335   Stream &syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst);
336 
337   template <typename DeviceSrcTy, typename HostDstTy>
338   Stream &syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
339                        ptrdiff_t ElementCount);
340 
341   template <typename HostSrcTy, typename DeviceDstTy>
342   Stream &syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst);
343 
344   template <typename HostSrcTy, typename DeviceDstTy>
345   Stream &syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
346                        ptrdiff_t ElementCount);
347 
348   /// \}
349 
350   /// Enqueues an operation in the stream to set the bytes of a given device
351   /// memory region to a given value.
352   ///
353   /// DeviceDstTy must be convertible to DeviceMemorySpan<T> for non-const T.
354   template <typename DeviceDstTy>
355   Stream &asyncMemsetD(DeviceDstTy &&DeviceDst, char ByteValue);
356 
357   /// Enqueues a kernel launch operation on this stream.
358   Stream &asyncKernelLaunch(const Kernel &TheKernel,
359                             KernelLaunchDimensions LaunchDimensions,
360                             Span<void *> Arguments, Span<size_t> ArgumentSizes,
361                             size_t SharedMemoryBytes = 0);
362 
363   /// Enqueues an event in the stream.
364   Stream &enqueueEvent(Event &E);
365 
366   // Allows implicit conversion to (CUstream_st *). This makes triple-chevron
367   // kernel calls look nicer because you can just pass a acxxel::Stream
368   // directly.
369   operator CUstream_st *() {
370     return static_cast<CUstream_st *>(TheHandle.get());
371   }
372 
373   /// Gets the current status for the Stream and clears the Stream's status.
takeStatus()374   Status takeStatus() ACXXEL_WARN_UNUSED_RESULT {
375     Status OldStatus = TheStatus;
376     TheStatus = Status();
377     return OldStatus;
378   }
379 
380 private:
381   // Only a platform can make a stream.
382   friend class Platform;
Stream(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)383   Stream(Platform *APlatform, int DeviceIndex, void *AHandle,
384          HandleDestructor Destructor)
385       : ThePlatform(APlatform), TheDeviceIndex(DeviceIndex),
386         TheHandle(AHandle, Destructor) {}
387 
setStatus(const Status & S)388   const Status &setStatus(const Status &S) {
389     if (S.isError() && !TheStatus.isError()) {
390       TheStatus = S;
391     }
392     return S;
393   }
394 
takeStatusOr(const Status & S)395   Status takeStatusOr(const Status &S) {
396     if (TheStatus.isError()) {
397       Status OldStatus = TheStatus;
398       TheStatus = Status();
399       return OldStatus;
400     }
401     return S;
402   }
403 
404   // The platform that created the stream.
405   Platform *ThePlatform;
406 
407   // The index of the device on which the stream operates.
408   int TheDeviceIndex;
409 
410   // A handle to the platform-specific handle implementation.
411   std::unique_ptr<void, HandleDestructor> TheHandle;
412   Status TheStatus;
413 };
414 
415 /// A user-created event on a device.
416 ///
417 /// This is useful for setting synchronization points in a Stream. The host can
418 /// synchronize with a Stream without using events, but that requires all the
419 /// work in the Stream to be finished in order for the host to be notified.
420 /// Events provide more flexibility by allowing the host to be notified when a
421 /// single Event in the Stream is finished, rather than all the work in the
422 /// Stream.
423 class Event {
424 public:
425   Event(const Event &) = delete;
426   Event &operator=(const Event &) = delete;
427   Event(Event &&) noexcept;
428   Event &operator=(Event &&That) noexcept;
429   ~Event() = default;
430 
431   /// Checks to see if the event is done running.
432   bool isDone();
433 
434   /// Blocks the host until the event is done.
435   Status sync();
436 
437   /// Gets the time elapsed between the previous event's execution and this
438   /// event's execution.
439   Expected<float> getSecondsSince(const Event &Previous);
440 
441 private:
442   // Only a platform can make an event.
443   friend class Platform;
Event(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)444   Event(Platform *APlatform, int DeviceIndex, void *AHandle,
445         HandleDestructor Destructor)
446       : ThePlatform(APlatform), TheDeviceIndex(DeviceIndex),
447         TheHandle(AHandle, Destructor) {}
448 
449   Platform *ThePlatform;
450 
451   // The index of the device on which the event can be enqueued.
452   int TheDeviceIndex;
453 
454   std::unique_ptr<void, HandleDestructor> TheHandle;
455 };
456 
457 /// An accelerator platform.
458 ///
459 /// This is the base class for all platforms such as CUDA and OpenCL. It
460 /// contains many virtual methods that must be overridden by each platform
461 /// implementation.
462 ///
463 /// It also has some template wrapper functions that take care of type checking
464 /// and then forward their arguments on to raw virtual functions that are
465 /// implemented by each specific platform.
466 class Platform {
467 public:
~Platform()468   virtual ~Platform(){};
469 
470   /// Gets the number of devices for this platform in this system.
471   virtual Expected<int> getDeviceCount() = 0;
472 
473   /// Creates a stream on the given device for the platform.
474   virtual Expected<Stream> createStream(int DeviceIndex = 0) = 0;
475 
476   /// Creates an event on the given device for the platform.
477   virtual Expected<Event> createEvent(int DeviceIndex = 0) = 0;
478 
479   /// Allocates owned device memory.
480   ///
481   /// \warning This function only allocates space in device memory, it does not
482   /// call the constructor of T.
483   template <typename T>
484   Expected<DeviceMemory<T>> mallocD(ptrdiff_t ElementCount,
485                                     int DeviceIndex = 0) {
486     Expected<void *> MaybePointer =
487         rawMallocD(ElementCount * sizeof(T), DeviceIndex);
488     if (MaybePointer.isError())
489       return MaybePointer.getError();
490     return DeviceMemory<T>(this, MaybePointer.getValue(), ElementCount,
491                            this->getDeviceMemoryHandleDestructor());
492   }
493 
494   /// Creates a DeviceMemorySpan for a device symbol.
495   ///
496   /// This function is present to support __device__ variables in CUDA. Given a
497   /// pointer to a __device__ variable, this function returns a DeviceMemorySpan
498   /// referencing the device memory that stores that __device__ variable.
499   template <typename ElementType>
500   Expected<DeviceMemorySpan<ElementType>> getSymbolMemory(ElementType *Symbol,
501                                                           int DeviceIndex = 0) {
502     Expected<void *> MaybeAddress =
503         rawGetDeviceSymbolAddress(Symbol, DeviceIndex);
504     if (MaybeAddress.isError())
505       return MaybeAddress.getError();
506     ElementType *Address = static_cast<ElementType *>(MaybeAddress.getValue());
507     Expected<ptrdiff_t> MaybeSize = rawGetDeviceSymbolSize(Symbol, DeviceIndex);
508     if (MaybeSize.isError())
509       return MaybeSize.getError();
510     ptrdiff_t Size = MaybeSize.getValue();
511     return DeviceMemorySpan<ElementType>(this, Address,
512                                          Size / sizeof(ElementType), 0);
513   }
514 
515   /// \name Host memory registration functions.
516   /// \{
517 
518   template <typename T>
registerHostMem(Span<const T> Memory)519   Expected<AsyncHostMemory<const T>> registerHostMem(Span<const T> Memory) {
520     Status S = rawRegisterHostMem(Memory.data(), Memory.size() * sizeof(T));
521     if (S.isError())
522       return S;
523     return AsyncHostMemory<const T>(
524         Memory.data(), Memory.size(),
525         this->getUnregisterHostMemoryHandleDestructor());
526   }
527 
528   template <typename T>
registerHostMem(Span<T> Memory)529   Expected<AsyncHostMemory<T>> registerHostMem(Span<T> Memory) {
530     Status S = rawRegisterHostMem(Memory.data(), Memory.size() * sizeof(T));
531     if (S.isError())
532       return S;
533     return AsyncHostMemory<T>(Memory.data(), Memory.size(),
534                               this->getUnregisterHostMemoryHandleDestructor());
535   }
536 
537   template <typename T, size_t N>
registerHostMem(T (& Array)[N])538   Expected<AsyncHostMemory<T>> registerHostMem(T (&Array)[N]) {
539     Span<T> Span(Array);
540     Status S = rawRegisterHostMem(Span.data(), Span.size() * sizeof(T));
541     if (S.isError())
542       return S;
543     return AsyncHostMemory<T>(Span.data(), Span.size(),
544                               this->getUnregisterHostMemoryHandleDestructor());
545   }
546 
547   /// Registers memory stored in a container with a data() member function and
548   /// which can be converted to a Span<T*>.
549   template <typename Container>
550   auto registerHostMem(Container &Cont) -> Expected<AsyncHostMemory<
551       typename std::remove_reference<decltype(*Cont.data())>::type>> {
552     using ValueType =
553         typename std::remove_reference<decltype(*Cont.data())>::type;
554     Span<ValueType> Span(Cont);
555     Status S = rawRegisterHostMem(Span.data(), Span.size() * sizeof(ValueType));
556     if (S.isError())
557       return S;
558     return AsyncHostMemory<ValueType>(
559         Span.data(), Span.size(),
560         this->getUnregisterHostMemoryHandleDestructor());
561   }
562 
563   /// Allocates an owned, registered array of objects on the host.
564   ///
565   /// Default constructs each element in the resulting array.
566   template <typename T>
newAsyncHostMem(ptrdiff_t ElementCount)567   Expected<OwnedAsyncHostMemory<T>> newAsyncHostMem(ptrdiff_t ElementCount) {
568     Expected<void *> MaybeMemory =
569         rawMallocRegisteredH(ElementCount * sizeof(T));
570     if (MaybeMemory.isError())
571       return MaybeMemory.getError();
572     T *Memory = static_cast<T *>(MaybeMemory.getValue());
573     for (ptrdiff_t I = 0; I < ElementCount; ++I)
574       new (Memory + I) T;
575     return OwnedAsyncHostMemory<T>(Memory, ElementCount,
576                                    this->getFreeHostMemoryHandleDestructor());
577   }
578 
579   /// \}
580 
581   virtual Expected<Program> createProgramFromSource(Span<const char> Source,
582                                                     int DeviceIndex = 0) = 0;
583 
584 protected:
585   friend class Stream;
586   friend class Event;
587   friend class Program;
588   template <typename T> friend class DeviceMemorySpan;
589 
getStreamHandle(Stream & Stream)590   void *getStreamHandle(Stream &Stream) { return Stream.TheHandle.get(); }
getEventHandle(Event & Event)591   void *getEventHandle(Event &Event) { return Event.TheHandle.get(); }
592 
593   // Pass along access to Stream constructor to subclasses.
constructStream(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)594   Stream constructStream(Platform *APlatform, int DeviceIndex, void *AHandle,
595                          HandleDestructor Destructor) {
596     return Stream(APlatform, DeviceIndex, AHandle, Destructor);
597   }
598 
599   // Pass along access to Event constructor to subclasses.
constructEvent(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)600   Event constructEvent(Platform *APlatform, int DeviceIndex, void *AHandle,
601                        HandleDestructor Destructor) {
602     return Event(APlatform, DeviceIndex, AHandle, Destructor);
603   }
604 
605   // Pass along access to Program constructor to subclasses.
constructProgram(Platform * APlatform,void * AHandle,HandleDestructor Destructor)606   Program constructProgram(Platform *APlatform, void *AHandle,
607                            HandleDestructor Destructor) {
608     return Program(APlatform, AHandle, Destructor);
609   }
610 
611   virtual Status streamSync(void *Stream) = 0;
612   virtual Status streamWaitOnEvent(void *Stream, void *Event) = 0;
613 
614   virtual Status enqueueEvent(void *Event, void *Stream) = 0;
615   virtual bool eventIsDone(void *Event) = 0;
616   virtual Status eventSync(void *Event) = 0;
617   virtual Expected<float> getSecondsBetweenEvents(void *StartEvent,
618                                                   void *EndEvent) = 0;
619 
620   virtual Expected<void *> rawMallocD(ptrdiff_t ByteCount, int DeviceIndex) = 0;
621   virtual HandleDestructor getDeviceMemoryHandleDestructor() = 0;
622   virtual void *getDeviceMemorySpanHandle(void *BaseHandle, size_t ByteSize,
623                                           size_t ByteOffset) = 0;
624   virtual void rawDestroyDeviceMemorySpanHandle(void *Handle) = 0;
625 
626   virtual Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol,
627                                                      int DeviceIndex) = 0;
628   virtual Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol,
629                                                      int DeviceIndex) = 0;
630 
631   virtual Status rawRegisterHostMem(const void *Memory,
632                                     ptrdiff_t ByteCount) = 0;
633   virtual HandleDestructor getUnregisterHostMemoryHandleDestructor() = 0;
634 
635   virtual Expected<void *> rawMallocRegisteredH(ptrdiff_t ByteCount) = 0;
636   virtual HandleDestructor getFreeHostMemoryHandleDestructor() = 0;
637 
638   virtual Status asyncCopyDToD(const void *DeviceSrc,
639                                ptrdiff_t DeviceSrcByteOffset, void *DeviceDst,
640                                ptrdiff_t DeviceDstByteOffset,
641                                ptrdiff_t ByteCount, void *Stream) = 0;
642   virtual Status asyncCopyDToH(const void *DeviceSrc,
643                                ptrdiff_t DeviceSrcByteOffset, void *HostDst,
644                                ptrdiff_t ByteCount, void *Stream) = 0;
645   virtual Status asyncCopyHToD(const void *HostSrc, void *DeviceDst,
646                                ptrdiff_t DeviceDstByteOffset,
647                                ptrdiff_t ByteCount, void *Stream) = 0;
648 
649   virtual Status asyncMemsetD(void *DeviceDst, ptrdiff_t ByteOffset,
650                               ptrdiff_t ByteCount, char ByteValue,
651                               void *Stream) = 0;
652 
653   virtual Status addStreamCallback(Stream &Stream, StreamCallback Callback) = 0;
654 
655   virtual Expected<void *> rawCreateKernel(void *Program,
656                                            const std::string &Name) = 0;
657   virtual HandleDestructor getKernelHandleDestructor() = 0;
658 
659   virtual Status rawEnqueueKernelLaunch(void *Stream, void *Kernel,
660                                         KernelLaunchDimensions LaunchDimensions,
661                                         Span<void *> Arguments,
662                                         Span<size_t> ArgumentSizes,
663                                         size_t SharedMemoryBytes) = 0;
664 };
665 
666 // Implementation of templated Stream functions.
667 
668 template <typename DeviceSrcTy, typename DeviceDstTy>
asyncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst)669 Stream &Stream::asyncCopyDToD(DeviceSrcTy &&DeviceSrc,
670                               DeviceDstTy &&DeviceDst) {
671   using SrcElementTy =
672       typename std::remove_reference<DeviceSrcTy>::type::value_type;
673   using DstElementTy =
674       typename std::remove_reference<DeviceDstTy>::type::value_type;
675   static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
676                 "asyncCopyDToD cannot copy between arrays of different types");
677   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
678   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
679   if (DeviceSrcSpan.size() != DeviceDstSpan.size()) {
680     setStatus(Status("asyncCopyDToD source element count " +
681                      std::to_string(DeviceSrcSpan.size()) +
682                      " does not equal destination element count " +
683                      std::to_string(DeviceDstSpan.size())));
684     return *this;
685   }
686   setStatus(ThePlatform->asyncCopyDToD(
687       DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
688       DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
689       DeviceSrcSpan.byte_size(), TheHandle.get()));
690   return *this;
691 }
692 
693 template <typename DeviceSrcTy, typename DeviceDstTy>
asyncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst,ptrdiff_t ElementCount)694 Stream &Stream::asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
695                               ptrdiff_t ElementCount) {
696   using SrcElementTy =
697       typename std::remove_reference<DeviceSrcTy>::type::value_type;
698   using DstElementTy =
699       typename std::remove_reference<DeviceDstTy>::type::value_type;
700   static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
701                 "asyncCopyDToD cannot copy between arrays of different types");
702   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
703   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
704   if (DeviceSrcSpan.size() < ElementCount) {
705     setStatus(Status("asyncCopyDToD source element count " +
706                      std::to_string(DeviceSrcSpan.size()) +
707                      " is less than requested element count " +
708                      std::to_string(ElementCount)));
709     return *this;
710   }
711   if (DeviceDstSpan.size() < ElementCount) {
712     setStatus(Status("asyncCopyDToD destination element count " +
713                      std::to_string(DeviceDst.size()) +
714                      " is less than requested element count " +
715                      std::to_string(ElementCount)));
716     return *this;
717   }
718   setStatus(ThePlatform->asyncCopyDToD(
719       DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
720       DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
721       ElementCount * sizeof(SrcElementTy), TheHandle.get()));
722   return *this;
723 }
724 
725 template <typename DeviceSrcTy, typename HostDstTy>
asyncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst)726 Stream &Stream::asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst) {
727   using SrcElementTy =
728       typename std::remove_reference<DeviceSrcTy>::type::value_type;
729   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
730   AsyncHostMemorySpan<SrcElementTy> HostDstSpan(HostDst);
731   if (DeviceSrcSpan.size() != HostDstSpan.size()) {
732     setStatus(Status("asyncCopyDToH source element count " +
733                      std::to_string(DeviceSrcSpan.size()) +
734                      " does not equal destination element count " +
735                      std::to_string(HostDstSpan.size())));
736     return *this;
737   }
738   setStatus(ThePlatform->asyncCopyDToH(
739       DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
740       HostDstSpan.data(), DeviceSrcSpan.byte_size(), TheHandle.get()));
741   return *this;
742 }
743 
744 template <typename DeviceSrcTy, typename HostDstTy>
asyncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst,ptrdiff_t ElementCount)745 Stream &Stream::asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
746                               ptrdiff_t ElementCount) {
747   using SrcElementTy =
748       typename std::remove_reference<DeviceSrcTy>::type::value_type;
749   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
750   AsyncHostMemorySpan<SrcElementTy> HostDstSpan(HostDst);
751   if (DeviceSrcSpan.size() < ElementCount) {
752     setStatus(Status("asyncCopyDToH source element count " +
753                      std::to_string(DeviceSrcSpan.size()) +
754                      " is less than requested element count " +
755                      std::to_string(ElementCount)));
756     return *this;
757   }
758   if (HostDstSpan.size() < ElementCount) {
759     setStatus(Status("asyncCopyDToH destination element count " +
760                      std::to_string(HostDstSpan.size()) +
761                      " is less than requested element count " +
762                      std::to_string(ElementCount)));
763     return *this;
764   }
765   setStatus(ThePlatform->asyncCopyDToH(
766       DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
767       HostDstSpan.data(), ElementCount * sizeof(SrcElementTy),
768       TheHandle.get()));
769   return *this;
770 }
771 
772 template <typename HostSrcTy, typename DeviceDstTy>
asyncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy && DeviceDst)773 Stream &Stream::asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst) {
774   using DstElementTy =
775       typename std::remove_reference<DeviceDstTy>::type::value_type;
776   AsyncHostMemorySpan<const DstElementTy> HostSrcSpan(HostSrc);
777   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
778   if (HostSrcSpan.size() != DeviceDstSpan.size()) {
779     setStatus(Status("asyncCopyHToD source element count " +
780                      std::to_string(HostSrcSpan.size()) +
781                      " does not equal destination element count " +
782                      std::to_string(DeviceDstSpan.size())));
783     return *this;
784   }
785   setStatus(ThePlatform->asyncCopyHToD(
786       HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
787       DeviceDstSpan.byte_offset(), HostSrcSpan.byte_size(), TheHandle.get()));
788   return *this;
789 }
790 
791 template <typename HostSrcTy, typename DeviceDstTy>
asyncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy & DeviceDst,ptrdiff_t ElementCount)792 Stream &Stream::asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
793                               ptrdiff_t ElementCount) {
794   using DstElementTy =
795       typename std::remove_reference<DeviceDstTy>::type::value_type;
796   AsyncHostMemorySpan<const DstElementTy> HostSrcSpan(HostSrc);
797   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
798   if (HostSrcSpan.size() < ElementCount) {
799     setStatus(Status("copyHToD source element count " +
800                      std::to_string(HostSrcSpan.size()) +
801                      " is less than requested element count " +
802                      std::to_string(ElementCount)));
803     return *this;
804   }
805   if (DeviceDstSpan.size() < ElementCount) {
806     setStatus(Status("copyHToD destination element count " +
807                      std::to_string(DeviceDstSpan.size()) +
808                      " is less than requested element count " +
809                      std::to_string(ElementCount)));
810     return *this;
811   }
812   setStatus(ThePlatform->asyncCopyHToD(
813       HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
814       DeviceDstSpan.byte_offset(), ElementCount * sizeof(DstElementTy),
815       TheHandle.get()));
816   return *this;
817 }
818 
819 template <typename DeviceDstTy>
asyncMemsetD(DeviceDstTy && DeviceDst,char ByteValue)820 Stream &Stream::asyncMemsetD(DeviceDstTy &&DeviceDst, char ByteValue) {
821   using DstElementTy =
822       typename std::remove_reference<DeviceDstTy>::type::value_type;
823   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
824   setStatus(ThePlatform->asyncMemsetD(
825       DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
826       DeviceDstSpan.byte_size(), ByteValue, TheHandle.get()));
827   return *this;
828 }
829 
830 template <typename DeviceSrcTy, typename DeviceDstTy>
syncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst)831 Stream &Stream::syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst) {
832   using SrcElementTy =
833       typename std::remove_reference<DeviceSrcTy>::type::value_type;
834   using DstElementTy =
835       typename std::remove_reference<DeviceDstTy>::type::value_type;
836   static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
837                 "copyDToD cannot copy between arrays of different types");
838   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
839   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
840   if (DeviceSrcSpan.size() != DeviceDstSpan.size()) {
841     setStatus(Status("copyDToD source element count " +
842                      std::to_string(DeviceSrcSpan.size()) +
843                      " does not equal destination element count " +
844                      std::to_string(DeviceDstSpan.size())));
845     return *this;
846   }
847   if (setStatus(ThePlatform->asyncCopyDToD(
848                     DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
849                     DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
850                     DeviceSrcSpan.byte_size(), TheHandle.get()))
851           .isError()) {
852     return *this;
853   }
854   setStatus(sync());
855   return *this;
856 }
857 
858 template <typename DeviceSrcTy, typename DeviceDstTy>
syncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst,ptrdiff_t ElementCount)859 Stream &Stream::syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
860                              ptrdiff_t ElementCount) {
861   using SrcElementTy =
862       typename std::remove_reference<DeviceSrcTy>::type::value_type;
863   using DstElementTy =
864       typename std::remove_reference<DeviceDstTy>::type::value_type;
865   static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
866                 "copyDToD cannot copy between arrays of different types");
867   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
868   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
869   if (DeviceSrcSpan.size() < ElementCount) {
870     setStatus(Status("copyDToD source element count " +
871                      std::to_string(DeviceSrcSpan.size()) +
872                      " is less than requested element count " +
873                      std::to_string(ElementCount)));
874     return *this;
875   }
876   if (DeviceDstSpan.size() < ElementCount) {
877     setStatus(Status("copyDToD destination element count " +
878                      std::to_string(DeviceDst.size()) +
879                      " is less than requested element count " +
880                      std::to_string(ElementCount)));
881     return *this;
882   }
883   if (setStatus(ThePlatform->asyncCopyDToD(
884                     DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
885                     DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
886                     ElementCount * sizeof(SrcElementTy), TheHandle.get()))
887           .isError()) {
888     return *this;
889   }
890   setStatus(sync());
891   return *this;
892 }
893 
894 template <typename DeviceSrcTy, typename HostDstTy>
syncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst)895 Stream &Stream::syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst) {
896   using SrcElementTy =
897       typename std::remove_reference<DeviceSrcTy>::type::value_type;
898   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
899   Span<SrcElementTy> HostDstSpan(HostDst);
900   if (DeviceSrcSpan.size() != HostDstSpan.size()) {
901     setStatus(Status("copyDToH source element count " +
902                      std::to_string(DeviceSrcSpan.size()) +
903                      " does not equal destination element count " +
904                      std::to_string(HostDstSpan.size())));
905     return *this;
906   }
907   if (setStatus(ThePlatform->asyncCopyDToH(
908                     DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
909                     HostDstSpan.data(), DeviceSrcSpan.byte_size(),
910                     TheHandle.get()))
911           .isError()) {
912     return *this;
913   }
914   setStatus(sync());
915   return *this;
916 }
917 
918 template <typename DeviceSrcTy, typename HostDstTy>
syncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst,ptrdiff_t ElementCount)919 Stream &Stream::syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
920                              ptrdiff_t ElementCount) {
921   using SrcElementTy =
922       typename std::remove_reference<DeviceSrcTy>::type::value_type;
923   DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
924   Span<SrcElementTy> HostDstSpan(HostDst);
925   if (DeviceSrcSpan.size() < ElementCount) {
926     setStatus(Status("copyDToH source element count " +
927                      std::to_string(DeviceSrcSpan.size()) +
928                      " is less than requested element count " +
929                      std::to_string(ElementCount)));
930     return *this;
931   }
932   if (HostDstSpan.size() < ElementCount) {
933     setStatus(Status("copyDToH destination element count " +
934                      std::to_string(HostDstSpan.size()) +
935                      " is less than requested element count " +
936                      std::to_string(ElementCount)));
937     return *this;
938   }
939   if (setStatus(ThePlatform->asyncCopyDToH(
940                     DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
941                     HostDstSpan.data(), ElementCount * sizeof(SrcElementTy),
942                     TheHandle.get()))
943           .isError()) {
944     return *this;
945   }
946   setStatus(sync());
947   return *this;
948 }
949 
950 template <typename HostSrcTy, typename DeviceDstTy>
syncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy && DeviceDst)951 Stream &Stream::syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst) {
952   using DstElementTy =
953       typename std::remove_reference<DeviceDstTy>::type::value_type;
954   Span<const DstElementTy> HostSrcSpan(HostSrc);
955   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
956   if (HostSrcSpan.size() != DeviceDstSpan.size()) {
957     setStatus(Status("copyHToD source element count " +
958                      std::to_string(HostSrcSpan.size()) +
959                      " does not equal destination element count " +
960                      std::to_string(DeviceDstSpan.size())));
961     return *this;
962   }
963   if (setStatus(ThePlatform->asyncCopyHToD(
964                     HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
965                     DeviceDstSpan.byte_offset(), DeviceDstSpan.byte_size(),
966                     TheHandle.get()))
967           .isError()) {
968     return *this;
969   }
970   setStatus(sync());
971   return *this;
972 }
973 
974 template <typename HostSrcTy, typename DeviceDstTy>
syncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy & DeviceDst,ptrdiff_t ElementCount)975 Stream &Stream::syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
976                              ptrdiff_t ElementCount) {
977   using DstElementTy =
978       typename std::remove_reference<DeviceDstTy>::type::value_type;
979   Span<const DstElementTy> HostSrcSpan(HostSrc);
980   DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
981   if (HostSrcSpan.size() < ElementCount) {
982     setStatus(Status("copyHToD source element count " +
983                      std::to_string(HostSrcSpan.size()) +
984                      " is less than requested element count " +
985                      std::to_string(ElementCount)));
986     return *this;
987   }
988   if (DeviceDstSpan.size() < ElementCount) {
989     setStatus(Status("copyHToD destination element count " +
990                      std::to_string(DeviceDstSpan.size()) +
991                      " is less than requested element count " +
992                      std::to_string(ElementCount)));
993     return *this;
994   }
995   if (setStatus(ThePlatform->asyncCopyHToD(
996                     HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
997                     DeviceDstSpan.byte_offset(),
998                     ElementCount * sizeof(DstElementTy), TheHandle.get()))
999           .isError()) {
1000     return *this;
1001   }
1002   setStatus(sync());
1003   return *this;
1004 }
1005 
1006 /// Owned device memory.
1007 ///
1008 /// Device memory that frees itself when it goes out of scope.
1009 template <typename ElementType> class DeviceMemory {
1010 public:
1011   using element_type = ElementType;
1012   using index_type = std::ptrdiff_t;
1013   using value_type = typename std::remove_const<element_type>::type;
1014 
1015   DeviceMemory(const DeviceMemory &) = delete;
1016   DeviceMemory &operator=(const DeviceMemory &) = delete;
1017   DeviceMemory(DeviceMemory &&) noexcept;
1018   DeviceMemory &operator=(DeviceMemory &&) noexcept;
1019   ~DeviceMemory() = default;
1020 
1021   /// Gets the raw base handle for the underlying platform implementation.
handle()1022   void *handle() const { return ThePointer.get(); }
1023 
length()1024   index_type length() const { return TheSize; }
size()1025   index_type size() const { return TheSize; }
byte_size()1026   index_type byte_size() const { // NOLINT
1027     return TheSize * sizeof(element_type);
1028   }
empty()1029   bool empty() const { return TheSize == 0; }
1030 
1031   // These conversion operators are useful for making triple-chevron kernel
1032   // launches more concise.
1033   operator element_type *() {
1034     return static_cast<element_type *>(ThePointer.get());
1035   }
1036   operator const element_type *() const { return ThePointer.get(); }
1037 
1038   /// Converts a const object to a DeviceMemorySpan of const elements.
asSpan()1039   DeviceMemorySpan<const element_type> asSpan() const {
1040     return DeviceMemorySpan<const element_type>(
1041         ThePlatform, static_cast<const element_type *>(ThePointer.get()),
1042         TheSize, 0);
1043   }
1044 
1045   /// Converts an object to a DeviceMemorySpan.
asSpan()1046   DeviceMemorySpan<element_type> asSpan() {
1047     return DeviceMemorySpan<element_type>(
1048         ThePlatform, static_cast<element_type *>(ThePointer.get()), TheSize, 0);
1049   }
1050 
1051 private:
1052   friend class Platform;
1053   template <typename T> friend class DeviceMemorySpan;
1054 
DeviceMemory(Platform * ThePlatform,void * Pointer,index_type ElementCount,HandleDestructor Destructor)1055   DeviceMemory(Platform *ThePlatform, void *Pointer, index_type ElementCount,
1056                HandleDestructor Destructor)
1057       : ThePlatform(ThePlatform), ThePointer(Pointer, Destructor),
1058         TheSize(ElementCount) {}
1059 
1060   Platform *ThePlatform;
1061   std::unique_ptr<void, HandleDestructor> ThePointer;
1062   ptrdiff_t TheSize;
1063 };
1064 
1065 template <typename T>
1066 DeviceMemory<T>::DeviceMemory(DeviceMemory &&) noexcept = default;
1067 template <typename T>
1068 DeviceMemory<T> &DeviceMemory<T>::operator=(DeviceMemory &&) noexcept = default;
1069 
1070 /// View into device memory.
1071 ///
1072 /// Like a Span, but for device memory rather than host memory.
1073 template <typename ElementType> class DeviceMemorySpan {
1074 public:
1075   /// \name constants and types
1076   /// \{
1077   using element_type = ElementType;
1078   using index_type = std::ptrdiff_t;
1079   using pointer = element_type *;
1080   using reference = element_type &;
1081   using iterator = element_type *;
1082   using const_iterator = const element_type *;
1083   using value_type = typename std::remove_const<element_type>::type;
1084   /// \}
1085 
DeviceMemorySpan()1086   DeviceMemorySpan()
1087       : ThePlatform(nullptr), TheHandle(nullptr), TheSize(0), TheOffset(0),
1088         TheSpanHandle(nullptr) {}
1089 
1090   // Intentionally implicit.
1091   template <typename OtherElementType>
DeviceMemorySpan(DeviceMemorySpan<OtherElementType> & ASpan)1092   DeviceMemorySpan(DeviceMemorySpan<OtherElementType> &ASpan)
1093       : ThePlatform(ASpan.ThePlatform),
1094         TheHandle(static_cast<pointer>(ASpan.baseHandle())),
1095         TheSize(ASpan.size()), TheOffset(ASpan.offset()),
1096         TheSpanHandle(nullptr) {}
1097 
1098   // Intentionally implicit.
1099   template <typename OtherElementType>
DeviceMemorySpan(DeviceMemorySpan<OtherElementType> && ASpan)1100   DeviceMemorySpan(DeviceMemorySpan<OtherElementType> &&ASpan)
1101       : ThePlatform(ASpan.ThePlatform),
1102         TheHandle(static_cast<pointer>(ASpan.baseHandle())),
1103         TheSize(ASpan.size()), TheOffset(ASpan.offset()),
1104         TheSpanHandle(nullptr) {}
1105 
1106   // Intentionally implicit.
1107   template <typename OtherElementType>
DeviceMemorySpan(DeviceMemory<OtherElementType> & Memory)1108   DeviceMemorySpan(DeviceMemory<OtherElementType> &Memory)
1109       : ThePlatform(Memory.ThePlatform),
1110         TheHandle(static_cast<value_type *>(Memory.handle())),
1111         TheSize(Memory.size()), TheOffset(0), TheSpanHandle(nullptr) {}
1112 
~DeviceMemorySpan()1113   ~DeviceMemorySpan() {
1114     if (TheSpanHandle) {
1115       ThePlatform->rawDestroyDeviceMemorySpanHandle(
1116           const_cast<value_type *>(TheSpanHandle));
1117     }
1118   }
1119 
1120   /// \name observers
1121   /// \{
length()1122   index_type length() const { return TheSize; }
size()1123   index_type size() const { return TheSize; }
byte_size()1124   index_type byte_size() const { // NOLINT
1125     return TheSize * sizeof(element_type);
1126   }
offset()1127   index_type offset() const { return TheOffset; }
byte_offset()1128   index_type byte_offset() const { // NOLINT
1129     return TheOffset * sizeof(element_type);
1130   }
empty()1131   bool empty() const { return TheSize == 0; }
1132   /// \}
1133 
baseHandle()1134   void *baseHandle() const {
1135     return static_cast<void *>(const_cast<value_type *>(TheHandle));
1136   }
1137 
1138   /// Casts to a host memory pointer.
1139   ///
1140   /// This is only guaranteed to make sense for the CUDA platform, where device
1141   /// pointers can be stored and manipulated much like host pointers. This makes
1142   /// it easy to do triple-chevron kernel launches in CUDA because
1143   /// DeviceMemorySpan values can be passed to parameters expecting regular
1144   /// pointers.
1145   ///
1146   /// If the CUDA platform is using unified memory, it may also be possible to
1147   /// dereference this pointer on the host.
1148   ///
1149   /// For platforms other than CUDA, this may return a garbage pointer.
1150   operator element_type *() const {
1151     if (!TheSpanHandle)
1152       TheSpanHandle = ThePlatform->getDeviceMemorySpanHandle(
1153           TheHandle, TheSize * sizeof(element_type),
1154           TheOffset * sizeof(element_type));
1155     return TheSpanHandle;
1156   }
1157 
first(index_type Count)1158   DeviceMemorySpan<element_type> first(index_type Count) const {
1159     bool Valid = Count >= 0 && Count <= TheSize;
1160     if (!Valid)
1161       std::terminate();
1162     return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1163                                           TheOffset);
1164   }
1165 
last(index_type Count)1166   DeviceMemorySpan<element_type> last(index_type Count) const {
1167     bool Valid = Count >= 0 && Count <= TheSize;
1168     if (!Valid)
1169       std::terminate();
1170     return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1171                                           TheOffset + TheSize - Count);
1172   }
1173 
1174   DeviceMemorySpan<element_type>
1175   subspan(index_type Offset, index_type Count = dynamic_extent) const {
1176     bool Valid =
1177         (Offset == 0 || (Offset > 0 && Offset <= TheSize)) &&
1178         (Count == dynamic_extent || (Count >= 0 && Offset + Count <= TheSize));
1179     if (!Valid)
1180       std::terminate();
1181     return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1182                                           TheOffset + Offset);
1183   }
1184 
1185 private:
1186   template <typename T> friend class DeviceMemory;
1187   template <typename T> friend class DeviceMemorySpan;
1188   friend class Platform;
1189 
DeviceMemorySpan(Platform * ThePlatform,pointer AHandle,index_type Size,index_type Offset)1190   DeviceMemorySpan(Platform *ThePlatform, pointer AHandle, index_type Size,
1191                    index_type Offset)
1192       : ThePlatform(ThePlatform), TheHandle(AHandle), TheSize(Size),
1193         TheOffset(Offset), TheSpanHandle(nullptr) {}
1194 
1195   Platform *ThePlatform;
1196   pointer TheHandle;
1197   index_type TheSize;
1198   index_type TheOffset;
1199   pointer TheSpanHandle;
1200 };
1201 
1202 /// Asynchronous host memory.
1203 ///
1204 /// This memory is pinned or otherwise registered in the host memory space to
1205 /// allow for asynchronous copies between it and device memory.
1206 ///
1207 /// This memory unpins/unregisters itself when it goes out of scope, but does
1208 /// not free itself.
1209 template <typename ElementType> class AsyncHostMemory {
1210 public:
1211   using value_type = ElementType;
1212   using remove_const_type = typename std::remove_const<ElementType>::type;
1213 
1214   AsyncHostMemory(const AsyncHostMemory &) = delete;
1215   AsyncHostMemory &operator=(const AsyncHostMemory &) = delete;
1216   AsyncHostMemory(AsyncHostMemory &&) noexcept;
1217   AsyncHostMemory &operator=(AsyncHostMemory &&) noexcept;
1218   ~AsyncHostMemory() = default;
1219 
1220   template <typename OtherElementType>
AsyncHostMemory(AsyncHostMemory<OtherElementType> && Other)1221   AsyncHostMemory(AsyncHostMemory<OtherElementType> &&Other)
1222       : ThePointer(std::move(Other.ThePointer)),
1223         TheElementCount(Other.TheElementCount) {
1224     static_assert(
1225         std::is_assignable<ElementType *, OtherElementType *>::value,
1226         "cannot assign OtherElementType pointer to ElementType pointer type");
1227   }
1228 
data()1229   ElementType *data() const {
1230     return const_cast<ElementType *>(
1231         static_cast<remove_const_type *>(ThePointer.get()));
1232   }
size()1233   ptrdiff_t size() const { return TheElementCount; }
1234 
1235 private:
1236   template <typename U> friend class AsyncHostMemory;
1237   friend class Platform;
AsyncHostMemory(ElementType * Pointer,ptrdiff_t ElementCount,HandleDestructor Destructor)1238   AsyncHostMemory(ElementType *Pointer, ptrdiff_t ElementCount,
1239                   HandleDestructor Destructor)
1240       : ThePointer(
1241             static_cast<void *>(const_cast<remove_const_type *>(Pointer)),
1242             Destructor),
1243         TheElementCount(ElementCount) {}
1244 
1245   std::unique_ptr<void, HandleDestructor> ThePointer;
1246   ptrdiff_t TheElementCount;
1247 };
1248 
1249 template <typename T>
1250 AsyncHostMemory<T>::AsyncHostMemory(AsyncHostMemory &&) noexcept = default;
1251 template <typename T>
1252 AsyncHostMemory<T> &AsyncHostMemory<T>::
1253 operator=(AsyncHostMemory &&) noexcept = default;
1254 
1255 /// Owned registered host memory.
1256 ///
1257 /// Like AsyncHostMemory, but this memory also frees itself in addition to
1258 /// unpinning/unregistering itself when it goes out of scope.
1259 template <typename ElementType> class OwnedAsyncHostMemory {
1260 public:
1261   using remove_const_type = typename std::remove_const<ElementType>::type;
1262 
1263   OwnedAsyncHostMemory(const OwnedAsyncHostMemory &) = delete;
1264   OwnedAsyncHostMemory &operator=(const OwnedAsyncHostMemory &) = delete;
1265   OwnedAsyncHostMemory(OwnedAsyncHostMemory &&) noexcept;
1266   OwnedAsyncHostMemory &operator=(OwnedAsyncHostMemory &&) noexcept;
1267 
~OwnedAsyncHostMemory()1268   ~OwnedAsyncHostMemory() {
1269     if (ThePointer.get()) {
1270       // We use placement new to construct these objects, so we have to call the
1271       // destructors explicitly.
1272       for (ptrdiff_t I = 0; I < TheElementCount; ++I)
1273         static_cast<ElementType *>(ThePointer.get())[I].~ElementType();
1274     }
1275   }
1276 
get()1277   ElementType *get() const {
1278     return const_cast<ElementType *>(
1279         static_cast<remove_const_type *>(ThePointer.get()));
1280   }
1281 
1282   ElementType &operator[](ptrdiff_t I) const {
1283     assert(I >= 0 && I < TheElementCount);
1284     return get()[I];
1285   }
1286 
1287 private:
1288   template <typename T> friend class AsyncHostMemorySpan;
1289 
1290   friend class Platform;
1291 
OwnedAsyncHostMemory(void * Memory,ptrdiff_t ElementCount,HandleDestructor Destructor)1292   OwnedAsyncHostMemory(void *Memory, ptrdiff_t ElementCount,
1293                        HandleDestructor Destructor)
1294       : ThePointer(Memory, Destructor), TheElementCount(ElementCount) {}
1295 
1296   std::unique_ptr<void, HandleDestructor> ThePointer;
1297   ptrdiff_t TheElementCount;
1298 };
1299 
1300 template <typename T>
1301 OwnedAsyncHostMemory<T>::OwnedAsyncHostMemory(
1302     OwnedAsyncHostMemory &&) noexcept = default;
1303 template <typename T>
1304 OwnedAsyncHostMemory<T> &OwnedAsyncHostMemory<T>::
1305 operator=(OwnedAsyncHostMemory &&) noexcept = default;
1306 
1307 /// View into registered host memory.
1308 ///
1309 /// Like Span but for registered host memory.
1310 template <typename ElementType> class AsyncHostMemorySpan {
1311 public:
1312   /// \name constants and types
1313   /// \{
1314   using element_type = ElementType;
1315   using index_type = std::ptrdiff_t;
1316   using pointer = element_type *;
1317   using reference = element_type &;
1318   using iterator = element_type *;
1319   using const_iterator = const element_type *;
1320   using value_type = typename std::remove_const<element_type>::type;
1321   /// \}
1322 
AsyncHostMemorySpan()1323   AsyncHostMemorySpan() : TheSpan() {}
1324 
1325   // Intentionally implicit.
1326   template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemory<OtherElementType> & Memory)1327   AsyncHostMemorySpan(AsyncHostMemory<OtherElementType> &Memory)
1328       : TheSpan(Memory.data(), Memory.size()) {}
1329 
1330   // Intentionally implicit.
1331   template <typename OtherElementType>
AsyncHostMemorySpan(OwnedAsyncHostMemory<OtherElementType> & Owned)1332   AsyncHostMemorySpan(OwnedAsyncHostMemory<OtherElementType> &Owned)
1333       : TheSpan(Owned.get(), Owned.TheElementCount) {}
1334 
1335   // Intentionally implicit.
1336   template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> & ASpan)1337   AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> &ASpan)
1338       : TheSpan(ASpan) {}
1339 
1340   // Intentionally implicit.
1341   template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> && Span)1342   AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> &&Span)
1343       : TheSpan(Span) {}
1344 
1345   /// \name observers
1346   /// \{
length()1347   index_type length() const { return TheSpan.length(); }
size()1348   index_type size() const { return TheSpan.size(); }
byte_size()1349   index_type byte_size() const { // NOLINT
1350     return TheSpan.size() * sizeof(element_type);
1351   }
empty()1352   bool empty() const { return TheSpan.empty(); }
1353   /// \}
1354 
data()1355   pointer data() const noexcept { return TheSpan.data(); }
1356   operator element_type *() const { return TheSpan.data(); }
1357 
first(index_type Count)1358   AsyncHostMemorySpan<element_type> first(index_type Count) const {
1359     return AsyncHostMemorySpan<element_type>(TheSpan.first(Count));
1360   }
1361 
last(index_type Count)1362   AsyncHostMemorySpan<element_type> last(index_type Count) const {
1363     return AsyncHostMemorySpan<element_type>(TheSpan.last(Count));
1364   }
1365 
1366   AsyncHostMemorySpan<element_type>
1367   subspan(index_type Offset, index_type Count = dynamic_extent) const {
1368     return AsyncHostMemorySpan<element_type>(TheSpan.subspan(Offset, Count));
1369   }
1370 
1371 private:
1372   template <typename T> friend class AsyncHostMemory;
1373 
AsyncHostMemorySpan(Span<ElementType> ArraySpan)1374   explicit AsyncHostMemorySpan(Span<ElementType> ArraySpan)
1375       : TheSpan(ArraySpan) {}
1376 
1377   Span<ElementType> TheSpan;
1378 };
1379 
1380 } // namespace acxxel
1381 
1382 #endif // ACXXEL_ACXXEL_H
1383