1from functools import partial 2 3import six 4from django.db.models.query import QuerySet 5from graphql_relay.connection.arrayconnection import ( 6 connection_from_list_slice, 7 cursor_to_offset, 8 get_offset_with_default, 9 offset_to_cursor, 10) 11from promise import Promise 12 13from graphene import Int, NonNull 14from graphene.relay import ConnectionField, PageInfo 15from graphene.types import Field, List 16 17from .settings import graphene_settings 18from .utils import maybe_queryset 19 20 21class DjangoListField(Field): 22 def __init__(self, _type, *args, **kwargs): 23 from .types import DjangoObjectType 24 25 if isinstance(_type, NonNull): 26 _type = _type.of_type 27 28 # Django would never return a Set of None vvvvvvv 29 super(DjangoListField, self).__init__(List(NonNull(_type)), *args, **kwargs) 30 31 assert issubclass( 32 self._underlying_type, DjangoObjectType 33 ), "DjangoListField only accepts DjangoObjectType types" 34 35 @property 36 def _underlying_type(self): 37 _type = self._type 38 while hasattr(_type, "of_type"): 39 _type = _type.of_type 40 return _type 41 42 @property 43 def model(self): 44 return self._underlying_type._meta.model 45 46 def get_manager(self): 47 return self.model._default_manager 48 49 @staticmethod 50 def list_resolver( 51 django_object_type, resolver, default_manager, root, info, **args 52 ): 53 queryset = maybe_queryset(resolver(root, info, **args)) 54 if queryset is None: 55 queryset = maybe_queryset(default_manager) 56 57 if isinstance(queryset, QuerySet): 58 # Pass queryset to the DjangoObjectType get_queryset method 59 queryset = maybe_queryset(django_object_type.get_queryset(queryset, info)) 60 61 return queryset 62 63 def get_resolver(self, parent_resolver): 64 _type = self.type 65 if isinstance(_type, NonNull): 66 _type = _type.of_type 67 django_object_type = _type.of_type.of_type 68 return partial( 69 self.list_resolver, django_object_type, parent_resolver, self.get_manager(), 70 ) 71 72 73class DjangoConnectionField(ConnectionField): 74 def __init__(self, *args, **kwargs): 75 self.on = kwargs.pop("on", False) 76 self.max_limit = kwargs.pop( 77 "max_limit", graphene_settings.RELAY_CONNECTION_MAX_LIMIT 78 ) 79 self.enforce_first_or_last = kwargs.pop( 80 "enforce_first_or_last", 81 graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST, 82 ) 83 kwargs.setdefault("offset", Int()) 84 super(DjangoConnectionField, self).__init__(*args, **kwargs) 85 86 @property 87 def type(self): 88 from .types import DjangoObjectType 89 90 _type = super(ConnectionField, self).type 91 non_null = False 92 if isinstance(_type, NonNull): 93 _type = _type.of_type 94 non_null = True 95 assert issubclass( 96 _type, DjangoObjectType 97 ), "DjangoConnectionField only accepts DjangoObjectType types" 98 assert _type._meta.connection, "The type {} doesn't have a connection".format( 99 _type.__name__ 100 ) 101 connection_type = _type._meta.connection 102 if non_null: 103 return NonNull(connection_type) 104 return connection_type 105 106 @property 107 def connection_type(self): 108 type = self.type 109 if isinstance(type, NonNull): 110 return type.of_type 111 return type 112 113 @property 114 def node_type(self): 115 return self.connection_type._meta.node 116 117 @property 118 def model(self): 119 return self.node_type._meta.model 120 121 def get_manager(self): 122 if self.on: 123 return getattr(self.model, self.on) 124 else: 125 return self.model._default_manager 126 127 @classmethod 128 def resolve_queryset(cls, connection, queryset, info, args): 129 # queryset is the resolved iterable from ObjectType 130 return connection._meta.node.get_queryset(queryset, info) 131 132 @classmethod 133 def resolve_connection(cls, connection, args, iterable, max_limit=None): 134 # Remove the offset parameter and convert it to an after cursor. 135 offset = args.pop("offset", None) 136 after = args.get("after") 137 if offset: 138 if after: 139 offset += cursor_to_offset(after) + 1 140 # input offset starts at 1 while the graphene offset starts at 0 141 args["after"] = offset_to_cursor(offset - 1) 142 143 iterable = maybe_queryset(iterable) 144 145 if isinstance(iterable, QuerySet): 146 list_length = iterable.count() 147 else: 148 list_length = len(iterable) 149 list_slice_length = ( 150 min(max_limit, list_length) if max_limit is not None else list_length 151 ) 152 153 # If after is higher than list_length, connection_from_list_slice 154 # would try to do a negative slicing which makes django throw an 155 # AssertionError 156 after = min(get_offset_with_default(args.get("after"), -1) + 1, list_length) 157 158 if max_limit is not None and "first" not in args: 159 if "last" in args: 160 args["first"] = list_length 161 list_slice_length = list_length 162 else: 163 args["first"] = max_limit 164 165 connection = connection_from_list_slice( 166 iterable[after:], 167 args, 168 slice_start=after, 169 list_length=list_length, 170 list_slice_length=list_slice_length, 171 connection_type=connection, 172 edge_type=connection.Edge, 173 pageinfo_type=PageInfo, 174 ) 175 connection.iterable = iterable 176 connection.length = list_length 177 return connection 178 179 @classmethod 180 def connection_resolver( 181 cls, 182 resolver, 183 connection, 184 default_manager, 185 queryset_resolver, 186 max_limit, 187 enforce_first_or_last, 188 root, 189 info, 190 **args 191 ): 192 first = args.get("first") 193 last = args.get("last") 194 offset = args.get("offset") 195 before = args.get("before") 196 197 if enforce_first_or_last: 198 assert first or last, ( 199 "You must provide a `first` or `last` value to properly paginate the `{}` connection." 200 ).format(info.field_name) 201 202 if max_limit: 203 if first: 204 assert first <= max_limit, ( 205 "Requesting {} records on the `{}` connection exceeds the `first` limit of {} records." 206 ).format(first, info.field_name, max_limit) 207 args["first"] = min(first, max_limit) 208 209 if last: 210 assert last <= max_limit, ( 211 "Requesting {} records on the `{}` connection exceeds the `last` limit of {} records." 212 ).format(last, info.field_name, max_limit) 213 args["last"] = min(last, max_limit) 214 215 if offset is not None: 216 assert before is None, ( 217 "You can't provide a `before` value at the same time as an `offset` value to properly paginate the `{}` connection." 218 ).format(info.field_name) 219 220 # eventually leads to DjangoObjectType's get_queryset (accepts queryset) 221 # or a resolve_foo (does not accept queryset) 222 iterable = resolver(root, info, **args) 223 if iterable is None: 224 iterable = default_manager 225 # thus the iterable gets refiltered by resolve_queryset 226 # but iterable might be promise 227 iterable = queryset_resolver(connection, iterable, info, args) 228 on_resolve = partial( 229 cls.resolve_connection, connection, args, max_limit=max_limit 230 ) 231 232 if Promise.is_thenable(iterable): 233 return Promise.resolve(iterable).then(on_resolve) 234 235 return on_resolve(iterable) 236 237 def get_resolver(self, parent_resolver): 238 return partial( 239 self.connection_resolver, 240 parent_resolver, 241 self.connection_type, 242 self.get_manager(), 243 self.get_queryset_resolver(), 244 self.max_limit, 245 self.enforce_first_or_last, 246 ) 247 248 def get_queryset_resolver(self): 249 return self.resolve_queryset 250