1import numpy as np
2
3from ..core import indexing
4from ..core.utils import Frozen, FrozenDict, close_on_error
5from ..core.variable import Variable
6from .common import (
7    BACKEND_ENTRYPOINTS,
8    AbstractDataStore,
9    BackendArray,
10    BackendEntrypoint,
11    _normalize_path,
12)
13from .file_manager import CachingFileManager
14from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock
15from .store import StoreBackendEntrypoint
16
17try:
18    import Nio
19
20    has_pynio = True
21except ModuleNotFoundError:
22    has_pynio = False
23
24
25# PyNIO can invoke netCDF libraries internally
26# Add a dedicated lock just in case NCL as well isn't thread-safe.
27NCL_LOCK = SerializableLock()
28PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK])
29
30
31class NioArrayWrapper(BackendArray):
32    def __init__(self, variable_name, datastore):
33        self.datastore = datastore
34        self.variable_name = variable_name
35        array = self.get_array()
36        self.shape = array.shape
37        self.dtype = np.dtype(array.typecode())
38
39    def get_array(self, needs_lock=True):
40        ds = self.datastore._manager.acquire(needs_lock)
41        return ds.variables[self.variable_name]
42
43    def __getitem__(self, key):
44        return indexing.explicit_indexing_adapter(
45            key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
46        )
47
48    def _getitem(self, key):
49        with self.datastore.lock:
50            array = self.get_array(needs_lock=False)
51
52            if key == () and self.ndim == 0:
53                return array.get_value()
54
55            return array[key]
56
57
58class NioDataStore(AbstractDataStore):
59    """Store for accessing datasets via PyNIO"""
60
61    def __init__(self, filename, mode="r", lock=None, **kwargs):
62
63        if lock is None:
64            lock = PYNIO_LOCK
65        self.lock = ensure_lock(lock)
66        self._manager = CachingFileManager(
67            Nio.open_file, filename, lock=lock, mode=mode, kwargs=kwargs
68        )
69        # xarray provides its own support for FillValue,
70        # so turn off PyNIO's support for the same.
71        self.ds.set_option("MaskedArrayMode", "MaskedNever")
72
73    @property
74    def ds(self):
75        return self._manager.acquire()
76
77    def open_store_variable(self, name, var):
78        data = indexing.LazilyIndexedArray(NioArrayWrapper(name, self))
79        return Variable(var.dimensions, data, var.attributes)
80
81    def get_variables(self):
82        return FrozenDict(
83            (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
84        )
85
86    def get_attrs(self):
87        return Frozen(self.ds.attributes)
88
89    def get_dimensions(self):
90        return Frozen(self.ds.dimensions)
91
92    def get_encoding(self):
93        return {
94            "unlimited_dims": {k for k in self.ds.dimensions if self.ds.unlimited(k)}
95        }
96
97    def close(self):
98        self._manager.close()
99
100
101class PynioBackendEntrypoint(BackendEntrypoint):
102    available = has_pynio
103
104    def open_dataset(
105        self,
106        filename_or_obj,
107        mask_and_scale=True,
108        decode_times=True,
109        concat_characters=True,
110        decode_coords=True,
111        drop_variables=None,
112        use_cftime=None,
113        decode_timedelta=None,
114        mode="r",
115        lock=None,
116    ):
117        filename_or_obj = _normalize_path(filename_or_obj)
118        store = NioDataStore(
119            filename_or_obj,
120            mode=mode,
121            lock=lock,
122        )
123
124        store_entrypoint = StoreBackendEntrypoint()
125        with close_on_error(store):
126            ds = store_entrypoint.open_dataset(
127                store,
128                mask_and_scale=mask_and_scale,
129                decode_times=decode_times,
130                concat_characters=concat_characters,
131                decode_coords=decode_coords,
132                drop_variables=drop_variables,
133                use_cftime=use_cftime,
134                decode_timedelta=decode_timedelta,
135            )
136        return ds
137
138
139BACKEND_ENTRYPOINTS["pynio"] = PynioBackendEntrypoint
140