1import numba
2import numba.cuda
3import numpy
4import rmm
5
6from .cuda import cuda_deserialize, cuda_serialize
7from .serialize import dask_deserialize, dask_serialize
8
9# Used for RMM 0.11.0+ otherwise Numba serializers used
10if hasattr(rmm, "DeviceBuffer"):
11
12    @cuda_serialize.register(rmm.DeviceBuffer)
13    def cuda_serialize_rmm_device_buffer(x):
14        header = x.__cuda_array_interface__.copy()
15        header["strides"] = (1,)
16        frames = [x]
17        return header, frames
18
19    @cuda_deserialize.register(rmm.DeviceBuffer)
20    def cuda_deserialize_rmm_device_buffer(header, frames):
21        (arr,) = frames
22
23        # We should already have `DeviceBuffer`
24        # as RMM is used preferably for allocations
25        # when it is available (as it is here).
26        assert isinstance(arr, rmm.DeviceBuffer)
27
28        return arr
29
30    @dask_serialize.register(rmm.DeviceBuffer)
31    def dask_serialize_rmm_device_buffer(x):
32        header, frames = cuda_serialize_rmm_device_buffer(x)
33        frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames]
34        return header, frames
35
36    @dask_deserialize.register(rmm.DeviceBuffer)
37    def dask_deserialize_rmm_device_buffer(header, frames):
38        (frame,) = frames
39
40        arr = numpy.asarray(memoryview(frame))
41        ptr = arr.__array_interface__["data"][0]
42        size = arr.nbytes
43
44        buf = rmm.DeviceBuffer(ptr=ptr, size=size)
45
46        return buf
47