1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3# ******************************************************************************
4#
5#  Project:  GDAL utils.auxiliary
6#  Purpose:  gdal utility functions
7#  Author:   Even Rouault <even.rouault at spatialys.com>
8#  Author:   Idan Miara <idan@miara.com>
9#
10# ******************************************************************************
11#  Copyright (c) 2015, Even Rouault <even.rouault at spatialys.com>
12#  Copyright (c) 2020, Idan Miara <idan@miara.com>
13#
14#  Permission is hereby granted, free of charge, to any person obtaining a
15#  copy of this software and associated documentation files (the "Software"),
16#  to deal in the Software without restriction, including without limitation
17#  the rights to use, copy, modify, merge, publish, distribute, sublicense,
18#  and/or sell copies of the Software, and to permit persons to whom the
19#  Software is furnished to do so, subject to the following conditions:
20#
21#  The above copyright notice and this permission notice shall be included
22#  in all copies or substantial portions of the Software.
23#
24#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
25#  OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
27#  THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
29#  FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
30#  DEALINGS IN THE SOFTWARE.
31# ******************************************************************************
32import os
33from numbers import Real
34from typing import Optional, Union, Sequence, Tuple, Dict, Any, Iterator, List
35from warnings import warn
36
37from osgeo import gdal, __version__ as gdal_version_str
38from osgeo_utils.auxiliary.base import get_extension, is_path_like, PathLikeOrStr, enum_to_str, OptionalBoolStr, \
39    is_true, \
40    MaybeSequence, T
41
42PathOrDS = Union[PathLikeOrStr, gdal.Dataset]
43DataTypeOrStr = Union[str, int]
44CreationOptions = Optional[Dict[str, Any]]
45gdal_version = tuple(int(s) for s in str(gdal_version_str).split('.') if s.isdigit())[:3]
46
47
48def DoesDriverHandleExtension(drv: gdal.Driver, ext: str) -> bool:
49    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
50    return exts is not None and exts.lower().find(ext.lower()) >= 0
51
52
53def GetOutputDriversFor(filename: PathLikeOrStr, is_raster=True) -> List[str]:
54    filename = os.fspath(filename)
55    drv_list = []
56    ext = get_extension(filename)
57    if ext.lower() == 'vrt':
58        return ['VRT']
59    for i in range(gdal.GetDriverCount()):
60        drv = gdal.GetDriver(i)
61        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
62            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
63            drv.GetMetadataItem(gdal.DCAP_RASTER if is_raster else gdal.DCAP_VECTOR) is not None:
64            if ext and DoesDriverHandleExtension(drv, ext):
65                drv_list.append(drv.ShortName)
66            else:
67                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
68                if prefix is not None and filename.lower().startswith(prefix.lower()):
69                    drv_list.append(drv.ShortName)
70
71    # GMT is registered before netCDF for opening reasons, but we want
72    # netCDF to be used by default for output.
73    if ext.lower() == 'nc' and len(drv_list) >= 2 and \
74        drv_list[0].upper() == 'GMT' and drv_list[1].upper() == 'NETCDF':
75        drv_list = ['NETCDF', 'GMT']
76
77    return drv_list
78
79
80def GetOutputDriverFor(filename: PathLikeOrStr, is_raster=True, default_raster_format='GTiff',
81                       default_vector_format='ESRI Shapefile') -> str:
82    if not filename:
83        return 'MEM'
84    drv_list = GetOutputDriversFor(filename, is_raster)
85    ext = get_extension(filename)
86    if not drv_list:
87        if not ext:
88            return default_raster_format if is_raster else default_vector_format
89        else:
90            raise Exception("Cannot guess driver for %s" % filename)
91    elif len(drv_list) > 1:
92        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
93    return drv_list[0]
94
95
96def open_ds(filename_or_ds: MaybeSequence[PathOrDS], *args, **kwargs) -> MaybeSequence[gdal.Dataset]:
97    if not isinstance(filename_or_ds, PathOrDS.__args__):
98        return [open_ds(f) for f in filename_or_ds]
99    ods = OpenDS(filename_or_ds, *args, **kwargs)
100    return ods.__enter__()
101
102
103def get_ovr_count(filename_or_ds: PathOrDS) -> int:
104    with OpenDS(filename_or_ds) as ds:
105        bnd = ds.GetRasterBand(1)
106        return bnd.GetOverviewCount()
107
108
109def get_pixel_size(filename_or_ds: PathOrDS) -> Tuple[Real, Real]:
110    ds = open_ds(filename_or_ds)
111    geo_transform = ds.GetGeoTransform(can_return_null=True)
112    if geo_transform is not None:
113        return geo_transform[1], geo_transform[5]
114    else:
115        return 1, 1
116
117
118ListOfTupleTT_OrT = List[Union[T, Tuple[T, T]]]
119
120
121def get_sizes_factors_resolutions(filename_or_ds: PathOrDS, dim: Optional[int] = 0) -> \
122     Tuple[ListOfTupleTT_OrT[int], ListOfTupleTT_OrT[Real], ListOfTupleTT_OrT[Real]]:
123
124    ds = open_ds(filename_or_ds)
125    bnd = ds.GetRasterBand(1)
126    ovr_count = bnd.GetOverviewCount()
127    r0 = get_pixel_size(ds)
128    s0 = ds.RasterXSize, ds.RasterYSize
129    f0 = 1, 1
130    sizes = [s0]
131    factors = [f0]
132    resolutions = [r0]
133    for i_overview in range(ovr_count):
134        h_overview = bnd.GetOverview(i_overview)
135        if h_overview is not None:
136            s = h_overview.XSize, h_overview.YSize
137            f = s0[0] / s[0], s0[1] / s[1]
138            r = r0[0] * f[0], r0[1] * f[1]
139            sizes.append(s)
140            factors.append(f)
141            resolutions.append(r)
142    if dim is not None:
143        sizes = [x[dim] for x in sizes]
144        factors = [x[dim] for x in factors]
145        resolutions = [x[dim] for x in resolutions]
146    return sizes, factors, resolutions
147
148
149def get_best_ovr_by_resolutions(requested_res: Real, resolutions: Sequence[Real]):
150    for ovr, res in enumerate(resolutions):
151        if res > requested_res:
152            return max(0, ovr-1)
153    return len(resolutions)-1
154
155
156def get_ovr_idx(filename_or_ds: PathOrDS,
157                ovr_idx: Optional[int] = None,
158                ovr_res: Optional[Union[int, float]] = None) -> int:
159    """
160    This function uses a different convention than the GDAL API itself:
161    * ovr_idx = 0 means the full resolution image (GDAL API: no OVERVIEW_LEVEL)
162    * ovr_idx = [1|2|3|...] means the [1st|2nd|3rd|...] overview (GDAL API: OVERVIEW_LEVEL = [0|1|2|...])
163    * ovr_idx = -1 means the last overview (GDAL API: OVERVIEW_LEVEL = bnd.GetOverviewCount())
164    * ovr_idx = -i means the i-th overview from the last (GDAL API: OVERVIEW_LEVEL = bnd.GetOverviewCount()-i+1)
165
166    returns a non-negative ovr_idx, from given mutually exclusive ovr_idx (index) or ovr_res (resolution)
167    ovr_idx == None and ovr_res == None => returns 0
168    ovr_idx: int >= 0 => returns the given ovr_idx
169    ovr_idx: int < 0 => -1 is the last overview; -2 is the one before the last and so on
170    ovr_res: float|int => returns the best suitable overview for a given resolution
171             meaning the ovr with the lowest resolution which is higher then the request
172    ovr_idx: float = x => same as (ovr_idx=None, ovr_res=x)
173    """
174    if ovr_res is not None:
175        if ovr_idx is not None:
176            raise Exception(f'ovr_idx({ovr_idx}) and ovr_res({ovr_res}) are mutually exclusive both were set')
177        ovr_idx = float(ovr_res)
178    if ovr_idx is None:
179        return 0
180    if isinstance(ovr_idx, Sequence):
181        ovr_idx = ovr_idx[0]  # in case resolution in both axis were given we'll consider only x resolution
182    if isinstance(ovr_idx, int):
183        if ovr_idx < 0:
184            overview_count = get_ovr_count(filename_or_ds)
185            ovr_idx = max(0, overview_count + 1 + ovr_idx)
186    elif isinstance(ovr_idx, float):
187        _sizes, _factors, resolutions = get_sizes_factors_resolutions(filename_or_ds)
188        ovr_idx = get_best_ovr_by_resolutions(ovr_idx, resolutions)
189    else:
190        raise Exception(f'Got an unexpected overview: {ovr_idx}')
191    return ovr_idx
192
193
194class OpenDS:
195    __slots__ = ['filename', 'ds', 'args', 'kwargs', 'own', 'silent_fail']
196
197    def __init__(self, filename_or_ds: PathOrDS, silent_fail=False, *args, **kwargs):
198        self.ds: Optional[gdal.Dataset] = None
199        self.filename: Optional[PathLikeOrStr] = None
200        if is_path_like(filename_or_ds):
201            self.filename = os.fspath(filename_or_ds)
202        else:
203            self.ds = filename_or_ds
204        self.args = args
205        self.kwargs = kwargs
206        self.own = False
207        self.silent_fail = silent_fail
208
209    def __enter__(self) -> gdal.Dataset:
210
211        if self.ds is None:
212            self.ds = self._open_ds(self.filename, *self.args, **self.kwargs)
213            if self.ds is None and not self.silent_fail:
214                raise IOError('could not open file "{}"'.format(self.filename))
215            self.own = True
216        return self.ds
217
218    def __exit__(self, exc_type, exc_val, exc_tb):
219        if self.own:
220            self.ds = False
221
222    @staticmethod
223    def _open_ds(
224        filename: PathLikeOrStr,
225        access_mode=gdal.GA_ReadOnly,
226        ovr_idx: Optional[Union[int, float]] = None,
227        ovr_only: bool = False,
228        open_options: Optional[Union[Dict[str, str], Sequence[str]]] = None,
229        logger=None,
230    ) -> gdal.Dataset:
231        """
232        opens a gdal Dataset with the given arguments and returns it
233
234        :param filename: filename of the dataset to be opened
235        :param access_mode: access mode to open the dataset
236        :param ovr_idx: the index of the overview of the dataset,
237               Note: uses different numbering then GDAL API itself. read docs of: `get_ovr_idx`
238        :param ovr_only: open the dataset without exposing its overviews
239        :param open_options: gdal open options to be used to open the dataset
240        :param logger: logger to be used to log the opening operation
241        :return: gdal.Dataset
242        """
243        if not open_options:
244            open_options = dict()
245        elif isinstance(open_options, Sequence):
246            open_options = {k: v for k, v in (s.split('=', 1) for s in open_options)}
247        else:
248            open_options = dict(open_options)
249        ovr_idx = get_ovr_idx(filename, ovr_idx)
250        # gdal overview 0 is the first overview (after the base layer)
251        if ovr_idx == 0:
252            if ovr_only:
253                if gdal_version >= (3, 3):
254                    open_options["OVERVIEW_LEVEL"] = 'NONE'
255                else:
256                    raise Exception('You asked to not expose overviews, Which is not supported in your gdal version, '
257                                    'please update your gdal version to gdal >= 3.3 or do not ask to hide overviews')
258        else:  # if ovr_idx > 0:
259            open_options["OVERVIEW_LEVEL"] = f'{ovr_idx - 1}{"only" if ovr_only else ""}'
260        if logger is not None:
261            s = 'opening file: "{}"'.format(filename)
262            if open_options:
263                s = s + " with options: {}".format(str(open_options))
264            logger.debug(s)
265        open_options = ["{}={}".format(k, v) for k, v in open_options.items()]
266
267        return gdal.OpenEx(str(filename), access_mode, open_options=open_options)
268
269
270def get_data_type(data_type: Optional[DataTypeOrStr]):
271    if data_type is None:
272        return None
273    if isinstance(data_type, str):
274        return gdal.GetDataTypeByName(data_type)
275    else:
276        return data_type
277
278
279def get_raster_bands(ds: gdal.Dataset) -> Iterator[gdal.Band]:
280    return (ds.GetRasterBand(i + 1) for i in range(ds.RasterCount))
281
282
283def get_band_types(filename_or_ds: PathOrDS):
284    with OpenDS(filename_or_ds) as ds:
285        return [band.DataType for band in get_raster_bands(ds)]
286
287
288def get_band_minimum(band: gdal.Band):
289    ret = band.GetMinimum()
290    if ret is None:
291        band.ComputeStatistics(0)
292    return band.GetMinimum()
293
294
295def get_raster_band(filename_or_ds: PathOrDS, bnd_index: int = 1, ovr_index: Optional[int] = None):
296    with OpenDS(filename_or_ds) as ds:
297        bnd = ds.GetRasterBand(bnd_index)
298        if ovr_index is not None:
299            bnd = bnd.GetOverview(ovr_index)
300        return bnd
301
302
303def get_raster_minimum(filename_or_ds: PathOrDS, bnd_index: Optional[int] = 1):
304    with OpenDS(filename_or_ds) as ds:
305        if bnd_index is None:
306            return min(get_band_minimum(bnd) for bnd in get_raster_bands(ds))
307        else:
308            bnd = ds.GetRasterBand(bnd_index)
309            return get_band_minimum(bnd)
310
311
312def get_raster_min_max(filename_or_ds: PathOrDS, bnd_index: int = 1, approx_ok: Union[bool, int] = False):
313    with OpenDS(filename_or_ds) as ds:
314        bnd = ds.GetRasterBand(bnd_index)
315        min_max = bnd.ComputeRasterMinMax(int(approx_ok))
316        return min_max
317
318
319def get_nodatavalue(filename_or_ds: PathOrDS):
320    with OpenDS(filename_or_ds) as ds:
321        band = next(get_raster_bands(ds))
322        return band.GetNoDataValue()
323
324
325def unset_nodatavalue(filename_or_ds: PathOrDS):
326    with OpenDS(filename_or_ds, access_mode=gdal.GA_Update) as ds:
327        for b in get_raster_bands(ds):
328            b.DeleteNoDataValue()
329
330
331def get_metadata_item(filename_or_ds: PathOrDS, key: str, domain: str, default: Any = None):
332    key = key.strip()
333    domain = domain.strip()
334    with OpenDS(filename_or_ds) as ds:
335        metadata_item = ds.GetMetadataItem(key, domain)
336        return metadata_item if metadata_item is not None else default
337
338
339def get_image_structure_metadata(filename_or_ds: PathOrDS, key: str, default: Any = None):
340    return get_metadata_item(filename_or_ds, key=key, domain="IMAGE_STRUCTURE", default=default)
341
342
343def get_bigtiff_creation_option_value(big_tiff: OptionalBoolStr):
344    return "IF_SAFER" if big_tiff is None \
345        else big_tiff if bool(big_tiff) and isinstance(big_tiff, str) \
346        else str(is_true(big_tiff))
347
348
349def get_ext_by_of(of: str):
350    ext = enum_to_str(of).lower()
351    if ext in ['gtiff', 'cog', 'mem']:
352        ext = 'tif'
353    return '.' + ext
354
355
356def get_band_nums(ds: gdal.Dataset, band_nums: Optional[MaybeSequence[int]] = None):
357    if not band_nums:
358        band_nums = list(range(1, ds.RasterCount + 1))
359    elif isinstance(band_nums, int):
360        band_nums = [band_nums]
361    return band_nums
362
363
364def get_bands(filename_or_ds: PathOrDS, band_nums: Optional[MaybeSequence[int]] = None, ovr_idx: Optional[int] = None) \
365            -> List[gdal.Band]:
366    """
367    returns a list of gdal bands of the given dataset
368    :param filename_or_ds: filename or the dataset itself
369    :param band_nums: sequence of bands numbers (or a single number)
370    :param ovr_idx: the index of the overview of the dataset,
371           Note: uses different numbering then GDAL API itself. read docs of: `get_ovr_idx`
372    :return:
373    """
374    ds = open_ds(filename_or_ds)
375    band_nums = get_band_nums(ds, band_nums)
376    bands = []
377    for band_num in band_nums:
378        band: gdal.Band = ds.GetRasterBand(band_num)
379        if band is None:
380            raise Exception(f'Could not get band {band_num} from file {filename_or_ds}')
381        if ovr_idx:
382            ovr_idx = get_ovr_idx(ds, ovr_idx)
383            if ovr_idx != 0:
384                band = band.GetOverview(ovr_idx-1)
385            if band is None:
386                raise Exception(f'Could not get overview {ovr_idx} from band {band_num} of file {filename_or_ds}')
387        bands.append(band)
388    return bands
389
390
391def get_scales_and_offsets(bands: Union[PathOrDS, MaybeSequence[gdal.Band]]) -> Tuple[bool, MaybeSequence[Real], MaybeSequence[Real]]:
392    if isinstance(bands, PathOrDS.__args__):
393        bands = get_bands(bands)
394    single_band = not isinstance(bands, Sequence)
395    if single_band:
396        bands = [bands]
397    scales = [bnd.GetScale() or 1 for bnd in bands]
398    offsets = [bnd.GetOffset() or 0 for bnd in bands]
399    is_scaled = any(scale != 1 for scale in scales) or any(offset != 0 for offset in offsets)
400    if single_band:
401        scales, offsets = scales[0], offsets[0]
402    return is_scaled, scales, offsets
403