1# -*- coding: utf-8 -*-
2
3"""
4***************************************************************************
5    AlgorithmsTest.py
6    ---------------------
7    Date                 : January 2016
8    Copyright            : (C) 2016 by Matthias Kuhn
9    Email                : matthias@opengis.ch
10***************************************************************************
11*                                                                         *
12*   This program is free software; you can redistribute it and/or modify  *
13*   it under the terms of the GNU General Public License as published by  *
14*   the Free Software Foundation; either version 2 of the License, or     *
15*   (at your option) any later version.                                   *
16*                                                                         *
17***************************************************************************
18"""
19
20__author__ = 'Matthias Kuhn'
21__date__ = 'January 2016'
22__copyright__ = '(C) 2016, Matthias Kuhn'
23
24import qgis  # NOQA switch sip api
25
26import os
27import yaml
28import nose2
29import shutil
30import glob
31import hashlib
32import tempfile
33import re
34
35from osgeo import gdal
36from osgeo.gdalconst import GA_ReadOnly
37from numpy import nan_to_num
38from copy import deepcopy
39
40from qgis.core import (QgsVectorLayer,
41                       QgsRasterLayer,
42                       QgsCoordinateReferenceSystem,
43                       QgsFeatureRequest,
44                       QgsMapLayer,
45                       QgsProject,
46                       QgsApplication,
47                       QgsProcessingContext,
48                       QgsProcessingUtils,
49                       QgsProcessingFeedback)
50from qgis.analysis import (QgsNativeAlgorithms)
51from qgis.testing import (_UnexpectedSuccess,
52                          start_app,
53                          unittest)
54from utilities import unitTestDataPath
55
56import processing
57
58
59def processingTestDataPath():
60    return os.path.join(os.path.dirname(__file__), 'testdata')
61
62
63class AlgorithmsTest(object):
64
65    def test_algorithms(self):
66        """
67        This is the main test function. All others will be executed based on the definitions in testdata/algorithm_tests.yaml
68        """
69        with open(os.path.join(processingTestDataPath(), self.test_definition_file()), 'r') as stream:
70            algorithm_tests = yaml.load(stream, Loader=yaml.SafeLoader)
71
72        if 'tests' in algorithm_tests and algorithm_tests['tests'] is not None:
73            for idx, algtest in enumerate(algorithm_tests['tests']):
74                print('About to start {} of {}: "{}"'.format(idx, len(algorithm_tests['tests']), algtest['name']))
75                yield self.check_algorithm, algtest['name'], algtest
76
77    def check_algorithm(self, name, defs):
78        """
79        Will run an algorithm definition and check if it generates the expected result
80        :param name: The identifier name used in the test output heading
81        :param defs: A python dict containing a test algorithm definition
82        """
83        self.vector_layer_params = {}
84        QgsProject.instance().clear()
85
86        if 'project' in defs:
87            full_project_path = os.path.join(processingTestDataPath(), defs['project'])
88            project_read_success = QgsProject.instance().read(full_project_path)
89            self.assertTrue(project_read_success, 'Failed to load project file: ' + defs['project'])
90
91        if 'project_crs' in defs:
92            QgsProject.instance().setCrs(QgsCoordinateReferenceSystem(defs['project_crs']))
93        else:
94            QgsProject.instance().setCrs(QgsCoordinateReferenceSystem())
95
96        if 'ellipsoid' in defs:
97            QgsProject.instance().setEllipsoid(defs['ellipsoid'])
98        else:
99            QgsProject.instance().setEllipsoid('')
100
101        params = self.load_params(defs['params'])
102
103        print('Running alg: "{}"'.format(defs['algorithm']))
104        alg = QgsApplication.processingRegistry().createAlgorithmById(defs['algorithm'])
105
106        parameters = {}
107        if isinstance(params, list):
108            for param in zip(alg.parameterDefinitions(), params):
109                parameters[param[0].name()] = param[1]
110        else:
111            for k, p in params.items():
112                parameters[k] = p
113
114        for r, p in list(defs['results'].items()):
115            if 'in_place_result' not in p or not p['in_place_result']:
116                parameters[r] = self.load_result_param(p)
117
118        expectFailure = False
119        if 'expectedFailure' in defs:
120            exec(('\n'.join(defs['expectedFailure'][:-1])), globals(), locals())
121            expectFailure = eval(defs['expectedFailure'][-1])
122
123        if 'expectedException' in defs:
124            expectFailure = True
125
126        # ignore user setting for invalid geometry handling
127        context = QgsProcessingContext()
128        context.setProject(QgsProject.instance())
129
130        if 'skipInvalid' in defs and defs['skipInvalid']:
131            context.setInvalidGeometryCheck(QgsFeatureRequest.GeometrySkipInvalid)
132
133        feedback = QgsProcessingFeedback()
134
135        print('Algorithm parameters are {}'.format(parameters))
136
137        # first check that algorithm accepts the parameters we pass...
138        ok, msg = alg.checkParameterValues(parameters, context)
139        self.assertTrue(ok, 'Algorithm failed checkParameterValues with result {}'.format(msg))
140
141        if expectFailure:
142            try:
143                results, ok = alg.run(parameters, context, feedback)
144                self.check_results(results, context, parameters, defs['results'])
145                if ok:
146                    raise _UnexpectedSuccess
147            except Exception:
148                pass
149        else:
150            results, ok = alg.run(parameters, context, feedback)
151            self.assertTrue(ok, 'params: {}, results: {}'.format(parameters, results))
152            self.check_results(results, context, parameters, defs['results'])
153
154    def load_params(self, params):
155        """
156        Loads an array of parameters
157        """
158        if isinstance(params, list):
159            return [self.load_param(p) for p in params]
160        elif isinstance(params, dict):
161            return {key: self.load_param(p, key) for key, p in params.items()}
162        else:
163            return params
164
165    def load_param(self, param, id=None):
166        """
167        Loads a parameter. If it's not a map, the parameter will be returned as-is. If it is a map, it will process the
168        parameter based on its key `type` and return the appropriate parameter to pass to the algorithm.
169        """
170        try:
171            if param['type'] in ('vector', 'raster', 'table'):
172                return self.load_layer(id, param).id()
173            elif param['type'] == 'vrtlayers':
174                vals = []
175                for p in param['params']:
176                    p['layer'] = self.load_layer(None, {'type': 'vector', 'name': p['layer']})
177                    vals.append(p)
178                return vals
179            elif param['type'] == 'multi':
180                return [self.load_param(p) for p in param['params']]
181            elif param['type'] == 'file':
182                return self.filepath_from_param(param)
183            elif param['type'] == 'interpolation':
184                prefix = processingTestDataPath()
185                tmp = ''
186                for r in param['name'].split('::|::'):
187                    v = r.split('::~::')
188                    tmp += '{}::~::{}::~::{}::~::{};'.format(os.path.join(prefix, v[0]),
189                                                             v[1], v[2], v[3])
190                return tmp[:-1]
191        except TypeError:
192            # No type specified, use whatever is there
193            return param
194
195        raise KeyError("Unknown type '{}' specified for parameter".format(param['type']))
196
197    def load_result_param(self, param):
198        """
199        Loads a result parameter. Creates a temporary destination where the result should go to and returns this location
200        so it can be sent to the algorithm as parameter.
201        """
202        if param['type'] in ['vector', 'file', 'table', 'regex']:
203            outdir = tempfile.mkdtemp()
204            self.cleanup_paths.append(outdir)
205            if isinstance(param['name'], str):
206                basename = os.path.basename(param['name'])
207            else:
208                basename = os.path.basename(param['name'][0])
209
210            filepath = self.uri_path_join(outdir, basename)
211            return filepath
212        elif param['type'] == 'rasterhash':
213            outdir = tempfile.mkdtemp()
214            self.cleanup_paths.append(outdir)
215            if self.test_definition_file().lower().startswith('saga'):
216                basename = 'raster.sdat'
217            else:
218                basename = 'raster.tif'
219            filepath = os.path.join(outdir, basename)
220            return filepath
221        elif param['type'] == 'directory':
222            outdir = tempfile.mkdtemp()
223            return outdir
224
225        raise KeyError("Unknown type '{}' specified for parameter".format(param['type']))
226
227    def load_layers(self, id, param):
228        layers = []
229        if param['type'] in ('vector', 'table'):
230            if isinstance(param['name'], str) or 'uri' in param:
231                layers.append(self.load_layer(id, param))
232            else:
233                for n in param['name']:
234                    layer_param = deepcopy(param)
235                    layer_param['name'] = n
236                    layers.append(self.load_layer(id, layer_param))
237        else:
238            layers.append(self.load_layer(id, param))
239        return layers
240
241    def load_layer(self, id, param):
242        """
243        Loads a layer which was specified as parameter.
244        """
245
246        filepath = self.filepath_from_param(param)
247
248        if 'in_place' in param and param['in_place']:
249            # check if alg modifies layer in place
250            tmpdir = tempfile.mkdtemp()
251            self.cleanup_paths.append(tmpdir)
252            path, file_name = os.path.split(filepath)
253            base, ext = os.path.splitext(file_name)
254            for file in glob.glob(os.path.join(path, '{}.*'.format(base))):
255                shutil.copy(os.path.join(path, file), tmpdir)
256            filepath = os.path.join(tmpdir, file_name)
257            self.in_place_layers[id] = filepath
258
259        if param['type'] in ('vector', 'table'):
260            gmlrex = r'\.gml\b'
261            if re.search(gmlrex, filepath, re.IGNORECASE):
262                # ewwwww - we have to force SRS detection for GML files, otherwise they'll be loaded
263                # with no srs
264                filepath += '|option:FORCE_SRS_DETECTION=YES'
265
266            if filepath in self.vector_layer_params:
267                return self.vector_layer_params[filepath]
268
269            options = QgsVectorLayer.LayerOptions()
270            options.loadDefaultStyle = False
271            lyr = QgsVectorLayer(filepath, param['name'], 'ogr', options)
272            self.vector_layer_params[filepath] = lyr
273        elif param['type'] == 'raster':
274            options = QgsRasterLayer.LayerOptions()
275            options.loadDefaultStyle = False
276            lyr = QgsRasterLayer(filepath, param['name'], 'gdal', options)
277
278        self.assertTrue(lyr.isValid(), 'Could not load layer "{}" from param {}'.format(filepath, param))
279        QgsProject.instance().addMapLayer(lyr)
280        return lyr
281
282    def filepath_from_param(self, param):
283        """
284        Creates a filepath from a param
285        """
286        prefix = processingTestDataPath()
287        if 'location' in param and param['location'] == 'qgs':
288            prefix = unitTestDataPath()
289
290        if 'uri' in param:
291            path = param['uri']
292        else:
293            path = param['name']
294
295        return self.uri_path_join(prefix, path)
296
297    def uri_path_join(self, prefix, filepath):
298        if filepath.startswith('ogr:'):
299            if not prefix[-1] == os.path.sep:
300                prefix += os.path.sep
301            filepath = re.sub(r"dbname='", "dbname='{}".format(prefix), filepath)
302        else:
303            filepath = os.path.join(prefix, filepath)
304
305        return filepath
306
307    def check_results(self, results, context, params, expected):
308        """
309        Checks if result produced by an algorithm matches with the expected specification.
310        """
311        for id, expected_result in expected.items():
312            if expected_result['type'] in ('vector', 'table'):
313                if 'compare' in expected_result and not expected_result['compare']:
314                    # skipping the comparison, so just make sure output is valid
315                    if isinstance(results[id], QgsMapLayer):
316                        result_lyr = results[id]
317                    else:
318                        result_lyr = QgsProcessingUtils.mapLayerFromString(results[id], context)
319                    self.assertTrue(result_lyr.isValid())
320                    continue
321
322                expected_lyrs = self.load_layers(id, expected_result)
323                if 'in_place_result' in expected_result:
324                    result_lyr = QgsProcessingUtils.mapLayerFromString(self.in_place_layers[id], context)
325                    self.assertTrue(result_lyr.isValid(), self.in_place_layers[id])
326                else:
327                    try:
328                        results[id]
329                    except KeyError as e:
330                        raise KeyError('Expected result {} does not exist in {}'.format(str(e), list(results.keys())))
331
332                    if isinstance(results[id], QgsMapLayer):
333                        result_lyr = results[id]
334                    else:
335                        string = results[id]
336
337                        gmlrex = r'\.gml\b'
338                        if re.search(gmlrex, string, re.IGNORECASE):
339                            # ewwwww - we have to force SRS detection for GML files, otherwise they'll be loaded
340                            # with no srs
341                            string += '|option:FORCE_SRS_DETECTION=YES'
342
343                        result_lyr = QgsProcessingUtils.mapLayerFromString(string, context)
344                    self.assertTrue(result_lyr, results[id])
345
346                compare = expected_result.get('compare', {})
347                pk = expected_result.get('pk', None)
348
349                if len(expected_lyrs) == 1:
350                    self.assertLayersEqual(expected_lyrs[0], result_lyr, compare=compare, pk=pk)
351                else:
352                    res = False
353                    for l in expected_lyrs:
354                        if self.checkLayersEqual(l, result_lyr, compare=compare, pk=pk):
355                            res = True
356                            break
357                    self.assertTrue(res, 'Could not find matching layer in expected results')
358
359            elif 'rasterhash' == expected_result['type']:
360                print("id:{} result:{}".format(id, results[id]))
361                self.assertTrue(os.path.exists(results[id]), 'File does not exist: {}, {}'.format(results[id], params))
362                dataset = gdal.Open(results[id], GA_ReadOnly)
363                dataArray = nan_to_num(dataset.ReadAsArray(0))
364                strhash = hashlib.sha224(dataArray.data).hexdigest()
365
366                if not isinstance(expected_result['hash'], str):
367                    self.assertIn(strhash, expected_result['hash'])
368                else:
369                    self.assertEqual(strhash, expected_result['hash'])
370            elif 'file' == expected_result['type']:
371                result_filepath = results[id]
372                if isinstance(expected_result.get('name'), list):
373                    # test to see if any match expected
374                    for path in expected_result['name']:
375                        expected_filepath = self.filepath_from_param({'name': path})
376                        if self.checkFilesEqual(expected_filepath, result_filepath):
377                            break
378                    else:
379                        expected_filepath = self.filepath_from_param({'name': expected_result['name'][0]})
380                else:
381                    expected_filepath = self.filepath_from_param(expected_result)
382
383                self.assertFilesEqual(expected_filepath, result_filepath)
384            elif 'directory' == expected_result['type']:
385                expected_dirpath = self.filepath_from_param(expected_result)
386                result_dirpath = results[id]
387
388                self.assertDirectoriesEqual(expected_dirpath, result_dirpath)
389            elif 'regex' == expected_result['type']:
390                with open(results[id], 'r') as file:
391                    data = file.read()
392
393                for rule in expected_result.get('rules', []):
394                    self.assertRegex(data, rule)
395
396
397class GenericAlgorithmsTest(unittest.TestCase):
398    """
399    General (non-provider specific) algorithm tests
400    """
401
402    @classmethod
403    def setUpClass(cls):
404        start_app()
405        from processing.core.Processing import Processing
406        Processing.initialize()
407        cls.cleanup_paths = []
408
409    @classmethod
410    def tearDownClass(cls):
411        from processing.core.Processing import Processing
412        Processing.deinitialize()
413        for path in cls.cleanup_paths:
414            shutil.rmtree(path)
415
416    def testAlgorithmCompliance(self):
417        for p in QgsApplication.processingRegistry().providers():
418            print('testing provider {}'.format(p.id()))
419            for a in p.algorithms():
420                print('testing algorithm {}'.format(a.id()))
421                self.check_algorithm(a)
422
423    def check_algorithm(self, alg):
424        # check that calling helpUrl() works without error
425        alg.helpUrl()
426
427
428if __name__ == '__main__':
429    nose2.main()
430