1import sqlalchemy as sa
2from sqlalchemy.dialects import postgresql
3from sqlalchemy.ext.compiler import compiles
4from sqlalchemy.sql.expression import ColumnElement, FunctionElement
5from sqlalchemy.sql.functions import GenericFunction
6
7from .functions.orm import quote
8
9
10class array_get(FunctionElement):
11    name = 'array_get'
12
13
14@compiles(array_get)
15def compile_array_get(element, compiler, **kw):
16    args = list(element.clauses)
17    if len(args) != 2:
18        raise Exception(
19            "Function 'array_get' expects two arguments (%d given)." %
20            len(args)
21        )
22
23    if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
24        raise Exception(
25            "Second argument should be an integer."
26        )
27    return '(%s)[%s]' % (
28        compiler.process(args[0]),
29        sa.text(str(args[1].value + 1))
30    )
31
32
33class row_to_json(GenericFunction):
34    name = 'row_to_json'
35    type = postgresql.JSON
36
37
38@compiles(row_to_json, 'postgresql')
39def compile_row_to_json(element, compiler, **kw):
40    return "%s(%s)" % (element.name, compiler.process(element.clauses))
41
42
43class json_array_length(GenericFunction):
44    name = 'json_array_length'
45    type = sa.Integer
46
47
48@compiles(json_array_length, 'postgresql')
49def compile_json_array_length(element, compiler, **kw):
50    return "%s(%s)" % (element.name, compiler.process(element.clauses))
51
52
53class Asterisk(ColumnElement):
54    def __init__(self, selectable):
55        self.selectable = selectable
56
57
58@compiles(Asterisk)
59def compile_asterisk(element, compiler, **kw):
60    return '%s.*' % quote(compiler.dialect, element.selectable.name)
61