1import contextlib
2import csv
3import locale
4import pickle
5import re
6import sys
7import warnings
8from typing import List, Iterable
9
10from functools import lru_cache
11from importlib import import_module
12from itertools import chain
13
14from os import path, remove
15from tempfile import NamedTemporaryFile
16from urllib.parse import urlparse, urlsplit, urlunsplit, \
17    unquote as urlunquote, quote
18from urllib.request import urlopen, Request
19from pathlib import Path
20
21import numpy as np
22
23import xlrd
24import xlsxwriter
25import openpyxl
26
27from Orange.data import _io, Table, Domain, ContinuousVariable
28from Orange.data import Compression, open_compressed, detect_encoding, \
29    isnastr, guess_data_type, sanitize_variable
30from Orange.data.io_base import FileFormatBase, Flags, DataTableMixin, PICKLE_PROTOCOL
31
32from Orange.util import flatten
33
34
35# Support values longer than 128K (i.e. text contents features)
36csv.field_size_limit(100*1024*1024)
37
38__all__ = ["Flags", "FileFormat"]
39
40
41Compression = Compression
42open_compressed = open_compressed
43detect_encoding = detect_encoding
44isnastr = isnastr
45guess_data_type = guess_data_type
46sanitize_variable = sanitize_variable
47Flags = Flags
48FileFormatMeta = type(FileFormatBase)
49
50
51class FileFormat(FileFormatBase):
52    """
53    Subclasses set the following attributes and override the following methods:
54
55        EXTENSIONS = ('.ext1', '.ext2', ...)
56        DESCRIPTION = 'human-readable file format description'
57        SUPPORT_COMPRESSED = False
58        SUPPORT_SPARSE_DATA = False
59
60        def read(self):
61            ...  # load headers, data, ...
62            return self.data_table(data, headers)
63
64        @classmethod
65        def write_file(cls, filename, data):
66            ...
67            self.write_headers(writer.write, data)
68            writer.writerows(data)
69
70    Wrapper FileFormat.data_table() returns Orange.data.Table from `data`
71    iterable (list (rows) of lists of values (cols)).
72    """
73
74    # Priority when multiple formats support the same extension. Also
75    # the sort order in file open/save combo boxes. Lower is better.
76    PRIORITY = 10000
77    OPTIONAL_TYPE_ANNOTATIONS = False
78    SUPPORT_COMPRESSED = False
79    SUPPORT_SPARSE_DATA = False
80
81    def __init__(self, filename):
82        """
83        Parameters
84        ----------
85        filename : str
86            name of the file to open
87        """
88        self.filename = filename
89        self.sheet = None
90
91    @property
92    def sheets(self) -> List:
93        """FileFormats with a notion of sheets should override this property
94        to return a list of sheet names in the file.
95
96        Returns
97        -------
98        a list of sheet names
99        """
100        return []
101
102    def select_sheet(self, sheet):
103        """Select sheet to be read
104
105        Parameters
106        ----------
107        sheet : str
108            sheet name
109        """
110        self.sheet = sheet
111
112
113def class_from_qualified_name(format_name):
114    """ File format class from qualified name. """
115    elements = format_name.split(".")
116    m = import_module(".".join(elements[:-1]))
117    return getattr(m, elements[-1])
118
119
120class CSVReader(FileFormat, DataTableMixin):
121    """Reader for comma separated files"""
122
123    EXTENSIONS = ('.csv',)
124    DESCRIPTION = 'Comma-separated values'
125    DELIMITERS = ',;:\t$ '
126    SUPPORT_COMPRESSED = True
127    SUPPORT_SPARSE_DATA = False
128    PRIORITY = 20
129    OPTIONAL_TYPE_ANNOTATIONS = True
130
131    def read(self):
132        for encoding in (lambda: ('us-ascii', None),                 # fast
133                         lambda: (detect_encoding(self.filename), None),  # precise
134                         lambda: (locale.getpreferredencoding(False), None),
135                         lambda: (sys.getdefaultencoding(), None),   # desperate
136                         lambda: ('utf-8', None),                    # ...
137                         lambda: ('utf-8', 'ignore')):               # fallback
138            encoding, errors = encoding()
139            # Clear the error flag for all except the last check, because
140            # the error of second-to-last check is stored and shown as warning in owfile
141            if errors != 'ignore':
142                error = ''
143            with self.open(self.filename, mode='rt', newline='',
144                           encoding=encoding, errors=errors) as file:
145                # Sniff the CSV dialect (delimiter, quotes, ...)
146                try:
147                    dialect = csv.Sniffer().sniff(
148                        # Take first couple of *complete* lines as sample
149                        ''.join(file.readline() for _ in range(10)),
150                        self.DELIMITERS)
151                    delimiter = dialect.delimiter
152                    quotechar = dialect.quotechar
153                except UnicodeDecodeError as e:
154                    error = e
155                    continue
156                except csv.Error:
157                    delimiter = self.DELIMITERS[0]
158                    quotechar = csv.excel.quotechar
159
160                file.seek(0)
161                try:
162                    reader = csv.reader(
163                        file, delimiter=delimiter, quotechar=quotechar,
164                        skipinitialspace=True,
165                    )
166                    data = self.data_table(reader)
167
168                    # TODO: Name can be set unconditionally when/if
169                    # self.filename will always be a string with the file name.
170                    # Currently, some tests pass StringIO instead of
171                    # the file name to a reader.
172                    if isinstance(self.filename, str):
173                        data.name = path.splitext(
174                            path.split(self.filename)[-1])[0]
175                    if error and isinstance(error, UnicodeDecodeError):
176                        pos, endpos = error.args[2], error.args[3]
177                        warning = ('Skipped invalid byte(s) in position '
178                                   '{}{}').format(pos,
179                                                  ('-' + str(endpos)) if (endpos - pos) > 1 else '')
180                        warnings.warn(warning)
181                    self.set_table_metadata(self.filename, data)
182                    return data
183                except Exception as e:
184                    error = e
185                    continue
186        raise ValueError('Cannot parse dataset {}: {}'.format(self.filename, error)) from error
187
188    @classmethod
189    def write_file(cls, filename, data, with_annotations=True):
190        with cls.open(filename, mode='wt', newline='', encoding='utf-8') as file:
191            writer = csv.writer(file, delimiter=cls.DELIMITERS[0])
192            cls.write_headers(writer.writerow, data, with_annotations)
193            cls.write_data(writer.writerow, data)
194            cls.write_table_metadata(filename, data)
195
196
197class TabReader(CSVReader):
198    """Reader for tab separated files"""
199    EXTENSIONS = ('.tab', '.tsv')
200    DESCRIPTION = 'Tab-separated values'
201    DELIMITERS = '\t'
202    PRIORITY = 10
203
204
205class PickleReader(FileFormat):
206    """Reader for pickled Table objects"""
207    EXTENSIONS = ('.pkl', '.pickle')
208    DESCRIPTION = 'Pickled Orange data'
209    SUPPORT_COMPRESSED = True
210    SUPPORT_SPARSE_DATA = True
211
212    def read(self):
213        with self.open(self.filename, 'rb') as f:
214            table = pickle.load(f)
215            if not isinstance(table, Table):
216                raise TypeError("file does not contain a data table")
217            else:
218                return table
219
220    @classmethod
221    def write_file(cls, filename, data):
222        with cls.open(filename, 'wb') as f:
223            pickle.dump(data, f, protocol=PICKLE_PROTOCOL)
224
225
226class BasketReader(FileFormat):
227    """Reader for basket (sparse) files"""
228    EXTENSIONS = ('.basket', '.bsk')
229    DESCRIPTION = 'Basket file'
230    SUPPORT_SPARSE_DATA = True
231
232    def read(self):
233        def constr_vars(inds):
234            if inds:
235                return [ContinuousVariable(x.decode("utf-8")) for _, x in
236                        sorted((ind, name) for name, ind in inds.items())]
237
238        X, Y, metas, attr_indices, class_indices, meta_indices = \
239            _io.sparse_read_float(self.filename.encode(sys.getdefaultencoding()))
240
241        attrs = constr_vars(attr_indices)
242        classes = constr_vars(class_indices)
243        meta_attrs = constr_vars(meta_indices)
244        domain = Domain(attrs, classes, meta_attrs)
245        table = Table.from_numpy(
246            domain, attrs and X, classes and Y, metas and meta_attrs)
247        table.name = path.splitext(path.split(self.filename)[-1])[0]
248        return table
249
250
251class _BaseExcelReader(FileFormat, DataTableMixin):
252    """Base class for reading excel files"""
253    SUPPORT_COMPRESSED = False
254    SUPPORT_SPARSE_DATA = False
255
256    def __init__(self, filename):
257        super().__init__(filename=filename)
258        self._workbook = None
259
260    def get_cells(self) -> Iterable:
261        raise NotImplementedError
262
263    def read(self):
264        try:
265            cells = self.get_cells()
266            table = self.data_table(cells)
267            table.name = path.splitext(path.split(self.filename)[-1])[0]
268            if self.sheet and len(self.sheets) > 1:
269                table.name = '-'.join((table.name, self.sheet))
270        except Exception:
271            raise IOError("Couldn't load spreadsheet from " + self.filename)
272        return table
273
274
275class ExcelReader(_BaseExcelReader):
276    """Reader for .xlsx files"""
277    EXTENSIONS = ('.xlsx',)
278    DESCRIPTION = 'Microsoft Excel spreadsheet'
279    ERRORS = ("#VALUE!", "#DIV/0!", "#REF!", "#NUM!", "#NULL!", "#NAME?")
280
281    def __init__(self, filename):
282        super().__init__(filename)
283        self.sheet = self.workbook.active.title
284
285    @property
286    def workbook(self) -> openpyxl.Workbook:
287        if not self._workbook:
288            with warnings.catch_warnings():
289                # We don't care about extensions, but we hate warnings
290                warnings.filterwarnings(
291                    "ignore",
292                    ".*extension is not supported and will be removed.*",
293                    UserWarning)
294                self._workbook = openpyxl.load_workbook(self.filename,
295                                                    data_only=True)
296        return self._workbook
297
298    @property
299    @lru_cache(1)
300    def sheets(self) -> List:
301        return self.workbook.sheetnames if self.workbook else []
302
303    def get_cells(self) -> Iterable:
304        def str_(x):
305            return str(x) if x is not None and x not in ExcelReader.ERRORS \
306                else ""
307
308        sheet = self._get_active_sheet()
309        min_col = sheet.min_column
310        max_col = sheet.max_column
311        cells = ([str_(cell.value) for cell in row[min_col - 1: max_col]]
312                 for row in sheet.iter_rows(sheet.min_row, sheet.max_row + 1))
313        return filter(any, cells)
314
315    def _get_active_sheet(self) -> openpyxl.worksheet.worksheet.Worksheet:
316        if self.sheet:
317            return self.workbook[self.sheet]
318        else:
319            return self.workbook.active
320
321    @classmethod
322    def write_file(cls, filename, data):
323        vars = list(chain((ContinuousVariable('_w'),) if data.has_weights() else (),
324                          data.domain.attributes,
325                          data.domain.class_vars,
326                          data.domain.metas))
327        formatters = [cls.formatter(v) for v in vars]
328        zipped_list_data = zip(data.W if data.W.ndim > 1 else data.W[:, np.newaxis],
329                               data.X,
330                               data.Y if data.Y.ndim > 1 else data.Y[:, np.newaxis],
331                               data.metas)
332        headers = cls.header_names(data)
333        workbook = xlsxwriter.Workbook(filename)
334        sheet = workbook.add_worksheet()
335        for c, header in enumerate(headers):
336            sheet.write(0, c, header)
337        for i, row in enumerate(zipped_list_data, 1):
338            for j, (fmt, v) in enumerate(zip(formatters, flatten(row))):
339                sheet.write(i, j, fmt(v))
340        workbook.close()
341
342
343class XlsReader(_BaseExcelReader):
344    """Reader for .xls files"""
345    EXTENSIONS = ('.xls',)
346    DESCRIPTION = 'Microsoft Excel 97-2004 spreadsheet'
347
348    def __init__(self, filename):
349        super().__init__(filename)
350        self.sheet = self.workbook.sheet_by_index(0).name
351
352    @property
353    def workbook(self) -> xlrd.Book:
354        if not self._workbook:
355            self._workbook = xlrd.open_workbook(self.filename, on_demand=True)
356        return self._workbook
357
358    @property
359    @lru_cache(1)
360    def sheets(self) -> List:
361        return self.workbook.sheet_names() if self.workbook else []
362
363    def get_cells(self) -> Iterable:
364        def str_(cell):
365            return "" if cell.ctype == xlrd.XL_CELL_ERROR else str(cell.value)
366
367        sheet = self._get_active_sheet()
368        first_row = next(i for i in range(sheet.nrows)
369                         if any(sheet.row_values(i)))
370        first_col = next(i for i in range(sheet.ncols)
371                         if sheet.cell_value(first_row, i))
372        row_len = sheet.row_len(first_row)
373        return filter(any, ([str_(sheet.cell(row, col))
374                             if col < sheet.row_len(row) else ''
375                             for col in range(first_col, row_len)]
376                            for row in range(first_row, sheet.nrows)))
377
378    def _get_active_sheet(self) -> xlrd.sheet.Sheet:
379        if self.sheet:
380            return self.workbook.sheet_by_name(self.sheet)
381        else:
382            return self.workbook.sheet_by_index(0)
383
384
385class DotReader(FileFormat):
386    """Writer for dot (graph) files"""
387    EXTENSIONS = ('.dot', '.gv')
388    DESCRIPTION = 'Dot graph description'
389    SUPPORT_COMPRESSED = True
390    SUPPORT_SPARSE_DATA = False
391
392    @classmethod
393    def write_graph(cls, filename, graph):
394        from sklearn import tree
395        tree.export_graphviz(graph, out_file=cls.open(filename, 'wt'))
396
397    @classmethod
398    def write(cls, filename, tree):
399        if type(tree) == dict:
400            tree = tree['tree']
401        cls.write_graph(filename, tree)
402
403
404class UrlReader(FileFormat):
405    def __init__(self, filename):
406        filename = filename.strip()
407        if not urlparse(filename).scheme:
408            filename = 'http://' + filename
409        filename = quote(filename, safe="/:")
410        super().__init__(filename)
411
412    @staticmethod
413    def urlopen(url):
414        req = Request(
415            url,
416            # Avoid 403 error with servers that dislike scrapers
417            headers={'User-Agent': 'Mozilla/5.0 (X11; Linux) Gecko/20100101 Firefox/'})
418        return urlopen(req, timeout=10)
419
420    def read(self):
421        self.filename = self._trim(self._resolve_redirects(self.filename))
422        with contextlib.closing(self.urlopen(self.filename)) as response:
423            name = self._suggest_filename(response.headers['content-disposition'])
424            # using Path since splitext does not extract more extensions
425            extension = ''.join(Path(name).suffixes)  # get only file extension
426            with NamedTemporaryFile(suffix=extension, delete=False) as f:
427                f.write(response.read())
428                # delete=False is a workaround for https://bugs.python.org/issue14243
429
430            reader = self.get_reader(f.name)
431            data = reader.read()
432            remove(f.name)
433        # Override name set in from_file() to avoid holding the temp prefix
434        data.name = path.splitext(name)[0]
435        data.origin = self.filename
436        return data
437
438    def _resolve_redirects(self, url):
439        # Resolve (potential) redirects to a final URL
440        with contextlib.closing(self.urlopen(url)) as response:
441            return response.url
442
443    @classmethod
444    def _trim(cls, url):
445        URL_TRIMMERS = (
446            cls._trim_googlesheet,
447            cls._trim_dropbox,
448        )
449        for trim in URL_TRIMMERS:
450            try:
451                url = trim(url)
452            except ValueError:
453                continue
454            else:
455                break
456        return url
457
458    @staticmethod
459    def _trim_googlesheet(url):
460        match = re.match(r'(?:https?://)?(?:www\.)?'
461                         r'docs\.google\.com/spreadsheets/d/'
462                         r'(?P<workbook_id>[-\w_]+)'
463                         r'(?:/.*?gid=(?P<sheet_id>\d+).*|.*)?',
464                         url, re.IGNORECASE)
465        try:
466            workbook, sheet = match.group('workbook_id'), match.group('sheet_id')
467            if not workbook:
468                raise ValueError
469        except (AttributeError, ValueError):
470            raise ValueError
471        url = 'https://docs.google.com/spreadsheets/d/{}/export?format=tsv'.format(workbook)
472        if sheet:
473            url += '&gid=' + sheet
474        return url
475
476    @staticmethod
477    def _trim_dropbox(url):
478        parts = urlsplit(url)
479        if not parts.netloc.endswith('dropbox.com'):
480            raise ValueError
481        return urlunsplit(parts._replace(query='dl=1'))
482
483    def _suggest_filename(self, content_disposition):
484        default_name = re.sub(r'[\\:/]', '_', urlparse(self.filename).path)
485
486        # See https://tools.ietf.org/html/rfc6266#section-4.1
487        matches = re.findall(r"filename\*?=(?:\"|.{0,10}?'[^']*')([^\"]+)",
488                             content_disposition or '')
489        return urlunquote(matches[-1]) if matches else default_name
490