1#!/usr/local/bin/python3.8
2# -*- coding: utf-8 -*-
3############################################################################
4#
5# MODULE:    v.class.mlpy
6# AUTHOR(S): Vaclav Petras
7# PURPOSE:   Classifies features in vecor map.
8# COPYRIGHT: (C) 2012 by Vaclav Petras, and the GRASS Development Team
9#
10#  This program is free software; you can redistribute it and/or modify
11#  it under the terms of the GNU General Public License as published by
12#  the Free Software Foundation; either version 2 of the License, or
13#  (at your option) any later version.
14#
15#  This program is distributed in the hope that it will be useful,
16#  but WITHOUT ANY WARRANTY; without even the implied warranty of
17#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
18#  GNU General Public License for more details.
19#
20############################################################################
21
22#%module
23#% description: Vector supervised classification tool which uses attributes as classification parameters (order of columns matters, names not), cat column identifies feature, class_column is excluded from classification parameters.
24#% keyword: vector
25#% keyword: classification
26#% keyword: supervised
27#%end
28#%option G_OPT_V_MAP
29#%  key: input
30#%  description: Input vector map (attribute table required)
31#%  required: yes
32#%  multiple: no
33#%end
34#%option G_OPT_V_MAP
35#%  key: training
36#%  description: Training vector map (attribute table required)
37#%  required: yes
38#%  multiple: no
39#%end
40#%option G_OPT_V_FIELD
41#%  key: class_column
42#%  type: string
43#%  label: Name of column containing class
44#%  description: Used for both input/output and training dataset. If column does not exists in input map attribute table, it will be created.
45#%  required: no
46#%  multiple: no
47#%  answer: class
48#%end
49#%option
50#%  key: columns
51#%  type: string
52#%  label: Columns to be used in classification
53#%  description: Columns to be used in classification. If left empty, all columns will be used for classification except for class_column and cat column.
54#%  required: no
55#%  multiple: yes
56#%end
57
58
59# TODO: add other classifiers
60# TODO: improve doc
61# TODO: input/training could be multiple
62# TODO: handle layers
63# TODO: output to new map (all classes/one class), depens what is faster
64
65
66import grass.script as grass
67
68import numpy as np
69
70
71def addColumn(mapName, columnName, columnType):
72    """Adds column to the map's table."""
73    columnDefinition = columnName + ' ' + columnType
74    grass.run_command('v.db.addcolumn', map=mapName,
75                      columns=columnDefinition)
76
77
78def hasColumn(tableDescription, column):
79    """Checks if the column is in the table description
80
81    \todo This should be part of some object in the lib.
82    """
83    for col in tableDescription['cols']:
84        if col[0] == column:
85            return True
86    return False
87
88
89def updateColumn(mapName, column, cats, values=None):
90    """!Updates column values for rows with a given categories.
91
92    \param cats categories to be updated
93    or a list of tuples (cat, value) if \p values is None
94    \param values to be set for column (same length as cats) or \c None
95    """
96    statements = ''
97    for i in range(len(cats)):
98        if values is None:
99            cat = str(cats[i][0])
100            val = str(cats[i][1])
101        else:
102            cat = str(cats[i])
103            val = str(values[i])
104        statement = 'UPDATE ' + mapName + ' SET '
105        statement += column + ' = ' + val
106        statement += ' WHERE cat = ' + cat
107        statements += statement + ';\n'
108
109    grass.write_command('db.execute', input='-', stdin=statements)
110
111
112class Classifier:
113
114    """!Interface class between mlpy and other code
115
116    It does not uses numpy in the interface bu this may be wrong.
117    """
118
119    def __init__(self):
120        try:
121            import mlpy
122        except ImportError:
123            grass.fatal(_("Cannot import mlpy (http://mlpy.sourceforge.net)"
124                          " library."
125                          " Please install it or ensure that it is on path"
126                          " (use PYTHONPATH variable)."))
127        # Pytlit has a problem with this mlpy and v.class.mlpy.py
128        # thus, warnings for objects from mlpy has to be disabled
129        self.mlclassifier = mlpy.DLDA(delta=0.01)  # pylint: disable=E1101
130
131    def learn(self, values, classes):
132        self.mlclassifier.learn(np.array(values), np.array(classes))
133
134    def pred(self, values):
135        return self.mlclassifier.pred(np.array(values))
136
137
138# TODO: raise exception when str can not be float
139# TODO: repair those functions, probably create a class
140# TODO: use numpy or array
141def fromDbTableToSimpleTable(dbTable, columnsDescription, columnWithClass):
142    sTable = []
143    for row in dbTable:
144        sRow = []
145        for i, col in enumerate(row):
146            columnName = columnsDescription[i][0]
147            if columnName != columnWithClass and columnName != 'cat':
148                sRow.append(float(col))
149        sTable.append(sRow)
150
151    return sTable
152
153
154def extractColumnWithClass(dbTable, columnsDescription, columnWithClass):
155    classColumn = []
156    for row in dbTable:
157        for i, col in enumerate(row):
158            columnName = columnsDescription[i][0]
159            if columnName == columnWithClass:
160                classColumn.append(float(col))
161
162    return classColumn
163
164
165def extractNthColumn(dbTable, columnNumber):
166    classColumn = []
167    for row in dbTable:
168        for i, col in enumerate(row):
169            if columnNumber == i:
170                classColumn.append(float(col))
171
172    return classColumn
173
174
175def extractColumnWithCats(dbTable, columnsDescription):
176    column = []
177    for row in dbTable:
178        for i, col in enumerate(row):
179            columnName = columnsDescription[i][0]
180            if columnName == 'cat':
181                column.append(float(col))
182
183    return column
184
185
186# unused
187def fatal_noAttributeTable(mapName):
188    grass.fatal(_("Vector map <%s> has no or empty attribute table")
189                % mapName)
190
191
192def fatal_noEnoughColumns(mapName, ncols, required):
193    grass.fatal(_("Not enough columns in vector map <%(map)s>"
194                  " (found %(ncols)s, expected at least %(r)s")
195                % {'map': mapName, 'ncols': ncols, 'r': required})
196
197
198def fatal_noClassColumn(mapName, columnName):
199    grass.fatal(_("Vector map <%(map)s> does not have"
200                  " the column <%(col)s> containing class")
201                % {'map': mapName, 'col': columnName})
202
203
204def fatal_noRows(mapName):
205    grass.fatal(_("Empty attribute table for map vector <%(map)s>")
206                % {'map': mapName})
207
208
209def checkNcols(mapName, tableDescription, requiredNcols):
210    ncols = tableDescription['ncols']
211    if ncols < requiredNcols:
212        fatal_noEnoughColumns(mapName, ncols, requiredNcols)
213
214
215def checkNrows(mapName, tableDescription):
216    if not tableDescription['nrows'] > 0:
217        fatal_noRows(mapName)
218
219
220def checkDbConnection(mapName):
221    """! Checks if vector map has an attribute table.
222
223    \todo check layer
224    """
225    ret = grass.vector_db(mapName)
226    if not ret:
227        grass.fatal(_("Vector map <%s> has no attribute table") % mapName)
228
229
230def main():
231    options, unused = grass.parser()
232
233    mapName = options['input']
234    trainingMapName = options['training']
235
236    columnWithClass = options['class_column']
237
238    useAllColumns = True
239    if options['columns']:
240        # columns as string
241        columns = options['columns'].strip()
242        useAllColumns = False
243
244    # TODO: allow same input and output map only if --overwrite was specified
245    # TODO: is adding column overwriting or overwriting is only updating of existing?
246
247    # variable names connected to training dataset have training prefix
248    # variable names connected to classified dataset have no prefix
249
250    # checking database connection (if map has a table)
251    # TODO: layer
252    checkDbConnection(trainingMapName)
253    checkDbConnection(mapName)
254
255    # loading descriptions first to check them
256
257    trainingTableDescription = grass.db_describe(table=trainingMapName)
258
259    if useAllColumns:
260        trainingMinNcols = 3
261        checkNcols(trainingMapName, trainingTableDescription, trainingMinNcols)
262    else:
263        pass
264
265    checkNrows(trainingMapName, trainingTableDescription)
266
267    if not hasColumn(trainingTableDescription, columnWithClass):
268        fatal_noClassColumn(trainingMapName, columnWithClass)
269
270    tableDescription = grass.db_describe(table=mapName)
271
272    if useAllColumns:
273        minNcols = 2
274        checkNcols(mapName, tableDescription, minNcols)
275    else:
276        pass
277
278    checkNrows(mapName, tableDescription)
279
280    # TODO: check same (+-1) number of columns
281
282    # loadnig data
283
284    # TODO: make fun from this
285    if useAllColumns:
286        dbTable = grass.db_select(table=trainingMapName)
287    else:
288        # assuming that columns concatenated by comma
289        sql = 'SELECT %s,%s FROM %s' % (columnWithClass, columns, trainingMapName)
290        dbTable = grass.db_select(sql=sql)
291
292    trainingParameters = fromDbTableToSimpleTable(dbTable,
293                                                  columnsDescription=trainingTableDescription['cols'],
294                                                  columnWithClass=columnWithClass)
295
296    if useAllColumns:
297        trainingClasses = extractColumnWithClass(dbTable,
298                                                 columnsDescription=trainingTableDescription['cols'],
299                                                 columnWithClass=columnWithClass)
300    else:
301        # FIXME: magic num?
302        trainingClasses = extractNthColumn(dbTable, 0)
303
304    # TODO: hard coded 'cat'?
305    if useAllColumns:
306        dbTable = grass.db_select(table=mapName)
307    else:
308        # assuming that columns concatenated by comma
309        sql = 'SELECT %s,%s FROM %s' % ('cat', columns, mapName)
310        dbTable = grass.db_select(sql=sql)
311
312    parameters = fromDbTableToSimpleTable(dbTable,
313                                          columnsDescription=tableDescription['cols'],
314                                          columnWithClass=columnWithClass)
315    if useAllColumns:
316        cats = extractColumnWithCats(dbTable, columnsDescription=tableDescription['cols'])
317    else:
318        cats = extractNthColumn(dbTable, 0)
319
320    # since dbTable can be big it is better to avoid to have it in memory twice
321    del dbTable
322    del trainingTableDescription
323
324    classifier = Classifier()
325    classifier.learn(trainingParameters, trainingClasses)
326    classes = classifier.pred(parameters)
327
328    # add column only if not exists and the classification was successful
329    if not hasColumn(tableDescription, columnWithClass):
330        addColumn(mapName, columnWithClass, 'int')
331
332    updateColumn(mapName, columnWithClass, cats, classes)
333
334    # TODO: output as a new map (use INSERT, can be faster)
335    # TODO: output as a new layer?
336
337
338if __name__ == "__main__":
339    main()
340