1# This file is part of python-sql.  The COPYRIGHT file at the top level of
2# this repository contains the full copyright notices and license terms.
3from itertools import chain
4
5from sql import Expression, Select, CombiningQuery, Flavor, FromItem, Window
6
7__all__ = ['Abs', 'Cbrt', 'Ceil', 'Degrees', 'Div', 'Exp', 'Floor', 'Ln',
8    'Log', 'Mod', 'Pi', 'Power', 'Radians', 'Random', 'Round', 'SetSeed',
9    'Sign', 'Sqrt', 'Trunc', 'WidthBucket',
10    'Acos', 'Asin', 'Atan', 'Atan2', 'Cos', 'Cot', 'Sin', 'Tan',
11    'BitLength', 'CharLength', 'Overlay', 'Position', 'Substring', 'Trim',
12    'Upper',
13    'ToChar', 'ToDate', 'ToNumber', 'ToTimestamp',
14    'Age', 'ClockTimestamp', 'CurrentDate', 'CurrentTime', 'CurrentTimestamp',
15    'DatePart', 'DateTrunc', 'Extract', 'Isfinite', 'JustifyDays',
16    'JustifyHours', 'JustifyInterval', 'Localtime', 'Localtimestamp', 'Now',
17    'StatementTimestamp', 'Timeofday', 'TransactionTimestamp',
18    'AtTimeZone',
19    'RowNumber', 'Rank', 'DenseRank', 'PercentRank', 'CumeDist', 'Ntile',
20    'Lag', 'Lead', 'FirstValue', 'LastValue', 'NthValue']
21
22# Mathematical
23
24
25class Function(Expression, FromItem):
26    __slots__ = ('args', '_columns_definitions')
27    table = ''
28    name = ''
29    _function = ''
30
31    def __init__(self, *args, **kwargs):
32        self.args = args
33        self.columns_definitions = kwargs.get('columns_definitions', [])
34
35    @property
36    def columns_definitions(self):
37        return ', '.join('"%s" %s' % (c, d)
38            for c, d in self._columns_definitions)
39
40    @columns_definitions.setter
41    def columns_definitions(self, value):
42        assert isinstance(value, list)
43        self._columns_definitions = value
44
45    @staticmethod
46    def _format(value):
47        if isinstance(value, Expression):
48            return str(value)
49        elif isinstance(value, (Select, CombiningQuery)):
50            return '(%s)' % value
51        else:
52            return Flavor().get().param
53
54    def __str__(self):
55        Mapping = Flavor.get().function_mapping.get(self.__class__)
56        if Mapping:
57            return str(Mapping(*self.args))
58        return self._function + '(' + ', '.join(
59            map(self._format, self.args)) + ')'
60
61    @property
62    def params(self):
63        Mapping = Flavor.get().function_mapping.get(self.__class__)
64        if Mapping:
65            return Mapping(*self.args).params
66        p = []
67        for arg in self.args:
68            if isinstance(arg, (Expression, Select, CombiningQuery)):
69                p.extend(arg.params)
70            else:
71                p.append(arg)
72        return tuple(p)
73
74
75class FunctionKeyword(Function):
76    __slots__ = ()
77    _function = ''
78    _keywords = ()
79
80    def __str__(self):
81        Mapping = Flavor.get().function_mapping.get(self.__class__)
82        if Mapping:
83            return str(Mapping(*self.args))
84        return (self._function + '('
85            + ' '.join(chain(*zip(
86                        self._keywords,
87                        map(self._format, self.args))))[1:]
88            + ')')
89
90
91class FunctionNotCallable(Function):
92    __slots__ = ()
93    _function = ''
94
95    def __str__(self):
96        Mapping = Flavor.get().function_mapping.get(self.__class__)
97        if Mapping:
98            return str(Mapping(*self.args))
99        return self._function
100
101
102class Abs(Function):
103    __slots__ = ()
104    _function = 'ABS'
105
106
107class Cbrt(Function):
108    __slots__ = ()
109    _function = 'CBRT'
110
111
112class Ceil(Function):
113    __slots__ = ()
114    _function = 'CEIL'
115
116
117class Degrees(Function):
118    __slots__ = ()
119    _function = 'DEGREES'
120
121
122class Div(Function):
123    __slots__ = ()
124    _function = 'DIV'
125
126
127class Exp(Function):
128    __slots__ = ()
129    _function = 'EXP'
130
131
132class Floor(Function):
133    __slots__ = ()
134    _function = 'FLOOR'
135
136
137class Ln(Function):
138    __slots__ = ()
139    _function = 'LN'
140
141
142class Log(Function):
143    __slots__ = ()
144    _function = 'LOG'
145
146
147class Mod(Function):
148    __slots__ = ()
149    _function = 'MOD'
150
151
152class Pi(Function):
153    __slots__ = ()
154    _function = 'PI'
155
156
157class Power(Function):
158    __slots__ = ()
159    _function = 'POWER'
160
161
162class Radians(Function):
163    __slots__ = ()
164    _function = 'RADIANS'
165
166
167class Random(Function):
168    __slots__ = ()
169    _function = 'RANDOM'
170
171
172class Round(Function):
173    __slots__ = ()
174    _function = 'ROUND'
175
176
177class SetSeed(Function):
178    __slots__ = ()
179    _function = 'SETSEED'
180
181
182class Sign(Function):
183    __slots__ = ()
184    _function = 'SIGN'
185
186
187class Sqrt(Function):
188    __slots__ = ()
189    _function = 'SQRT'
190
191
192class Trunc(Function):
193    __slots__ = ()
194    _function = 'TRUNC'
195
196
197class WidthBucket(Function):
198    __slots__ = ()
199    _function = 'WIDTH_BUCKET'
200
201# Trigonometric
202
203
204class Acos(Function):
205    __slots__ = ()
206    _function = 'ACOS'
207
208
209class Asin(Function):
210    __slots__ = ()
211    _function = 'ASIN'
212
213
214class Atan(Function):
215    __slots__ = ()
216    _function = 'ATAN'
217
218
219class Atan2(Function):
220    __slots__ = ()
221    _function = 'ATAN2'
222
223
224class Cos(Function):
225    __slots__ = ()
226    _function = 'Cos'
227
228
229class Cot(Function):
230    __slots__ = ()
231    _function = 'COT'
232
233
234class Sin(Function):
235    __slots__ = ()
236    _function = 'SIN'
237
238
239class Tan(Function):
240    __slots__ = ()
241    _function = 'TAN'
242
243# String
244
245
246class BitLength(Function):
247    __slots__ = ()
248    _function = 'BIT_LENGTH'
249
250
251class CharLength(Function):
252    __slots__ = ()
253    _function = 'CHAR_LENGTH'
254
255
256class Lower(Function):
257    __slots__ = ()
258    _function = 'LOWER'
259
260
261class OctetLength(Function):
262    __slots__ = ()
263    _function = 'OCTET_LENGTH'
264
265
266class Overlay(FunctionKeyword):
267    __slots__ = ()
268    _function = 'OVERLAY'
269    _keywords = ('', 'PLACING', 'FROM', 'FOR')
270
271
272class Position(FunctionKeyword):
273    __slots__ = ()
274    _function = 'POSITION'
275    _keywords = ('', 'IN')
276
277
278class Substring(FunctionKeyword):
279    __slots__ = ()
280    _function = 'SUBSTRING'
281    _keywords = ('', 'FROM', 'FOR')
282
283
284class Trim(Function):
285    __slots__ = ('position', 'characters', 'string')
286    _function = 'TRIM'
287
288    def __init__(self, string, position='BOTH', characters=' '):
289        assert position.upper() in ('LEADING', 'TRAILING', 'BOTH')
290        self.position = position.upper()
291        self.characters = characters
292        self.string = string
293
294    def __str__(self):
295        flavor = Flavor.get()
296        Mapping = flavor.function_mapping.get(self.__class__)
297        if Mapping:
298            return str(Mapping(self.string, self.position, self.characters))
299        param = flavor.param
300
301        def format(arg):
302            if isinstance(arg, str):
303                return param
304            else:
305                return str(arg)
306        return self._function + '(%s %s FROM %s)' % (
307            self.position, format(self.characters), format(self.string))
308
309    @property
310    def params(self):
311        Mapping = Flavor.get().function_mapping.get(self.__class__)
312        if Mapping:
313            return Mapping(self.string, self.position, self.characters).params
314        p = []
315        for arg in (self.characters, self.string):
316            if isinstance(arg, str):
317                p.append(arg)
318            elif hasattr(arg, 'params'):
319                p.extend(arg.params)
320        return tuple(p)
321
322
323class Upper(Function):
324    __slots__ = ()
325    _function = 'UPPER'
326
327
328class ToChar(Function):
329    __slots__ = ()
330    _function = 'TO_CHAR'
331
332
333class ToDate(Function):
334    __slots__ = ()
335    _function = 'TO_DATE'
336
337
338class ToNumber(Function):
339    __slots__ = ()
340    _function = 'TO_NUMBER'
341
342
343class ToTimestamp(Function):
344    __slots__ = ()
345    _function = 'TO_TIMESTAMP'
346
347
348class Age(Function):
349    __slots__ = ()
350    _function = 'AGE'
351
352
353class ClockTimestamp(Function):
354    __slots__ = ()
355    _function = 'CLOCK_TIMESTAMP'
356
357
358class CurrentDate(FunctionNotCallable):
359    __slots__ = ()
360    _function = 'CURRENT_DATE'
361
362
363class CurrentTime(FunctionNotCallable):
364    __slots__ = ()
365    _function = 'CURRENT_TIME'
366
367
368class CurrentTimestamp(FunctionNotCallable):
369    __slots__ = ()
370    _function = 'CURRENT_TIMESTAMP'
371
372
373class DatePart(Function):
374    __slots__ = ()
375    _function = 'DATE_PART'
376
377
378class DateTrunc(Function):
379    __slots__ = ()
380    _function = 'DATE_TRUNC'
381
382
383class Extract(FunctionKeyword):
384    __slots__ = ()
385    _function = 'EXTRACT'
386    _keywords = ('', 'FROM')
387
388
389class Isfinite(Function):
390    __slots__ = ()
391    _function = 'ISFINITE'
392
393
394class JustifyDays(Function):
395    __slots__ = ()
396    _function = 'JUSTIFY_DAYS'
397
398
399class JustifyHours(Function):
400    __slots__ = ()
401    _function = 'JUSTIFY_HOURS'
402
403
404class JustifyInterval(Function):
405    __slots__ = ()
406    _function = 'JUSTIFY_INTERVAL'
407
408
409class Localtime(FunctionNotCallable):
410    __slots__ = ()
411    _function = 'LOCALTIME'
412
413
414class Localtimestamp(FunctionNotCallable):
415    __slots__ = ()
416    _function = 'LOCALTIMESTAMP'
417
418
419class Now(Function):
420    __slots__ = ()
421    _function = 'NOW'
422
423
424class StatementTimestamp(Function):
425    __slots__ = ()
426    _function = 'STATEMENT_TIMESTAMP'
427
428
429class Timeofday(Function):
430    __slots__ = ()
431    _function = 'TIMEOFDAY'
432
433
434class TransactionTimestamp(Function):
435    __slots__ = ()
436    _function = 'TRANSACTION_TIMESTAMP'
437
438
439class AtTimeZone(Function):
440    __slots__ = ('field', 'zone')
441
442    def __init__(self, field, zone):
443        self.field = field
444        self.zone = zone
445
446    def __str__(self):
447        flavor = Flavor.get()
448        Mapping = flavor.function_mapping.get(self.__class__)
449        if Mapping:
450            return str(Mapping(self.field, self.zone))
451        if isinstance(self.zone, Expression):
452            zone = str(self.zone)
453        elif isinstance(self.zone, (Select, CombiningQuery)):
454            zone = '(%s)' % self.zone
455        else:
456            zone = flavor.param
457        return '%s AT TIME ZONE %s' % (str(self.field), zone)
458
459    @property
460    def params(self):
461        Mapping = Flavor.get().function_mapping.get(self.__class__)
462        if Mapping:
463            return Mapping(self.field, self.zone).params
464        if isinstance(self.zone, (Expression, Select, CombiningQuery)):
465            return self.field.params + self.zone.params
466        else:
467            return self.field.params + (self.zone,)
468
469
470class WindowFunction(Function):
471    __slots__ = ('_filter', '_window')
472
473    def __init__(self, *args, **kwargs):
474        self.filter_ = kwargs.pop('filter_', None)
475        self.window = kwargs['window']
476        super(WindowFunction, self).__init__(*args, **kwargs)
477
478    @property
479    def filter_(self):
480        return self._filter
481
482    @filter_.setter
483    def filter_(self, value):
484        from sql.operators import And, Or
485        if value is not None:
486            assert isinstance(value, (Expression, And, Or))
487        self._filter = value
488
489    @property
490    def window(self):
491        return self._window
492
493    @window.setter
494    def window(self, value):
495        if value:
496            assert isinstance(value, Window)
497        self._window = value
498
499    def __str__(self):
500        function = super(WindowFunction, self).__str__()
501        filter_ = ''
502        if self.filter_:
503            filter_ = ' FILTER (WHERE %s)' % self.filter_
504        if self.window.has_alias:
505            over = ' OVER "%s"' % self.window.alias
506        else:
507            over = ' OVER (%s)' % self.window
508        return function + filter_ + over
509
510    @property
511    def params(self):
512        p = list(super(WindowFunction, self).params)
513        if self.filter_:
514            p.extend(self.filter_.params)
515        if not self.window.has_alias:
516            p.extend(self.window.params)
517        return tuple(p)
518
519
520class RowNumber(WindowFunction):
521    __slots__ = ()
522    _function = 'ROW_NUMBER'
523
524
525class Rank(WindowFunction):
526    __slots__ = ()
527    _function = 'RANK'
528
529
530class DenseRank(WindowFunction):
531    __slots__ = ()
532    _function = 'DENSE_RANK'
533
534
535class PercentRank(WindowFunction):
536    __slots__ = ()
537    _function = 'PERCENT_RANK'
538
539
540class CumeDist(WindowFunction):
541    __slots__ = ()
542    _function = 'CUME_DIST'
543
544
545class Ntile(WindowFunction):
546    __slots__ = ()
547    _function = 'NTILE'
548
549
550class Lag(WindowFunction):
551    __slots__ = ()
552    _function = 'LAG'
553
554
555class Lead(WindowFunction):
556    __slots__ = ()
557    _function = 'LEAD'
558
559
560class FirstValue(WindowFunction):
561    __slots__ = ()
562    _function = 'FIRST_VALUE'
563
564
565class LastValue(WindowFunction):
566    __slots__ = ()
567    _function = 'LAST_VALUE'
568
569
570class NthValue(WindowFunction):
571    __slots__ = ()
572    _function = 'NTH_VALUE'
573