1# -*- coding: utf-8 -*-
2
3"""
4/***************************************************************************
5Name                 : Virtual layers plugin for DB Manager
6Date                 : December 2015
7copyright            : (C) 2015 by Hugo Mercier
8email                : hugo dot mercier at oslandia dot com
9
10 ***************************************************************************/
11
12/***************************************************************************
13 *                                                                         *
14 *   This program is free software; you can redistribute it and/or modify  *
15 *   it under the terms of the GNU General Public License as published by  *
16 *   the Free Software Foundation; either version 2 of the License, or     *
17 *   (at your option) any later version.                                   *
18 *                                                                         *
19 ***************************************************************************/
20"""
21
22from qgis.PyQt.QtCore import QUrl, QTemporaryFile
23
24from ..connector import DBConnector
25from ..plugin import Table
26
27from qgis.core import (
28    QgsDataSourceUri,
29    QgsVirtualLayerDefinition,
30    QgsProject,
31    QgsMapLayerType,
32    QgsVectorLayer,
33    QgsCoordinateReferenceSystem,
34    QgsWkbTypes
35)
36
37import sqlite3
38
39
40class sqlite3_connection(object):
41
42    def __init__(self, sqlite_file):
43        self.conn = sqlite3.connect(sqlite_file)
44
45    def __enter__(self):
46        return self.conn
47
48    def __exit__(self, ex_type, value, traceback):
49        self.conn.close()
50        return ex_type is None
51
52
53def getQueryGeometryName(sqlite_file):
54    # introspect the file
55    with sqlite3_connection(sqlite_file) as conn:
56        c = conn.cursor()
57        for r in c.execute("SELECT url FROM _meta"):
58            d = QgsVirtualLayerDefinition.fromUrl(QUrl(r[0]))
59            if d.hasDefinedGeometry():
60                return d.geometryField()
61        return None
62
63
64def classFactory():
65    return VLayerConnector
66
67
68# Tables in DB Manager are identified by their display names
69# This global registry maps a display name with a layer id
70# It is filled when getVectorTables is called
71class VLayerRegistry(object):
72    _instance = None
73
74    @classmethod
75    def instance(cls):
76        if cls._instance is None:
77            cls._instance = VLayerRegistry()
78        return cls._instance
79
80    def __init__(self):
81        self.layers = {}
82
83    def reset(self):
84        self.layers = {}
85
86    def has(self, k):
87        return k in self.layers
88
89    def get(self, k):
90        return self.layers.get(k)
91
92    def __getitem__(self, k):
93        return self.get(k)
94
95    def set(self, k, l):
96        self.layers[k] = l
97
98    def __setitem__(self, k, l):
99        self.set(k, l)
100
101    def items(self):
102        return list(self.layers.items())
103
104    def getLayer(self, l):
105        lid = self.layers.get(l)
106        if lid is None:
107            return lid
108        if lid not in QgsProject.instance().mapLayers().keys():
109            self.layers.pop(l)
110            return None
111        return QgsProject.instance().mapLayer(lid)
112
113
114class VLayerConnector(DBConnector):
115
116    def __init__(self, uri):
117        pass
118
119    def _execute(self, cursor, sql):
120        # This is only used to get list of fields
121        class DummyCursor(object):
122
123            def __init__(self, sql):
124                self.sql = sql
125
126            def close(self):
127                pass
128
129        return DummyCursor(sql)
130
131    def _get_cursor(self, name=None):
132        # fix_print_with_import
133        print(("_get_cursor_", name))
134
135    def _get_cursor_columns(self, c):
136        tf = QTemporaryFile()
137        tf.open()
138        tmp = tf.fileName()
139        tf.close()
140
141        df = QgsVirtualLayerDefinition()
142        df.setFilePath(tmp)
143        df.setQuery(c.sql)
144        p = QgsVectorLayer(df.toString(), "vv", "virtual")
145        if not p.isValid():
146            return []
147        f = [f.name() for f in p.fields()]
148        if p.geometryType() != QgsWkbTypes.NullGeometry:
149            gn = getQueryGeometryName(tmp)
150            if gn:
151                f += [gn]
152        return f
153
154    def uri(self):
155        return QgsDataSourceUri("qgis")
156
157    def getInfo(self):
158        return "info"
159
160    def getSpatialInfo(self):
161        return None
162
163    def hasSpatialSupport(self):
164        return True
165
166    def hasRasterSupport(self):
167        return False
168
169    def hasCustomQuerySupport(self):
170        return True
171
172    def hasTableColumnEditingSupport(self):
173        return False
174
175    def fieldTypes(self):
176        return [
177            "integer", "bigint", "smallint",  # integers
178            "real", "double", "float", "numeric",  # floats
179            "varchar", "varchar(255)", "character(20)", "text",  # strings
180            "date", "datetime"  # date/time
181        ]
182
183    def getSchemas(self):
184        return None
185
186    def getTables(self, schema=None, add_sys_tables=False):
187        """ get list of tables """
188        return self.getVectorTables()
189
190    def getVectorTables(self, schema=None):
191        """ get list of table with a geometry column
192                it returns:
193                        name (table name)
194                        is_system_table
195                        type = 'view' (is a view?)
196                        geometry_column:
197                                f_table_name (the table name in geometry_columns may be in a wrong case, use this to load the layer)
198                                f_geometry_column
199                                type
200                                coord_dimension
201                                srid
202        """
203        reg = VLayerRegistry.instance()
204        VLayerRegistry.instance().reset()
205        lst = []
206        for _, l in QgsProject.instance().mapLayers().items():
207            if l.type() == QgsMapLayerType.VectorLayer:
208
209                lname = l.name()
210                # if there is already a layer with this name, use the layer id
211                # as name
212                if reg.has(lname):
213                    lname = l.id()
214                VLayerRegistry.instance().set(lname, l.id())
215
216                geomType = None
217                dim = None
218                if l.isSpatial():
219                    g = l.dataProvider().wkbType()
220                    g_flat = QgsWkbTypes.flatType(g)
221                    geomType = QgsWkbTypes.displayString(g_flat).upper()
222                    if geomType:
223                        dim = 'XY'
224                        if QgsWkbTypes.hasZ(g):
225                            dim += 'Z'
226                        if QgsWkbTypes.hasM(g):
227                            dim += 'M'
228                    lst.append(
229                        (Table.VectorType, lname, False, False, l.id(), 'geometry', geomType, dim, l.crs().postgisSrid()))
230                else:
231                    lst.append((Table.TableType, lname, False, False))
232        return lst
233
234    def getRasterTables(self, schema=None):
235        return []
236
237    def getTableRowCount(self, table):
238        t = table[1]
239        l = VLayerRegistry.instance().getLayer(t)
240        if not l or not l.isValid():
241            return None
242        return l.featureCount()
243
244    def getTableFields(self, table):
245        """ return list of columns in table """
246        t = table[1]
247        l = VLayerRegistry.instance().getLayer(t)
248        if not l or not l.isValid():
249            return []
250        # id, name, type, nonnull, default, pk
251        n = l.dataProvider().fields().size()
252        f = [(i, f.name(), f.typeName(), False, None, False)
253             for i, f in enumerate(l.dataProvider().fields())]
254        if l.isSpatial():
255            f += [(n, "geometry", "geometry", False, None, False)]
256        return f
257
258    def getTableIndexes(self, table):
259        return []
260
261    def getTableConstraints(self, table):
262        return None
263
264    def getTableTriggers(self, table):
265        return []
266
267    def deleteTableTrigger(self, trigger, table=None):
268        return
269
270    def getTableExtent(self, table, geom):
271        is_id, t = table
272        if is_id:
273            l = QgsProject.instance().mapLayer(t)
274        else:
275            l = VLayerRegistry.instance().getLayer(t)
276        if not l or not l.isValid():
277            return None
278        e = l.extent()
279        r = (e.xMinimum(), e.yMinimum(), e.xMaximum(), e.yMaximum())
280        return r
281
282    def getViewDefinition(self, view):
283        print("**unimplemented** getViewDefinition")
284
285    def getSpatialRefInfo(self, srid):
286        crs = QgsCoordinateReferenceSystem(srid)
287        return crs.description()
288
289    def isVectorTable(self, table):
290        return True
291
292    def isRasterTable(self, table):
293        return False
294
295    def createTable(self, table, field_defs, pkey):
296        print("**unimplemented** createTable")
297        return False
298
299    def deleteTable(self, table):
300        print("**unimplemented** deleteTable")
301        return False
302
303    def emptyTable(self, table):
304        print("**unimplemented** emptyTable")
305        return False
306
307    def renameTable(self, table, new_table):
308        print("**unimplemented** renameTable")
309        return False
310
311    def moveTable(self, table, new_table, new_schema=None):
312        print("**unimplemented** moveTable")
313        return False
314
315    def createView(self, view, query):
316        print("**unimplemented** createView")
317        return False
318
319    def deleteView(self, view):
320        print("**unimplemented** deleteView")
321        return False
322
323    def renameView(self, view, new_name):
324        print("**unimplemented** renameView")
325        return False
326
327    def runVacuum(self):
328        print("**unimplemented** runVacuum")
329        return False
330
331    def addTableColumn(self, table, field_def):
332        print("**unimplemented** addTableColumn")
333        return False
334
335    def deleteTableColumn(self, table, column):
336        print("**unimplemented** deleteTableColumn")
337
338    def updateTableColumn(self, table, column, new_name, new_data_type=None, new_not_null=None, new_default=None, comment=None):
339        print("**unimplemented** updateTableColumn")
340
341    def renameTableColumn(self, table, column, new_name):
342        print("**unimplemented** renameTableColumn")
343        return False
344
345    def setColumnType(self, table, column, data_type):
346        print("**unimplemented** setColumnType")
347        return False
348
349    def setColumnDefault(self, table, column, default):
350        print("**unimplemented** setColumnDefault")
351        return False
352
353    def setColumnNull(self, table, column, is_null):
354        print("**unimplemented** setColumnNull")
355        return False
356
357    def isGeometryColumn(self, table, column):
358        print("**unimplemented** isGeometryColumn")
359        return False
360
361    def addGeometryColumn(self, table, geom_column='geometry', geom_type='POINT', srid=-1, dim=2):
362        print("**unimplemented** addGeometryColumn")
363        return False
364
365    def deleteGeometryColumn(self, table, geom_column):
366        print("**unimplemented** deleteGeometryColumn")
367        return False
368
369    def addTableUniqueConstraint(self, table, column):
370        print("**unimplemented** addTableUniqueConstraint")
371        return False
372
373    def deleteTableConstraint(self, table, constraint):
374        print("**unimplemented** deleteTableConstraint")
375        return False
376
377    def addTablePrimaryKey(self, table, column):
378        print("**unimplemented** addTablePrimaryKey")
379        return False
380
381    def createTableIndex(self, table, name, column, unique=False):
382        print("**unimplemented** createTableIndex")
383        return False
384
385    def deleteTableIndex(self, table, name):
386        print("**unimplemented** deleteTableIndex")
387        return False
388
389    def createSpatialIndex(self, table, geom_column='geometry'):
390        print("**unimplemented** createSpatialIndex")
391        return False
392
393    def deleteSpatialIndex(self, table, geom_column='geometry'):
394        print("**unimplemented** deleteSpatialIndex")
395        return False
396
397    def hasSpatialIndex(self, table, geom_column='geometry'):
398        print("**unimplemented** hasSpatialIndex")
399        return False
400
401    def execution_error_types(self):
402        print("**unimplemented** execution_error_types")
403        return False
404
405    def connection_error_types(self):
406        print("**unimplemented** connection_error_types")
407        return False
408
409    def getSqlDictionary(self):
410        from .sql_dictionary import getSqlDictionary
411        sql_dict = getSqlDictionary()
412
413        items = []
414        for tbl in self.getTables():
415            items.append(tbl[1])  # table name
416
417            for fld in self.getTableFields((None, tbl[1])):
418                items.append(fld[1])  # field name
419
420        sql_dict["identifier"] = items
421        return sql_dict
422
423    def getQueryBuilderDictionary(self):
424        from .sql_dictionary import getQueryBuilderDictionary
425
426        return getQueryBuilderDictionary()
427