1import os
2import warnings
3
4import numpy as np
5
6from ..core import indexing
7from ..core.utils import Frozen, FrozenDict, close_on_error
8from ..core.variable import Variable
9from .common import (
10    BACKEND_ENTRYPOINTS,
11    AbstractDataStore,
12    BackendArray,
13    BackendEntrypoint,
14    _normalize_path,
15)
16from .locks import SerializableLock, ensure_lock
17from .store import StoreBackendEntrypoint
18
19try:
20    import cfgrib
21
22    has_cfgrib = True
23except ModuleNotFoundError:
24    has_cfgrib = False
25# cfgrib throws a RuntimeError if eccodes is not installed
26except (ImportError, RuntimeError):
27    warnings.warn(
28        "Failed to load cfgrib - most likely there is a problem accessing the ecCodes library. "
29        "Try `import cfgrib` to get the full error message"
30    )
31    has_cfgrib = False
32
33# FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe
34#   in most circumstances. See:
35#       https://confluence.ecmwf.int/display/ECC/Frequently+Asked+Questions
36ECCODES_LOCK = SerializableLock()
37
38
39class CfGribArrayWrapper(BackendArray):
40    def __init__(self, datastore, array):
41        self.datastore = datastore
42        self.shape = array.shape
43        self.dtype = array.dtype
44        self.array = array
45
46    def __getitem__(self, key):
47        return indexing.explicit_indexing_adapter(
48            key, self.shape, indexing.IndexingSupport.BASIC, self._getitem
49        )
50
51    def _getitem(self, key):
52        with self.datastore.lock:
53            return self.array[key]
54
55
56class CfGribDataStore(AbstractDataStore):
57    """
58    Implements the ``xr.AbstractDataStore`` read-only API for a GRIB file.
59    """
60
61    def __init__(self, filename, lock=None, **backend_kwargs):
62
63        if lock is None:
64            lock = ECCODES_LOCK
65        self.lock = ensure_lock(lock)
66        self.ds = cfgrib.open_file(filename, **backend_kwargs)
67
68    def open_store_variable(self, name, var):
69        if isinstance(var.data, np.ndarray):
70            data = var.data
71        else:
72            wrapped_array = CfGribArrayWrapper(self, var.data)
73            data = indexing.LazilyIndexedArray(wrapped_array)
74
75        encoding = self.ds.encoding.copy()
76        encoding["original_shape"] = var.data.shape
77
78        return Variable(var.dimensions, data, var.attributes, encoding)
79
80    def get_variables(self):
81        return FrozenDict(
82            (k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
83        )
84
85    def get_attrs(self):
86        return Frozen(self.ds.attributes)
87
88    def get_dimensions(self):
89        return Frozen(self.ds.dimensions)
90
91    def get_encoding(self):
92        dims = self.get_dimensions()
93        return {"unlimited_dims": {k for k, v in dims.items() if v is None}}
94
95
96class CfgribfBackendEntrypoint(BackendEntrypoint):
97    available = has_cfgrib
98
99    def guess_can_open(self, filename_or_obj):
100        try:
101            _, ext = os.path.splitext(filename_or_obj)
102        except TypeError:
103            return False
104        return ext in {".grib", ".grib2", ".grb", ".grb2"}
105
106    def open_dataset(
107        self,
108        filename_or_obj,
109        *,
110        mask_and_scale=True,
111        decode_times=True,
112        concat_characters=True,
113        decode_coords=True,
114        drop_variables=None,
115        use_cftime=None,
116        decode_timedelta=None,
117        lock=None,
118        indexpath="{path}.{short_hash}.idx",
119        filter_by_keys={},
120        read_keys=[],
121        encode_cf=("parameter", "time", "geography", "vertical"),
122        squeeze=True,
123        time_dims=("time", "step"),
124    ):
125
126        filename_or_obj = _normalize_path(filename_or_obj)
127        store = CfGribDataStore(
128            filename_or_obj,
129            indexpath=indexpath,
130            filter_by_keys=filter_by_keys,
131            read_keys=read_keys,
132            encode_cf=encode_cf,
133            squeeze=squeeze,
134            time_dims=time_dims,
135            lock=lock,
136        )
137        store_entrypoint = StoreBackendEntrypoint()
138        with close_on_error(store):
139            ds = store_entrypoint.open_dataset(
140                store,
141                mask_and_scale=mask_and_scale,
142                decode_times=decode_times,
143                concat_characters=concat_characters,
144                decode_coords=decode_coords,
145                drop_variables=drop_variables,
146                use_cftime=use_cftime,
147                decode_timedelta=decode_timedelta,
148            )
149        return ds
150
151
152BACKEND_ENTRYPOINTS["cfgrib"] = CfgribfBackendEntrypoint
153