1import numba 2import numba.cuda 3import numpy 4import rmm 5 6from .cuda import cuda_deserialize, cuda_serialize 7from .serialize import dask_deserialize, dask_serialize 8 9# Used for RMM 0.11.0+ otherwise Numba serializers used 10if hasattr(rmm, "DeviceBuffer"): 11 12 @cuda_serialize.register(rmm.DeviceBuffer) 13 def cuda_serialize_rmm_device_buffer(x): 14 header = x.__cuda_array_interface__.copy() 15 header["strides"] = (1,) 16 frames = [x] 17 return header, frames 18 19 @cuda_deserialize.register(rmm.DeviceBuffer) 20 def cuda_deserialize_rmm_device_buffer(header, frames): 21 (arr,) = frames 22 23 # We should already have `DeviceBuffer` 24 # as RMM is used preferably for allocations 25 # when it is available (as it is here). 26 assert isinstance(arr, rmm.DeviceBuffer) 27 28 return arr 29 30 @dask_serialize.register(rmm.DeviceBuffer) 31 def dask_serialize_rmm_device_buffer(x): 32 header, frames = cuda_serialize_rmm_device_buffer(x) 33 frames = [numba.cuda.as_cuda_array(f).copy_to_host().data for f in frames] 34 return header, frames 35 36 @dask_deserialize.register(rmm.DeviceBuffer) 37 def dask_deserialize_rmm_device_buffer(header, frames): 38 (frame,) = frames 39 40 arr = numpy.asarray(memoryview(frame)) 41 ptr = arr.__array_interface__["data"][0] 42 size = arr.nbytes 43 44 buf = rmm.DeviceBuffer(ptr=ptr, size=size) 45 46 return buf 47