1# -*- coding: utf-8 -*-
2
3"""
4/***************************************************************************
5Name                 : DB Manager
6Description          : Database manager plugin for QGIS
7Date                 : May 23, 2011
8copyright            : (C) 2011 by Giuseppe Sucameli
9email                : brush.tyler@gmail.com
10
11 ***************************************************************************/
12
13/***************************************************************************
14 *                                                                         *
15 *   This program is free software; you can redistribute it and/or modify  *
16 *   it under the terms of the GNU General Public License as published by  *
17 *   the Free Software Foundation; either version 2 of the License, or     *
18 *   (at your option) any later version.                                   *
19 *                                                                         *
20 ***************************************************************************/
21"""
22from builtins import str
23from builtins import range
24
25from qgis.PyQt.QtCore import (Qt,
26                              QTime,
27                              QRegExp,
28                              QAbstractTableModel,
29                              pyqtSignal,
30                              QObject)
31from qgis.PyQt.QtGui import (QFont,
32                             QStandardItemModel,
33                             QStandardItem)
34from qgis.PyQt.QtWidgets import QApplication
35
36from qgis.core import QgsTask
37
38from .plugin import DbError, BaseError
39
40
41class BaseTableModel(QAbstractTableModel):
42
43    def __init__(self, header=None, data=None, parent=None):
44        QAbstractTableModel.__init__(self, parent)
45        self._header = header if header else []
46        self.resdata = data if data else []
47
48    def headerToString(self, sep=u"\t"):
49        header = self._header
50        return sep.join(header)
51
52    def rowToString(self, row, sep=u"\t"):
53        return sep.join(
54            str(self.getData(row, col))
55            for col in range(self.columnCount())
56        )
57
58    def getData(self, row, col):
59        return self.resdata[row][col]
60
61    def columnNames(self):
62        return list(self._header)
63
64    def rowCount(self, parent=None):
65        return len(self.resdata)
66
67    def columnCount(self, parent=None):
68        return len(self._header)
69
70    def data(self, index, role):
71        if role not in [Qt.DisplayRole,
72                        Qt.EditRole,
73                        Qt.FontRole]:
74            return None
75
76        val = self.getData(index.row(), index.column())
77
78        if role == Qt.EditRole:
79            return val
80
81        if role == Qt.FontRole:  # draw NULL in italic
82            if val is not None:
83                return None
84            f = QFont()
85            f.setItalic(True)
86            return f
87
88        if val is None:
89            return "NULL"
90        elif isinstance(val, memoryview):
91            # hide binary data
92            return None
93        elif isinstance(val, str) and len(val) > 300:
94            # too much data to display, elide the string
95            val = val[:300]
96        try:
97            return str(val)  # convert to Unicode
98        except UnicodeDecodeError:
99            return str(val, 'utf-8', 'replace')  # convert from utf8 and replace errors (if any)
100
101    def headerData(self, section, orientation, role):
102        if role != Qt.DisplayRole:
103            return None
104
105        if orientation == Qt.Vertical:
106            # header for a row
107            return section + 1
108        else:
109            # header for a column
110            return self._header[section]
111
112
113class TableDataModel(BaseTableModel):
114
115    def __init__(self, table, parent=None):
116        self.db = table.database().connector
117        self.table = table
118
119        fieldNames = [x.name for x in table.fields()]
120        BaseTableModel.__init__(self, fieldNames, None, parent)
121
122        # get table fields
123        self.fields = []
124        for fld in table.fields():
125            self.fields.append(self._sanitizeTableField(fld))
126
127        self.fetchedCount = 201
128        self.fetchedFrom = -self.fetchedCount - 1  # so the first call to getData will exec fetchMoreData(0)
129
130    def _sanitizeTableField(self, field):
131        """ quote column names to avoid some problems (e.g. columns with upper case) """
132        return self.db.quoteId(field)
133
134    def getData(self, row, col):
135        if row < self.fetchedFrom or row >= self.fetchedFrom + self.fetchedCount:
136            margin = self.fetchedCount / 2
137            start = int(self.rowCount() - margin if row + margin >= self.rowCount() else row - margin)
138            if start < 0:
139                start = 0
140            self.fetchMoreData(start)
141        return self.resdata[row - self.fetchedFrom][col]
142
143    def fetchMoreData(self, row_start):
144        pass
145
146    def rowCount(self, index=None):
147        # case for tables with no columns ... any reason to use them? :-)
148        return self.table.rowCount if self.table.rowCount is not None and self.columnCount(index) > 0 else 0
149
150
151class SqlResultModelAsync(QObject):
152    done = pyqtSignal()
153
154    def __init__(self):
155        super().__init__()
156        self.error = BaseError('')
157        self.status = None
158        self.model = None
159        self.task = None
160        self.canceled = False
161
162    def cancel(self):
163        self.canceled = True
164        if self.task:
165            self.task.cancel()
166
167    def modelDone(self):
168        if self.task:
169            self.status = self.task.status
170            self.model = self.task.model
171            self.error = self.task.error
172
173        self.done.emit()
174
175
176class SqlResultModelTask(QgsTask):
177
178    def __init__(self, db, sql, parent):
179        super().__init__(description=QApplication.translate("DBManagerPlugin", "Executing SQL"))
180        self.db = db
181        self.sql = sql
182        self.parent = parent
183        self.error = BaseError('')
184        self.model = None
185
186
187class SqlResultModel(BaseTableModel):
188
189    def __init__(self, db, sql, parent=None):
190        self.db = db.connector
191
192        t = QTime()
193        t.start()
194        c = self.db._execute(None, sql)
195
196        self._affectedRows = 0
197        data = []
198        header = self.db._get_cursor_columns(c)
199        if header is None:
200            header = []
201
202        try:
203            if len(header) > 0:
204                data = self.db._fetchall(c)
205            self._affectedRows = len(data)
206        except DbError:
207            # nothing to fetch!
208            data = []
209            header = []
210
211        super().__init__(header, data, parent)
212
213        # commit before closing the cursor to make sure that the changes are stored
214        self.db._commit()
215        c.close()
216        self._secs = t.elapsed() / 1000.0
217        del c
218        del t
219
220    def secs(self):
221        return self._secs
222
223    def affectedRows(self):
224        return self._affectedRows
225
226
227class SimpleTableModel(QStandardItemModel):
228
229    def __init__(self, header, editable=False, parent=None):
230        self.header = header
231        self.editable = editable
232        QStandardItemModel.__init__(self, 0, len(self.header), parent)
233
234    def rowFromData(self, data):
235        row = []
236        for c in data:
237            item = QStandardItem(str(c))
238            item.setFlags((item.flags() | Qt.ItemIsEditable) if self.editable else (item.flags() & ~Qt.ItemIsEditable))
239            row.append(item)
240        return row
241
242    def headerData(self, section, orientation, role):
243        if orientation == Qt.Horizontal and role == Qt.DisplayRole:
244            return self.header[section]
245        return None
246
247    def _getNewObject(self):
248        pass
249
250    def getObject(self, row):
251        return self._getNewObject()
252
253    def getObjectIter(self):
254        for row in range(self.rowCount()):
255            yield self.getObject(row)
256
257
258class TableFieldsModel(SimpleTableModel):
259
260    def __init__(self, parent, editable=False):
261        SimpleTableModel.__init__(self, ['Name', 'Type', 'Null', 'Default', 'Comment'], editable, parent)
262
263    def headerData(self, section, orientation, role):
264        if orientation == Qt.Vertical and role == Qt.DisplayRole:
265            return section + 1
266        return SimpleTableModel.headerData(self, section, orientation, role)
267
268    def flags(self, index):
269        flags = SimpleTableModel.flags(self, index)
270        if index.column() == 2 and flags & Qt.ItemIsEditable:  # set Null column as checkable instead of editable
271            flags = flags & ~Qt.ItemIsEditable | Qt.ItemIsUserCheckable
272        return flags
273
274    def append(self, fld):
275        data = [fld.name, fld.type2String(), not fld.notNull, fld.default2String(), fld.getComment()]
276        self.appendRow(self.rowFromData(data))
277        row = self.rowCount() - 1
278        self.setData(self.index(row, 0), fld, Qt.UserRole)
279        self.setData(self.index(row, 1), fld.primaryKey, Qt.UserRole)
280        self.setData(self.index(row, 2), None, Qt.DisplayRole)
281        self.setData(self.index(row, 2), Qt.Unchecked if fld.notNull else Qt.Checked, Qt.CheckStateRole)
282
283    def _getNewObject(self):
284        from .plugin import TableField
285
286        return TableField(None)
287
288    def getObject(self, row):
289        val = self.data(self.index(row, 0), Qt.UserRole)
290        fld = val if val is not None else self._getNewObject()
291        fld.name = self.data(self.index(row, 0)) or ""
292        typestr = self.data(self.index(row, 1)) or ""
293        regex = QRegExp("([^\\(]+)\\(([^\\)]+)\\)")
294        startpos = regex.indexIn(typestr)
295        if startpos >= 0:
296            fld.dataType = regex.cap(1).strip()
297            fld.modifier = regex.cap(2).strip()
298        else:
299            fld.modifier = None
300            fld.dataType = typestr
301
302        fld.notNull = self.data(self.index(row, 2), Qt.CheckStateRole) == Qt.Unchecked
303        fld.primaryKey = self.data(self.index(row, 1), Qt.UserRole)
304        fld.comment = self.data(self.index(row, 4))
305        return fld
306
307    def getFields(self):
308        return [
309            fld
310            for fld in self.getObjectIter()
311        ]
312
313
314class TableConstraintsModel(SimpleTableModel):
315
316    def __init__(self, parent, editable=False):
317        SimpleTableModel.__init__(self, [QApplication.translate("DBManagerPlugin", 'Name'),
318                                         QApplication.translate("DBManagerPlugin", 'Type'),
319                                         QApplication.translate("DBManagerPlugin", 'Column(s)')], editable, parent)
320
321    def append(self, constr):
322        field_names = [str(k_v[1].name) for k_v in iter(list(constr.fields().items()))]
323        data = [constr.name, constr.type2String(), u", ".join(field_names)]
324        self.appendRow(self.rowFromData(data))
325        row = self.rowCount() - 1
326        self.setData(self.index(row, 0), constr, Qt.UserRole)
327        self.setData(self.index(row, 1), constr.type, Qt.UserRole)
328        self.setData(self.index(row, 2), constr.columns, Qt.UserRole)
329
330    def _getNewObject(self):
331        from .plugin import TableConstraint
332
333        return TableConstraint(None)
334
335    def getObject(self, row):
336        constr = self.data(self.index(row, 0), Qt.UserRole)
337        if not constr:
338            constr = self._getNewObject()
339        constr.name = self.data(self.index(row, 0)) or ""
340        constr.type = self.data(self.index(row, 1), Qt.UserRole)
341        constr.columns = self.data(self.index(row, 2), Qt.UserRole)
342        return constr
343
344    def getConstraints(self):
345        return [
346            constr
347            for constr in self.getObjectIter()
348        ]
349
350
351class TableIndexesModel(SimpleTableModel):
352
353    def __init__(self, parent, editable=False):
354        SimpleTableModel.__init__(self, [QApplication.translate("DBManagerPlugin", 'Name'),
355                                         QApplication.translate("DBManagerPlugin", 'Column(s)')], editable, parent)
356
357    def append(self, idx):
358        field_names = [str(k_v1[1].name) for k_v1 in iter(list(idx.fields().items()))]
359        data = [idx.name, u", ".join(field_names)]
360        self.appendRow(self.rowFromData(data))
361        row = self.rowCount() - 1
362        self.setData(self.index(row, 0), idx, Qt.UserRole)
363        self.setData(self.index(row, 1), idx.columns, Qt.UserRole)
364
365    def _getNewObject(self):
366        from .plugin import TableIndex
367
368        return TableIndex(None)
369
370    def getObject(self, row):
371        idx = self.data(self.index(row, 0), Qt.UserRole)
372        if not idx:
373            idx = self._getNewObject()
374        idx.name = self.data(self.index(row, 0))
375        idx.columns = self.data(self.index(row, 1), Qt.UserRole)
376        return idx
377
378    def getIndexes(self):
379        return [
380            idx
381            for idx in self.getObjectIter()
382        ]
383