1# -*- coding: utf-8 -*-
2
3"""
4/***************************************************************************
5Name                 : DB Manager
6Description          : Database manager plugin for QGIS
7Date                 : May 23, 2011
8copyright            : (C) 2011 by Giuseppe Sucameli
9email                : brush.tyler@gmail.com
10
11 ***************************************************************************/
12
13/***************************************************************************
14 *                                                                         *
15 *   This program is free software; you can redistribute it and/or modify  *
16 *   it under the terms of the GNU General Public License as published by  *
17 *   the Free Software Foundation; either version 2 of the License, or     *
18 *   (at your option) any later version.                                   *
19 *                                                                         *
20 ***************************************************************************/
21"""
22from builtins import str
23from builtins import map
24from builtins import range
25
26# this will disable the dbplugin if the connector raise an ImportError
27from .connector import PostGisDBConnector
28
29from qgis.PyQt.QtCore import Qt, QRegExp, QCoreApplication
30from qgis.PyQt.QtGui import QIcon
31from qgis.PyQt.QtWidgets import QAction, QApplication, QMessageBox
32from qgis.core import Qgis, QgsApplication, QgsSettings
33from qgis.gui import QgsMessageBar
34
35from ..plugin import ConnectionError, InvalidDataException, DBPlugin, Database, Schema, Table, VectorTable, RasterTable, \
36    TableField, TableConstraint, TableIndex, TableTrigger, TableRule
37
38import re
39
40
41def classFactory():
42    return PostGisDBPlugin
43
44
45class PostGisDBPlugin(DBPlugin):
46
47    @classmethod
48    def icon(self):
49        return QgsApplication.getThemeIcon("/mIconPostgis.svg")
50
51    @classmethod
52    def typeName(self):
53        return 'postgis'
54
55    @classmethod
56    def typeNameString(self):
57        return QCoreApplication.translate('db_manager', 'PostGIS')
58
59    @classmethod
60    def providerName(self):
61        return 'postgres'
62
63    @classmethod
64    def connectionSettingsKey(self):
65        return '/PostgreSQL/connections'
66
67    def databasesFactory(self, connection, uri):
68        return PGDatabase(connection, uri)
69
70    def connect(self, parent=None):
71        conn_name = self.connectionName()
72        settings = QgsSettings()
73        settings.beginGroup(u"/%s/%s" % (self.connectionSettingsKey(), conn_name))
74
75        if not settings.contains("database"):  # non-existent entry?
76            raise InvalidDataException(self.tr('There is no defined database connection "{0}".').format(conn_name))
77
78        from qgis.core import QgsDataSourceUri
79
80        uri = QgsDataSourceUri()
81
82        settingsList = ["service", "host", "port", "database", "username", "password", "authcfg"]
83        service, host, port, database, username, password, authcfg = [settings.value(x, "", type=str) for x in settingsList]
84
85        useEstimatedMetadata = settings.value("estimatedMetadata", False, type=bool)
86        try:
87            sslmode = settings.enumValue("sslmode", QgsDataSourceUri.SslPrefer)
88        except TypeError:
89            sslmode = QgsDataSourceUri.SslPrefer
90
91        settings.endGroup()
92
93        if hasattr(authcfg, 'isNull') and authcfg.isNull():
94            authcfg = ''
95
96        if service:
97            uri.setConnection(service, database, username, password, sslmode, authcfg)
98        else:
99            uri.setConnection(host, port, database, username, password, sslmode, authcfg)
100
101        uri.setUseEstimatedMetadata(useEstimatedMetadata)
102
103        try:
104            return self.connectToUri(uri)
105        except ConnectionError:
106            return False
107
108
109class PGDatabase(Database):
110
111    def __init__(self, connection, uri):
112        Database.__init__(self, connection, uri)
113
114    def connectorsFactory(self, uri):
115        return PostGisDBConnector(uri, self.connection())
116
117    def dataTablesFactory(self, row, db, schema=None):
118        return PGTable(row, db, schema)
119
120    def info(self):
121        from .info_model import PGDatabaseInfo
122        return PGDatabaseInfo(self)
123
124    def vectorTablesFactory(self, row, db, schema=None):
125        return PGVectorTable(row, db, schema)
126
127    def rasterTablesFactory(self, row, db, schema=None):
128        return PGRasterTable(row, db, schema)
129
130    def schemasFactory(self, row, db):
131        return PGSchema(row, db)
132
133    def sqlResultModel(self, sql, parent):
134        from .data_model import PGSqlResultModel
135
136        return PGSqlResultModel(self, sql, parent)
137
138    def sqlResultModelAsync(self, sql, parent):
139        from .data_model import PGSqlResultModelAsync
140
141        return PGSqlResultModelAsync(self, sql, parent)
142
143    def registerDatabaseActions(self, mainWindow):
144        Database.registerDatabaseActions(self, mainWindow)
145
146        # add a separator
147        separator = QAction(self)
148        separator.setSeparator(True)
149        mainWindow.registerAction(separator, self.tr("&Table"))
150
151        action = QAction(self.tr("Run &Vacuum Analyze"), self)
152        mainWindow.registerAction(action, self.tr("&Table"), self.runVacuumAnalyzeActionSlot)
153
154        action = QAction(self.tr("Run &Refresh Materialized View"), self)
155        mainWindow.registerAction(action, self.tr("&Table"), self.runRefreshMaterializedViewSlot)
156
157    def runVacuumAnalyzeActionSlot(self, item, action, parent):
158        QApplication.restoreOverrideCursor()
159        try:
160            if not isinstance(item, Table) or item.isView:
161                parent.infoBar.pushMessage(self.tr("Select a table for vacuum analyze."), Qgis.Info,
162                                           parent.iface.messageTimeout())
163                return
164        finally:
165            QApplication.setOverrideCursor(Qt.WaitCursor)
166
167        item.runVacuumAnalyze()
168
169    def runRefreshMaterializedViewSlot(self, item, action, parent):
170        QApplication.restoreOverrideCursor()
171        try:
172            if not isinstance(item, PGTable) or item._relationType != 'm':
173                parent.infoBar.pushMessage(self.tr("Select a materialized view for refresh."), Qgis.Info,
174                                           parent.iface.messageTimeout())
175                return
176        finally:
177            QApplication.setOverrideCursor(Qt.WaitCursor)
178
179        item.runRefreshMaterializedView()
180
181    def hasLowercaseFieldNamesOption(self):
182        return True
183
184    def supportsComment(self):
185        return True
186
187    def executeSql(self, sql):
188        return self.connector._executeSql(sql)
189
190
191class PGSchema(Schema):
192
193    def __init__(self, row, db):
194        Schema.__init__(self, db)
195        self.oid, self.name, self.owner, self.perms, self.comment = row
196
197
198class PGTable(Table):
199
200    def __init__(self, row, db, schema=None):
201        Table.__init__(self, db, schema)
202        self.name, schema_name, self._relationType, self.owner, self.estimatedRowCount, self.pages, self.comment = row
203        self.isView = self._relationType in set(['v', 'm'])
204        self.estimatedRowCount = int(self.estimatedRowCount)
205
206    def runVacuumAnalyze(self):
207        self.aboutToChange.emit()
208        self.database().connector.runVacuumAnalyze((self.schemaName(), self.name))
209        # TODO: change only this item, not re-create all the tables in the schema/database
210        self.schema().refresh() if self.schema() else self.database().refresh()
211
212    def runRefreshMaterializedView(self):
213        self.aboutToChange.emit()
214        self.database().connector.runRefreshMaterializedView((self.schemaName(), self.name))
215        # TODO: change only this item, not re-create all the tables in the schema/database
216        self.schema().refresh() if self.schema() else self.database().refresh()
217
218    def runAction(self, action):
219        action = str(action)
220
221        if action.startswith("vacuumanalyze/"):
222            if action == "vacuumanalyze/run":
223                self.runVacuumAnalyze()
224                return True
225
226        elif action.startswith("rule/"):
227            parts = action.split('/')
228            rule_name = parts[1]
229            rule_action = parts[2]
230
231            msg = self.tr(u"Do you want to {0} rule {1}?").format(rule_action, rule_name)
232
233            QApplication.restoreOverrideCursor()
234
235            try:
236                if QMessageBox.question(None, self.tr("Table rule"), msg,
237                                        QMessageBox.Yes | QMessageBox.No) == QMessageBox.No:
238                    return False
239            finally:
240                QApplication.setOverrideCursor(Qt.WaitCursor)
241
242            if rule_action == "delete":
243                self.aboutToChange.emit()
244                self.database().connector.deleteTableRule(rule_name, (self.schemaName(), self.name))
245                self.refreshRules()
246                return True
247
248        elif action.startswith("refreshmaterializedview/"):
249            if action == "refreshmaterializedview/run":
250                self.runRefreshMaterializedView()
251                return True
252
253        return Table.runAction(self, action)
254
255    def tableFieldsFactory(self, row, table):
256        return PGTableField(row, table)
257
258    def tableConstraintsFactory(self, row, table):
259        return PGTableConstraint(row, table)
260
261    def tableIndexesFactory(self, row, table):
262        return PGTableIndex(row, table)
263
264    def tableTriggersFactory(self, row, table):
265        return PGTableTrigger(row, table)
266
267    def tableRulesFactory(self, row, table):
268        return PGTableRule(row, table)
269
270    def info(self):
271        from .info_model import PGTableInfo
272
273        return PGTableInfo(self)
274
275    def crs(self):
276        return self.database().connector.getCrs(self.srid)
277
278    def tableDataModel(self, parent):
279        from .data_model import PGTableDataModel
280
281        return PGTableDataModel(self, parent)
282
283    def delete(self):
284        self.aboutToChange.emit()
285        if self.isView:
286            ret = self.database().connector.deleteView((self.schemaName(), self.name), self._relationType == 'm')
287        else:
288            ret = self.database().connector.deleteTable((self.schemaName(), self.name))
289        if not ret:
290            self.deleted.emit()
291        return ret
292
293
294class PGVectorTable(PGTable, VectorTable):
295
296    def __init__(self, row, db, schema=None):
297        PGTable.__init__(self, row[:-4], db, schema)
298        VectorTable.__init__(self, db, schema)
299        self.geomColumn, self.geomType, self.geomDim, self.srid = row[-4:]
300
301    def info(self):
302        from .info_model import PGVectorTableInfo
303
304        return PGVectorTableInfo(self)
305
306    def runAction(self, action):
307        if PGTable.runAction(self, action):
308            return True
309        return VectorTable.runAction(self, action)
310
311    def geometryType(self):
312        """ Returns the proper WKT type.
313        PostGIS records type like this:
314        | WKT Type     | geomType    | geomDim |
315        |--------------|-------------|---------|
316        | LineString   | LineString  | 2       |
317        | LineStringZ  | LineString  | 3       |
318        | LineStringM  | LineStringM | 3       |
319        | LineStringZM | LineString  | 4       |
320        """
321        geometryType = self.geomType
322        if self.geomDim == 3 and self.geomType[-1] != "M":
323            geometryType += "Z"
324        elif self.geomDim == 4:
325            geometryType += "ZM"
326
327        return geometryType
328
329
330class PGRasterTable(PGTable, RasterTable):
331
332    def __init__(self, row, db, schema=None):
333        PGTable.__init__(self, row[:-6], db, schema)
334        RasterTable.__init__(self, db, schema)
335        self.geomColumn, self.pixelType, self.pixelSizeX, self.pixelSizeY, self.isExternal, self.srid = row[-6:]
336        self.geomType = 'RASTER'
337
338    def info(self):
339        from .info_model import PGRasterTableInfo
340
341        return PGRasterTableInfo(self)
342
343    def uri(self, uri=None):
344        """Returns the datasource URI for postgresraster provider"""
345
346        if not uri:
347            uri = self.database().uri()
348        service = (u'service=\'%s\'' % uri.service()) if uri.service() else ''
349        dbname = (u'dbname=\'%s\'' % uri.database()) if uri.database() else ''
350        host = (u'host=%s' % uri.host()) if uri.host() else ''
351        user = (u'user=%s' % uri.username()) if uri.username() else ''
352        passw = (u'password=%s' % uri.password()) if uri.password() else ''
353        port = (u'port=%s' % uri.port()) if uri.port() else ''
354
355        schema = self.schemaName() if self.schemaName() else 'public'
356        table = '"%s"."%s"' % (schema, self.name)
357
358        if not dbname:
359            # postgresraster provider *requires* a dbname
360            connector = self.database().connector
361            r = connector._execute(None, "SELECT current_database()")
362            dbname = (u'dbname=\'%s\'' % connector._fetchone(r)[0])
363            connector._close_cursor(r)
364
365        # Find first raster field
366        col = ''
367        for fld in self.fields():
368            if fld.dataType == "raster":
369                col = u'column=\'%s\'' % fld.name
370                break
371
372        uri = u'%s %s %s %s %s %s %s table=%s' % \
373            (service, dbname, host, user, passw, port, col, table)
374
375        return uri
376
377    def mimeUri(self):
378        uri = u"raster:postgresraster:{}:{}".format(self.name, re.sub(":", r"\:", self.uri()))
379        return uri
380
381    def toMapLayer(self, geometryType=None, crs=None):
382        from qgis.core import QgsRasterLayer, QgsContrastEnhancement, QgsDataSourceUri, QgsCredentials
383
384        rl = QgsRasterLayer(self.uri(), self.name, "postgresraster")
385        if not rl.isValid():
386            err = rl.error().summary()
387            uri = QgsDataSourceUri(self.database().uri())
388            conninfo = uri.connectionInfo(False)
389            username = uri.username()
390            password = uri.password()
391
392            for i in range(3):
393                (ok, username, password) = QgsCredentials.instance().get(conninfo, username, password, err)
394                if ok:
395                    uri.setUsername(username)
396                    uri.setPassword(password)
397                    rl = QgsRasterLayer(self.uri(uri), self.name)
398                    if rl.isValid():
399                        break
400
401        if rl.isValid():
402            rl.setContrastEnhancement(QgsContrastEnhancement.StretchToMinimumMaximum)
403        return rl
404
405
406class PGTableField(TableField):
407
408    def __init__(self, row, table):
409        TableField.__init__(self, table)
410        self.num, self.name, self.dataType, self.charMaxLen, self.modifier, self.notNull, self.hasDefault, self.default, typeStr = row
411        self.primaryKey = False
412
413        # get modifier (e.g. "precision,scale") from formatted type string
414        trimmedTypeStr = typeStr.strip()
415        regex = QRegExp("\\((.+)\\)$")
416        startpos = regex.indexIn(trimmedTypeStr)
417        if startpos >= 0:
418            self.modifier = regex.cap(1).strip()
419        else:
420            self.modifier = None
421
422        # find out whether fields are part of primary key
423        for con in self.table().constraints():
424            if con.type == TableConstraint.TypePrimaryKey and self.num in con.columns:
425                self.primaryKey = True
426                break
427
428    def getComment(self):
429        """Returns the comment for a field"""
430        tab = self.table()
431        # SQL Query checking if a comment exists for the field
432        sql_cpt = "Select count(*) from pg_description pd, pg_class pc, pg_attribute pa where relname = '%s' and attname = '%s' and pa.attrelid = pc.oid and pd.objoid = pc.oid and pd.objsubid = pa.attnum" % (tab.name, self.name)
433        # SQL Query that return the comment of the field
434        sql = "Select pd.description from pg_description pd, pg_class pc, pg_attribute pa where relname = '%s' and attname = '%s' and pa.attrelid = pc.oid and pd.objoid = pc.oid and pd.objsubid = pa.attnum" % (tab.name, self.name)
435        c = tab.database().connector._execute(None, sql_cpt)  # Execute Check query
436        res = tab.database().connector._fetchone(c)[0]  # Store result
437        if res == 1:
438            # When a comment exists
439            c = tab.database().connector._execute(None, sql)  # Execute query
440            res = tab.database().connector._fetchone(c)[0]  # Store result
441            tab.database().connector._close_cursor(c)  # Close cursor
442            return res  # Return comment
443        else:
444            return ''
445
446
447class PGTableConstraint(TableConstraint):
448
449    def __init__(self, row, table):
450        TableConstraint.__init__(self, table)
451        self.name, constr_type_str, self.isDefferable, self.isDeffered, columns = row[:5]
452        self.columns = list(map(int, columns.split(' ')))
453
454        if constr_type_str in TableConstraint.types:
455            self.type = TableConstraint.types[constr_type_str]
456        else:
457            self.type = TableConstraint.TypeUnknown
458
459        if self.type == TableConstraint.TypeCheck:
460            self.checkSource = row[5]
461        elif self.type == TableConstraint.TypeForeignKey:
462            self.foreignTable = row[6]
463            self.foreignOnUpdate = TableConstraint.onAction[row[7]]
464            self.foreignOnDelete = TableConstraint.onAction[row[8]]
465            self.foreignMatchType = TableConstraint.matchTypes[row[9]]
466            self.foreignKeys = row[10]
467
468
469class PGTableIndex(TableIndex):
470
471    def __init__(self, row, table):
472        TableIndex.__init__(self, table)
473        self.name, columns, self.isUnique = row
474        self.columns = list(map(int, columns.split(' ')))
475
476
477class PGTableTrigger(TableTrigger):
478
479    def __init__(self, row, table):
480        TableTrigger.__init__(self, table)
481        self.name, self.function, self.type, self.enabled = row
482
483
484class PGTableRule(TableRule):
485
486    def __init__(self, row, table):
487        TableRule.__init__(self, table)
488        self.name, self.definition = row
489