1from enum import EnumMeta
2
3from singledispatch import singledispatch
4from sqlalchemy import types
5from sqlalchemy.dialects import postgresql
6from sqlalchemy.orm import interfaces, strategies
7
8from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
9                      String)
10from graphene.types.json import JSONString
11
12from .batching import get_batch_resolver
13from .enums import enum_for_sa_enum
14from .fields import (BatchSQLAlchemyConnectionField,
15                     default_connection_field_factory)
16from .registry import get_global_registry
17from .resolvers import get_attr_resolver, get_custom_resolver
18
19try:
20    from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
21except ImportError:
22    ChoiceType = JSONType = ScalarListType = TSVectorType = object
23
24
25is_selectin_available = getattr(strategies, 'SelectInLoader', None)
26
27
28def get_column_doc(column):
29    return getattr(column, "doc", None)
30
31
32def is_column_nullable(column):
33    return bool(getattr(column, "nullable", True))
34
35
36def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
37                                    orm_field_name, **field_kwargs):
38    """
39    :param sqlalchemy.RelationshipProperty relationship_prop:
40    :param SQLAlchemyObjectType obj_type:
41    :param function|None connection_field_factory:
42    :param bool batching:
43    :param str orm_field_name:
44    :param dict field_kwargs:
45    :rtype: Dynamic
46    """
47    def dynamic_type():
48        """:rtype: Field|None"""
49        direction = relationship_prop.direction
50        child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
51        batching_ = batching if is_selectin_available else False
52
53        if not child_type:
54            return None
55
56        if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
57            return _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching_, orm_field_name,
58                                                    **field_kwargs)
59
60        if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
61            return _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching_,
62                                                    connection_field_factory, **field_kwargs)
63
64    return Dynamic(dynamic_type)
65
66
67def _convert_o2o_or_m2o_relationship(relationship_prop, obj_type, batching, orm_field_name, **field_kwargs):
68    """
69    Convert one-to-one or many-to-one relationshsip. Return an object field.
70
71    :param sqlalchemy.RelationshipProperty relationship_prop:
72    :param SQLAlchemyObjectType obj_type:
73    :param bool batching:
74    :param str orm_field_name:
75    :param dict field_kwargs:
76    :rtype: Field
77    """
78    child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
79
80    resolver = get_custom_resolver(obj_type, orm_field_name)
81    if resolver is None:
82        resolver = get_batch_resolver(relationship_prop) if batching else \
83            get_attr_resolver(obj_type, relationship_prop.key)
84
85    return Field(child_type, resolver=resolver, **field_kwargs)
86
87
88def _convert_o2m_or_m2m_relationship(relationship_prop, obj_type, batching, connection_field_factory, **field_kwargs):
89    """
90    Convert one-to-many or many-to-many relationshsip. Return a list field or a connection field.
91
92    :param sqlalchemy.RelationshipProperty relationship_prop:
93    :param SQLAlchemyObjectType obj_type:
94    :param bool batching:
95    :param function|None connection_field_factory:
96    :param dict field_kwargs:
97    :rtype: Field
98    """
99    child_type = obj_type._meta.registry.get_type_for_model(relationship_prop.mapper.entity)
100
101    if not child_type._meta.connection:
102        return Field(List(child_type), **field_kwargs)
103
104    # TODO Allow override of connection_field_factory and resolver via ORMField
105    if connection_field_factory is None:
106        connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship if batching else \
107            default_connection_field_factory
108
109    return connection_field_factory(relationship_prop, obj_type._meta.registry, **field_kwargs)
110
111
112def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
113    if 'type' not in field_kwargs:
114        # TODO The default type should be dependent on the type of the property propety.
115        field_kwargs['type'] = String
116
117    return Field(
118        resolver=resolver,
119        **field_kwargs
120    )
121
122
123def convert_sqlalchemy_composite(composite_prop, registry, resolver):
124    converter = registry.get_converter_for_composite(composite_prop.composite_class)
125    if not converter:
126        try:
127            raise Exception(
128                "Don't know how to convert the composite field %s (%s)"
129                % (composite_prop, composite_prop.composite_class)
130            )
131        except AttributeError:
132            # handle fields that are not attached to a class yet (don't have a parent)
133            raise Exception(
134                "Don't know how to convert the composite field %r (%s)"
135                % (composite_prop, composite_prop.composite_class)
136            )
137
138    # TODO Add a way to override composite fields default parameters
139    return converter(composite_prop, registry)
140
141
142def _register_composite_class(cls, registry=None):
143    if registry is None:
144        from .registry import get_global_registry
145
146        registry = get_global_registry()
147
148    def inner(fn):
149        registry.register_composite_converter(cls, fn)
150
151    return inner
152
153
154convert_sqlalchemy_composite.register = _register_composite_class
155
156
157def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
158    column = column_prop.columns[0]
159    field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
160    field_kwargs.setdefault('required', not is_column_nullable(column))
161    field_kwargs.setdefault('description', get_column_doc(column))
162
163    return Field(
164        resolver=resolver,
165        **field_kwargs
166    )
167
168
169@singledispatch
170def convert_sqlalchemy_type(type, column, registry=None):
171    raise Exception(
172        "Don't know how to convert the SQLAlchemy field %s (%s)"
173        % (column, column.__class__)
174    )
175
176
177@convert_sqlalchemy_type.register(types.Date)
178@convert_sqlalchemy_type.register(types.Time)
179@convert_sqlalchemy_type.register(types.String)
180@convert_sqlalchemy_type.register(types.Text)
181@convert_sqlalchemy_type.register(types.Unicode)
182@convert_sqlalchemy_type.register(types.UnicodeText)
183@convert_sqlalchemy_type.register(postgresql.UUID)
184@convert_sqlalchemy_type.register(postgresql.INET)
185@convert_sqlalchemy_type.register(postgresql.CIDR)
186@convert_sqlalchemy_type.register(TSVectorType)
187def convert_column_to_string(type, column, registry=None):
188    return String
189
190
191@convert_sqlalchemy_type.register(types.DateTime)
192def convert_column_to_datetime(type, column, registry=None):
193    from graphene.types.datetime import DateTime
194    return DateTime
195
196
197@convert_sqlalchemy_type.register(types.SmallInteger)
198@convert_sqlalchemy_type.register(types.Integer)
199def convert_column_to_int_or_id(type, column, registry=None):
200    return ID if column.primary_key else Int
201
202
203@convert_sqlalchemy_type.register(types.Boolean)
204def convert_column_to_boolean(type, column, registry=None):
205    return Boolean
206
207
208@convert_sqlalchemy_type.register(types.Float)
209@convert_sqlalchemy_type.register(types.Numeric)
210@convert_sqlalchemy_type.register(types.BigInteger)
211def convert_column_to_float(type, column, registry=None):
212    return Float
213
214
215@convert_sqlalchemy_type.register(types.Enum)
216def convert_enum_to_enum(type, column, registry=None):
217    return lambda: enum_for_sa_enum(type, registry or get_global_registry())
218
219
220# TODO Make ChoiceType conversion consistent with other enums
221@convert_sqlalchemy_type.register(ChoiceType)
222def convert_choice_to_enum(type, column, registry=None):
223    name = "{}_{}".format(column.table.name, column.name).upper()
224    if isinstance(type.choices, EnumMeta):
225        # type.choices may be Enum/IntEnum, in ChoiceType both presented as EnumMeta
226        # do not use from_enum here because we can have more than one enum column in table
227        return Enum(name, list((v.name, v.value) for v in type.choices))
228    else:
229        return Enum(name, type.choices)
230
231
232@convert_sqlalchemy_type.register(ScalarListType)
233def convert_scalar_list_to_list(type, column, registry=None):
234    return List(String)
235
236
237@convert_sqlalchemy_type.register(types.ARRAY)
238@convert_sqlalchemy_type.register(postgresql.ARRAY)
239def convert_array_to_list(_type, column, registry=None):
240    inner_type = convert_sqlalchemy_type(column.type.item_type, column)
241    return List(inner_type)
242
243
244@convert_sqlalchemy_type.register(postgresql.HSTORE)
245@convert_sqlalchemy_type.register(postgresql.JSON)
246@convert_sqlalchemy_type.register(postgresql.JSONB)
247def convert_json_to_string(type, column, registry=None):
248    return JSONString
249
250
251@convert_sqlalchemy_type.register(JSONType)
252def convert_json_type_to_string(type, column, registry=None):
253    return JSONString
254