1"""CSV importer.
2"""
3
4__copyright__ = "Copyright (C) 2016  Martin Blais"
5__license__ = "GNU GPLv2"
6
7# TODO(blais): Rename the beancount.ingest.importers.csv module and remove this.
8from beancount.utils import test_utils
9test_utils.remove_alt_csv_path()
10# pylint: disable=wrong-import-order
11import csv
12
13import collections
14import datetime
15import enum
16import io
17from inspect import signature
18from typing import Callable, Dict, Optional, Union
19
20import dateutil.parser
21
22from beancount.core import data
23from beancount.core.amount import Amount
24from beancount.core.number import ZERO, D
25from beancount.ingest.importers.mixins import filing, identifier
26from beancount.utils.date_utils import parse_date_liberally
27
28
29class Col(enum.Enum):
30    """The set of interpretable columns."""
31
32    # The settlement date, the date we should create the posting at.
33    DATE = '[DATE]'
34
35    # The date at which the transaction took place.
36    TXN_DATE = '[TXN_DATE]'
37
38    # The time at which the transaction took place.
39    # Beancount does not support time field -- just add it to metadata.
40    TXN_TIME = '[TXN_TIME]'
41
42    # The payee field.
43    PAYEE = '[PAYEE]'
44
45    # The narration fields. Use multiple fields to combine them together.
46    NARRATION = NARRATION1 = '[NARRATION1]'
47    NARRATION2 = '[NARRATION2]'
48    NARRATION3 = '[NARRATION3]'
49
50    # The amount being posted.
51    AMOUNT = '[AMOUNT]'
52
53    # Debits and credits being posted in separate, dedicated columns.
54    AMOUNT_DEBIT = '[DEBIT]'
55    AMOUNT_CREDIT = '[CREDIT]'
56
57    # The balance amount, after the row has posted.
58    BALANCE = '[BALANCE]'
59
60    # A field to use as a tag name.
61    TAG = '[TAG]'
62
63    # A field to use as a unique reference id or number.
64    REFERENCE_ID = '[REF]'
65
66    # A column which says DEBIT or CREDIT (generally ignored).
67    DRCR = '[DRCR]'
68
69    # Last 4 digits of the card.
70    LAST4 = '[LAST4]'
71
72    # An account name.
73    ACCOUNT = '[ACCOUNT]'
74
75    # Categorization, if the institution supports it. You could, in theory,
76    # specialize your importer to use this automatically assign a good expenses
77    # account.
78    CATEGORY = '[CATEGORY]'
79
80
81def get_amounts(iconfig, row, allow_zero_amounts, parse_amount):
82    """Get the amount columns of a row.
83
84    Args:
85      iconfig: A dict of Col to row index.
86      row: A row array containing the values of the given row.
87      allow_zero_amounts: Is a transaction with amount D('0.00') okay? If not,
88        return (None, None).
89    Returns:
90      A pair of (debit-amount, credit-amount), both of which are either an
91      instance of Decimal or None, or not available.
92    """
93    debit, credit = None, None
94    if Col.AMOUNT in iconfig:
95        credit = row[iconfig[Col.AMOUNT]]
96    else:
97        debit, credit = [row[iconfig[col]] if col in iconfig else None
98                         for col in [Col.AMOUNT_DEBIT, Col.AMOUNT_CREDIT]]
99
100    # If zero amounts aren't allowed, return null value.
101    is_zero_amount = ((credit is not None and parse_amount(credit) == ZERO) and
102                      (debit is not None and parse_amount(debit) == ZERO))
103    if not allow_zero_amounts and is_zero_amount:
104        return (None, None)
105
106    return (-parse_amount(debit) if debit else None,
107            parse_amount(credit) if credit else None)
108
109
110class Importer(identifier.IdentifyMixin, filing.FilingMixin):
111    """Importer for CSV files."""
112    # pylint: disable=too-many-instance-attributes
113
114    def __init__(self, config, account, currency,
115                 regexps=None,
116                 skip_lines: int = 0,
117                 last4_map: Optional[Dict] = None,
118                 categorizer: Optional[Callable] = None,
119                 institution: Optional[str] = None,
120                 debug: bool = False,
121                 csv_dialect: Union[str, csv.Dialect] = 'excel',
122                 dateutil_kwds: Optional[Dict] = None,
123                 narration_sep: str = '; ',
124                 encoding: Optional[str] = None,
125                 invert_sign: Optional[bool] = False,
126                 **kwds):
127        """Constructor.
128
129        Args:
130          config: A dict of Col enum types to the names or indexes of the columns.
131          account: An account string, the account to post this to.
132          currency: A currency string, the currency of this account.
133          regexps: A list of regular expression strings.
134          skip_lines: Skip first x (garbage) lines of file.
135          last4_map: A dict that maps last 4 digits of the card to a friendly string.
136          categorizer: A callable with two arguments (transaction, row) that can attach
137            the other posting (usually expenses) to a transaction with only single posting.
138          institution: An optional name of an institution to rename the files to.
139          debug: Whether or not to print debug information
140          csv_dialect: A `csv` dialect given either as string or as instance or
141            subclass of `csv.Dialect`.
142          dateutil_kwds: An optional dict defining the dateutil parser kwargs.
143          narration_sep: A string, a separator to use for splitting up the payee and
144            narration fields of a source field.
145          encoding: An optional encoding for the file. Typically useful for files
146            encoded in 'latin1' instead of 'utf-8' (the default).
147          invert_sign: If true, invert the amount's sign unconditionally.
148          **kwds: Extra keyword arguments to provide to the base mixins.
149        """
150        assert isinstance(config, dict), "Invalid type: {}".format(config)
151        self.config = config
152
153        self.currency = currency
154        assert isinstance(skip_lines, int)
155        self.skip_lines = skip_lines
156        self.last4_map = last4_map or {}
157        self.debug = debug
158        self.dateutil_kwds = dateutil_kwds
159        self.csv_dialect = csv_dialect
160        self.narration_sep = narration_sep
161        self.encoding = encoding
162        self.invert_sign = invert_sign
163
164        self.categorizer = categorizer
165
166        # Prepare kwds for filing mixin.
167        kwds['filing'] = account
168        if institution:
169            prefix = kwds.get('prefix', None)
170            assert prefix is None
171            kwds['prefix'] = institution
172
173        # Prepare kwds for identifier mixin.
174        if isinstance(regexps, str):
175            regexps = [regexps]
176        matchers = kwds.setdefault('matchers', [])
177        matchers.append(('mime', 'text/csv'))
178        if regexps:
179            for regexp in regexps:
180                matchers.append(('content', regexp))
181
182        super().__init__(**kwds)
183
184    def file_date(self, file):
185        "Get the maximum date from the file."
186        iconfig, has_header = normalize_config(
187            self.config,
188            file.head(encoding=self.encoding),
189            self.csv_dialect,
190            self.skip_lines,
191        )
192        if Col.DATE in iconfig:
193            reader = iter(csv.reader(open(file.name, encoding=self.encoding),
194                                     dialect=self.csv_dialect))
195            for _ in range(self.skip_lines):
196                next(reader)
197            if has_header:
198                next(reader)
199            max_date = None
200            for row in reader:
201                if not row:
202                    continue
203                if row[0].startswith('#'):
204                    continue
205                date_str = row[iconfig[Col.DATE]]
206                date = parse_date_liberally(date_str, self.dateutil_kwds)
207                if max_date is None or date > max_date:
208                    max_date = date
209            return max_date
210
211    def extract(self, file, existing_entries=None):
212        account = self.file_account(file)
213        entries = []
214
215        # Normalize the configuration to fetch by index.
216        iconfig, has_header = normalize_config(
217            self.config,
218            file.head(encoding=self.encoding),
219            self.csv_dialect,
220            self.skip_lines,
221        )
222
223        reader = iter(csv.reader(open(file.name, encoding=self.encoding),
224                                 dialect=self.csv_dialect))
225
226        # Skip garbage lines
227        for _ in range(self.skip_lines):
228            next(reader)
229
230        # Skip header, if one was detected.
231        if has_header:
232            next(reader)
233
234        def get(row, ftype):
235            try:
236                return row[iconfig[ftype]] if ftype in iconfig else None
237            except IndexError:  # FIXME: this should not happen
238                return None
239
240        # Parse all the transactions.
241        first_row = last_row = None
242        for index, row in enumerate(reader, 1):
243            if not row:
244                continue
245            if row[0].startswith('#'):
246                continue
247
248            # If debugging, print out the rows.
249            if self.debug:
250                print(row)
251
252            if first_row is None:
253                first_row = row
254            last_row = row
255
256            # Extract the data we need from the row, based on the configuration.
257            date = get(row, Col.DATE)
258            txn_date = get(row, Col.TXN_DATE)
259            txn_time = get(row, Col.TXN_TIME)
260
261            payee = get(row, Col.PAYEE)
262            if payee:
263                payee = payee.strip()
264
265            fields = filter(None, [get(row, field)
266                                   for field in (Col.NARRATION1,
267                                                 Col.NARRATION2,
268                                                 Col.NARRATION3)])
269            narration = self.narration_sep.join(
270                field.strip() for field in fields).replace('\n', '; ')
271
272            tag = get(row, Col.TAG)
273            tags = {tag} if tag else data.EMPTY_SET
274
275            link = get(row, Col.REFERENCE_ID)
276            links = {link} if link else data.EMPTY_SET
277
278            last4 = get(row, Col.LAST4)
279
280            balance = get(row, Col.BALANCE)
281
282            # Create a transaction
283            meta = data.new_metadata(file.name, index)
284            if txn_date is not None:
285                meta['date'] = parse_date_liberally(txn_date,
286                                                    self.dateutil_kwds)
287            if txn_time is not None:
288                meta['time'] = str(dateutil.parser.parse(txn_time).time())
289            if balance is not None:
290                meta['balance'] = self.parse_amount(balance)
291            if last4:
292                last4_friendly = self.last4_map.get(last4.strip())
293                meta['card'] = last4_friendly if last4_friendly else last4
294            date = parse_date_liberally(date, self.dateutil_kwds)
295            txn = data.Transaction(meta, date, self.FLAG, payee, narration,
296                                   tags, links, [])
297
298            # Attach one posting to the transaction
299            amount_debit, amount_credit = self.get_amounts(iconfig, row,
300                                                           False, self.parse_amount)
301
302            # Skip empty transactions
303            if amount_debit is None and amount_credit is None:
304                continue
305
306            for amount in [amount_debit, amount_credit]:
307                if amount is None:
308                    continue
309                if self.invert_sign:
310                    amount = -amount
311                units = Amount(amount, self.currency)
312                txn.postings.append(
313                    data.Posting(account, units, None, None, None, None))
314
315            # Attach the other posting(s) to the transaction.
316            txn = self.call_categorizer(txn, row)
317
318            # Add the transaction to the output list
319            entries.append(txn)
320
321        # Figure out if the file is in ascending or descending order.
322        first_date = parse_date_liberally(get(first_row, Col.DATE),
323                                          self.dateutil_kwds)
324        last_date = parse_date_liberally(get(last_row, Col.DATE),
325                                         self.dateutil_kwds)
326        is_ascending = first_date < last_date
327
328        # Reverse the list if the file is in descending order
329        if not is_ascending:
330            entries = list(reversed(entries))
331
332        # Add a balance entry if possible
333        if Col.BALANCE in iconfig and entries:
334            entry = entries[-1]
335            date = entry.date + datetime.timedelta(days=1)
336            balance = entry.meta.get('balance', None)
337            if balance is not None:
338                meta = data.new_metadata(file.name, index)
339                entries.append(
340                    data.Balance(meta, date,
341                                 account, Amount(balance, self.currency),
342                                 None, None))
343
344        # Remove the 'balance' metadata.
345        for entry in entries:
346            entry.meta.pop('balance', None)
347
348        return entries
349
350    def call_categorizer(self, txn, row):
351        if not isinstance(self.categorizer, collections.abc.Callable):
352            return txn
353
354        # TODO(blais): Remove introspection here, just commit to the two
355        # parameter version.
356        params = signature(self.categorizer).parameters
357        if len(params) < 2:
358            return self.categorizer(txn)
359        else:
360            return self.categorizer(txn, row)
361
362    def parse_amount(self, string):
363        """The method used to create Decimal instances. You can override this."""
364        return D(string)
365
366    def get_amounts(self, iconfig, row, allow_zero_amounts, parse_amount):
367        """See function get_amounts() for details.
368
369        This method is present to allow clients to override it in order to deal
370        with special cases, e.g., columns with currency symbols in them.
371        """
372        return get_amounts(iconfig, row, allow_zero_amounts, parse_amount)
373
374
375def normalize_config(config, head, dialect='excel', skip_lines: int = 0):
376    """Using the header line, convert the configuration field name lookups to int indexes.
377
378    Args:
379      config: A dict of Col types to string or indexes.
380      head: A string, some decent number of bytes of the head of the file.
381      dialect: A dialect definition to parse the header
382      skip_lines: Skip first x (garbage) lines of file.
383    Returns:
384      A pair of
385        A dict of Col types to integer indexes of the fields, and
386        a boolean, true if the file has a header.
387    Raises:
388      ValueError: If there is no header and the configuration does not consist
389        entirely of integer indexes.
390    """
391    # Skip garbage lines before sniffing the header
392    assert isinstance(skip_lines, int)
393    assert skip_lines >= 0
394    for _ in range(skip_lines):
395        head = head[head.find('\n')+1:]
396
397    has_header = csv.Sniffer().has_header(head)
398    if has_header:
399        header = next(csv.reader(io.StringIO(head), dialect=dialect))
400        field_map = {field_name.strip(): index
401                     for index, field_name in enumerate(header)}
402        index_config = {}
403        for field_type, field in config.items():
404            if isinstance(field, str):
405                field = field_map[field]
406            index_config[field_type] = field
407    else:
408        if any(not isinstance(field, int)
409               for field_type, field in config.items()):
410            raise ValueError("CSV config without header has non-index fields: "
411                             "{}".format(config))
412        index_config = config
413    return index_config, has_header
414