1# -*- coding: utf-8 -*-
2
3"""
4***************************************************************************
5    plugin_test.py
6    ---------------------
7    Date                 : May 2017
8    Copyright            : (C) 2017, Sandro Santilli
9    Email                : strk at kbt dot io
10***************************************************************************
11*                                                                         *
12*   This program is free software; you can redistribute it and/or modify  *
13*   it under the terms of the GNU General Public License as published by  *
14*   the Free Software Foundation; either version 2 of the License, or     *
15*   (at your option) any later version.                                   *
16*                                                                         *
17***************************************************************************
18"""
19
20__author__ = 'Sandro Santilli'
21__date__ = 'May 2017'
22__copyright__ = '(C) 2017, Sandro Santilli'
23
24import os
25import re
26import qgis
27from qgis.testing import start_app, unittest
28from qgis.core import QgsDataSourceUri
29from qgis.utils import iface
30from qgis.PyQt.QtCore import QObject
31
32start_app()
33
34from db_manager.db_plugins.postgis.plugin import PostGisDBPlugin, PGRasterTable
35from db_manager.db_plugins.postgis.plugin import PGDatabase
36from db_manager.db_plugins.postgis.data_model import PGSqlResultModel
37from db_manager.db_plugins.plugin import Table
38from db_manager.db_plugins.postgis.connector import PostGisDBConnector
39
40
41class TestDBManagerPostgisPlugin(unittest.TestCase):
42
43    @classmethod
44    def setUpClass(self):
45        self.old_pgdatabase_env = os.environ.get('PGDATABASE')
46        # QGIS_PGTEST_DB contains the full connection string and not only the DB name!
47        QGIS_PGTEST_DB = os.environ.get('QGIS_PGTEST_DB')
48        if QGIS_PGTEST_DB is not None:
49            test_uri = QgsDataSourceUri(QGIS_PGTEST_DB)
50            self.testdb = test_uri.database()
51        else:
52            self.testdb = 'qgis_test'
53        os.environ['PGDATABASE'] = self.testdb
54
55        # Create temporary service file
56        self.old_pgservicefile_env = os.environ.get('PGSERVICEFILE')
57        self.tmpservicefile = '/tmp/qgis-test-{}-pg_service.conf'.format(os.getpid())
58        os.environ['PGSERVICEFILE'] = self.tmpservicefile
59
60        f = open(self.tmpservicefile, "w")
61        f.write("[dbmanager]\ndbname={}\n".format(self.testdb))
62        # TODO: add more things if PGSERVICEFILE was already set ?
63        f.close()
64
65    @classmethod
66    def tearDownClass(self):
67        # Restore previous env variables if needed
68        if self.old_pgdatabase_env:
69            os.environ['PGDATABASE'] = self.old_pgdatabase_env
70        if self.old_pgservicefile_env:
71            os.environ['PGSERVICEFILE'] = self.old_pgservicefile_env
72        # Remove temporary service file
73        os.unlink(self.tmpservicefile)
74
75    # See https://github.com/qgis/QGIS/issues/24525
76
77    def test_rasterTableURI(self):
78
79        def check_rasterTableURI(expected_dbname):
80            tables = database.tables()
81            raster_tables_count = 0
82            for tab in tables:
83                if tab.type == Table.RasterType:
84                    raster_tables_count += 1
85                    uri = tab.uri()
86                    m = re.search(' dbname=\'([^ ]*)\' ', uri)
87                    self.assertTrue(m)
88                    actual_dbname = m.group(1)
89                    self.assertEqual(actual_dbname, expected_dbname)
90                # print(tab.type)
91                # print(tab.quotedName())
92                # print(tab)
93
94            # We need to make sure a database is created with at
95            # least one raster table !
96            self.assertGreaterEqual(raster_tables_count, 1)
97
98        obj = QObject()  # needs to be kept alive
99        obj.connectionName = lambda: 'fake'
100        obj.providerName = lambda: 'postgres'
101
102        # Test for empty URI
103        # See https://github.com/qgis/QGIS/issues/24525
104        # and https://github.com/qgis/QGIS/issues/19005
105
106        expected_dbname = self.testdb
107        os.environ['PGDATABASE'] = expected_dbname
108
109        database = PGDatabase(obj, QgsDataSourceUri())
110        self.assertIsInstance(database, PGDatabase)
111
112        uri = database.uri()
113        self.assertEqual(uri.host(), '')
114        self.assertEqual(uri.username(), '')
115        self.assertEqual(uri.database(), expected_dbname)
116        self.assertEqual(uri.service(), '')
117
118        check_rasterTableURI(expected_dbname)
119
120        # Test for service-only URI
121        # See https://github.com/qgis/QGIS/issues/24526
122
123        os.environ['PGDATABASE'] = 'fake'
124        database = PGDatabase(obj, QgsDataSourceUri('service=dbmanager'))
125        self.assertIsInstance(database, PGDatabase)
126
127        uri = database.uri()
128        self.assertEqual(uri.host(), '')
129        self.assertEqual(uri.username(), '')
130        self.assertEqual(uri.database(), '')
131        self.assertEqual(uri.service(), 'dbmanager')
132
133        check_rasterTableURI(expected_dbname)
134
135    # See https://github.com/qgis/QGIS/issues/24732
136    def test_unicodeInQuery(self):
137        os.environ['PGDATABASE'] = self.testdb
138        obj = QObject()  # needs to be kept alive
139        obj.connectionName = lambda: 'fake'
140        obj.providerName = lambda: 'postgres'
141        database = PGDatabase(obj, QgsDataSourceUri())
142        self.assertIsInstance(database, PGDatabase)
143        # SQL as string literal
144        res = database.sqlResultModel("SELECT 'é'::text", obj)
145        self.assertIsInstance(res, PGSqlResultModel)
146        dat = res.getData(0, 0)
147        self.assertEqual(dat, u"é")
148        # SQL as unicode literal
149        res = database.sqlResultModel(u"SELECT 'é'::text", obj)
150        self.assertIsInstance(res, PGSqlResultModel)
151        dat = res.getData(0, 0)
152        self.assertEqual(dat, u"é")
153
154
155if __name__ == '__main__':
156    unittest.main()
157