1from functools import partial
2
3import six
4from django.db.models.query import QuerySet
5from graphql_relay.connection.arrayconnection import (
6    connection_from_list_slice,
7    cursor_to_offset,
8    get_offset_with_default,
9    offset_to_cursor,
10)
11from promise import Promise
12
13from graphene import Int, NonNull
14from graphene.relay import ConnectionField, PageInfo
15from graphene.types import Field, List
16
17from .settings import graphene_settings
18from .utils import maybe_queryset
19
20
21class DjangoListField(Field):
22    def __init__(self, _type, *args, **kwargs):
23        from .types import DjangoObjectType
24
25        if isinstance(_type, NonNull):
26            _type = _type.of_type
27
28        # Django would never return a Set of None  vvvvvvv
29        super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs)
30
31        assert issubclass(
32            self._underlying_type, DjangoObjectType
33        ), "DjangoListField only accepts DjangoObjectType types"
34
35    @property
36    def _underlying_type(self):
37        _type = self._type
38        while hasattr(_type, "of_type"):
39            _type = _type.of_type
40        return _type
41
42    @property
43    def model(self):
44        return self._underlying_type._meta.model
45
46    def get_manager(self):
47        return self.model._default_manager
48
49    @staticmethod
50    def list_resolver(
51        django_object_type, resolver, default_manager, root, info, **args
52    ):
53        queryset = maybe_queryset(resolver(root, info, **args))
54        if queryset is None:
55            queryset = maybe_queryset(default_manager)
56
57        if isinstance(queryset, QuerySet):
58            # Pass queryset to the DjangoObjectType get_queryset method
59            queryset = maybe_queryset(django_object_type.get_queryset(queryset, info))
60
61        return queryset
62
63    def get_resolver(self, parent_resolver):
64        _type = self.type
65        if isinstance(_type, NonNull):
66            _type = _type.of_type
67        django_object_type = _type.of_type.of_type
68        return partial(
69            self.list_resolver, django_object_type, parent_resolver, self.get_manager(),
70        )
71
72
73class DjangoConnectionField(ConnectionField):
74    def __init__(self, *args, **kwargs):
75        self.on = kwargs.pop("on", False)
76        self.max_limit = kwargs.pop(
77            "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT
78        )
79        self.enforce_first_or_last = kwargs.pop(
80            "enforce_first_or_last",
81            graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST,
82        )
83        kwargs.setdefault("offset", Int())
84        super(DjangoConnectionField, self).__init__(*args, **kwargs)
85
86    @property
87    def type(self):
88        from .types import DjangoObjectType
89
90        _type = super(ConnectionField, self).type
91        non_null = False
92        if isinstance(_type, NonNull):
93            _type = _type.of_type
94            non_null = True
95        assert issubclass(
96            _type, DjangoObjectType
97        ), "DjangoConnectionField only accepts DjangoObjectType types"
98        assert _type._meta.connection, "The type {} doesn't have a connection".format(
99            _type.__name__
100        )
101        connection_type = _type._meta.connection
102        if non_null:
103            return NonNull(connection_type)
104        return connection_type
105
106    @property
107    def connection_type(self):
108        type = self.type
109        if isinstance(type, NonNull):
110            return type.of_type
111        return type
112
113    @property
114    def node_type(self):
115        return self.connection_type._meta.node
116
117    @property
118    def model(self):
119        return self.node_type._meta.model
120
121    def get_manager(self):
122        if self.on:
123            return getattr(self.model, self.on)
124        else:
125            return self.model._default_manager
126
127    @classmethod
128    def resolve_queryset(cls, connection, queryset, info, args):
129        # queryset is the resolved iterable from ObjectType
130        return connection._meta.node.get_queryset(queryset, info)
131
132    @classmethod
133    def resolve_connection(cls, connection, args, iterable, max_limit=None):
134        # Remove the offset parameter and convert it to an after cursor.
135        offset = args.pop("offset", None)
136        after = args.get("after")
137        if offset:
138            if after:
139                offset += cursor_to_offset(after) + 1
140            # input offset starts at 1 while the graphene offset starts at 0
141            args["after"] = offset_to_cursor(offset - 1)
142
143        iterable = maybe_queryset(iterable)
144
145        if isinstance(iterable, QuerySet):
146            list_length = iterable.count()
147        else:
148            list_length = len(iterable)
149        list_slice_length = (
150            min(max_limit, list_length) if max_limit is not None else list_length
151        )
152
153        # If after is higher than list_length, connection_from_list_slice
154        # would try to do a negative slicing which makes django throw an
155        # AssertionError
156        after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length)
157
158        if max_limit is not None and "first" not in args:
159            if "last" in args:
160                args["first"] = list_length
161                list_slice_length = list_length
162            else:
163                args["first"] = max_limit
164
165        connection = connection_from_list_slice(
166            iterable[after:],
167            args,
168            slice_start=after,
169            list_length=list_length,
170            list_slice_length=list_slice_length,
171            connection_type=connection,
172            edge_type=connection.Edge,
173            pageinfo_type=PageInfo,
174        )
175        connection.iterable = iterable
176        connection.length = list_length
177        return connection
178
179    @classmethod
180    def connection_resolver(
181        cls,
182        resolver,
183        connection,
184        default_manager,
185        queryset_resolver,
186        max_limit,
187        enforce_first_or_last,
188        root,
189        info,
190        **args
191    ):
192        first = args.get("first")
193        last = args.get("last")
194        offset = args.get("offset")
195        before = args.get("before")
196
197        if enforce_first_or_last:
198            assert first or last, (
199                "You must provide a `first` or `last` value to properly paginate the `{}` connection."
200            ).format(info.field_name)
201
202        if max_limit:
203            if first:
204                assert first <= max_limit, (
205                    "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records."
206                ).format(first, info.field_name, max_limit)
207                args["first"] = min(first, max_limit)
208
209            if last:
210                assert last <= max_limit, (
211                    "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records."
212                ).format(last, info.field_name, max_limit)
213                args["last"] = min(last, max_limit)
214
215        if offset is not None:
216            assert before is None, (
217                "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `{}` connection."
218            ).format(info.field_name)
219
220        # eventually leads to DjangoObjectType's get_queryset (accepts queryset)
221        # or a resolve_foo (does not accept queryset)
222        iterable = resolver(root, info, **args)
223        if iterable is None:
224            iterable = default_manager
225        # thus the iterable gets refiltered by resolve_queryset
226        # but iterable might be promise
227        iterable = queryset_resolver(connection, iterable, info, args)
228        on_resolve = partial(
229            cls.resolve_connection, connection, args, max_limit=max_limit
230        )
231
232        if Promise.is_thenable(iterable):
233            return Promise.resolve(iterable).then(on_resolve)
234
235        return on_resolve(iterable)
236
237    def get_resolver(self, parent_resolver):
238        return partial(
239            self.connection_resolver,
240            parent_resolver,
241            self.connection_type,
242            self.get_manager(),
243            self.get_queryset_resolver(),
244            self.max_limit,
245            self.enforce_first_or_last,
246        )
247
248    def get_queryset_resolver(self):
249        return self.resolve_queryset
250