1import sparse
2
3from .serialize import dask_deserialize, dask_serialize, deserialize, serialize
4
5
6@dask_serialize.register(sparse.COO)
7def serialize_sparse(x):
8    coords_header, coords_frames = serialize(x.coords)
9    data_header, data_frames = serialize(x.data)
10
11    header = {
12        "coords-header": coords_header,
13        "data-header": data_header,
14        "shape": x.shape,
15        "nframes": [len(coords_frames), len(data_frames)],
16    }
17    return header, coords_frames + data_frames
18
19
20@dask_deserialize.register(sparse.COO)
21def deserialize_sparse(header, frames):
22
23    coords_frames = frames[: header["nframes"][0]]
24    data_frames = frames[header["nframes"][0] :]
25
26    coords = deserialize(header["coords-header"], coords_frames)
27    data = deserialize(header["data-header"], data_frames)
28
29    shape = header["shape"]
30
31    return sparse.COO(coords, data, shape=shape)
32