1# ================================================================= 2# 3# Authors: Gregory Petrochenkov <gpetrochenkov@usgs.gov> 4# 5# Copyright (c) 2020 Gregory Petrochenkov 6# 7# Permission is hereby granted, free of charge, to any person 8# obtaining a copy of this software and associated documentation 9# files (the "Software"), to deal in the Software without 10# restriction, including without limitation the rights to use, 11# copy, modify, merge, publish, distribute, sublicense, and/or sell 12# copies of the Software, and to permit persons to whom the 13# Software is furnished to do so, subject to the following 14# conditions: 15# 16# The above copyright notice and this permission notice shall be 17# included in all copies or substantial portions of the Software. 18# 19# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 20# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 21# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 22# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 23# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 24# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 25# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 26# OTHER DEALINGS IN THE SOFTWARE. 27# 28# ================================================================= 29 30import os 31import logging 32import tempfile 33import zipfile 34 35import xarray 36import numpy as np 37 38from pygeoapi.provider.base import (BaseProvider, 39 ProviderConnectionError, 40 ProviderNoDataError, 41 ProviderQueryError) 42from pygeoapi.util import read_data 43 44LOGGER = logging.getLogger(__name__) 45 46 47class XarrayProvider(BaseProvider): 48 """Xarray Provider""" 49 50 def __init__(self, provider_def): 51 """ 52 Initialize object 53 :param provider_def: provider definition 54 :returns: pygeoapi.provider.xarray_.XarrayProvider 55 """ 56 57 super().__init__(provider_def) 58 59 try: 60 if provider_def['data'].endswith('.zarr'): 61 open_func = xarray.open_zarr 62 else: 63 open_func = xarray.open_dataset 64 self._data = open_func(self.data) 65 self._data = _convert_float32_to_float64(self._data) 66 self._coverage_properties = self._get_coverage_properties() 67 68 self.axes = [self._coverage_properties['x_axis_label'], 69 self._coverage_properties['y_axis_label'], 70 self._coverage_properties['time_axis_label']] 71 72 self.fields = self._coverage_properties['fields'] 73 except Exception as err: 74 LOGGER.warning(err) 75 raise ProviderConnectionError(err) 76 77 def get_coverage_domainset(self, *args, **kwargs): 78 """ 79 Provide coverage domainset 80 81 :returns: CIS JSON object of domainset metadata 82 """ 83 84 c_props = self._coverage_properties 85 domainset = { 86 'type': 'DomainSetType', 87 'generalGrid': { 88 'type': 'GeneralGridCoverageType', 89 'srsName': c_props['bbox_crs'], 90 'axisLabels': [ 91 c_props['x_axis_label'], 92 c_props['y_axis_label'], 93 c_props['time_axis_label'] 94 ], 95 'axis': [{ 96 'type': 'RegularAxisType', 97 'axisLabel': c_props['x_axis_label'], 98 'lowerBound': c_props['bbox'][0], 99 'upperBound': c_props['bbox'][2], 100 'uomLabel': c_props['bbox_units'], 101 'resolution': c_props['resx'] 102 }, { 103 'type': 'RegularAxisType', 104 'axisLabel': c_props['y_axis_label'], 105 'lowerBound': c_props['bbox'][1], 106 'upperBound': c_props['bbox'][3], 107 'uomLabel': c_props['bbox_units'], 108 'resolution': c_props['resy'] 109 }, 110 { 111 'type': 'RegularAxisType', 112 'axisLabel': c_props['time_axis_label'], 113 'lowerBound': c_props['time_range'][0], 114 'upperBound': c_props['time_range'][1], 115 'uomLabel': c_props['restime'], 116 'resolution': c_props['restime'] 117 } 118 ], 119 'gridLimits': { 120 'type': 'GridLimitsType', 121 'srsName': 'http://www.opengis.net/def/crs/OGC/0/Index2D', 122 'axisLabels': ['i', 'j'], 123 'axis': [{ 124 'type': 'IndexAxisType', 125 'axisLabel': 'i', 126 'lowerBound': 0, 127 'upperBound': c_props['width'] 128 }, { 129 'type': 'IndexAxisType', 130 'axisLabel': 'j', 131 'lowerBound': 0, 132 'upperBound': c_props['height'] 133 }] 134 } 135 }, 136 '_meta': { 137 'tags': self._data.attrs 138 } 139 } 140 141 return domainset 142 143 def get_coverage_rangetype(self, *args, **kwargs): 144 """ 145 Provide coverage rangetype 146 147 :returns: CIS JSON object of rangetype metadata 148 """ 149 150 rangetype = { 151 'type': 'DataRecordType', 152 'field': [] 153 } 154 155 for name, var in self._data.variables.items(): 156 LOGGER.debug('Determining rangetype for {}'.format(name)) 157 158 desc, units = None, None 159 if len(var.shape) >= 3: 160 parameter = self._get_parameter_metadata( 161 name, var.attrs) 162 desc = parameter['description'] 163 units = parameter['unit_label'] 164 165 rangetype['field'].append({ 166 'id': name, 167 'type': 'QuantityType', 168 'name': var.attrs.get('long_name') or desc, 169 'definition': str(var.dtype), 170 'nodata': 'null', 171 'uom': { 172 'id': 'http://www.opengis.net/def/uom/UCUM/{}'.format( 173 units), 174 'type': 'UnitReference', 175 'code': units 176 }, 177 '_meta': { 178 'tags': var.attrs 179 } 180 }) 181 182 return rangetype 183 184 def query(self, range_subset=[], subsets={}, bbox=[], datetime_=None, 185 format_='json', **kwargs): 186 """ 187 Extract data from collection collection 188 189 :param range_subset: list of data variables to return (all if blank) 190 :param subsets: dict of subset names with lists of ranges 191 :param bbox: bounding box [minx,miny,maxx,maxy] 192 :param datetime_: temporal (datestamp or extent) 193 :param format_: data format of output 194 195 :returns: coverage data as dict of CoverageJSON or native format 196 """ 197 198 if not range_subset and not subsets and format_ != 'json': 199 LOGGER.debug('No parameters specified, returning native data') 200 if format_ == 'zarr': 201 return _get_zarr_data(self._data) 202 else: 203 return read_data(self.data) 204 205 if len(range_subset) < 1: 206 range_subset = self.fields 207 208 data = self._data[[*range_subset]] 209 210 if any([self._coverage_properties['x_axis_label'] in subsets, 211 self._coverage_properties['y_axis_label'] in subsets, 212 self._coverage_properties['time_axis_label'] in subsets, 213 datetime_ is not None]): 214 215 LOGGER.debug('Creating spatio-temporal subset') 216 217 query_params = {} 218 for key, val in subsets.items(): 219 LOGGER.debug('Processing subset: {}'.format(key)) 220 if data.coords[key].values[0] > data.coords[key].values[-1]: 221 LOGGER.debug('Reversing slicing from high to low') 222 query_params[key] = slice(val[1], val[0]) 223 else: 224 query_params[key] = slice(val[0], val[1]) 225 226 if bbox: 227 if all([self._coverage_properties['x_axis_label'] in subsets, 228 self._coverage_properties['y_axis_label'] in subsets, 229 len(bbox) > 0]): 230 msg = 'bbox and subsetting by coordinates are exclusive' 231 LOGGER.warning(msg) 232 raise ProviderQueryError(msg) 233 else: 234 query_params[self._coverage_properties['x_axis_label']] = \ 235 slice(bbox[0], bbox[2]) 236 query_params[self._coverage_properties['y_axis_label']] = \ 237 slice(bbox[1], bbox[3]) 238 239 if datetime_ is not None: 240 if self._coverage_properties['time_axis_label'] in subsets: 241 msg = 'datetime and temporal subsetting are exclusive' 242 LOGGER.error(msg) 243 raise ProviderQueryError(msg) 244 else: 245 if '/' in datetime_: 246 begin, end = datetime_.split('/') 247 if begin < end: 248 query_params[self.time_field] = slice(begin, end) 249 else: 250 LOGGER.debug('Reversing slicing from high to low') 251 query_params[self.time_field] = slice(end, begin) 252 else: 253 query_params[self.time_field] = datetime_ 254 255 LOGGER.debug('Query parameters: {}'.format(query_params)) 256 try: 257 data = data.sel(query_params) 258 except Exception as err: 259 LOGGER.warning(err) 260 raise ProviderQueryError(err) 261 262 if (any([data.coords[self.x_field].size == 0, 263 data.coords[self.y_field].size == 0, 264 data.coords[self.time_field].size == 0])): 265 msg = 'No data found' 266 LOGGER.warning(msg) 267 raise ProviderNoDataError(msg) 268 269 out_meta = { 270 'bbox': [ 271 data.coords[self.x_field].values[0], 272 data.coords[self.y_field].values[0], 273 data.coords[self.x_field].values[-1], 274 data.coords[self.y_field].values[-1] 275 ], 276 "time": [ 277 _to_datetime_string(data.coords[self.time_field].values[0]), 278 _to_datetime_string(data.coords[self.time_field].values[-1]) 279 ], 280 "driver": "xarray", 281 "height": data.dims[self.y_field], 282 "width": data.dims[self.x_field], 283 "time_steps": data.dims[self.time_field], 284 "variables": {var_name: var.attrs 285 for var_name, var in data.variables.items()} 286 } 287 288 LOGGER.debug('Serializing data in memory') 289 if format_ == 'json': 290 LOGGER.debug('Creating output in CoverageJSON') 291 return self.gen_covjson(out_meta, data, range_subset) 292 elif format_ == 'zarr': 293 LOGGER.debug('Returning data in native zarr format') 294 return _get_zarr_data(data) 295 else: # return data in native format 296 with tempfile.TemporaryFile() as fp: 297 LOGGER.debug('Returning data in native NetCDF format') 298 fp.write(data.to_netcdf()) 299 fp.seek(0) 300 return fp.read() 301 302 def gen_covjson(self, metadata, data, range_type): 303 """ 304 Generate coverage as CoverageJSON representation 305 306 :param metadata: coverage metadata 307 :param data: rasterio DatasetReader object 308 :param range_type: range type list 309 310 :returns: dict of CoverageJSON representation 311 """ 312 313 LOGGER.debug('Creating CoverageJSON domain') 314 minx, miny, maxx, maxy = metadata['bbox'] 315 mint, maxt = metadata['time'] 316 317 try: 318 tmp_min = data.coords[self.y_field].values[0] 319 except IndexError: 320 tmp_min = data.coords[self.y_field].values 321 try: 322 tmp_max = data.coords[self.y_field].values[-1] 323 except IndexError: 324 tmp_max = data.coords[self.y_field].values 325 326 if tmp_min > tmp_max: 327 LOGGER.debug('Reversing direction of {}'.format(self.y_field)) 328 miny = tmp_max 329 maxy = tmp_min 330 331 cj = { 332 'type': 'Coverage', 333 'domain': { 334 'type': 'Domain', 335 'domainType': 'Grid', 336 'axes': { 337 'x': { 338 'start': minx, 339 'stop': maxx, 340 'num': metadata['width'] 341 }, 342 'y': { 343 'start': maxy, 344 'stop': miny, 345 'num': metadata['height'] 346 }, 347 self.time_field: { 348 'start': mint, 349 'stop': maxt, 350 'num': metadata['time_steps'] 351 } 352 }, 353 'referencing': [{ 354 'coordinates': ['x', 'y'], 355 'system': { 356 'type': self._coverage_properties['crs_type'], 357 'id': self._coverage_properties['bbox_crs'] 358 } 359 }] 360 }, 361 'parameters': {}, 362 'ranges': {} 363 } 364 365 for variable in range_type: 366 pm = self._get_parameter_metadata( 367 variable, self._data[variable].attrs) 368 369 parameter = { 370 'type': 'Parameter', 371 'description': pm['description'], 372 'unit': { 373 'symbol': pm['unit_label'] 374 }, 375 'observedProperty': { 376 'id': pm['observed_property_id'], 377 'label': { 378 'en': pm['observed_property_name'] 379 } 380 } 381 } 382 383 cj['parameters'][pm['id']] = parameter 384 385 try: 386 for key in cj['parameters'].keys(): 387 cj['ranges'][key] = { 388 'type': 'NdArray', 389 'dataType': str(self._data[variable].dtype), 390 'axisNames': [ 391 'y', 'x', self._coverage_properties['time_axis_label'] 392 ], 393 'shape': [metadata['height'], 394 metadata['width'], 395 metadata['time_steps']] 396 } 397 398 data = data.fillna(None) 399 cj['ranges'][key]['values'] = data[key].values.flatten().tolist() # noqa 400 except IndexError as err: 401 LOGGER.warning(err) 402 raise ProviderQueryError('Invalid query parameter') 403 404 return cj 405 406 def _get_coverage_properties(self): 407 """ 408 Helper function to normalize coverage properties 409 410 :returns: `dict` of coverage properties 411 """ 412 413 time_var, y_var, x_var = [None, None, None] 414 for coord in self._data.coords: 415 if coord.lower() == 'time': 416 time_var = coord 417 continue 418 if self._data.coords[coord].attrs['units'] == 'degrees_north': 419 y_var = coord 420 continue 421 if self._data.coords[coord].attrs['units'] == 'degrees_east': 422 x_var = coord 423 continue 424 425 if self.x_field is None: 426 self.x_field = x_var 427 if self.y_field is None: 428 self.y_field = y_var 429 if self.time_field is None: 430 self.time_field = time_var 431 432 # It would be preferable to use CF attributes to get width 433 # resolution etc but for now a generic approach is used to asess 434 # all of the attributes based on lat lon vars 435 436 properties = { 437 'bbox': [ 438 self._data.coords[self.x_field].values[0], 439 self._data.coords[self.y_field].values[0], 440 self._data.coords[self.x_field].values[-1], 441 self._data.coords[self.y_field].values[-1], 442 ], 443 'time_range': [ 444 _to_datetime_string( 445 self._data.coords[self.time_field].values[0] 446 ), 447 _to_datetime_string( 448 self._data.coords[self.time_field].values[-1] 449 ) 450 ], 451 'bbox_crs': 'http://www.opengis.net/def/crs/OGC/1.3/CRS84', 452 'crs_type': 'GeographicCRS', 453 'x_axis_label': self.x_field, 454 'y_axis_label': self.y_field, 455 'time_axis_label': self.time_field, 456 'width': self._data.dims[self.x_field], 457 'height': self._data.dims[self.y_field], 458 'time': self._data.dims[self.time_field], 459 'time_duration': self.get_time_coverage_duration(), 460 'bbox_units': 'degrees', 461 'resx': np.abs(self._data.coords[self.x_field].values[1] 462 - self._data.coords[self.x_field].values[0]), 463 'resy': np.abs(self._data.coords[self.y_field].values[1] 464 - self._data.coords[self.y_field].values[0]), 465 'restime': self.get_time_resolution() 466 } 467 468 if 'crs' in self._data.variables.keys(): 469 properties['bbox_crs'] = '{}/{}'.format( 470 'http://www.opengis.net/def/crs/OGC/1.3/', 471 self._data.crs.epsg_code) 472 473 properties['inverse_flattening'] = self._data.crs.\ 474 inverse_flattening 475 476 properties['crs_type'] = 'ProjectedCRS' 477 478 properties['axes'] = [ 479 properties['x_axis_label'], 480 properties['y_axis_label'], 481 properties['time_axis_label'] 482 ] 483 484 properties['fields'] = [name for name in self._data.variables 485 if len(self._data.variables[name].shape) >= 3] 486 487 return properties 488 489 @staticmethod 490 def _get_parameter_metadata(name, attrs): 491 """ 492 Helper function to derive parameter name and units 493 :param name: name of variable 494 :param attrs: dictionary of variable attributes 495 :returns: dict of parameter metadata 496 """ 497 498 return { 499 'id': name, 500 'description': attrs.get('long_name', None), 501 'unit_label': attrs.get('units', None), 502 'unit_symbol': attrs.get('units', None), 503 'observed_property_id': name, 504 'observed_property_name': attrs.get('long_name', None) 505 } 506 507 def get_time_resolution(self): 508 """ 509 Helper function to derive time resolution 510 :returns: time resolution string 511 """ 512 513 if self._data[self.time_field].size > 1: 514 time_diff = (self._data[self.time_field][1] - 515 self._data[self.time_field][0]) 516 517 dt = np.array([time_diff.values.astype('timedelta64[{}]'.format(x)) 518 for x in ['Y', 'M', 'D', 'h', 'm', 's', 'ms']]) 519 520 return str(dt[np.array([x.astype(np.int) for x in dt]) > 0][0]) 521 else: 522 return None 523 524 def get_time_coverage_duration(self): 525 """ 526 Helper function to derive time coverage duration 527 :returns: time coverage duration string 528 """ 529 530 dur = self._data[self.time_field][-1] - self._data[self.time_field][0] 531 ms_difference = dur.values.astype('timedelta64[ms]').astype(np.double) 532 533 time_dict = { 534 'days': int(ms_difference / 1000 / 60 / 60 / 24), 535 'hours': int((ms_difference / 1000 / 60 / 60) % 24), 536 'minutes': int((ms_difference / 1000 / 60) % 60), 537 'seconds': int(ms_difference / 1000) % 60 538 } 539 540 times = ['{} {}'.format(val, key) for key, val 541 in time_dict.items() if val > 0] 542 543 return ', '.join(times) 544 545 546def _to_datetime_string(datetime_obj): 547 """ 548 Convenience function to formulate string from various datetime objects 549 550 :param datetime_obj: datetime object (native datetime, cftime) 551 552 :returns: str representation of datetime 553 """ 554 555 try: 556 value = np.datetime_as_string(datetime_obj) 557 except Exception as err: 558 LOGGER.warning(err) 559 value = datetime_obj.strftime('%Y-%m-%dT%H:%M:%S.%fZ') 560 561 return value 562 563 564def _zip_dir(path, ziph, cwd): 565 """ 566 Convenience function to zip directory with sub directories 567 (based on source: https://stackoverflow.com/questions/1855095/) 568 :param path: str directory to zip 569 :param ziph: zipfile file 570 :param cwd: current working directory 571 572 """ 573 for root, dirs, files in os.walk(path): 574 for file in files: 575 576 if len(dirs) < 1: 577 new_root = '/'.join(root.split('/')[:-1]) 578 new_path = os.path.join(root.split('/')[-1], file) 579 else: 580 new_root = root 581 new_path = file 582 583 os.chdir(new_root) 584 ziph.write(new_path) 585 os.chdir(cwd) 586 587 588def _get_zarr_data(data): 589 """ 590 Returns bytes to read from Zarr directory zip 591 :param data: Xarray dataset of coverage data 592 593 :returns: byte array of zip data 594 """ 595 596 tmp_dir = tempfile.TemporaryDirectory().name 597 data.to_zarr('{}zarr.zarr'.format(tmp_dir), mode='w') 598 with zipfile.ZipFile('{}zarr.zarr.zip'.format(tmp_dir), 599 'w', zipfile.ZIP_DEFLATED) as zipf: 600 _zip_dir('{}zarr.zarr'.format(tmp_dir), zipf, os.getcwd()) 601 zip_file = open('{}zarr.zarr.zip'.format(tmp_dir), 'rb') 602 return zip_file.read() 603 604 605def _convert_float32_to_float64(data): 606 """ 607 Converts DataArray values of float32 to float64 608 :param data: Xarray dataset of coverage data 609 610 :returns: Xarray dataset of coverage data 611 """ 612 613 for var_name in data.variables: 614 if data[var_name].dtype == 'float32': 615 og_attrs = data[var_name].attrs 616 data[var_name] = data[var_name].astype('float64') 617 data[var_name].attrs = og_attrs 618 619 return data 620