1"""Interpreter for the query language's AST.
2
3This code accepts the abstract syntax tree produced by the query parser,
4resolves the column and function names, compiles and interpreter and prepares a
5query to be run against a list of entries.
6"""
7__copyright__ = "Copyright (C) 2014-2016  Martin Blais"
8__license__ = "GNU GPLv2"
9
10import collections
11import copy
12import datetime
13import re
14import operator
15from decimal import Decimal
16
17from beancount.core import inventory
18from beancount.query import query_parser
19
20
21# A global constant which sets whether we support inferred/implicit group-by
22# semantics.
23SUPPORT_IMPLICIT_GROUPBY = True
24
25
26class CompilationError(Exception):
27    """A compiler/interpreter error."""
28
29
30class EvalNode:
31    __slots__ = ('dtype',)
32
33    def __init__(self, dtype):
34        # The output data type produce by this node. This is intended to be
35        # inferred by the nodes on construction.
36        assert dtype is not None, "Internal erro: Invalid dtype, must be deduced."
37        self.dtype = dtype
38
39    def __eq__(self, other):
40        """Override the equality operator to compare the data type and a all attributes
41        of this node. This is used by tests for comparing nodes.
42        """
43        return (isinstance(other, type(self))
44                and all(
45                    getattr(self, attribute) == getattr(other, attribute)
46                    for attribute in self.__slots__))
47
48    def __str__(self):
49        return "{}({})".format(type(self).__name__,
50                               ', '.join(repr(getattr(self, child))
51                                         for child in self.__slots__))
52    __repr__ = __str__
53
54    def childnodes(self):
55        """Returns the child nodes of this node.
56        Yields:
57          A list of EvalNode instances.
58        """
59        for attr in self.__slots__:
60            child = getattr(self, attr)
61            if isinstance(child, EvalNode):
62                yield child
63            elif isinstance(child, list):
64                for element in child:
65                    if isinstance(element, EvalNode):
66                        yield element
67
68    def __call__(self, context):
69        """Evaluate this node. This is designed to recurse on its children.
70        All subclasses must override and implement this method.
71
72        Args:
73          context: The evaluation object to which the evaluation need to apply.
74            This is either an entry, a Posting instance, or a particular result
75            set row from a sub-select. This is the provider for the underlying
76            data.
77        Returns:
78          The evaluated value for this sub-expression tree.
79        """
80        raise NotImplementedError
81
82
83class EvalConstant(EvalNode):
84    __slots__ = ('value',)
85
86    def __init__(self, value):
87        super().__init__(type(value))
88        self.value = value
89
90    def __call__(self, _):
91        return self.value
92
93
94class EvalUnaryOp(EvalNode):
95    __slots__ = ('operand', 'operator')
96
97    def __init__(self, operator, operand, dtype):
98        super().__init__(dtype)
99        self.operand = operand
100        self.operator = operator
101
102    def __call__(self, context):
103        return self.operator(self.operand(context))
104
105class EvalNot(EvalUnaryOp):
106
107    def __init__(self, operand):
108        super().__init__(operator.not_, operand, bool)
109
110
111class EvalBinaryOp(EvalNode):
112    __slots__ = ('left', 'right', 'operator')
113
114    def __init__(self, operator, left, right, dtype):
115        super().__init__(dtype)
116        self.operator = operator
117        self.left = left
118        self.right = right
119
120    def __call__(self, context):
121        return self.operator(self.left(context), self.right(context))
122
123class EvalEqual(EvalBinaryOp):
124
125    def __init__(self, left, right):
126        super().__init__(operator.eq, left, right, bool)
127
128class EvalAnd(EvalBinaryOp):
129
130    def __init__(self, left, right):
131        super().__init__(operator.and_, left, right, bool)
132
133class EvalOr(EvalBinaryOp):
134
135    def __init__(self, left, right):
136        super().__init__(operator.or_, left, right, bool)
137
138class EvalGreater(EvalBinaryOp):
139
140    def __init__(self, left, right):
141        super().__init__(operator.gt, left, right, bool)
142
143class EvalGreaterEq(EvalBinaryOp):
144
145    def __init__(self, left, right):
146        super().__init__(operator.ge, left, right, bool)
147
148class EvalLess(EvalBinaryOp):
149
150    def __init__(self, left, right):
151        super().__init__(operator.lt, left, right, bool)
152
153class EvalLessEq(EvalBinaryOp):
154
155    def __init__(self, left, right):
156        super().__init__(operator.le, left, right, bool)
157
158class EvalMatch(EvalBinaryOp):
159
160    @staticmethod
161    def match(left, right):
162        if left is None or right is None:
163            return False
164        return bool(re.search(right, left, re.IGNORECASE))
165
166    def __init__(self, left, right):
167        super().__init__(self.match, left, right, bool)
168        if right.dtype != str:
169            raise CompilationError(
170                "Invalid data type for RHS of match: '{}'; must be a string".format(
171                    right.dtype))
172
173class EvalContains(EvalBinaryOp):
174
175    def __init__(self, left, right):
176        super().__init__(operator.contains, left, right, bool)
177
178    def __call__(self, context):
179        # Note: we need to reverse the arguments.
180        arg_left = self.left(context)
181        arg_right = self.right(context)
182        return self.operator(arg_right, arg_left)
183
184
185# Note: We ought to implement implicit type promotion here,
186# e.g., int -> float -> Decimal.
187
188# Note(2): This does not support multiplication on Amount, Position, Inventory.
189# We need to rewrite the evaluator to support types in order to do this
190# properly.
191
192class EvalMul(EvalBinaryOp):
193
194    def __init__(self, left, right):
195        f = lambda x, y: Decimal(x * y)
196        super().__init__(f, left, right, Decimal)
197
198class EvalDiv(EvalBinaryOp):
199
200    def __init__(self, left, right):
201        f = lambda x, y: Decimal(x / y)
202        super().__init__(f, left, right, Decimal)
203
204class EvalAdd(EvalBinaryOp):
205
206    def __init__(self, left, right):
207        f = lambda x, y: Decimal(x + y)
208        super().__init__(f, left, right, Decimal)
209
210class EvalSub(EvalBinaryOp):
211
212    def __init__(self, left, right):
213        f = lambda x, y: Decimal(x - y)
214        super().__init__(f, left, right, Decimal)
215
216
217# Interpreter nodes.
218OPERATORS = {
219    query_parser.Constant: EvalConstant,
220    query_parser.Not: EvalNot,
221    query_parser.Equal: EvalEqual,
222    query_parser.Match: EvalMatch,
223    query_parser.And: EvalAnd,
224    query_parser.Or: EvalOr,
225    query_parser.Greater: EvalGreater,
226    query_parser.GreaterEq: EvalGreaterEq,
227    query_parser.Less: EvalLess,
228    query_parser.LessEq: EvalLessEq,
229    query_parser.Contains: EvalContains,
230    query_parser.Mul: EvalMul,
231    query_parser.Div: EvalDiv,
232    query_parser.Add: EvalAdd,
233    query_parser.Sub: EvalSub,
234    }
235
236
237
238ANY = object()
239
240class EvalFunction(EvalNode):
241    """Base class for all function objects."""
242    __slots__ = ('operands',)
243
244    # Type constraints on the input arguments.
245    __intypes__ = []
246
247    def __init__(self, operands, dtype):
248        super().__init__(dtype)
249        assert isinstance(operands, list), "Internal error: invalid type for operands."
250        self.operands = operands
251
252        # Check the data types
253        if len(operands) != len(self.__intypes__):
254            raise CompilationError(
255                "Invalid number of arguments for {}: found {} expected {}".format(
256                    type(self).__name__, len(operands), len(self.__intypes__)))
257
258        # Check each of the types.
259        for index, (operand, intype) in enumerate(zip(operands, self.__intypes__)):
260            if not issubclass(operand.dtype, intype):
261                raise CompilationError(
262                    "Invalid type for argument {} of {}: found {} expected {}".format(
263                        index, type(self).__name__, operand.dtype, intype))
264
265    def eval_args(self, context):
266        return [operand(context)
267                for operand in self.operands]
268
269
270class EvalColumn(EvalNode):
271    "Base class for all column accessors."
272
273class EvalAggregator(EvalFunction):
274    "Base class for all aggregator evaluator types."
275
276    # We should not have to recurse any further because there should be no
277    # aggregations under an aggregation node.
278
279    def allocate(self, allocator):
280        """Allocate handles to store data for a node's aggregate storage.
281
282        This is called once before beginning aggregations. If you need any
283        kind of per-aggregate storage during the computation phase, get it
284        in this method.
285
286        Args:
287          allocator: An instance of Allocator, on which you can call allocate() to
288            obtain a handle for a slot to store data on store objects later on.
289        """
290        # Do nothing by default.
291
292    def initialize(self, store):
293        """Initialize this node's aggregate data. If the node is not an aggregate,
294        simply initialize the subnodes. Override this method in the aggregator
295        if you need data for storage.
296
297        Args:
298          store: An object indexable by handles appropriated during allocate().
299        """
300        # Do nothing by default.
301
302    def update(self, store, context):
303        """Evaluate this node. This is designed to recurse on its children.
304
305        Args:
306          store: An object indexable by handles appropriated during allocate().
307          context: The object to which the evaluation need to apply (see __call__).
308        """
309        # Do nothing by default.
310
311    def finalize(self, store):
312        """Finalize this node's aggregate data and return it.
313
314        For aggregate methods, this finalizes the node and returns the final
315        value. The context node will be the alloc instead of the context object.
316
317        Args:
318          store: An object indexable by handles appropriated during allocate().
319        """
320        # Do nothing by default.
321
322    def __call__(self, context):
323        """Return the value on evaluation.
324
325        Args:
326          context: The evaluation object to which the evaluation need to apply.
327            This is either an entry, a Posting instance, or a particular result
328            set row from a sub-select. This is the provider for the underlying
329            data.
330        Returns:
331          The final aggregated value.
332        """
333        # Return None by default.
334
335
336class CompilationEnvironment:
337    """Base class for all compilation contexts. A compilation context provides
338    column accessors specific to the particular row objects that we will access.
339    """
340    # The name of the context.
341    context_name = None
342
343    # Maps of names to evaluators for columns and functions.
344    columns = None
345    functions = None
346
347    def get_column(self, name):
348        """Return a column accessor for the given named column.
349        Args:
350          name: A string, the name of the column to access.
351        """
352        try:
353            return self.columns[name]()
354        except KeyError as exc:
355            raise CompilationError("Invalid column name '{}' in {} context.".format(
356                name, self.context_name)) from exc
357
358    def get_function(self, name, operands):
359        """Return a function accessor for the given named function.
360        Args:
361          name: A string, the name of the function to access.
362        """
363        try:
364            key = tuple([name] + [operand.dtype for operand in operands])
365            return self.functions[key](operands)
366        except KeyError:
367            # If not found with the operands, try just looking it up by name.
368            try:
369                return self.functions[name](operands)
370            except KeyError as exc:
371                signature = '{}({})'.format(name,
372                                            ', '.join(operand.dtype.__name__
373                                                      for operand in operands))
374                raise CompilationError("Invalid function '{}' in {} context".format(
375                    signature, self.context_name)) from exc
376
377
378class AttributeColumn(EvalColumn):
379    def __call__(self, row):
380        return getattr(row, self.name)
381
382class ResultSetEnvironment(CompilationEnvironment):
383    """An execution context that provides access to attributes from a result set.
384    """
385    context_name = 'sub-query'
386
387    def get_column(self, name):
388        """Override the column getter to provide a single attribute getter.
389        """
390        # FIXME: How do we figure out the data type here? We need the context.
391        return AttributeColumn(name)
392
393
394def compile_expression(expr, environ):
395    """Bind an expression to its execution context.
396
397    Args:
398      expr: The root node of an expression.
399      environ: An CompilationEnvironment instance.
400    Returns:
401      The root node of a bound expression.
402    """
403    # Convert column references to the context.
404    if isinstance(expr, query_parser.Column):
405        c_expr = environ.get_column(expr.name)
406
407    elif isinstance(expr, query_parser.Function):
408        c_operands = [compile_expression(operand, environ)
409                      for operand in expr.operands]
410        c_expr = environ.get_function(expr.fname, c_operands)
411
412    elif isinstance(expr, query_parser.UnaryOp):
413        node_type = OPERATORS[type(expr)]
414        c_expr = node_type(compile_expression(expr.operand, environ))
415
416    elif isinstance(expr, query_parser.BinaryOp):
417        node_type = OPERATORS[type(expr)]
418        c_expr = node_type(compile_expression(expr.left, environ),
419                           compile_expression(expr.right, environ))
420
421    elif isinstance(expr, query_parser.Constant):
422        c_expr = EvalConstant(expr.value)
423
424    else:
425        assert False, "Invalid expression to compile: {}".format(expr)
426
427    return c_expr
428
429
430def get_columns_and_aggregates(node):
431    """Find the columns and aggregate nodes below this tree.
432
433    All nodes under aggregate nodes are ignored.
434
435    Args:
436      node: An instance of EvalNode.
437    Returns:
438      A pair of (columns, aggregates), both of which are lists of EvalNode instances.
439        columns: The list of all columns accessed not under an aggregate node.
440        aggregates: The list of all aggregate nodes.
441    """
442    columns = []
443    aggregates = []
444    _get_columns_and_aggregates(node, columns, aggregates)
445    return columns, aggregates
446
447def _get_columns_and_aggregates(node, columns, aggregates):
448    """Walk down a tree of nodes and fetch the column accessors and aggregates.
449
450    This function ignores all nodes under aggregate nodes.
451
452    Args:
453      node: An instance of EvalNode.
454      columns: An accumulator for columns found so far.
455      aggregate: An accumulator for aggregate notes found so far.
456    """
457    if isinstance(node, EvalAggregator):
458        aggregates.append(node)
459    elif isinstance(node, EvalColumn):
460        columns.append(node)
461    else:
462        for child in node.childnodes():
463            _get_columns_and_aggregates(child, columns, aggregates)
464
465
466def is_aggregate(node):
467    """Return true if the node is an aggregate.
468
469    Args:
470      node: An instance of EvalNode.
471    Returns:
472      A boolean.
473    """
474    # Note: We could be a tiny bit more efficient here, but it doesn't matter
475    # much. Performance of the query compilation matters very little overall.
476    _, aggregates = get_columns_and_aggregates(node)
477    return bool(aggregates)
478
479
480def is_hashable_type(node):
481    """Return true if the node is of a hashable type.
482
483    Args:
484      node: An instance of EvalNode.
485    Returns:
486      A boolean.
487    """
488    return not issubclass(node.dtype, inventory.Inventory)
489
490
491def find_unique_name(name, allocated_set):
492    """Come up with a unique name for 'name' amongst 'allocated_set'.
493
494    Args:
495      name: A string, the prefix of the name to find a unique for.
496      allocated_set: A set of string, the set of already allocated names.
497    Returns:
498      A unique name. 'allocated_set' is unmodified.
499    """
500    # Make sure the name is unique.
501    prefix = name
502    i = 1
503    while name in allocated_set:
504        name = '{}_{}'.format(prefix, i)
505        i += 1
506    return name
507
508
509# A compiled target.
510#
511# Attributes:
512#   c_expr: A compiled expression tree (an EvalNode root node).
513#   name: The name of the target. If None, this is an invisible
514#     target that gets evaluated but not displayed.
515#   is_aggregate: A boolean, true if 'c_expr' is an aggregate.
516EvalTarget = collections.namedtuple('EvalTarget', 'c_expr name is_aggregate')
517
518def compile_targets(targets, environ):
519    """Compile the targets and check for their validity. Process wildcard.
520
521    Args:
522      targets: A list of target expressions from the parser.
523      environ: A compilation context for the targets.
524    Returns:
525      A list of compiled target expressions with resolved names.
526    """
527    # Bind the targets expressions to the execution context.
528    if isinstance(targets, query_parser.Wildcard):
529        # Insert the full list of available columns.
530        targets = [query_parser.Target(query_parser.Column(name), None)
531                   for name in environ.wildcard_columns]
532
533    # Compile targets.
534    c_targets = []
535    target_names = set()
536    for target in targets:
537        c_expr = compile_expression(target.expression, environ)
538        target_name = find_unique_name(
539            target.name or query_parser.get_expression_name(target.expression),
540            target_names)
541        target_names.add(target_name)
542        c_targets.append(EvalTarget(c_expr, target_name, is_aggregate(c_expr)))
543
544    # Figure out if this query is an aggregate query and check validity of each
545    # target's aggregation type.
546    for index, c_target in enumerate(c_targets):
547        columns, aggregates = get_columns_and_aggregates(c_target.c_expr)
548
549        # Check for mixed aggregates and non-aggregates.
550        if columns and aggregates:
551            raise CompilationError(
552                "Mixed aggregates and non-aggregates are not allowed")
553
554        if aggregates:
555            # Check for aggregates of aggregates.
556            for aggregate in aggregates:
557                for child in aggregate.childnodes():
558                    if is_aggregate(child):
559                        raise CompilationError(
560                            "Aggregates of aggregates are not allowed")
561
562    return c_targets
563
564
565def compile_group_by(group_by, c_targets, environ):
566    """Process a group-by clause.
567
568    Args:
569      group_by: A GroupBy instance as provided by the parser.
570      c_targets: A list of compiled target expressions.
571      environ: A compilation context to be used to evaluate GROUP BY expressions.
572    Returns:
573      A tuple of
574       new_targets: A list of new compiled target nodes.
575       group_indexes: If the query is an aggregate query, a list of integer
576         indexes to be used for processing grouping. Note that this list may be
577         empty (in the case of targets with only aggregates). On the other hand,
578         if this is not an aggregated query, this is set to None. So do
579         distinguish the empty list vs. None.
580    """
581    new_targets = copy.copy(c_targets)
582    c_target_expressions = [c_target.c_expr for c_target in c_targets]
583
584    group_indexes = []
585    if group_by:
586        # Check that HAVING is not supported yet.
587        if group_by and group_by.having is not None:
588            raise CompilationError("The HAVING clause is not supported yet")
589
590        assert group_by.columns, "Internal error with GROUP-BY parsing"
591
592        # Compile group-by expressions and resolve them to their targets if
593        # possible. A GROUP-BY column may be one of the following:
594        #
595        # * A reference to a target by name.
596        # * A reference to a target by index (starting at one).
597        # * A new, non-aggregate expression.
598        #
599        # References by name are converted to indexes. New expressions are
600        # inserted into the list of targets as invisible targets.
601        targets_name_map = {target.name: index
602                            for index, target in enumerate(c_targets)}
603        for column in group_by.columns:
604            index = None
605
606            # Process target references by index.
607            if isinstance(column, int):
608                index = column - 1
609                if not (0 <= index < len(c_targets)):
610                    raise CompilationError(
611                        "Invalid GROUP-BY column index {}".format(column))
612
613            else:
614                # Process target references by name. These will be parsed as
615                # simple Column expressions. If they refer to a target name, we
616                # resolve them.
617                if isinstance(column, query_parser.Column):
618                    name = column.name
619                    index = targets_name_map.get(name, None)
620
621                # Otherwise we compile the expression and add it to the list of
622                # targets to evaluate and index into that new target.
623                if index is None:
624                    c_expr = compile_expression(column, environ)
625
626                    # Check if the new expression is an aggregate.
627                    aggregate = is_aggregate(c_expr)
628                    if aggregate:
629                        raise CompilationError(
630                            "GROUP-BY expressions may not be aggregates: '{}'".format(
631                                column))
632
633                    # Attempt to reconcile the expression with one of the existing
634                    # target expressions.
635                    try:
636                        index = c_target_expressions.index(c_expr)
637                    except ValueError:
638                        # Add the new target. 'None' for the target name implies it
639                        # should be invisible, not to be rendered.
640                        index = len(new_targets)
641                        new_targets.append(EvalTarget(c_expr, None, aggregate))
642                        c_target_expressions.append(c_expr)
643
644            assert index is not None, "Internal error, could not index group-by reference."
645            group_indexes.append(index)
646
647            # Check that the group-by column references a non-aggregate.
648            c_expr = new_targets[index].c_expr
649            if is_aggregate(c_expr):
650                raise CompilationError(
651                    "GROUP-BY expressions may not reference aggregates: '{}'".format(
652                        column))
653
654            # Check that the group-by column has a supported hashable type.
655            if not is_hashable_type(c_expr):
656                raise CompilationError(
657                    "GROUP-BY a non-hashable type is not supported: '{}'".format(
658                        column))
659
660
661    else:
662        # If it does not have a GROUP-BY clause...
663        aggregate_bools = [c_target.is_aggregate for c_target in c_targets]
664        if any(aggregate_bools):
665            # If the query is an aggregate query, check that all the targets are
666            # aggregates.
667            if all(aggregate_bools):
668                assert group_indexes == []
669            else:
670                # If some of the targets aren't aggregates, automatically infer
671                # that they are to be implicit group by targets. This makes for
672                # a much more convenient syntax for our lightweight SQL, where
673                # grouping is optional.
674                if SUPPORT_IMPLICIT_GROUPBY:
675                    group_indexes = [index
676                                     for index, c_target in enumerate(c_targets)
677                                     if not c_target.is_aggregate]
678                else:
679                    raise CompilationError(
680                        "Aggregate query without a GROUP-BY should have only aggregates")
681        else:
682            # This is not an aggregate query; don't set group_indexes to
683            # anything useful, we won't need it.
684            group_indexes = None
685
686    return new_targets[len(c_targets):], group_indexes
687
688
689def compile_order_by(order_by, c_targets, environ):
690    """Process an order-by clause.
691
692    Args:
693      order_by: A OrderBy instance as provided by the parser.
694      c_targets: A list of compiled target expressions.
695      environ: A compilation context to be used to evaluate ORDER BY expressions.
696    Returns:
697      A tuple of
698       new_targets: A list of new compiled target nodes.
699       order_indexes: A list of integer indexes to be used for processing ordering.
700    """
701    new_targets = copy.copy(c_targets)
702    c_target_expressions = [c_target.c_expr for c_target in c_targets]
703    order_indexes = []
704
705    # Compile order-by expressions and resolve them to their targets if
706    # possible. A ORDER-BY column may be one of the following:
707    #
708    # * A reference to a target by name.
709    # * A reference to a target by index (starting at one).
710    # * A new expression, aggregate or not.
711    #
712    # References by name are converted to indexes. New expressions are
713    # inserted into the list of targets as invisible targets.
714    targets_name_map = {target.name: index
715                        for index, target in enumerate(c_targets)}
716    for column in order_by.columns:
717        index = None
718
719        # Process target references by index.
720        if isinstance(column, int):
721            index = column - 1
722            if not (0 <= index < len(c_targets)):
723                raise CompilationError(
724                    "Invalid ORDER-BY column index {}".format(column))
725
726        else:
727            # Process target references by name. These will be parsed as
728            # simple Column expressions. If they refer to a target name, we
729            # resolve them.
730            if isinstance(column, query_parser.Column):
731                name = column.name
732                index = targets_name_map.get(name, None)
733
734            # Otherwise we compile the expression and add it to the list of
735            # targets to evaluate and index into that new target.
736            if index is None:
737                c_expr = compile_expression(column, environ)
738
739                # Attempt to reconcile the expression with one of the existing
740                # target expressions.
741                try:
742                    index = c_target_expressions.index(c_expr)
743                except ValueError:
744                    # Add the new target. 'None' for the target name implies it
745                    # should be invisible, not to be rendered.
746                    index = len(new_targets)
747                    new_targets.append(EvalTarget(c_expr, None, is_aggregate(c_expr)))
748                    c_target_expressions.append(c_expr)
749
750        assert index is not None, "Internal error, could not index order-by reference."
751        order_indexes.append(index)
752
753    return (new_targets[len(c_targets):], order_indexes)
754
755
756# A compile FROM clause.
757#
758# Attributes:
759#   c_expr: A compiled expression tree (an EvalNode root node).
760#   close: (See query_parser.From.close).
761EvalFrom = collections.namedtuple('EvalFrom', 'c_expr open close clear')
762
763def compile_from(from_clause, environ):
764    """Compiled a From clause as provided by the parser, in the given environment.
765
766    Args:
767      select: An instance of query_parser.Select.
768      environ: : A compilation context for evaluating entry filters.
769    Returns:
770      An instance of Query, ready to be executed.
771    """
772    if from_clause is not None:
773        c_expression = (compile_expression(from_clause.expression, environ)
774                        if from_clause.expression is not None
775                        else None)
776
777        # Check that the from clause does not contain aggregates.
778        if c_expression is not None and is_aggregate(c_expression):
779            raise CompilationError("Aggregates are not allowed in from clause")
780
781        if (isinstance(from_clause.open, datetime.date) and
782            isinstance(from_clause.close, datetime.date) and
783            from_clause.open > from_clause.close):
784            raise CompilationError("Invalid dates: CLOSE date must follow OPEN date")
785
786        c_from = EvalFrom(c_expression,
787                          from_clause.open,
788                          from_clause.close,
789                          from_clause.clear)
790    else:
791        c_from = None
792
793    return c_from
794
795
796# A compiled query, ready for execution.
797#
798# Attributes:
799#   c_targets: A list of compiled targets (instancef of EvalTarget).
800#   c_from: An instance of EvalNode, a compiled expression tree, for directives.
801#   c_where: An instance of EvalNode, a compiled expression tree, for postings.
802#   group_indexes: A list of integers that describe which target indexes to
803#     group by. All the targets referenced here should be non-aggregates. In fact,
804#     this list of indexes should always cover all non-aggregates in 'c_targets'.
805#     And this list may well include some invisible columns if only specified in
806#     the GROUP BY clause.
807#   order_indexes: A list of integers that describe which targets to order by.
808#     This list may refer to either aggregates or non-aggregates.
809#   limit: An optional integer used to cut off the number of result rows returned.
810#   distinct: An optional boolean that requests we should uniquify the result rows.
811#   flatten: An optional boolean that requests we should output a single posting
812#     row for each currency present in an accumulated and output inventory.
813EvalQuery = collections.namedtuple('EvalQuery', ('c_targets c_from c_where '
814                                                 'group_indexes order_indexes ordering '
815                                                 'limit distinct flatten'))
816
817def compile_select(select, targets_environ, postings_environ, entries_environ):
818    """Prepare an AST for a Select statement into a very rudimentary execution tree.
819    The execution tree mostly looks much like an AST, but with some nodes
820    replaced with knowledge specific to an execution context and eventually some
821    basic optimizations.
822
823    Args:
824      select: An instance of query_parser.Select.
825      targets_environ: A compilation environment for evaluating targets.
826      postings_environ: A compilation environment for evaluating postings filters.
827      entries_environ: A compilation environment for evaluating entry filters.
828    Returns:
829      An instance of EvalQuery, ready to be executed.
830    """
831
832    # Process the FROM clause and figure out the execution environment for the
833    # targets and the where clause.
834    from_clause = select.from_clause
835    if isinstance(from_clause, query_parser.Select):
836        c_from = None
837        environ_target = ResultSetEnvironment()
838        environ_where = ResultSetEnvironment()
839
840        # Remove this when we add support for nested queries.
841        raise CompilationError("Queries from nested SELECT are not supported yet")
842
843    if from_clause is None or isinstance(from_clause, query_parser.From):
844        # Bind the from clause contents.
845        c_from = compile_from(from_clause, entries_environ)
846        environ_target = targets_environ
847        environ_where = postings_environ
848
849    else:
850        raise CompilationError("Unexpected from clause in AST: {}".format(from_clause))
851
852    # Compile the targets.
853    c_targets = compile_targets(select.targets, environ_target)
854
855    # Bind the WHERE expression to the execution environment.
856    if select.where_clause is not None:
857        c_where = compile_expression(select.where_clause, environ_where)
858
859        # Aggregates are disallowed in this clause. Check for this.
860        # NOTE: This should never trigger if the compilation environment does not
861        # contain any aggregate. Just being manic and safe here.
862        if is_aggregate(c_where):
863            raise CompilationError("Aggregates are disallowed in WHERE clause")
864    else:
865        c_where = None
866
867    # Process the GROUP-BY clause.
868    new_targets, group_indexes = compile_group_by(select.group_by,
869                                                  c_targets,
870                                                  environ_target)
871    if new_targets:
872        c_targets.extend(new_targets)
873
874    # Process the ORDER-BY clause.
875    if select.order_by is not None:
876        (new_targets, order_indexes) = compile_order_by(select.order_by,
877                                                        c_targets,
878                                                        environ_target)
879        if new_targets:
880            c_targets.extend(new_targets)
881        ordering = select.order_by.ordering
882    else:
883        order_indexes = None
884        ordering = None
885
886    # If this is an aggregate query (it groups, see list of indexes), check that
887    # the set of non-aggregates match exactly the group indexes. This should
888    # always be the case at this point, because we have added all the necessary
889    # targets to the list of group-by expressions and should have resolved all
890    # the indexes.
891    if group_indexes is not None:
892        non_aggregate_indexes = set(index
893                                    for index, c_target in enumerate(c_targets)
894                                    if not c_target.is_aggregate)
895        if non_aggregate_indexes != set(group_indexes):
896            missing_names = ['"{}"'.format(c_targets[index].name)
897                             for index in non_aggregate_indexes - set(group_indexes)]
898            raise CompilationError(
899                "All non-aggregates must be covered by GROUP-BY clause in aggregate query; "
900                "the following targets are missing: {}".format(",".join(missing_names)))
901
902    # Check that PIVOT-BY is not supported yet.
903    if select.pivot_by is not None:
904        raise CompilationError("The PIVOT BY clause is not supported yet")
905
906    return EvalQuery(c_targets,
907                     c_from,
908                     c_where,
909                     group_indexes,
910                     order_indexes,
911                     ordering,
912                     select.limit,
913                     select.distinct,
914                     select.flatten)
915
916
917def transform_journal(journal):
918    """Translate a Journal entry into an uncompiled Select statement.
919
920    Args:
921      journal: An instance of a Journal object.
922    Returns:
923      An instance of an uncompiled Select object.
924    """
925    cooked_select = query_parser.Parser().parse("""
926
927        SELECT
928           date,
929           flag,
930           MAXWIDTH(payee, 48),
931           MAXWIDTH(narration, 80),
932           account,
933           {summary_func}(position),
934           {summary_func}(balance)
935        {where}
936
937    """.format(where=('WHERE account ~ "{}"'.format(journal.account)
938                      if journal.account
939                      else ''),
940               summary_func=journal.summary_func or ''))
941
942    return query_parser.Select(cooked_select.targets,
943                               journal.from_clause,
944                               cooked_select.where_clause,
945                               None, None, None, None, None, None)
946
947
948def transform_balances(balances):
949    """Translate a Balances entry into an uncompiled Select statement.
950
951    Args:
952      balances: An instance of a Balance object.
953    Returns:
954      An instance of an uncompiled Select object.
955    """
956    ## FIXME: Change the aggregation rules to allow GROUP-BY not to include the
957    ## non-aggregate ORDER-BY columns, so we could just GROUP-BY accounts here
958    ## instead of having to include the sort-key. I think it should be fine if
959    ## the first or last sort-order value gets used, because it would simplify
960    ## the input statement.
961
962    cooked_select = query_parser.Parser().parse("""
963
964      SELECT account, SUM({}(position))
965      GROUP BY account, ACCOUNT_SORTKEY(account)
966      ORDER BY ACCOUNT_SORTKEY(account)
967
968    """.format(balances.summary_func or ""))
969
970    return query_parser.Select(cooked_select.targets,
971                               balances.from_clause,
972                               balances.where_clause,
973                               cooked_select.group_by,
974                               cooked_select.order_by,
975                               None, None, None, None)
976
977
978# A compiled print statement, ready for execution.
979#
980# Attributes:
981#   c_from: An instance of EvalNode, a compiled expression tree, for directives.
982EvalPrint = collections.namedtuple('EvalPrint', 'c_from')
983
984def compile_print(print_stmt, env_entries):
985    """Compile a Print statement.
986
987    Args:
988      statement: An instance of query_parser.Print.
989      entries_environ: : A compilation environment for evaluating entry filters.
990    Returns:
991      An instance of EvalPrint, ready to be executed.
992    """
993    c_from = compile_from(print_stmt.from_clause, env_entries)
994    return EvalPrint(c_from)
995
996
997def compile(statement, targets_environ, postings_environ, entries_environ):
998    """Prepare an AST any of the statement into an executable statement.
999
1000    Args:
1001      statement: An instance of the parser's Select, Balances, Journal or Print.
1002      targets_environ: A compilation environment for evaluating targets.
1003      postings_environ: : A compilation environment for evaluating postings filters.
1004      entries_environ: : A compilation environment for evaluating entry filters.
1005    Returns:
1006      An instance of EvalQuery or EvalPrint, ready to be executed.
1007    Raises:
1008      CompilationError: If the statement cannot be compiled, or is not one of the
1009        supported statements.
1010    """
1011    if isinstance(statement, query_parser.Balances):
1012        statement = transform_balances(statement)
1013    elif isinstance(statement, query_parser.Journal):
1014        statement = transform_journal(statement)
1015
1016    if isinstance(statement, query_parser.Select):
1017        c_query = compile_select(statement,
1018                                 targets_environ, postings_environ, entries_environ)
1019    elif isinstance(statement, query_parser.Print):
1020        c_query = compile_print(statement, entries_environ)
1021    else:
1022        raise CompilationError(
1023            "Cannot compile a statement of type '{}'".format(type(statement)))
1024
1025    return c_query
1026