1# =================================================================
2#
3# Authors: Gregory Petrochenkov <gpetrochenkov@usgs.gov>
4#
5# Copyright (c) 2020 Gregory Petrochenkov
6#
7# Permission is hereby granted, free of charge, to any person
8# obtaining a copy of this software and associated documentation
9# files (the "Software"), to deal in the Software without
10# restriction, including without limitation the rights to use,
11# copy, modify, merge, publish, distribute, sublicense, and/or sell
12# copies of the Software, and to permit persons to whom the
13# Software is furnished to do so, subject to the following
14# conditions:
15#
16# The above copyright notice and this permission notice shall be
17# included in all copies or substantial portions of the Software.
18#
19# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
20# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
21# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
22# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
23# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
24# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
25# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
26# OTHER DEALINGS IN THE SOFTWARE.
27#
28# =================================================================
29
30import os
31import logging
32import tempfile
33import zipfile
34
35import xarray
36import numpy as np
37
38from pygeoapi.provider.base import (BaseProvider,
39                                    ProviderConnectionError,
40                                    ProviderNoDataError,
41                                    ProviderQueryError)
42from pygeoapi.util import read_data
43
44LOGGER = logging.getLogger(__name__)
45
46
47class XarrayProvider(BaseProvider):
48    """Xarray Provider"""
49
50    def __init__(self, provider_def):
51        """
52        Initialize object
53        :param provider_def: provider definition
54        :returns: pygeoapi.provider.xarray_.XarrayProvider
55        """
56
57        super().__init__(provider_def)
58
59        try:
60            if provider_def['data'].endswith('.zarr'):
61                open_func = xarray.open_zarr
62            else:
63                open_func = xarray.open_dataset
64            self._data = open_func(self.data)
65            self._data = _convert_float32_to_float64(self._data)
66            self._coverage_properties = self._get_coverage_properties()
67
68            self.axes = [self._coverage_properties['x_axis_label'],
69                         self._coverage_properties['y_axis_label'],
70                         self._coverage_properties['time_axis_label']]
71
72            self.fields = self._coverage_properties['fields']
73        except Exception as err:
74            LOGGER.warning(err)
75            raise ProviderConnectionError(err)
76
77    def get_coverage_domainset(self, *args, **kwargs):
78        """
79        Provide coverage domainset
80
81        :returns: CIS JSON object of domainset metadata
82        """
83
84        c_props = self._coverage_properties
85        domainset = {
86            'type': 'DomainSetType',
87            'generalGrid': {
88                'type': 'GeneralGridCoverageType',
89                'srsName': c_props['bbox_crs'],
90                'axisLabels': [
91                    c_props['x_axis_label'],
92                    c_props['y_axis_label'],
93                    c_props['time_axis_label']
94                ],
95                'axis': [{
96                    'type': 'RegularAxisType',
97                    'axisLabel': c_props['x_axis_label'],
98                    'lowerBound': c_props['bbox'][0],
99                    'upperBound': c_props['bbox'][2],
100                    'uomLabel': c_props['bbox_units'],
101                    'resolution': c_props['resx']
102                }, {
103                    'type': 'RegularAxisType',
104                    'axisLabel': c_props['y_axis_label'],
105                    'lowerBound': c_props['bbox'][1],
106                    'upperBound': c_props['bbox'][3],
107                    'uomLabel': c_props['bbox_units'],
108                    'resolution': c_props['resy']
109                },
110                    {
111                        'type': 'RegularAxisType',
112                        'axisLabel': c_props['time_axis_label'],
113                        'lowerBound': c_props['time_range'][0],
114                        'upperBound': c_props['time_range'][1],
115                        'uomLabel': c_props['restime'],
116                        'resolution': c_props['restime']
117                    }
118                ],
119                'gridLimits': {
120                    'type': 'GridLimitsType',
121                    'srsName': 'http://www.opengis.net/def/crs/OGC/0/Index2D',
122                    'axisLabels': ['i', 'j'],
123                    'axis': [{
124                        'type': 'IndexAxisType',
125                        'axisLabel': 'i',
126                        'lowerBound': 0,
127                        'upperBound': c_props['width']
128                    }, {
129                        'type': 'IndexAxisType',
130                        'axisLabel': 'j',
131                        'lowerBound': 0,
132                        'upperBound': c_props['height']
133                    }]
134                }
135            },
136            '_meta': {
137                'tags': self._data.attrs
138            }
139        }
140
141        return domainset
142
143    def get_coverage_rangetype(self, *args, **kwargs):
144        """
145        Provide coverage rangetype
146
147        :returns: CIS JSON object of rangetype metadata
148        """
149
150        rangetype = {
151            'type': 'DataRecordType',
152            'field': []
153        }
154
155        for name, var in self._data.variables.items():
156            LOGGER.debug('Determining rangetype for {}'.format(name))
157
158            desc, units = None, None
159            if len(var.shape) >= 3:
160                parameter = self._get_parameter_metadata(
161                    name, var.attrs)
162                desc = parameter['description']
163                units = parameter['unit_label']
164
165                rangetype['field'].append({
166                    'id': name,
167                    'type': 'QuantityType',
168                    'name': var.attrs.get('long_name') or desc,
169                    'definition': str(var.dtype),
170                    'nodata': 'null',
171                    'uom': {
172                        'id': 'http://www.opengis.net/def/uom/UCUM/{}'.format(
173                             units),
174                        'type': 'UnitReference',
175                        'code': units
176                    },
177                    '_meta': {
178                        'tags': var.attrs
179                    }
180                })
181
182        return rangetype
183
184    def query(self, range_subset=[], subsets={}, bbox=[], datetime_=None,
185              format_='json', **kwargs):
186        """
187         Extract data from collection collection
188
189        :param range_subset: list of data variables to return (all if blank)
190        :param subsets: dict of subset names with lists of ranges
191        :param bbox: bounding box [minx,miny,maxx,maxy]
192        :param datetime_: temporal (datestamp or extent)
193        :param format_: data format of output
194
195        :returns: coverage data as dict of CoverageJSON or native format
196        """
197
198        if not range_subset and not subsets and format_ != 'json':
199            LOGGER.debug('No parameters specified, returning native data')
200            if format_ == 'zarr':
201                return _get_zarr_data(self._data)
202            else:
203                return read_data(self.data)
204
205        if len(range_subset) < 1:
206            range_subset = self.fields
207
208        data = self._data[[*range_subset]]
209
210        if any([self._coverage_properties['x_axis_label'] in subsets,
211                self._coverage_properties['y_axis_label'] in subsets,
212                self._coverage_properties['time_axis_label'] in subsets,
213                datetime_ is not None]):
214
215            LOGGER.debug('Creating spatio-temporal subset')
216
217            query_params = {}
218            for key, val in subsets.items():
219                LOGGER.debug('Processing subset: {}'.format(key))
220                if data.coords[key].values[0] > data.coords[key].values[-1]:
221                    LOGGER.debug('Reversing slicing from high to low')
222                    query_params[key] = slice(val[1], val[0])
223                else:
224                    query_params[key] = slice(val[0], val[1])
225
226            if bbox:
227                if all([self._coverage_properties['x_axis_label'] in subsets,
228                        self._coverage_properties['y_axis_label'] in subsets,
229                        len(bbox) > 0]):
230                    msg = 'bbox and subsetting by coordinates are exclusive'
231                    LOGGER.warning(msg)
232                    raise ProviderQueryError(msg)
233                else:
234                    query_params[self._coverage_properties['x_axis_label']] = \
235                        slice(bbox[0], bbox[2])
236                    query_params[self._coverage_properties['y_axis_label']] = \
237                        slice(bbox[1], bbox[3])
238
239            if datetime_ is not None:
240                if self._coverage_properties['time_axis_label'] in subsets:
241                    msg = 'datetime and temporal subsetting are exclusive'
242                    LOGGER.error(msg)
243                    raise ProviderQueryError(msg)
244                else:
245                    if '/' in datetime_:
246                        begin, end = datetime_.split('/')
247                        if begin < end:
248                            query_params[self.time_field] = slice(begin, end)
249                        else:
250                            LOGGER.debug('Reversing slicing from high to low')
251                            query_params[self.time_field] = slice(end, begin)
252                    else:
253                        query_params[self.time_field] = datetime_
254
255            LOGGER.debug('Query parameters: {}'.format(query_params))
256            try:
257                data = data.sel(query_params)
258            except Exception as err:
259                LOGGER.warning(err)
260                raise ProviderQueryError(err)
261
262        if (any([data.coords[self.x_field].size == 0,
263                 data.coords[self.y_field].size == 0,
264                 data.coords[self.time_field].size == 0])):
265            msg = 'No data found'
266            LOGGER.warning(msg)
267            raise ProviderNoDataError(msg)
268
269        out_meta = {
270            'bbox': [
271                data.coords[self.x_field].values[0],
272                data.coords[self.y_field].values[0],
273                data.coords[self.x_field].values[-1],
274                data.coords[self.y_field].values[-1]
275            ],
276            "time": [
277                _to_datetime_string(data.coords[self.time_field].values[0]),
278                _to_datetime_string(data.coords[self.time_field].values[-1])
279            ],
280            "driver": "xarray",
281            "height": data.dims[self.y_field],
282            "width": data.dims[self.x_field],
283            "time_steps": data.dims[self.time_field],
284            "variables": {var_name: var.attrs
285                          for var_name, var in data.variables.items()}
286        }
287
288        LOGGER.debug('Serializing data in memory')
289        if format_ == 'json':
290            LOGGER.debug('Creating output in CoverageJSON')
291            return self.gen_covjson(out_meta, data, range_subset)
292        elif format_ == 'zarr':
293            LOGGER.debug('Returning data in native zarr format')
294            return _get_zarr_data(data)
295        else:  # return data in native format
296            with tempfile.TemporaryFile() as fp:
297                LOGGER.debug('Returning data in native NetCDF format')
298                fp.write(data.to_netcdf())
299                fp.seek(0)
300                return fp.read()
301
302    def gen_covjson(self, metadata, data, range_type):
303        """
304        Generate coverage as CoverageJSON representation
305
306        :param metadata: coverage metadata
307        :param data: rasterio DatasetReader object
308        :param range_type: range type list
309
310        :returns: dict of CoverageJSON representation
311        """
312
313        LOGGER.debug('Creating CoverageJSON domain')
314        minx, miny, maxx, maxy = metadata['bbox']
315        mint, maxt = metadata['time']
316
317        try:
318            tmp_min = data.coords[self.y_field].values[0]
319        except IndexError:
320            tmp_min = data.coords[self.y_field].values
321        try:
322            tmp_max = data.coords[self.y_field].values[-1]
323        except IndexError:
324            tmp_max = data.coords[self.y_field].values
325
326        if tmp_min > tmp_max:
327            LOGGER.debug('Reversing direction of {}'.format(self.y_field))
328            miny = tmp_max
329            maxy = tmp_min
330
331        cj = {
332            'type': 'Coverage',
333            'domain': {
334                'type': 'Domain',
335                'domainType': 'Grid',
336                'axes': {
337                    'x': {
338                        'start': minx,
339                        'stop': maxx,
340                        'num': metadata['width']
341                    },
342                    'y': {
343                        'start': maxy,
344                        'stop': miny,
345                        'num': metadata['height']
346                    },
347                    self.time_field: {
348                        'start': mint,
349                        'stop': maxt,
350                        'num': metadata['time_steps']
351                    }
352                },
353                'referencing': [{
354                    'coordinates': ['x', 'y'],
355                    'system': {
356                        'type': self._coverage_properties['crs_type'],
357                        'id': self._coverage_properties['bbox_crs']
358                    }
359                }]
360            },
361            'parameters': {},
362            'ranges': {}
363        }
364
365        for variable in range_type:
366            pm = self._get_parameter_metadata(
367                variable, self._data[variable].attrs)
368
369            parameter = {
370                'type': 'Parameter',
371                'description': pm['description'],
372                'unit': {
373                    'symbol': pm['unit_label']
374                },
375                'observedProperty': {
376                    'id': pm['observed_property_id'],
377                    'label': {
378                        'en': pm['observed_property_name']
379                    }
380                }
381            }
382
383            cj['parameters'][pm['id']] = parameter
384
385        try:
386            for key in cj['parameters'].keys():
387                cj['ranges'][key] = {
388                    'type': 'NdArray',
389                    'dataType': str(self._data[variable].dtype),
390                    'axisNames': [
391                        'y', 'x', self._coverage_properties['time_axis_label']
392                    ],
393                    'shape': [metadata['height'],
394                              metadata['width'],
395                              metadata['time_steps']]
396                }
397
398                data = data.fillna(None)
399                cj['ranges'][key]['values'] = data[key].values.flatten().tolist()  # noqa
400        except IndexError as err:
401            LOGGER.warning(err)
402            raise ProviderQueryError('Invalid query parameter')
403
404        return cj
405
406    def _get_coverage_properties(self):
407        """
408        Helper function to normalize coverage properties
409
410        :returns: `dict` of coverage properties
411        """
412
413        time_var, y_var, x_var = [None, None, None]
414        for coord in self._data.coords:
415            if coord.lower() == 'time':
416                time_var = coord
417                continue
418            if self._data.coords[coord].attrs['units'] == 'degrees_north':
419                y_var = coord
420                continue
421            if self._data.coords[coord].attrs['units'] == 'degrees_east':
422                x_var = coord
423                continue
424
425        if self.x_field is None:
426            self.x_field = x_var
427        if self.y_field is None:
428            self.y_field = y_var
429        if self.time_field is None:
430            self.time_field = time_var
431
432        # It would be preferable to use CF attributes to get width
433        # resolution etc but for now a generic approach is used to asess
434        # all of the attributes based on lat lon vars
435
436        properties = {
437            'bbox': [
438                self._data.coords[self.x_field].values[0],
439                self._data.coords[self.y_field].values[0],
440                self._data.coords[self.x_field].values[-1],
441                self._data.coords[self.y_field].values[-1],
442            ],
443            'time_range': [
444                _to_datetime_string(
445                    self._data.coords[self.time_field].values[0]
446                ),
447                _to_datetime_string(
448                    self._data.coords[self.time_field].values[-1]
449                )
450            ],
451            'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84',
452            'crs_type': 'GeographicCRS',
453            'x_axis_label': self.x_field,
454            'y_axis_label': self.y_field,
455            'time_axis_label': self.time_field,
456            'width': self._data.dims[self.x_field],
457            'height': self._data.dims[self.y_field],
458            'time': self._data.dims[self.time_field],
459            'time_duration': self.get_time_coverage_duration(),
460            'bbox_units': 'degrees',
461            'resx': np.abs(self._data.coords[self.x_field].values[1]
462                           - self._data.coords[self.x_field].values[0]),
463            'resy': np.abs(self._data.coords[self.y_field].values[1]
464                           - self._data.coords[self.y_field].values[0]),
465            'restime': self.get_time_resolution()
466        }
467
468        if 'crs' in self._data.variables.keys():
469            properties['bbox_crs'] = '{}/{}'.format(
470                'http://www.opengis.net/def/crs/OGC/1.3/',
471                self._data.crs.epsg_code)
472
473            properties['inverse_flattening'] = self._data.crs.\
474                inverse_flattening
475
476            properties['crs_type'] = 'ProjectedCRS'
477
478        properties['axes'] = [
479            properties['x_axis_label'],
480            properties['y_axis_label'],
481            properties['time_axis_label']
482        ]
483
484        properties['fields'] = [name for name in self._data.variables
485                                if len(self._data.variables[name].shape) >= 3]
486
487        return properties
488
489    @staticmethod
490    def _get_parameter_metadata(name, attrs):
491        """
492        Helper function to derive parameter name and units
493        :param name: name of variable
494        :param attrs: dictionary of variable attributes
495        :returns: dict of parameter metadata
496        """
497
498        return {
499            'id': name,
500            'description': attrs.get('long_name', None),
501            'unit_label': attrs.get('units', None),
502            'unit_symbol': attrs.get('units', None),
503            'observed_property_id': name,
504            'observed_property_name': attrs.get('long_name', None)
505        }
506
507    def get_time_resolution(self):
508        """
509        Helper function to derive time resolution
510        :returns: time resolution string
511        """
512
513        if self._data[self.time_field].size > 1:
514            time_diff = (self._data[self.time_field][1] -
515                         self._data[self.time_field][0])
516
517            dt = np.array([time_diff.values.astype('timedelta64[{}]'.format(x))
518                           for x in ['Y', 'M', 'D', 'h', 'm', 's', 'ms']])
519
520            return str(dt[np.array([x.astype(np.int) for x in dt]) > 0][0])
521        else:
522            return None
523
524    def get_time_coverage_duration(self):
525        """
526        Helper function to derive time coverage duration
527        :returns: time coverage duration string
528        """
529
530        dur = self._data[self.time_field][-1] - self._data[self.time_field][0]
531        ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double)
532
533        time_dict = {
534            'days': int(ms_difference / 1000 / 60 / 60 / 24),
535            'hours': int((ms_difference / 1000 / 60 / 60) % 24),
536            'minutes': int((ms_difference / 1000 / 60) % 60),
537            'seconds': int(ms_difference / 1000) % 60
538        }
539
540        times = ['{} {}'.format(val, key) for key, val
541                 in time_dict.items() if val > 0]
542
543        return ', '.join(times)
544
545
546def _to_datetime_string(datetime_obj):
547    """
548    Convenience function to formulate string from various datetime objects
549
550    :param datetime_obj: datetime object (native datetime, cftime)
551
552    :returns: str representation of datetime
553    """
554
555    try:
556        value = np.datetime_as_string(datetime_obj)
557    except Exception as err:
558        LOGGER.warning(err)
559        value = datetime_obj.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
560
561    return value
562
563
564def _zip_dir(path, ziph, cwd):
565    """
566        Convenience function to zip directory with sub directories
567        (based on source: https://stackoverflow.com/questions/1855095/)
568        :param path: str directory to zip
569        :param ziph: zipfile file
570        :param cwd: current working directory
571
572        """
573    for root, dirs, files in os.walk(path):
574        for file in files:
575
576            if len(dirs) < 1:
577                new_root = '/'.join(root.split('/')[:-1])
578                new_path = os.path.join(root.split('/')[-1], file)
579            else:
580                new_root = root
581                new_path = file
582
583            os.chdir(new_root)
584            ziph.write(new_path)
585            os.chdir(cwd)
586
587
588def _get_zarr_data(data):
589    """
590       Returns bytes to read from Zarr directory zip
591       :param data: Xarray dataset of coverage data
592
593       :returns: byte array of zip data
594       """
595
596    tmp_dir = tempfile.TemporaryDirectory().name
597    data.to_zarr('{}zarr.zarr'.format(tmp_dir), mode='w')
598    with zipfile.ZipFile('{}zarr.zarr.zip'.format(tmp_dir),
599                         'w', zipfile.ZIP_DEFLATED) as zipf:
600        _zip_dir('{}zarr.zarr'.format(tmp_dir), zipf, os.getcwd())
601    zip_file = open('{}zarr.zarr.zip'.format(tmp_dir), 'rb')
602    return zip_file.read()
603
604
605def _convert_float32_to_float64(data):
606    """
607        Converts DataArray values of float32 to float64
608        :param data: Xarray dataset of coverage data
609
610        :returns: Xarray dataset of coverage data
611        """
612
613    for var_name in data.variables:
614        if data[var_name].dtype == 'float32':
615            og_attrs = data[var_name].attrs
616            data[var_name] = data[var_name].astype('float64')
617            data[var_name].attrs = og_attrs
618
619    return data
620