1# This file is part of python-sql. The COPYRIGHT file at the top level of 2# this repository contains the full copyright notices and license terms. 3from sql import Expression, Window, Flavor, Literal 4 5__all__ = ['Avg', 'BitAnd', 'BitOr', 'BoolAnd', 'BoolOr', 'Count', 'Every', 6 'Max', 'Min', 'Stddev', 'Sum', 'Variance'] 7 8 9class Aggregate(Expression): 10 __slots__ = ('expression', '_distinct', '_order_by', '_within', 11 '_filter', '_window') 12 _sql = '' 13 14 def __init__(self, expression, distinct=False, order_by=None, within=None, 15 filter_=None, window=None): 16 super(Aggregate, self).__init__() 17 self.expression = expression 18 self.distinct = distinct 19 self.order_by = order_by 20 self.within = within 21 self.filter_ = filter_ 22 self.window = window 23 24 @property 25 def distinct(self): 26 return self._distinct 27 28 @distinct.setter 29 def distinct(self, value): 30 assert isinstance(value, bool) 31 self._distinct = value 32 33 @property 34 def order_by(self): 35 return self._order_by 36 37 @order_by.setter 38 def order_by(self, value): 39 if value is not None: 40 if isinstance(value, Expression): 41 value = [value] 42 assert all(isinstance(col, Expression) for col in value) 43 self._order_by = value 44 45 @property 46 def within(self): 47 return self._within 48 49 @within.setter 50 def within(self, value): 51 if value is not None: 52 if isinstance(value, Expression): 53 value = [value] 54 assert all(isinstance(col, Expression) for col in value) 55 self._within = value 56 57 @property 58 def filter_(self): 59 return self._filter 60 61 @filter_.setter 62 def filter_(self, value): 63 from sql.operators import And, Or 64 if value is not None: 65 assert isinstance(value, (Expression, And, Or)) 66 self._filter = value 67 68 @property 69 def window(self): 70 return self._window 71 72 @window.setter 73 def window(self, value): 74 if value: 75 assert isinstance(value, Window) 76 self._window = value 77 78 @property 79 def _case_expression(self): 80 return self.expression 81 82 def __str__(self): 83 quantifier = 'DISTINCT ' if self.distinct else '' 84 has_filter = Flavor.get().filter_ 85 expression = self.expression 86 if self.filter_ and not has_filter: 87 from sql.conditionals import Case 88 expression = Case((self.filter_, self._case_expression)) 89 order_by = '' 90 if self.order_by: 91 order_by = ' ORDER BY %s' % ', '.join(map(str, self.order_by)) 92 aggregate = '%s(%s%s%s)' % ( 93 self._sql, quantifier, expression, order_by) 94 within = '' 95 if self.within: 96 within = (' WITHIN GROUP (ORDER BY %s)' 97 % ', '.join(map(str, self.within))) 98 filter_ = '' 99 if self.filter_ and has_filter: 100 filter_ = ' FILTER (WHERE %s)' % self.filter_ 101 window = '' 102 if self.window: 103 if self.window.has_alias: 104 window = ' OVER "%s"' % self.window.alias 105 else: 106 window = ' OVER (%s)' % self.window 107 return aggregate + within + filter_ + window 108 109 @property 110 def params(self): 111 has_filter = Flavor.get().filter_ 112 p = [] 113 if self.filter_ and not has_filter: 114 p.extend(self.filter_.params) 115 p.extend(self._case_expression.params) 116 else: 117 p.extend(self.expression.params) 118 if self.order_by: 119 for expression in self.order_by: 120 p.extend(expression.params) 121 if self.within: 122 for expression in self.within: 123 p.extend(expression.params) 124 if self.filter_ and has_filter: 125 p.extend(self.filter_.params) 126 if self.window and not self.window.has_alias: 127 p.extend(self.window.params) 128 return tuple(p) 129 130 131class Avg(Aggregate): 132 __slots__ = () 133 _sql = 'AVG' 134 135 136class BitAnd(Aggregate): 137 __slots__ = () 138 _sql = 'BIT_AND' 139 140 141class BitOr(Aggregate): 142 __slots__ = () 143 _sql = 'BIT_OR' 144 145 146class BoolAnd(Aggregate): 147 __slots__ = () 148 _sql = 'BOOL_AND' 149 150 151class BoolOr(Aggregate): 152 __slots__ = () 153 _sql = 'BOOL_OR' 154 155 156class Count(Aggregate): 157 __slots__ = () 158 _sql = 'COUNT' 159 160 @property 161 def _case_expression(self): 162 expression = super(Count, self)._case_expression 163 if (isinstance(self.expression, Literal) 164 and expression.value == '*'): 165 expression = Literal(1) 166 return expression 167 168 169class Every(Aggregate): 170 __slots__ = () 171 _sql = 'EVERY' 172 173 174class Max(Aggregate): 175 __slots__ = () 176 _sql = 'MAX' 177 178 179class Min(Aggregate): 180 __slots__ = () 181 _sql = 'MIN' 182 183 184class Stddev(Aggregate): 185 __slots__ = () 186 _sql = 'Stddev' 187 188 189class Sum(Aggregate): 190 __slots__ = () 191 _sql = 'SUM' 192 193 194class Variance(Aggregate): 195 __slots__ = () 196 _sql = 'VARIANCE' 197