1# This file is part of python-sql.  The COPYRIGHT file at the top level of
2# this repository contains the full copyright notices and license terms.
3from sql import Expression, Window, Flavor, Literal
4
5__all__ = ['Avg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'Count', 'Every',
6    'Max', 'Min', 'Stddev', 'Sum', 'Variance']
7
8
9class Aggregate(Expression):
10    __slots__ = ('expression', '_distinct', '_order_by', '_within',
11        '_filter', '_window')
12    _sql = ''
13
14    def __init__(self, expression, distinct=False, order_by=None, within=None,
15            filter_=None, window=None):
16        super(Aggregate, self).__init__()
17        self.expression = expression
18        self.distinct = distinct
19        self.order_by = order_by
20        self.within = within
21        self.filter_ = filter_
22        self.window = window
23
24    @property
25    def distinct(self):
26        return self._distinct
27
28    @distinct.setter
29    def distinct(self, value):
30        assert isinstance(value, bool)
31        self._distinct = value
32
33    @property
34    def order_by(self):
35        return self._order_by
36
37    @order_by.setter
38    def order_by(self, value):
39        if value is not None:
40            if isinstance(value, Expression):
41                value = [value]
42            assert all(isinstance(col, Expression) for col in value)
43        self._order_by = value
44
45    @property
46    def within(self):
47        return self._within
48
49    @within.setter
50    def within(self, value):
51        if value is not None:
52            if isinstance(value, Expression):
53                value = [value]
54            assert all(isinstance(col, Expression) for col in value)
55        self._within = value
56
57    @property
58    def filter_(self):
59        return self._filter
60
61    @filter_.setter
62    def filter_(self, value):
63        from sql.operators import And, Or
64        if value is not None:
65            assert isinstance(value, (Expression, And, Or))
66        self._filter = value
67
68    @property
69    def window(self):
70        return self._window
71
72    @window.setter
73    def window(self, value):
74        if value:
75            assert isinstance(value, Window)
76        self._window = value
77
78    @property
79    def _case_expression(self):
80        return self.expression
81
82    def __str__(self):
83        quantifier = 'DISTINCT ' if self.distinct else ''
84        has_filter = Flavor.get().filter_
85        expression = self.expression
86        if self.filter_ and not has_filter:
87            from sql.conditionals import Case
88            expression = Case((self.filter_, self._case_expression))
89        order_by = ''
90        if self.order_by:
91            order_by = ' ORDER BY %s' % ', '.join(map(str, self.order_by))
92        aggregate = '%s(%s%s%s)' % (
93            self._sql, quantifier, expression, order_by)
94        within = ''
95        if self.within:
96            within = (' WITHIN GROUP (ORDER BY %s)'
97                % ', '.join(map(str, self.within)))
98        filter_ = ''
99        if self.filter_ and has_filter:
100            filter_ = ' FILTER (WHERE %s)' % self.filter_
101        window = ''
102        if self.window:
103            if self.window.has_alias:
104                window = ' OVER "%s"' % self.window.alias
105            else:
106                window = ' OVER (%s)' % self.window
107        return aggregate + within + filter_ + window
108
109    @property
110    def params(self):
111        has_filter = Flavor.get().filter_
112        p = []
113        if self.filter_ and not has_filter:
114            p.extend(self.filter_.params)
115            p.extend(self._case_expression.params)
116        else:
117            p.extend(self.expression.params)
118        if self.order_by:
119            for expression in self.order_by:
120                p.extend(expression.params)
121        if self.within:
122            for expression in self.within:
123                p.extend(expression.params)
124        if self.filter_ and has_filter:
125            p.extend(self.filter_.params)
126        if self.window and not self.window.has_alias:
127            p.extend(self.window.params)
128        return tuple(p)
129
130
131class Avg(Aggregate):
132    __slots__ = ()
133    _sql = 'AVG'
134
135
136class BitAnd(Aggregate):
137    __slots__ = ()
138    _sql = 'BIT_AND'
139
140
141class BitOr(Aggregate):
142    __slots__ = ()
143    _sql = 'BIT_OR'
144
145
146class BoolAnd(Aggregate):
147    __slots__ = ()
148    _sql = 'BOOL_AND'
149
150
151class BoolOr(Aggregate):
152    __slots__ = ()
153    _sql = 'BOOL_OR'
154
155
156class Count(Aggregate):
157    __slots__ = ()
158    _sql = 'COUNT'
159
160    @property
161    def _case_expression(self):
162        expression = super(Count, self)._case_expression
163        if (isinstance(self.expression, Literal)
164                and expression.value == '*'):
165            expression = Literal(1)
166        return expression
167
168
169class Every(Aggregate):
170    __slots__ = ()
171    _sql = 'EVERY'
172
173
174class Max(Aggregate):
175    __slots__ = ()
176    _sql = 'MAX'
177
178
179class Min(Aggregate):
180    __slots__ = ()
181    _sql = 'MIN'
182
183
184class Stddev(Aggregate):
185    __slots__ = ()
186    _sql = 'Stddev'
187
188
189class Sum(Aggregate):
190    __slots__ = ()
191    _sql = 'SUM'
192
193
194class Variance(Aggregate):
195    __slots__ = ()
196    _sql = 'VARIANCE'
197