1import re
2from collections import OrderedDict
3
4try:
5    from collections.abc import Iterable
6except ImportError:
7    from collections import Iterable
8
9from functools import partial
10
11from graphql_relay import connection_from_list
12
13from ..types import Boolean, Enum, Int, Interface, List, NonNull, Scalar, String, Union
14from ..types.field import Field
15from ..types.objecttype import ObjectType, ObjectTypeOptions
16from ..utils.thenables import maybe_thenable
17from .node import is_node
18
19
20class PageInfo(ObjectType):
21    class Meta:
22        description = (
23            "The Relay compliant `PageInfo` type, containing data necessary to"
24            " paginate this connection."
25        )
26
27    has_next_page = Boolean(
28        required=True,
29        name="hasNextPage",
30        description="When paginating forwards, are there more items?",
31    )
32
33    has_previous_page = Boolean(
34        required=True,
35        name="hasPreviousPage",
36        description="When paginating backwards, are there more items?",
37    )
38
39    start_cursor = String(
40        name="startCursor",
41        description="When paginating backwards, the cursor to continue.",
42    )
43
44    end_cursor = String(
45        name="endCursor",
46        description="When paginating forwards, the cursor to continue.",
47    )
48
49
50class ConnectionOptions(ObjectTypeOptions):
51    node = None
52
53
54class Connection(ObjectType):
55    class Meta:
56        abstract = True
57
58    @classmethod
59    def __init_subclass_with_meta__(cls, node=None, name=None, **options):
60        _meta = ConnectionOptions(cls)
61        assert node, "You have to provide a node in {}.Meta".format(cls.__name__)
62        assert isinstance(node, NonNull) or issubclass(
63            node, (Scalar, Enum, ObjectType, Interface, Union, NonNull)
64        ), ('Received incompatible node "{}" for Connection {}.').format(
65            node, cls.__name__
66        )
67
68        base_name = re.sub("Connection$", "", name or cls.__name__) or node._meta.name
69        if not name:
70            name = "{}Connection".format(base_name)
71
72        edge_class = getattr(cls, "Edge", None)
73        _node = node
74
75        class EdgeBase(object):
76            node = Field(_node, description="The item at the end of the edge")
77            cursor = String(required=True, description="A cursor for use in pagination")
78
79        class EdgeMeta:
80            description = "A Relay edge containing a `{}` and its cursor.".format(
81                base_name
82            )
83
84        edge_name = "{}Edge".format(base_name)
85        if edge_class:
86            edge_bases = (edge_class, EdgeBase, ObjectType)
87        else:
88            edge_bases = (EdgeBase, ObjectType)
89
90        edge = type(edge_name, edge_bases, {"Meta": EdgeMeta})
91        cls.Edge = edge
92
93        options["name"] = name
94        _meta.node = node
95        _meta.fields = OrderedDict(
96            [
97                (
98                    "page_info",
99                    Field(
100                        PageInfo,
101                        name="pageInfo",
102                        required=True,
103                        description="Pagination data for this connection.",
104                    ),
105                ),
106                (
107                    "edges",
108                    Field(
109                        NonNull(List(edge)),
110                        description="Contains the nodes in this connection.",
111                    ),
112                ),
113            ]
114        )
115        return super(Connection, cls).__init_subclass_with_meta__(
116            _meta=_meta, **options
117        )
118
119
120class IterableConnectionField(Field):
121    def __init__(self, type, *args, **kwargs):
122        kwargs.setdefault("before", String())
123        kwargs.setdefault("after", String())
124        kwargs.setdefault("first", Int())
125        kwargs.setdefault("last", Int())
126        super(IterableConnectionField, self).__init__(type, *args, **kwargs)
127
128    @property
129    def type(self):
130        type = super(IterableConnectionField, self).type
131        connection_type = type
132        if isinstance(type, NonNull):
133            connection_type = type.of_type
134
135        if is_node(connection_type):
136            raise Exception(
137                "ConnectionFields now need a explicit ConnectionType for Nodes.\n"
138                "Read more: https://github.com/graphql-python/graphene/blob/v2.0.0/UPGRADE-v2.0.md#node-connections"
139            )
140
141        assert issubclass(connection_type, Connection), (
142            '{} type have to be a subclass of Connection. Received "{}".'
143        ).format(self.__class__.__name__, connection_type)
144        return type
145
146    @classmethod
147    def resolve_connection(cls, connection_type, args, resolved):
148        if isinstance(resolved, connection_type):
149            return resolved
150
151        assert isinstance(resolved, Iterable), (
152            "Resolved value from the connection field have to be iterable or instance of {}. "
153            'Received "{}"'
154        ).format(connection_type, resolved)
155        connection = connection_from_list(
156            resolved,
157            args,
158            connection_type=connection_type,
159            edge_type=connection_type.Edge,
160            pageinfo_type=PageInfo,
161        )
162        connection.iterable = resolved
163        return connection
164
165    @classmethod
166    def connection_resolver(cls, resolver, connection_type, root, info, **args):
167        resolved = resolver(root, info, **args)
168
169        if isinstance(connection_type, NonNull):
170            connection_type = connection_type.of_type
171
172        on_resolve = partial(cls.resolve_connection, connection_type, args)
173        return maybe_thenable(resolved, on_resolve)
174
175    def get_resolver(self, parent_resolver):
176        resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)
177        return partial(self.connection_resolver, resolver, self.type)
178
179
180ConnectionField = IterableConnectionField
181