1 //===--- acxxel.h - The Acxxel API ------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 /// \mainpage Welcome to Acxxel
10 ///
11 /// \section Introduction
12 ///
13 /// \b Acxxel is a library providing a modern C++ interface for managing
14 /// accelerator devices such as GPUs. Acxxel handles operations such as
15 /// allocating device memory, copying data to and from device memory, creating
16 /// and managing device events, and creating and managing device streams.
17 ///
18 /// \subsection ExampleUsage Example Usage
19 ///
20 /// Below is some example code to show you the basics of Acxxel.
21 ///
22 /// \snippet examples/simple_example.cu Example simple saxpy
23 ///
24 /// The above code could be compiled with either `clang` or `nvcc`. Compare this
25 /// with the standard CUDA runtime library code to perform these same
26 /// operations:
27 ///
28 /// \snippet examples/simple_example.cu Example CUDA simple saxpy
29 ///
30 /// Notice that the CUDA runtime calls are not type safe. For example, if you
31 /// change the type of the inputs from `float` to `double`, you have to remember
32 /// to change the size calculation. If you forget, you will get garbage output
33 /// data. In the Acxxel example, you would instead get a helpful compile-time
34 /// error that wouldn't let you forget to change the types inside the function.
35 ///
36 /// The Acxxel example also automatically uses the right sizes for memory
37 /// copies, so you don't have to worry about computing the sizes yourself.
38 ///
39 /// The CUDA runtime interface makes it easy to get the source and destination
40 /// mixed up in a call to `cudaMemcpy`. If you pass the pointers in the wrong
41 /// order or pass the wrong enum value for the direction parameter, you won't
42 /// find out until runtime (if you remembered to check the error return value of
43 /// `cudaMemcpy`). In Acxxel there is no verbose direction enum because the name
44 /// of the function says which way the copy goes, and mixing up the order of
45 /// source and destination is a compile-time error.
46 ///
47 /// The CUDA runtime interface makes you clean up your device memory by calling
48 /// `cudaFree` for each call to `cudaMalloc`. In Acxxel, you don't have to worry
49 /// about that because the memory cleans itself up when it goes out of scope.
50 ///
51 /// \subsection AcxxelFeatures Acxxel Features
52 ///
53 /// Acxxel provides many nice features compared to the C-like interfaces, such
54 /// as the CUDA runtime API, which are normally used for the host code in
55 /// applications using accelerators.
56 ///
57 /// \subsubsection TypeSafety Type safety
58 ///
59 /// Most errors involving mixing up types, sources and destinations, or host and
60 /// device memory result in helpful compile-time errors.
61 ///
62 /// \subsubsection NoCopySizes No need to specify sizes for memory copies
63 ///
64 /// When the arguments to copy functions such as acxxel::Platform::copyHToD know
65 /// their sizes (e.g std::array, std::vector, and C-style arrays), there is no
66 /// need to specify the amount of memory to copy; Acxxel will just copy the
67 /// whole thing. Of course the copy functions also have overloads that accept an
68 /// element count for those times when you don't want to copy everything.
69 ///
70 /// \subsubsection MemoryCleanup Automatic memory cleanup
71 ///
72 /// Device memory allocated with acxxel::Platform::mallocD is automatically
73 /// freed when it goes out of scope.
74 ///
75 /// \subsubsection NiceErrorHandling Error handling
76 ///
77 /// Operations that would normally return values return acxxel::Expected obects
78 /// in Acxxel. These `Expected` objects contain either a value or an error
79 /// message explaining why the value is not present. This reminds the user to
80 /// check for errors, but also allows them to opt-out easily be calling the
81 /// acxxel::Expected::getValue or acxxel::Expected::takeValue methods. The
82 /// `getValue` method returns a reference to the value, leaving the `Expected`
83 /// instance as the value owner, whereas the `takeValue` method moves the value
84 /// out of the `Expected` object and transfers ownership to the caller.
85 ///
86 /// \subsubsection PlatformIndependence Platform independence
87 ///
88 /// Acxxel code works not only with CUDA, but also with any other platform that
89 /// can support its interface. For example, Acxxel supports OpenCL. The
90 /// acxxel::getCUDAPlatform and acxxel::getOpenCLPlatform functions are provided
91 /// to allow easy access to the built-in CUDA and OpenCL platforms. Other
92 /// platforms can be created by implementing the acxxel::Platform interface, and
93 /// instances of those classes can be created directly.
94 ///
95 /// \subsubsection CUDAInterop Seamless interoperation with CUDA
96 ///
97 /// Acxxel functions as a modern replacement for the standard CUDA runtime
98 /// library and interoperates seamlessly with kernel calls.
99
100 #ifndef ACXXEL_ACXXEL_H
101 #define ACXXEL_ACXXEL_H
102
103 #include "span.h"
104 #include "status.h"
105
106 #include <functional>
107 #include <memory>
108 #include <string>
109 #include <type_traits>
110
111 #if defined(__clang__) || defined(__GNUC__)
112 #define ACXXEL_WARN_UNUSED_RESULT __attribute__((warn_unused_result))
113 #else
114 #define ACXXEL_WARN_UNUSED_RESULT
115 #endif
116
117 /// This type is declared here to provide smooth interoperability with the CUDA
118 /// triple-chevron kernel launch syntax.
119 ///
120 /// A acxxel::Stream instance will be implicitly convertible to a CUstream_st*,
121 /// which is the type expected for the stream argument in the triple-chevron
122 /// CUDA kernel launch. This means that a acxxel::Stream can be passed without
123 /// explicit casting as the fourth argument to a triple-chevron CUDA kernel
124 /// launch.
125 struct CUstream_st; // NOLINT
126
127 namespace acxxel {
128
129 class Event;
130 class Platform;
131 class Stream;
132
133 template <typename T> class DeviceMemory;
134
135 template <typename T> class DeviceMemorySpan;
136
137 template <typename T> class AsyncHostMemory;
138
139 template <typename T> class AsyncHostMemorySpan;
140
141 template <typename T> class OwnedAsyncHostMemory;
142
143 /// Function type used to destroy opaque handles given out by the platform.
144 using HandleDestructor = void (*)(void *);
145
146 /// Functor type for enqueuing host callbacks on a stream.
147 using StreamCallback = std::function<void(Stream &, const Status &)>;
148
149 struct KernelLaunchDimensions {
150 // Intentionally implicit
151 KernelLaunchDimensions(unsigned int BlockX = 1, unsigned int BlockY = 1,
152 unsigned int BlockZ = 1, unsigned int GridX = 1,
153 unsigned int GridY = 1, unsigned int GridZ = 1)
BlockXKernelLaunchDimensions154 : BlockX(BlockX), BlockY(BlockY), BlockZ(BlockZ), GridX(GridX),
155 GridY(GridY), GridZ(GridZ) {}
156
157 unsigned int BlockX;
158 unsigned int BlockY;
159 unsigned int BlockZ;
160 unsigned int GridX;
161 unsigned int GridY;
162 unsigned int GridZ;
163 };
164
165 /// Logs a warning message.
166 void logWarning(const std::string &Message);
167
168 /// Gets a pointer to the standard CUDA platform.
169 Expected<Platform *> getCUDAPlatform();
170
171 /// Gets a pointer to the standard OpenCL platform.
172 Expected<Platform *> getOpenCLPlatform();
173
174 /// A function that can be executed on the device.
175 ///
176 /// A Kernel is created from a Program by calling Program::createKernel, and a
177 /// kernel is enqueued into a Stream by calling Stream::asyncKernelLaunch.
178 class Kernel {
179 public:
180 Kernel(const Kernel &) = delete;
181 Kernel &operator=(const Kernel &) = delete;
182 Kernel(Kernel &&) noexcept;
183 Kernel &operator=(Kernel &&That) noexcept;
184 ~Kernel() = default;
185
186 private:
187 // Only a Program can make a kernel.
188 friend class Program;
Kernel(Platform * APlatform,void * AHandle,HandleDestructor Destructor)189 Kernel(Platform *APlatform, void *AHandle, HandleDestructor Destructor)
190 : ThePlatform(APlatform), TheHandle(AHandle, Destructor) {}
191
192 // Let stream get raw handle for kernel launches.
193 friend class Stream;
194
195 Platform *ThePlatform;
196 std::unique_ptr<void, HandleDestructor> TheHandle;
197 };
198
199 /// A program loaded on a device.
200 ///
201 /// A program can be created by calling Platform::createProgramFromSource, and a
202 /// Kernel can be created from a program by running Program::createKernel.
203 ///
204 /// A program can contain any number of kernels, and a program only needs to be
205 /// loaded once in order to use all its kernels.
206 class Program {
207 public:
208 Program(const Program &) = delete;
209 Program &operator=(const Program &) = delete;
210 Program(Program &&) noexcept;
211 Program &operator=(Program &&That) noexcept;
212 ~Program() = default;
213
214 Expected<Kernel> createKernel(const std::string &Name);
215
216 private:
217 // Only a platform can make a program.
218 friend class Platform;
Program(Platform * APlatform,void * AHandle,HandleDestructor Destructor)219 Program(Platform *APlatform, void *AHandle, HandleDestructor Destructor)
220 : ThePlatform(APlatform), TheHandle(AHandle, Destructor) {}
221
222 Platform *ThePlatform;
223 std::unique_ptr<void, HandleDestructor> TheHandle;
224 };
225
226 /// A stream of computation.
227 ///
228 /// All operations enqueued on a Stream are serialized, but operations enqueued
229 /// on different Streams may run concurrently.
230 ///
231 /// Each Stream is associated with a specific, fixed device.
232 class Stream {
233 public:
234 Stream(const Stream &) = delete;
235 Stream &operator=(const Stream &) = delete;
236 Stream(Stream &&) noexcept;
237 Stream &operator=(Stream &&) noexcept;
238 ~Stream() = default;
239
240 /// Gets the index of the device on which this Stream operates.
getDeviceIndex()241 int getDeviceIndex() { return TheDeviceIndex; }
242
243 /// Blocks the host until the Stream is done executing all previously enqueued
244 /// work.
245 ///
246 /// Returns a Status for any errors emitted by the asynchronous work on the
247 /// Stream, or by any error in the synchronization process itself. Clears the
248 /// Status state of the stream.
249 Status sync() ACXXEL_WARN_UNUSED_RESULT;
250
251 /// Makes all future work submitted to this stream wait until the event
252 /// reports completion.
253 ///
254 /// This is useful because the event argument may be recorded on a different
255 /// stream, so this method allows for synchronization between streams without
256 /// synchronizing all streams.
257 ///
258 /// Returns a Status for any errors emitted by the asynchronous work on the
259 /// Stream, or by any error in the synchronization process itself. Clears the
260 /// Status state of the stream.
261 Status waitOnEvent(Event &Event) ACXXEL_WARN_UNUSED_RESULT;
262
263 /// Adds a host callback function to the stream.
264 ///
265 /// The callback will be called on the host after all previously enqueued work
266 /// on the stream is complete, and no work enqueued after the callback will
267 /// begin until after the callback has finished.
268 Stream &addCallback(std::function<void(Stream &, const Status &)> Callback);
269
270 /// \name Asynchronous device memory copies.
271 ///
272 /// These functions enqueue asynchronous memory copy operations into the
273 /// stream. Only async host memory is allowed for host arguments to these
274 /// functions. Async host memory can be created from normal host memory by
275 /// registering it with Platform::registerHostMem. AsyncHostMemory can also be
276 /// allocated directly by calling Platform::newAsyncHostMem.
277 ///
278 /// For all these functions, DeviceSrcTy must be convertible to
279 /// DeviceMemorySpan<const T>, DeviceDstTy must be convertible to
280 /// DeviceMemorySpan<T>, HostSrcTy must be convertible to
281 /// AsyncHostMemorySpan<const T> and HostDstTy must be convertible to
282 /// AsyncHostMemorySpan<T>. Additionally, the T types must match for the
283 /// destination and source.
284 /// \{
285
286 /// Copies from device memory to device memory.
287 template <typename DeviceSrcTy, typename DeviceDstTy>
288 Stream &asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst);
289
290 /// Copies from device memory to device memory with a given element count.
291 template <typename DeviceSrcTy, typename DeviceDstTy>
292 Stream &asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
293 ptrdiff_t ElementCount);
294
295 /// Copies from device memory to host memory.
296 template <typename DeviceSrcTy, typename HostDstTy>
297 Stream &asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst);
298
299 /// Copies from device memory to host memory with a given element count.
300 template <typename DeviceSrcTy, typename HostDstTy>
301 Stream &asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
302 ptrdiff_t ElementCount);
303
304 /// Copies from host memory to device memory.
305 template <typename HostSrcTy, typename DeviceDstTy>
306 Stream &asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst);
307
308 /// Copies from host memory to device memory with a given element count.
309 template <typename HostSrcTy, typename DeviceDstTy>
310 Stream &asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
311 ptrdiff_t ElementCount);
312
313 /// \}
314
315 /// \name Stream-synchronous device memory copies
316 ///
317 /// These functions block the host until the copy and all previously-enqueued
318 /// work on the stream has completed.
319 ///
320 /// For all these functions, DeviceSrcTy must be convertible to
321 /// DeviceMemorySpan<const T>, DeviceDstTy must be convertible to
322 /// DeviceMemorySpan<T>, HostSrcTy must be convertible to Span<const T> and
323 /// HostDstTy must be convertible to Span<T>. Additionally, the T types must
324 /// match for the destination and source.
325 /// \{
326
327 template <typename DeviceSrcTy, typename DeviceDstTy>
328 Stream &syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst);
329
330 template <typename DeviceSrcTy, typename DeviceDstTy>
331 Stream &syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
332 ptrdiff_t ElementCount);
333
334 template <typename DeviceSrcTy, typename HostDstTy>
335 Stream &syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst);
336
337 template <typename DeviceSrcTy, typename HostDstTy>
338 Stream &syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
339 ptrdiff_t ElementCount);
340
341 template <typename HostSrcTy, typename DeviceDstTy>
342 Stream &syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst);
343
344 template <typename HostSrcTy, typename DeviceDstTy>
345 Stream &syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
346 ptrdiff_t ElementCount);
347
348 /// \}
349
350 /// Enqueues an operation in the stream to set the bytes of a given device
351 /// memory region to a given value.
352 ///
353 /// DeviceDstTy must be convertible to DeviceMemorySpan<T> for non-const T.
354 template <typename DeviceDstTy>
355 Stream &asyncMemsetD(DeviceDstTy &&DeviceDst, char ByteValue);
356
357 /// Enqueues a kernel launch operation on this stream.
358 Stream &asyncKernelLaunch(const Kernel &TheKernel,
359 KernelLaunchDimensions LaunchDimensions,
360 Span<void *> Arguments, Span<size_t> ArgumentSizes,
361 size_t SharedMemoryBytes = 0);
362
363 /// Enqueues an event in the stream.
364 Stream &enqueueEvent(Event &E);
365
366 // Allows implicit conversion to (CUstream_st *). This makes triple-chevron
367 // kernel calls look nicer because you can just pass a acxxel::Stream
368 // directly.
369 operator CUstream_st *() {
370 return static_cast<CUstream_st *>(TheHandle.get());
371 }
372
373 /// Gets the current status for the Stream and clears the Stream's status.
takeStatus()374 Status takeStatus() ACXXEL_WARN_UNUSED_RESULT {
375 Status OldStatus = TheStatus;
376 TheStatus = Status();
377 return OldStatus;
378 }
379
380 private:
381 // Only a platform can make a stream.
382 friend class Platform;
Stream(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)383 Stream(Platform *APlatform, int DeviceIndex, void *AHandle,
384 HandleDestructor Destructor)
385 : ThePlatform(APlatform), TheDeviceIndex(DeviceIndex),
386 TheHandle(AHandle, Destructor) {}
387
setStatus(const Status & S)388 const Status &setStatus(const Status &S) {
389 if (S.isError() && !TheStatus.isError()) {
390 TheStatus = S;
391 }
392 return S;
393 }
394
takeStatusOr(const Status & S)395 Status takeStatusOr(const Status &S) {
396 if (TheStatus.isError()) {
397 Status OldStatus = TheStatus;
398 TheStatus = Status();
399 return OldStatus;
400 }
401 return S;
402 }
403
404 // The platform that created the stream.
405 Platform *ThePlatform;
406
407 // The index of the device on which the stream operates.
408 int TheDeviceIndex;
409
410 // A handle to the platform-specific handle implementation.
411 std::unique_ptr<void, HandleDestructor> TheHandle;
412 Status TheStatus;
413 };
414
415 /// A user-created event on a device.
416 ///
417 /// This is useful for setting synchronization points in a Stream. The host can
418 /// synchronize with a Stream without using events, but that requires all the
419 /// work in the Stream to be finished in order for the host to be notified.
420 /// Events provide more flexibility by allowing the host to be notified when a
421 /// single Event in the Stream is finished, rather than all the work in the
422 /// Stream.
423 class Event {
424 public:
425 Event(const Event &) = delete;
426 Event &operator=(const Event &) = delete;
427 Event(Event &&) noexcept;
428 Event &operator=(Event &&That) noexcept;
429 ~Event() = default;
430
431 /// Checks to see if the event is done running.
432 bool isDone();
433
434 /// Blocks the host until the event is done.
435 Status sync();
436
437 /// Gets the time elapsed between the previous event's execution and this
438 /// event's execution.
439 Expected<float> getSecondsSince(const Event &Previous);
440
441 private:
442 // Only a platform can make an event.
443 friend class Platform;
Event(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)444 Event(Platform *APlatform, int DeviceIndex, void *AHandle,
445 HandleDestructor Destructor)
446 : ThePlatform(APlatform), TheDeviceIndex(DeviceIndex),
447 TheHandle(AHandle, Destructor) {}
448
449 Platform *ThePlatform;
450
451 // The index of the device on which the event can be enqueued.
452 int TheDeviceIndex;
453
454 std::unique_ptr<void, HandleDestructor> TheHandle;
455 };
456
457 /// An accelerator platform.
458 ///
459 /// This is the base class for all platforms such as CUDA and OpenCL. It
460 /// contains many virtual methods that must be overridden by each platform
461 /// implementation.
462 ///
463 /// It also has some template wrapper functions that take care of type checking
464 /// and then forward their arguments on to raw virtual functions that are
465 /// implemented by each specific platform.
466 class Platform {
467 public:
~Platform()468 virtual ~Platform(){};
469
470 /// Gets the number of devices for this platform in this system.
471 virtual Expected<int> getDeviceCount() = 0;
472
473 /// Creates a stream on the given device for the platform.
474 virtual Expected<Stream> createStream(int DeviceIndex = 0) = 0;
475
476 /// Creates an event on the given device for the platform.
477 virtual Expected<Event> createEvent(int DeviceIndex = 0) = 0;
478
479 /// Allocates owned device memory.
480 ///
481 /// \warning This function only allocates space in device memory, it does not
482 /// call the constructor of T.
483 template <typename T>
484 Expected<DeviceMemory<T>> mallocD(ptrdiff_t ElementCount,
485 int DeviceIndex = 0) {
486 Expected<void *> MaybePointer =
487 rawMallocD(ElementCount * sizeof(T), DeviceIndex);
488 if (MaybePointer.isError())
489 return MaybePointer.getError();
490 return DeviceMemory<T>(this, MaybePointer.getValue(), ElementCount,
491 this->getDeviceMemoryHandleDestructor());
492 }
493
494 /// Creates a DeviceMemorySpan for a device symbol.
495 ///
496 /// This function is present to support __device__ variables in CUDA. Given a
497 /// pointer to a __device__ variable, this function returns a DeviceMemorySpan
498 /// referencing the device memory that stores that __device__ variable.
499 template <typename ElementType>
500 Expected<DeviceMemorySpan<ElementType>> getSymbolMemory(ElementType *Symbol,
501 int DeviceIndex = 0) {
502 Expected<void *> MaybeAddress =
503 rawGetDeviceSymbolAddress(Symbol, DeviceIndex);
504 if (MaybeAddress.isError())
505 return MaybeAddress.getError();
506 ElementType *Address = static_cast<ElementType *>(MaybeAddress.getValue());
507 Expected<ptrdiff_t> MaybeSize = rawGetDeviceSymbolSize(Symbol, DeviceIndex);
508 if (MaybeSize.isError())
509 return MaybeSize.getError();
510 ptrdiff_t Size = MaybeSize.getValue();
511 return DeviceMemorySpan<ElementType>(this, Address,
512 Size / sizeof(ElementType), 0);
513 }
514
515 /// \name Host memory registration functions.
516 /// \{
517
518 template <typename T>
registerHostMem(Span<const T> Memory)519 Expected<AsyncHostMemory<const T>> registerHostMem(Span<const T> Memory) {
520 Status S = rawRegisterHostMem(Memory.data(), Memory.size() * sizeof(T));
521 if (S.isError())
522 return S;
523 return AsyncHostMemory<const T>(
524 Memory.data(), Memory.size(),
525 this->getUnregisterHostMemoryHandleDestructor());
526 }
527
528 template <typename T>
registerHostMem(Span<T> Memory)529 Expected<AsyncHostMemory<T>> registerHostMem(Span<T> Memory) {
530 Status S = rawRegisterHostMem(Memory.data(), Memory.size() * sizeof(T));
531 if (S.isError())
532 return S;
533 return AsyncHostMemory<T>(Memory.data(), Memory.size(),
534 this->getUnregisterHostMemoryHandleDestructor());
535 }
536
537 template <typename T, size_t N>
registerHostMem(T (& Array)[N])538 Expected<AsyncHostMemory<T>> registerHostMem(T (&Array)[N]) {
539 Span<T> Span(Array);
540 Status S = rawRegisterHostMem(Span.data(), Span.size() * sizeof(T));
541 if (S.isError())
542 return S;
543 return AsyncHostMemory<T>(Span.data(), Span.size(),
544 this->getUnregisterHostMemoryHandleDestructor());
545 }
546
547 /// Registers memory stored in a container with a data() member function and
548 /// which can be converted to a Span<T*>.
549 template <typename Container>
550 auto registerHostMem(Container &Cont) -> Expected<AsyncHostMemory<
551 typename std::remove_reference<decltype(*Cont.data())>::type>> {
552 using ValueType =
553 typename std::remove_reference<decltype(*Cont.data())>::type;
554 Span<ValueType> Span(Cont);
555 Status S = rawRegisterHostMem(Span.data(), Span.size() * sizeof(ValueType));
556 if (S.isError())
557 return S;
558 return AsyncHostMemory<ValueType>(
559 Span.data(), Span.size(),
560 this->getUnregisterHostMemoryHandleDestructor());
561 }
562
563 /// Allocates an owned, registered array of objects on the host.
564 ///
565 /// Default constructs each element in the resulting array.
566 template <typename T>
newAsyncHostMem(ptrdiff_t ElementCount)567 Expected<OwnedAsyncHostMemory<T>> newAsyncHostMem(ptrdiff_t ElementCount) {
568 Expected<void *> MaybeMemory =
569 rawMallocRegisteredH(ElementCount * sizeof(T));
570 if (MaybeMemory.isError())
571 return MaybeMemory.getError();
572 T *Memory = static_cast<T *>(MaybeMemory.getValue());
573 for (ptrdiff_t I = 0; I < ElementCount; ++I)
574 new (Memory + I) T;
575 return OwnedAsyncHostMemory<T>(Memory, ElementCount,
576 this->getFreeHostMemoryHandleDestructor());
577 }
578
579 /// \}
580
581 virtual Expected<Program> createProgramFromSource(Span<const char> Source,
582 int DeviceIndex = 0) = 0;
583
584 protected:
585 friend class Stream;
586 friend class Event;
587 friend class Program;
588 template <typename T> friend class DeviceMemorySpan;
589
getStreamHandle(Stream & Stream)590 void *getStreamHandle(Stream &Stream) { return Stream.TheHandle.get(); }
getEventHandle(Event & Event)591 void *getEventHandle(Event &Event) { return Event.TheHandle.get(); }
592
593 // Pass along access to Stream constructor to subclasses.
constructStream(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)594 Stream constructStream(Platform *APlatform, int DeviceIndex, void *AHandle,
595 HandleDestructor Destructor) {
596 return Stream(APlatform, DeviceIndex, AHandle, Destructor);
597 }
598
599 // Pass along access to Event constructor to subclasses.
constructEvent(Platform * APlatform,int DeviceIndex,void * AHandle,HandleDestructor Destructor)600 Event constructEvent(Platform *APlatform, int DeviceIndex, void *AHandle,
601 HandleDestructor Destructor) {
602 return Event(APlatform, DeviceIndex, AHandle, Destructor);
603 }
604
605 // Pass along access to Program constructor to subclasses.
constructProgram(Platform * APlatform,void * AHandle,HandleDestructor Destructor)606 Program constructProgram(Platform *APlatform, void *AHandle,
607 HandleDestructor Destructor) {
608 return Program(APlatform, AHandle, Destructor);
609 }
610
611 virtual Status streamSync(void *Stream) = 0;
612 virtual Status streamWaitOnEvent(void *Stream, void *Event) = 0;
613
614 virtual Status enqueueEvent(void *Event, void *Stream) = 0;
615 virtual bool eventIsDone(void *Event) = 0;
616 virtual Status eventSync(void *Event) = 0;
617 virtual Expected<float> getSecondsBetweenEvents(void *StartEvent,
618 void *EndEvent) = 0;
619
620 virtual Expected<void *> rawMallocD(ptrdiff_t ByteCount, int DeviceIndex) = 0;
621 virtual HandleDestructor getDeviceMemoryHandleDestructor() = 0;
622 virtual void *getDeviceMemorySpanHandle(void *BaseHandle, size_t ByteSize,
623 size_t ByteOffset) = 0;
624 virtual void rawDestroyDeviceMemorySpanHandle(void *Handle) = 0;
625
626 virtual Expected<void *> rawGetDeviceSymbolAddress(const void *Symbol,
627 int DeviceIndex) = 0;
628 virtual Expected<ptrdiff_t> rawGetDeviceSymbolSize(const void *Symbol,
629 int DeviceIndex) = 0;
630
631 virtual Status rawRegisterHostMem(const void *Memory,
632 ptrdiff_t ByteCount) = 0;
633 virtual HandleDestructor getUnregisterHostMemoryHandleDestructor() = 0;
634
635 virtual Expected<void *> rawMallocRegisteredH(ptrdiff_t ByteCount) = 0;
636 virtual HandleDestructor getFreeHostMemoryHandleDestructor() = 0;
637
638 virtual Status asyncCopyDToD(const void *DeviceSrc,
639 ptrdiff_t DeviceSrcByteOffset, void *DeviceDst,
640 ptrdiff_t DeviceDstByteOffset,
641 ptrdiff_t ByteCount, void *Stream) = 0;
642 virtual Status asyncCopyDToH(const void *DeviceSrc,
643 ptrdiff_t DeviceSrcByteOffset, void *HostDst,
644 ptrdiff_t ByteCount, void *Stream) = 0;
645 virtual Status asyncCopyHToD(const void *HostSrc, void *DeviceDst,
646 ptrdiff_t DeviceDstByteOffset,
647 ptrdiff_t ByteCount, void *Stream) = 0;
648
649 virtual Status asyncMemsetD(void *DeviceDst, ptrdiff_t ByteOffset,
650 ptrdiff_t ByteCount, char ByteValue,
651 void *Stream) = 0;
652
653 virtual Status addStreamCallback(Stream &Stream, StreamCallback Callback) = 0;
654
655 virtual Expected<void *> rawCreateKernel(void *Program,
656 const std::string &Name) = 0;
657 virtual HandleDestructor getKernelHandleDestructor() = 0;
658
659 virtual Status rawEnqueueKernelLaunch(void *Stream, void *Kernel,
660 KernelLaunchDimensions LaunchDimensions,
661 Span<void *> Arguments,
662 Span<size_t> ArgumentSizes,
663 size_t SharedMemoryBytes) = 0;
664 };
665
666 // Implementation of templated Stream functions.
667
668 template <typename DeviceSrcTy, typename DeviceDstTy>
asyncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst)669 Stream &Stream::asyncCopyDToD(DeviceSrcTy &&DeviceSrc,
670 DeviceDstTy &&DeviceDst) {
671 using SrcElementTy =
672 typename std::remove_reference<DeviceSrcTy>::type::value_type;
673 using DstElementTy =
674 typename std::remove_reference<DeviceDstTy>::type::value_type;
675 static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
676 "asyncCopyDToD cannot copy between arrays of different types");
677 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
678 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
679 if (DeviceSrcSpan.size() != DeviceDstSpan.size()) {
680 setStatus(Status("asyncCopyDToD source element count " +
681 std::to_string(DeviceSrcSpan.size()) +
682 " does not equal destination element count " +
683 std::to_string(DeviceDstSpan.size())));
684 return *this;
685 }
686 setStatus(ThePlatform->asyncCopyDToD(
687 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
688 DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
689 DeviceSrcSpan.byte_size(), TheHandle.get()));
690 return *this;
691 }
692
693 template <typename DeviceSrcTy, typename DeviceDstTy>
asyncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst,ptrdiff_t ElementCount)694 Stream &Stream::asyncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
695 ptrdiff_t ElementCount) {
696 using SrcElementTy =
697 typename std::remove_reference<DeviceSrcTy>::type::value_type;
698 using DstElementTy =
699 typename std::remove_reference<DeviceDstTy>::type::value_type;
700 static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
701 "asyncCopyDToD cannot copy between arrays of different types");
702 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
703 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
704 if (DeviceSrcSpan.size() < ElementCount) {
705 setStatus(Status("asyncCopyDToD source element count " +
706 std::to_string(DeviceSrcSpan.size()) +
707 " is less than requested element count " +
708 std::to_string(ElementCount)));
709 return *this;
710 }
711 if (DeviceDstSpan.size() < ElementCount) {
712 setStatus(Status("asyncCopyDToD destination element count " +
713 std::to_string(DeviceDst.size()) +
714 " is less than requested element count " +
715 std::to_string(ElementCount)));
716 return *this;
717 }
718 setStatus(ThePlatform->asyncCopyDToD(
719 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
720 DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
721 ElementCount * sizeof(SrcElementTy), TheHandle.get()));
722 return *this;
723 }
724
725 template <typename DeviceSrcTy, typename HostDstTy>
asyncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst)726 Stream &Stream::asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst) {
727 using SrcElementTy =
728 typename std::remove_reference<DeviceSrcTy>::type::value_type;
729 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
730 AsyncHostMemorySpan<SrcElementTy> HostDstSpan(HostDst);
731 if (DeviceSrcSpan.size() != HostDstSpan.size()) {
732 setStatus(Status("asyncCopyDToH source element count " +
733 std::to_string(DeviceSrcSpan.size()) +
734 " does not equal destination element count " +
735 std::to_string(HostDstSpan.size())));
736 return *this;
737 }
738 setStatus(ThePlatform->asyncCopyDToH(
739 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
740 HostDstSpan.data(), DeviceSrcSpan.byte_size(), TheHandle.get()));
741 return *this;
742 }
743
744 template <typename DeviceSrcTy, typename HostDstTy>
asyncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst,ptrdiff_t ElementCount)745 Stream &Stream::asyncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
746 ptrdiff_t ElementCount) {
747 using SrcElementTy =
748 typename std::remove_reference<DeviceSrcTy>::type::value_type;
749 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
750 AsyncHostMemorySpan<SrcElementTy> HostDstSpan(HostDst);
751 if (DeviceSrcSpan.size() < ElementCount) {
752 setStatus(Status("asyncCopyDToH source element count " +
753 std::to_string(DeviceSrcSpan.size()) +
754 " is less than requested element count " +
755 std::to_string(ElementCount)));
756 return *this;
757 }
758 if (HostDstSpan.size() < ElementCount) {
759 setStatus(Status("asyncCopyDToH destination element count " +
760 std::to_string(HostDstSpan.size()) +
761 " is less than requested element count " +
762 std::to_string(ElementCount)));
763 return *this;
764 }
765 setStatus(ThePlatform->asyncCopyDToH(
766 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
767 HostDstSpan.data(), ElementCount * sizeof(SrcElementTy),
768 TheHandle.get()));
769 return *this;
770 }
771
772 template <typename HostSrcTy, typename DeviceDstTy>
asyncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy && DeviceDst)773 Stream &Stream::asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst) {
774 using DstElementTy =
775 typename std::remove_reference<DeviceDstTy>::type::value_type;
776 AsyncHostMemorySpan<const DstElementTy> HostSrcSpan(HostSrc);
777 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
778 if (HostSrcSpan.size() != DeviceDstSpan.size()) {
779 setStatus(Status("asyncCopyHToD source element count " +
780 std::to_string(HostSrcSpan.size()) +
781 " does not equal destination element count " +
782 std::to_string(DeviceDstSpan.size())));
783 return *this;
784 }
785 setStatus(ThePlatform->asyncCopyHToD(
786 HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
787 DeviceDstSpan.byte_offset(), HostSrcSpan.byte_size(), TheHandle.get()));
788 return *this;
789 }
790
791 template <typename HostSrcTy, typename DeviceDstTy>
asyncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy & DeviceDst,ptrdiff_t ElementCount)792 Stream &Stream::asyncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
793 ptrdiff_t ElementCount) {
794 using DstElementTy =
795 typename std::remove_reference<DeviceDstTy>::type::value_type;
796 AsyncHostMemorySpan<const DstElementTy> HostSrcSpan(HostSrc);
797 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
798 if (HostSrcSpan.size() < ElementCount) {
799 setStatus(Status("copyHToD source element count " +
800 std::to_string(HostSrcSpan.size()) +
801 " is less than requested element count " +
802 std::to_string(ElementCount)));
803 return *this;
804 }
805 if (DeviceDstSpan.size() < ElementCount) {
806 setStatus(Status("copyHToD destination element count " +
807 std::to_string(DeviceDstSpan.size()) +
808 " is less than requested element count " +
809 std::to_string(ElementCount)));
810 return *this;
811 }
812 setStatus(ThePlatform->asyncCopyHToD(
813 HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
814 DeviceDstSpan.byte_offset(), ElementCount * sizeof(DstElementTy),
815 TheHandle.get()));
816 return *this;
817 }
818
819 template <typename DeviceDstTy>
asyncMemsetD(DeviceDstTy && DeviceDst,char ByteValue)820 Stream &Stream::asyncMemsetD(DeviceDstTy &&DeviceDst, char ByteValue) {
821 using DstElementTy =
822 typename std::remove_reference<DeviceDstTy>::type::value_type;
823 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
824 setStatus(ThePlatform->asyncMemsetD(
825 DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
826 DeviceDstSpan.byte_size(), ByteValue, TheHandle.get()));
827 return *this;
828 }
829
830 template <typename DeviceSrcTy, typename DeviceDstTy>
syncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst)831 Stream &Stream::syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst) {
832 using SrcElementTy =
833 typename std::remove_reference<DeviceSrcTy>::type::value_type;
834 using DstElementTy =
835 typename std::remove_reference<DeviceDstTy>::type::value_type;
836 static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
837 "copyDToD cannot copy between arrays of different types");
838 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
839 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
840 if (DeviceSrcSpan.size() != DeviceDstSpan.size()) {
841 setStatus(Status("copyDToD source element count " +
842 std::to_string(DeviceSrcSpan.size()) +
843 " does not equal destination element count " +
844 std::to_string(DeviceDstSpan.size())));
845 return *this;
846 }
847 if (setStatus(ThePlatform->asyncCopyDToD(
848 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
849 DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
850 DeviceSrcSpan.byte_size(), TheHandle.get()))
851 .isError()) {
852 return *this;
853 }
854 setStatus(sync());
855 return *this;
856 }
857
858 template <typename DeviceSrcTy, typename DeviceDstTy>
syncCopyDToD(DeviceSrcTy && DeviceSrc,DeviceDstTy && DeviceDst,ptrdiff_t ElementCount)859 Stream &Stream::syncCopyDToD(DeviceSrcTy &&DeviceSrc, DeviceDstTy &&DeviceDst,
860 ptrdiff_t ElementCount) {
861 using SrcElementTy =
862 typename std::remove_reference<DeviceSrcTy>::type::value_type;
863 using DstElementTy =
864 typename std::remove_reference<DeviceDstTy>::type::value_type;
865 static_assert(std::is_same<SrcElementTy, DstElementTy>::value,
866 "copyDToD cannot copy between arrays of different types");
867 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
868 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
869 if (DeviceSrcSpan.size() < ElementCount) {
870 setStatus(Status("copyDToD source element count " +
871 std::to_string(DeviceSrcSpan.size()) +
872 " is less than requested element count " +
873 std::to_string(ElementCount)));
874 return *this;
875 }
876 if (DeviceDstSpan.size() < ElementCount) {
877 setStatus(Status("copyDToD destination element count " +
878 std::to_string(DeviceDst.size()) +
879 " is less than requested element count " +
880 std::to_string(ElementCount)));
881 return *this;
882 }
883 if (setStatus(ThePlatform->asyncCopyDToD(
884 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
885 DeviceDstSpan.baseHandle(), DeviceDstSpan.byte_offset(),
886 ElementCount * sizeof(SrcElementTy), TheHandle.get()))
887 .isError()) {
888 return *this;
889 }
890 setStatus(sync());
891 return *this;
892 }
893
894 template <typename DeviceSrcTy, typename HostDstTy>
syncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst)895 Stream &Stream::syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst) {
896 using SrcElementTy =
897 typename std::remove_reference<DeviceSrcTy>::type::value_type;
898 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
899 Span<SrcElementTy> HostDstSpan(HostDst);
900 if (DeviceSrcSpan.size() != HostDstSpan.size()) {
901 setStatus(Status("copyDToH source element count " +
902 std::to_string(DeviceSrcSpan.size()) +
903 " does not equal destination element count " +
904 std::to_string(HostDstSpan.size())));
905 return *this;
906 }
907 if (setStatus(ThePlatform->asyncCopyDToH(
908 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
909 HostDstSpan.data(), DeviceSrcSpan.byte_size(),
910 TheHandle.get()))
911 .isError()) {
912 return *this;
913 }
914 setStatus(sync());
915 return *this;
916 }
917
918 template <typename DeviceSrcTy, typename HostDstTy>
syncCopyDToH(DeviceSrcTy && DeviceSrc,HostDstTy && HostDst,ptrdiff_t ElementCount)919 Stream &Stream::syncCopyDToH(DeviceSrcTy &&DeviceSrc, HostDstTy &&HostDst,
920 ptrdiff_t ElementCount) {
921 using SrcElementTy =
922 typename std::remove_reference<DeviceSrcTy>::type::value_type;
923 DeviceMemorySpan<const SrcElementTy> DeviceSrcSpan(DeviceSrc);
924 Span<SrcElementTy> HostDstSpan(HostDst);
925 if (DeviceSrcSpan.size() < ElementCount) {
926 setStatus(Status("copyDToH source element count " +
927 std::to_string(DeviceSrcSpan.size()) +
928 " is less than requested element count " +
929 std::to_string(ElementCount)));
930 return *this;
931 }
932 if (HostDstSpan.size() < ElementCount) {
933 setStatus(Status("copyDToH destination element count " +
934 std::to_string(HostDstSpan.size()) +
935 " is less than requested element count " +
936 std::to_string(ElementCount)));
937 return *this;
938 }
939 if (setStatus(ThePlatform->asyncCopyDToH(
940 DeviceSrcSpan.baseHandle(), DeviceSrcSpan.byte_offset(),
941 HostDstSpan.data(), ElementCount * sizeof(SrcElementTy),
942 TheHandle.get()))
943 .isError()) {
944 return *this;
945 }
946 setStatus(sync());
947 return *this;
948 }
949
950 template <typename HostSrcTy, typename DeviceDstTy>
syncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy && DeviceDst)951 Stream &Stream::syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &&DeviceDst) {
952 using DstElementTy =
953 typename std::remove_reference<DeviceDstTy>::type::value_type;
954 Span<const DstElementTy> HostSrcSpan(HostSrc);
955 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
956 if (HostSrcSpan.size() != DeviceDstSpan.size()) {
957 setStatus(Status("copyHToD source element count " +
958 std::to_string(HostSrcSpan.size()) +
959 " does not equal destination element count " +
960 std::to_string(DeviceDstSpan.size())));
961 return *this;
962 }
963 if (setStatus(ThePlatform->asyncCopyHToD(
964 HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
965 DeviceDstSpan.byte_offset(), DeviceDstSpan.byte_size(),
966 TheHandle.get()))
967 .isError()) {
968 return *this;
969 }
970 setStatus(sync());
971 return *this;
972 }
973
974 template <typename HostSrcTy, typename DeviceDstTy>
syncCopyHToD(HostSrcTy && HostSrc,DeviceDstTy & DeviceDst,ptrdiff_t ElementCount)975 Stream &Stream::syncCopyHToD(HostSrcTy &&HostSrc, DeviceDstTy &DeviceDst,
976 ptrdiff_t ElementCount) {
977 using DstElementTy =
978 typename std::remove_reference<DeviceDstTy>::type::value_type;
979 Span<const DstElementTy> HostSrcSpan(HostSrc);
980 DeviceMemorySpan<DstElementTy> DeviceDstSpan(DeviceDst);
981 if (HostSrcSpan.size() < ElementCount) {
982 setStatus(Status("copyHToD source element count " +
983 std::to_string(HostSrcSpan.size()) +
984 " is less than requested element count " +
985 std::to_string(ElementCount)));
986 return *this;
987 }
988 if (DeviceDstSpan.size() < ElementCount) {
989 setStatus(Status("copyHToD destination element count " +
990 std::to_string(DeviceDstSpan.size()) +
991 " is less than requested element count " +
992 std::to_string(ElementCount)));
993 return *this;
994 }
995 if (setStatus(ThePlatform->asyncCopyHToD(
996 HostSrcSpan.data(), DeviceDstSpan.baseHandle(),
997 DeviceDstSpan.byte_offset(),
998 ElementCount * sizeof(DstElementTy), TheHandle.get()))
999 .isError()) {
1000 return *this;
1001 }
1002 setStatus(sync());
1003 return *this;
1004 }
1005
1006 /// Owned device memory.
1007 ///
1008 /// Device memory that frees itself when it goes out of scope.
1009 template <typename ElementType> class DeviceMemory {
1010 public:
1011 using element_type = ElementType;
1012 using index_type = std::ptrdiff_t;
1013 using value_type = typename std::remove_const<element_type>::type;
1014
1015 DeviceMemory(const DeviceMemory &) = delete;
1016 DeviceMemory &operator=(const DeviceMemory &) = delete;
1017 DeviceMemory(DeviceMemory &&) noexcept;
1018 DeviceMemory &operator=(DeviceMemory &&) noexcept;
1019 ~DeviceMemory() = default;
1020
1021 /// Gets the raw base handle for the underlying platform implementation.
handle()1022 void *handle() const { return ThePointer.get(); }
1023
length()1024 index_type length() const { return TheSize; }
size()1025 index_type size() const { return TheSize; }
byte_size()1026 index_type byte_size() const { // NOLINT
1027 return TheSize * sizeof(element_type);
1028 }
empty()1029 bool empty() const { return TheSize == 0; }
1030
1031 // These conversion operators are useful for making triple-chevron kernel
1032 // launches more concise.
1033 operator element_type *() {
1034 return static_cast<element_type *>(ThePointer.get());
1035 }
1036 operator const element_type *() const { return ThePointer.get(); }
1037
1038 /// Converts a const object to a DeviceMemorySpan of const elements.
asSpan()1039 DeviceMemorySpan<const element_type> asSpan() const {
1040 return DeviceMemorySpan<const element_type>(
1041 ThePlatform, static_cast<const element_type *>(ThePointer.get()),
1042 TheSize, 0);
1043 }
1044
1045 /// Converts an object to a DeviceMemorySpan.
asSpan()1046 DeviceMemorySpan<element_type> asSpan() {
1047 return DeviceMemorySpan<element_type>(
1048 ThePlatform, static_cast<element_type *>(ThePointer.get()), TheSize, 0);
1049 }
1050
1051 private:
1052 friend class Platform;
1053 template <typename T> friend class DeviceMemorySpan;
1054
DeviceMemory(Platform * ThePlatform,void * Pointer,index_type ElementCount,HandleDestructor Destructor)1055 DeviceMemory(Platform *ThePlatform, void *Pointer, index_type ElementCount,
1056 HandleDestructor Destructor)
1057 : ThePlatform(ThePlatform), ThePointer(Pointer, Destructor),
1058 TheSize(ElementCount) {}
1059
1060 Platform *ThePlatform;
1061 std::unique_ptr<void, HandleDestructor> ThePointer;
1062 ptrdiff_t TheSize;
1063 };
1064
1065 template <typename T>
1066 DeviceMemory<T>::DeviceMemory(DeviceMemory &&) noexcept = default;
1067 template <typename T>
1068 DeviceMemory<T> &DeviceMemory<T>::operator=(DeviceMemory &&) noexcept = default;
1069
1070 /// View into device memory.
1071 ///
1072 /// Like a Span, but for device memory rather than host memory.
1073 template <typename ElementType> class DeviceMemorySpan {
1074 public:
1075 /// \name constants and types
1076 /// \{
1077 using element_type = ElementType;
1078 using index_type = std::ptrdiff_t;
1079 using pointer = element_type *;
1080 using reference = element_type &;
1081 using iterator = element_type *;
1082 using const_iterator = const element_type *;
1083 using value_type = typename std::remove_const<element_type>::type;
1084 /// \}
1085
DeviceMemorySpan()1086 DeviceMemorySpan()
1087 : ThePlatform(nullptr), TheHandle(nullptr), TheSize(0), TheOffset(0),
1088 TheSpanHandle(nullptr) {}
1089
1090 // Intentionally implicit.
1091 template <typename OtherElementType>
DeviceMemorySpan(DeviceMemorySpan<OtherElementType> & ASpan)1092 DeviceMemorySpan(DeviceMemorySpan<OtherElementType> &ASpan)
1093 : ThePlatform(ASpan.ThePlatform),
1094 TheHandle(static_cast<pointer>(ASpan.baseHandle())),
1095 TheSize(ASpan.size()), TheOffset(ASpan.offset()),
1096 TheSpanHandle(nullptr) {}
1097
1098 // Intentionally implicit.
1099 template <typename OtherElementType>
DeviceMemorySpan(DeviceMemorySpan<OtherElementType> && ASpan)1100 DeviceMemorySpan(DeviceMemorySpan<OtherElementType> &&ASpan)
1101 : ThePlatform(ASpan.ThePlatform),
1102 TheHandle(static_cast<pointer>(ASpan.baseHandle())),
1103 TheSize(ASpan.size()), TheOffset(ASpan.offset()),
1104 TheSpanHandle(nullptr) {}
1105
1106 // Intentionally implicit.
1107 template <typename OtherElementType>
DeviceMemorySpan(DeviceMemory<OtherElementType> & Memory)1108 DeviceMemorySpan(DeviceMemory<OtherElementType> &Memory)
1109 : ThePlatform(Memory.ThePlatform),
1110 TheHandle(static_cast<value_type *>(Memory.handle())),
1111 TheSize(Memory.size()), TheOffset(0), TheSpanHandle(nullptr) {}
1112
~DeviceMemorySpan()1113 ~DeviceMemorySpan() {
1114 if (TheSpanHandle) {
1115 ThePlatform->rawDestroyDeviceMemorySpanHandle(
1116 const_cast<value_type *>(TheSpanHandle));
1117 }
1118 }
1119
1120 /// \name observers
1121 /// \{
length()1122 index_type length() const { return TheSize; }
size()1123 index_type size() const { return TheSize; }
byte_size()1124 index_type byte_size() const { // NOLINT
1125 return TheSize * sizeof(element_type);
1126 }
offset()1127 index_type offset() const { return TheOffset; }
byte_offset()1128 index_type byte_offset() const { // NOLINT
1129 return TheOffset * sizeof(element_type);
1130 }
empty()1131 bool empty() const { return TheSize == 0; }
1132 /// \}
1133
baseHandle()1134 void *baseHandle() const {
1135 return static_cast<void *>(const_cast<value_type *>(TheHandle));
1136 }
1137
1138 /// Casts to a host memory pointer.
1139 ///
1140 /// This is only guaranteed to make sense for the CUDA platform, where device
1141 /// pointers can be stored and manipulated much like host pointers. This makes
1142 /// it easy to do triple-chevron kernel launches in CUDA because
1143 /// DeviceMemorySpan values can be passed to parameters expecting regular
1144 /// pointers.
1145 ///
1146 /// If the CUDA platform is using unified memory, it may also be possible to
1147 /// dereference this pointer on the host.
1148 ///
1149 /// For platforms other than CUDA, this may return a garbage pointer.
1150 operator element_type *() const {
1151 if (!TheSpanHandle)
1152 TheSpanHandle = ThePlatform->getDeviceMemorySpanHandle(
1153 TheHandle, TheSize * sizeof(element_type),
1154 TheOffset * sizeof(element_type));
1155 return TheSpanHandle;
1156 }
1157
first(index_type Count)1158 DeviceMemorySpan<element_type> first(index_type Count) const {
1159 bool Valid = Count >= 0 && Count <= TheSize;
1160 if (!Valid)
1161 std::terminate();
1162 return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1163 TheOffset);
1164 }
1165
last(index_type Count)1166 DeviceMemorySpan<element_type> last(index_type Count) const {
1167 bool Valid = Count >= 0 && Count <= TheSize;
1168 if (!Valid)
1169 std::terminate();
1170 return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1171 TheOffset + TheSize - Count);
1172 }
1173
1174 DeviceMemorySpan<element_type>
1175 subspan(index_type Offset, index_type Count = dynamic_extent) const {
1176 bool Valid =
1177 (Offset == 0 || (Offset > 0 && Offset <= TheSize)) &&
1178 (Count == dynamic_extent || (Count >= 0 && Offset + Count <= TheSize));
1179 if (!Valid)
1180 std::terminate();
1181 return DeviceMemorySpan<element_type>(ThePlatform, TheHandle, Count,
1182 TheOffset + Offset);
1183 }
1184
1185 private:
1186 template <typename T> friend class DeviceMemory;
1187 template <typename T> friend class DeviceMemorySpan;
1188 friend class Platform;
1189
DeviceMemorySpan(Platform * ThePlatform,pointer AHandle,index_type Size,index_type Offset)1190 DeviceMemorySpan(Platform *ThePlatform, pointer AHandle, index_type Size,
1191 index_type Offset)
1192 : ThePlatform(ThePlatform), TheHandle(AHandle), TheSize(Size),
1193 TheOffset(Offset), TheSpanHandle(nullptr) {}
1194
1195 Platform *ThePlatform;
1196 pointer TheHandle;
1197 index_type TheSize;
1198 index_type TheOffset;
1199 pointer TheSpanHandle;
1200 };
1201
1202 /// Asynchronous host memory.
1203 ///
1204 /// This memory is pinned or otherwise registered in the host memory space to
1205 /// allow for asynchronous copies between it and device memory.
1206 ///
1207 /// This memory unpins/unregisters itself when it goes out of scope, but does
1208 /// not free itself.
1209 template <typename ElementType> class AsyncHostMemory {
1210 public:
1211 using value_type = ElementType;
1212 using remove_const_type = typename std::remove_const<ElementType>::type;
1213
1214 AsyncHostMemory(const AsyncHostMemory &) = delete;
1215 AsyncHostMemory &operator=(const AsyncHostMemory &) = delete;
1216 AsyncHostMemory(AsyncHostMemory &&) noexcept;
1217 AsyncHostMemory &operator=(AsyncHostMemory &&) noexcept;
1218 ~AsyncHostMemory() = default;
1219
1220 template <typename OtherElementType>
AsyncHostMemory(AsyncHostMemory<OtherElementType> && Other)1221 AsyncHostMemory(AsyncHostMemory<OtherElementType> &&Other)
1222 : ThePointer(std::move(Other.ThePointer)),
1223 TheElementCount(Other.TheElementCount) {
1224 static_assert(
1225 std::is_assignable<ElementType *, OtherElementType *>::value,
1226 "cannot assign OtherElementType pointer to ElementType pointer type");
1227 }
1228
data()1229 ElementType *data() const {
1230 return const_cast<ElementType *>(
1231 static_cast<remove_const_type *>(ThePointer.get()));
1232 }
size()1233 ptrdiff_t size() const { return TheElementCount; }
1234
1235 private:
1236 template <typename U> friend class AsyncHostMemory;
1237 friend class Platform;
AsyncHostMemory(ElementType * Pointer,ptrdiff_t ElementCount,HandleDestructor Destructor)1238 AsyncHostMemory(ElementType *Pointer, ptrdiff_t ElementCount,
1239 HandleDestructor Destructor)
1240 : ThePointer(
1241 static_cast<void *>(const_cast<remove_const_type *>(Pointer)),
1242 Destructor),
1243 TheElementCount(ElementCount) {}
1244
1245 std::unique_ptr<void, HandleDestructor> ThePointer;
1246 ptrdiff_t TheElementCount;
1247 };
1248
1249 template <typename T>
1250 AsyncHostMemory<T>::AsyncHostMemory(AsyncHostMemory &&) noexcept = default;
1251 template <typename T>
1252 AsyncHostMemory<T> &AsyncHostMemory<T>::
1253 operator=(AsyncHostMemory &&) noexcept = default;
1254
1255 /// Owned registered host memory.
1256 ///
1257 /// Like AsyncHostMemory, but this memory also frees itself in addition to
1258 /// unpinning/unregistering itself when it goes out of scope.
1259 template <typename ElementType> class OwnedAsyncHostMemory {
1260 public:
1261 using remove_const_type = typename std::remove_const<ElementType>::type;
1262
1263 OwnedAsyncHostMemory(const OwnedAsyncHostMemory &) = delete;
1264 OwnedAsyncHostMemory &operator=(const OwnedAsyncHostMemory &) = delete;
1265 OwnedAsyncHostMemory(OwnedAsyncHostMemory &&) noexcept;
1266 OwnedAsyncHostMemory &operator=(OwnedAsyncHostMemory &&) noexcept;
1267
~OwnedAsyncHostMemory()1268 ~OwnedAsyncHostMemory() {
1269 if (ThePointer.get()) {
1270 // We use placement new to construct these objects, so we have to call the
1271 // destructors explicitly.
1272 for (ptrdiff_t I = 0; I < TheElementCount; ++I)
1273 static_cast<ElementType *>(ThePointer.get())[I].~ElementType();
1274 }
1275 }
1276
get()1277 ElementType *get() const {
1278 return const_cast<ElementType *>(
1279 static_cast<remove_const_type *>(ThePointer.get()));
1280 }
1281
1282 ElementType &operator[](ptrdiff_t I) const {
1283 assert(I >= 0 && I < TheElementCount);
1284 return get()[I];
1285 }
1286
1287 private:
1288 template <typename T> friend class AsyncHostMemorySpan;
1289
1290 friend class Platform;
1291
OwnedAsyncHostMemory(void * Memory,ptrdiff_t ElementCount,HandleDestructor Destructor)1292 OwnedAsyncHostMemory(void *Memory, ptrdiff_t ElementCount,
1293 HandleDestructor Destructor)
1294 : ThePointer(Memory, Destructor), TheElementCount(ElementCount) {}
1295
1296 std::unique_ptr<void, HandleDestructor> ThePointer;
1297 ptrdiff_t TheElementCount;
1298 };
1299
1300 template <typename T>
1301 OwnedAsyncHostMemory<T>::OwnedAsyncHostMemory(
1302 OwnedAsyncHostMemory &&) noexcept = default;
1303 template <typename T>
1304 OwnedAsyncHostMemory<T> &OwnedAsyncHostMemory<T>::
1305 operator=(OwnedAsyncHostMemory &&) noexcept = default;
1306
1307 /// View into registered host memory.
1308 ///
1309 /// Like Span but for registered host memory.
1310 template <typename ElementType> class AsyncHostMemorySpan {
1311 public:
1312 /// \name constants and types
1313 /// \{
1314 using element_type = ElementType;
1315 using index_type = std::ptrdiff_t;
1316 using pointer = element_type *;
1317 using reference = element_type &;
1318 using iterator = element_type *;
1319 using const_iterator = const element_type *;
1320 using value_type = typename std::remove_const<element_type>::type;
1321 /// \}
1322
AsyncHostMemorySpan()1323 AsyncHostMemorySpan() : TheSpan() {}
1324
1325 // Intentionally implicit.
1326 template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemory<OtherElementType> & Memory)1327 AsyncHostMemorySpan(AsyncHostMemory<OtherElementType> &Memory)
1328 : TheSpan(Memory.data(), Memory.size()) {}
1329
1330 // Intentionally implicit.
1331 template <typename OtherElementType>
AsyncHostMemorySpan(OwnedAsyncHostMemory<OtherElementType> & Owned)1332 AsyncHostMemorySpan(OwnedAsyncHostMemory<OtherElementType> &Owned)
1333 : TheSpan(Owned.get(), Owned.TheElementCount) {}
1334
1335 // Intentionally implicit.
1336 template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> & ASpan)1337 AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> &ASpan)
1338 : TheSpan(ASpan) {}
1339
1340 // Intentionally implicit.
1341 template <typename OtherElementType>
AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> && Span)1342 AsyncHostMemorySpan(AsyncHostMemorySpan<OtherElementType> &&Span)
1343 : TheSpan(Span) {}
1344
1345 /// \name observers
1346 /// \{
length()1347 index_type length() const { return TheSpan.length(); }
size()1348 index_type size() const { return TheSpan.size(); }
byte_size()1349 index_type byte_size() const { // NOLINT
1350 return TheSpan.size() * sizeof(element_type);
1351 }
empty()1352 bool empty() const { return TheSpan.empty(); }
1353 /// \}
1354
data()1355 pointer data() const noexcept { return TheSpan.data(); }
1356 operator element_type *() const { return TheSpan.data(); }
1357
first(index_type Count)1358 AsyncHostMemorySpan<element_type> first(index_type Count) const {
1359 return AsyncHostMemorySpan<element_type>(TheSpan.first(Count));
1360 }
1361
last(index_type Count)1362 AsyncHostMemorySpan<element_type> last(index_type Count) const {
1363 return AsyncHostMemorySpan<element_type>(TheSpan.last(Count));
1364 }
1365
1366 AsyncHostMemorySpan<element_type>
1367 subspan(index_type Offset, index_type Count = dynamic_extent) const {
1368 return AsyncHostMemorySpan<element_type>(TheSpan.subspan(Offset, Count));
1369 }
1370
1371 private:
1372 template <typename T> friend class AsyncHostMemory;
1373
AsyncHostMemorySpan(Span<ElementType> ArraySpan)1374 explicit AsyncHostMemorySpan(Span<ElementType> ArraySpan)
1375 : TheSpan(ArraySpan) {}
1376
1377 Span<ElementType> TheSpan;
1378 };
1379
1380 } // namespace acxxel
1381
1382 #endif // ACXXEL_ACXXEL_H
1383