1from aioredis.util import wait_convert, _NOTSET, _ScanIter
2
3
4class SortedSetCommandsMixin:
5    """Sorted Sets commands mixin.
6
7    For commands details see: http://redis.io/commands/#sorted_set
8    """
9
10    ZSET_EXCLUDE_MIN = 'ZSET_EXCLUDE_MIN'
11    ZSET_EXCLUDE_MAX = 'ZSET_EXCLUDE_MAX'
12    ZSET_EXCLUDE_BOTH = 'ZSET_EXCLUDE_BOTH'
13
14    ZSET_AGGREGATE_SUM = 'ZSET_AGGREGATE_SUM'
15    ZSET_AGGREGATE_MIN = 'ZSET_AGGREGATE_MIN'
16    ZSET_AGGREGATE_MAX = 'ZSET_AGGREGATE_MAX'
17
18    ZSET_IF_NOT_EXIST = 'ZSET_IF_NOT_EXIST'  # NX
19    ZSET_IF_EXIST = 'ZSET_IF_EXIST'  # XX
20
21    def bzpopmax(self, key, *keys, timeout=0, encoding=_NOTSET):
22        """Remove and get an element with the highest score in the sorted set,
23        or block until one is available.
24
25        :raises TypeError: if timeout is not int
26        :raises ValueError: if timeout is less than 0
27        """
28        if not isinstance(timeout, int):
29            raise TypeError("timeout argument must be int")
30        if timeout < 0:
31            raise ValueError("timeout must be greater equal 0")
32        args = keys + (timeout,)
33        return self.execute(b'BZPOPMAX', key, *args, encoding=encoding)
34
35    def bzpopmin(self, key, *keys, timeout=0, encoding=_NOTSET):
36        """Remove and get an element with the lowest score in the sorted set,
37        or block until one is available.
38
39        :raises TypeError: if timeout is not int
40        :raises ValueError: if timeout is less than 0
41        """
42        if not isinstance(timeout, int):
43            raise TypeError("timeout argument must be int")
44        if timeout < 0:
45            raise ValueError("timeout must be greater equal 0")
46        args = keys + (timeout,)
47        return self.execute(b'BZPOPMIN', key, *args, encoding=encoding)
48
49    def zadd(self, key, score, member, *pairs, exist=None, changed=False,
50             incr=False):
51        """Add one or more members to a sorted set or update its score.
52
53        :raises TypeError: score not int or float
54        :raises TypeError: length of pairs is not even number
55        """
56        if not isinstance(score, (int, float)):
57            raise TypeError("score argument must be int or float")
58        if len(pairs) % 2 != 0:
59            raise TypeError("length of pairs must be even number")
60
61        scores = (item for i, item in enumerate(pairs) if i % 2 == 0)
62        if any(not isinstance(s, (int, float)) for s in scores):
63            raise TypeError("all scores must be int or float")
64
65        args = []
66        if exist is self.ZSET_IF_EXIST:
67            args.append(b'XX')
68        elif exist is self.ZSET_IF_NOT_EXIST:
69            args.append(b'NX')
70
71        if changed:
72            args.append(b'CH')
73
74        if incr:
75            if pairs:
76                raise ValueError('only one score-element pair '
77                                 'can be specified in this mode')
78            args.append(b'INCR')
79
80        args.extend([score, member])
81        if pairs:
82            args.extend(pairs)
83        return self.execute(b'ZADD', key, *args)
84
85    def zcard(self, key):
86        """Get the number of members in a sorted set."""
87        return self.execute(b'ZCARD', key)
88
89    def zcount(self, key, min=float('-inf'), max=float('inf'),
90               *, exclude=None):
91        """Count the members in a sorted set with scores
92        within the given values.
93
94        :raises TypeError: min or max is not float or int
95        :raises ValueError: if min greater than max
96        """
97        if not isinstance(min, (int, float)):
98            raise TypeError("min argument must be int or float")
99        if not isinstance(max, (int, float)):
100            raise TypeError("max argument must be int or float")
101        if min > max:
102            raise ValueError("min could not be greater than max")
103        return self.execute(b'ZCOUNT', key,
104                            *_encode_min_max(exclude, min, max))
105
106    def zincrby(self, key, increment, member):
107        """Increment the score of a member in a sorted set.
108
109        :raises TypeError: increment is not float or int
110        """
111        if not isinstance(increment, (int, float)):
112            raise TypeError("increment argument must be int or float")
113        fut = self.execute(b'ZINCRBY', key, increment, member)
114        return wait_convert(fut, int_or_float)
115
116    def zinterstore(self, destkey, key, *keys,
117                    with_weights=False, aggregate=None):
118        """Intersect multiple sorted sets and store result in a new key.
119
120        :param bool with_weights: when set to true each key must be a tuple
121                                  in form of (key, weight)
122        """
123        keys = (key,) + keys
124        numkeys = len(keys)
125        args = []
126        if with_weights:
127            assert all(isinstance(val, (list, tuple)) for val in keys), (
128                "All key arguments must be (key, weight) tuples")
129            weights = ['WEIGHTS']
130            for key, weight in keys:
131                args.append(key)
132                weights.append(weight)
133            args.extend(weights)
134        else:
135            args.extend(keys)
136
137        if aggregate is self.ZSET_AGGREGATE_SUM:
138            args.extend(('AGGREGATE', 'SUM'))
139        elif aggregate is self.ZSET_AGGREGATE_MAX:
140            args.extend(('AGGREGATE', 'MAX'))
141        elif aggregate is self.ZSET_AGGREGATE_MIN:
142            args.extend(('AGGREGATE', 'MIN'))
143        fut = self.execute(b'ZINTERSTORE', destkey, numkeys, *args)
144        return fut
145
146    def zlexcount(self, key, min=b'-', max=b'+', include_min=True,
147                  include_max=True):
148        """Count the number of members in a sorted set between a given
149        lexicographical range.
150
151        :raises TypeError: if min is not bytes
152        :raises TypeError: if max is not bytes
153        """
154        if not isinstance(min, bytes):  # FIXME
155            raise TypeError("min argument must be bytes")
156        if not isinstance(max, bytes):  # FIXME     Why only bytes?
157            raise TypeError("max argument must be bytes")
158        if not min == b'-':
159            min = (b'[' if include_min else b'(') + min
160        if not max == b'+':
161            max = (b'[' if include_max else b'(') + max
162        return self.execute(b'ZLEXCOUNT', key, min, max)
163
164    def zrange(self, key, start=0, stop=-1, withscores=False,
165               encoding=_NOTSET):
166        """Return a range of members in a sorted set, by index.
167
168        :raises TypeError: if start is not int
169        :raises TypeError: if stop is not int
170        """
171        if not isinstance(start, int):
172            raise TypeError("start argument must be int")
173        if not isinstance(stop, int):
174            raise TypeError("stop argument must be int")
175        if withscores:
176            args = [b'WITHSCORES']
177        else:
178            args = []
179        fut = self.execute(b'ZRANGE', key, start, stop, *args,
180                           encoding=encoding)
181        if withscores:
182            return wait_convert(fut, pairs_int_or_float)
183        return fut
184
185    def zrangebylex(self, key, min=b'-', max=b'+', include_min=True,
186                    include_max=True, offset=None, count=None,
187                    encoding=_NOTSET):
188        """Return a range of members in a sorted set, by lexicographical range.
189
190        :raises TypeError: if min is not bytes
191        :raises TypeError: if max is not bytes
192        :raises TypeError: if both offset and count are not specified
193        :raises TypeError: if offset is not bytes
194        :raises TypeError: if count is not bytes
195        """
196        if not isinstance(min, bytes):  # FIXME
197            raise TypeError("min argument must be bytes")
198        if not isinstance(max, bytes):  # FIXME
199            raise TypeError("max argument must be bytes")
200        if not min == b'-':
201            min = (b'[' if include_min else b'(') + min
202        if not max == b'+':
203            max = (b'[' if include_max else b'(') + max
204
205        if (offset is not None and count is None) or \
206                (count is not None and offset is None):
207            raise TypeError("offset and count must both be specified")
208        if offset is not None and not isinstance(offset, int):
209            raise TypeError("offset argument must be int")
210        if count is not None and not isinstance(count, int):
211            raise TypeError("count argument must be int")
212
213        args = []
214        if offset is not None and count is not None:
215            args.extend([b'LIMIT', offset, count])
216
217        return self.execute(b'ZRANGEBYLEX', key, min, max, *args,
218                            encoding=encoding)
219
220    def zrangebyscore(self, key, min=float('-inf'), max=float('inf'),
221                      withscores=False, offset=None, count=None,
222                      *, exclude=None, encoding=_NOTSET):
223        """Return a range of members in a sorted set, by score.
224
225        :raises TypeError: if min or max is not float or int
226        :raises TypeError: if both offset and count are not specified
227        :raises TypeError: if offset is not int
228        :raises TypeError: if count is not int
229        """
230        if not isinstance(min, (int, float)):
231            raise TypeError("min argument must be int or float")
232        if not isinstance(max, (int, float)):
233            raise TypeError("max argument must be int or float")
234
235        if (offset is not None and count is None) or \
236                (count is not None and offset is None):
237            raise TypeError("offset and count must both be specified")
238        if offset is not None and not isinstance(offset, int):
239            raise TypeError("offset argument must be int")
240        if count is not None and not isinstance(count, int):
241            raise TypeError("count argument must be int")
242
243        min, max = _encode_min_max(exclude, min, max)
244
245        args = []
246        if withscores:
247            args = [b'WITHSCORES']
248        if offset is not None and count is not None:
249            args.extend([b'LIMIT', offset, count])
250        fut = self.execute(b'ZRANGEBYSCORE', key, min, max, *args,
251                           encoding=encoding)
252        if withscores:
253            return wait_convert(fut, pairs_int_or_float)
254        return fut
255
256    def zrank(self, key, member):
257        """Determine the index of a member in a sorted set."""
258        return self.execute(b'ZRANK', key, member)
259
260    def zrem(self, key, member, *members):
261        """Remove one or more members from a sorted set."""
262        return self.execute(b'ZREM', key, member, *members)
263
264    def zremrangebylex(self, key, min=b'-', max=b'+',
265                       include_min=True, include_max=True):
266        """Remove all members in a sorted set between the given
267        lexicographical range.
268
269        :raises TypeError: if min is not bytes
270        :raises TypeError: if max is not bytes
271        """
272        if not isinstance(min, bytes):  # FIXME
273            raise TypeError("min argument must be bytes")
274        if not isinstance(max, bytes):  # FIXME
275            raise TypeError("max argument must be bytes")
276        if not min == b'-':
277            min = (b'[' if include_min else b'(') + min
278        if not max == b'+':
279            max = (b'[' if include_max else b'(') + max
280        return self.execute(b'ZREMRANGEBYLEX', key, min, max)
281
282    def zremrangebyrank(self, key, start, stop):
283        """Remove all members in a sorted set within the given indexes.
284
285        :raises TypeError: if start is not int
286        :raises TypeError: if stop is not int
287        """
288        if not isinstance(start, int):
289            raise TypeError("start argument must be int")
290        if not isinstance(stop, int):
291            raise TypeError("stop argument must be int")
292        return self.execute(b'ZREMRANGEBYRANK', key, start, stop)
293
294    def zremrangebyscore(self, key, min=float('-inf'), max=float('inf'),
295                         *, exclude=None):
296        """Remove all members in a sorted set within the given scores.
297
298        :raises TypeError: if min or max is not int or float
299        """
300        if not isinstance(min, (int, float)):
301            raise TypeError("min argument must be int or float")
302        if not isinstance(max, (int, float)):
303            raise TypeError("max argument must be int or float")
304
305        min, max = _encode_min_max(exclude, min, max)
306        return self.execute(b'ZREMRANGEBYSCORE', key, min, max)
307
308    def zrevrange(self, key, start, stop, withscores=False, encoding=_NOTSET):
309        """Return a range of members in a sorted set, by index,
310        with scores ordered from high to low.
311
312        :raises TypeError: if start or stop is not int
313        """
314        if not isinstance(start, int):
315            raise TypeError("start argument must be int")
316        if not isinstance(stop, int):
317            raise TypeError("stop argument must be int")
318        if withscores:
319            args = [b'WITHSCORES']
320        else:
321            args = []
322        fut = self.execute(b'ZREVRANGE', key, start, stop, *args,
323                           encoding=encoding)
324        if withscores:
325            return wait_convert(fut, pairs_int_or_float)
326        return fut
327
328    def zrevrangebyscore(self, key, max=float('inf'), min=float('-inf'),
329                         *, exclude=None, withscores=False,
330                         offset=None, count=None, encoding=_NOTSET):
331        """Return a range of members in a sorted set, by score,
332        with scores ordered from high to low.
333
334        :raises TypeError: if min or max is not float or int
335        :raises TypeError: if both offset and count are not specified
336        :raises TypeError: if offset is not int
337        :raises TypeError: if count is not int
338        """
339        if not isinstance(min, (int, float)):
340            raise TypeError("min argument must be int or float")
341        if not isinstance(max, (int, float)):
342            raise TypeError("max argument must be int or float")
343
344        if (offset is not None and count is None) or \
345                (count is not None and offset is None):
346            raise TypeError("offset and count must both be specified")
347        if offset is not None and not isinstance(offset, int):
348            raise TypeError("offset argument must be int")
349        if count is not None and not isinstance(count, int):
350            raise TypeError("count argument must be int")
351
352        min, max = _encode_min_max(exclude, min, max)
353
354        args = []
355        if withscores:
356            args = [b'WITHSCORES']
357        if offset is not None and count is not None:
358            args.extend([b'LIMIT', offset, count])
359        fut = self.execute(b'ZREVRANGEBYSCORE', key, max, min, *args,
360                           encoding=encoding)
361        if withscores:
362            return wait_convert(fut, pairs_int_or_float)
363        return fut
364
365    def zrevrangebylex(self, key, min=b'-', max=b'+', include_min=True,
366                       include_max=True, offset=None, count=None,
367                       encoding=_NOTSET):
368        """Return a range of members in a sorted set, by lexicographical range
369        from high to low.
370
371        :raises TypeError: if min is not bytes
372        :raises TypeError: if max is not bytes
373        :raises TypeError: if both offset and count are not specified
374        :raises TypeError: if offset is not bytes
375        :raises TypeError: if count is not bytes
376        """
377        if not isinstance(min, bytes):  # FIXME
378            raise TypeError("min argument must be bytes")
379        if not isinstance(max, bytes):  # FIXME
380            raise TypeError("max argument must be bytes")
381        if not min == b'-':
382            min = (b'[' if include_min else b'(') + min
383        if not max == b'+':
384            max = (b'[' if include_max else b'(') + max
385
386        if (offset is not None and count is None) or \
387                (count is not None and offset is None):
388            raise TypeError("offset and count must both be specified")
389        if offset is not None and not isinstance(offset, int):
390            raise TypeError("offset argument must be int")
391        if count is not None and not isinstance(count, int):
392            raise TypeError("count argument must be int")
393
394        args = []
395        if offset is not None and count is not None:
396            args.extend([b'LIMIT', offset, count])
397
398        return self.execute(b'ZREVRANGEBYLEX', key, max, min, *args,
399                            encoding=encoding)
400
401    def zrevrank(self, key, member):
402        """Determine the index of a member in a sorted set, with
403        scores ordered from high to low.
404        """
405        return self.execute(b'ZREVRANK', key, member)
406
407    def zscore(self, key, member):
408        """Get the score associated with the given member in a sorted set."""
409        fut = self.execute(b'ZSCORE', key, member)
410        return wait_convert(fut, optional_int_or_float)
411
412    def zunionstore(self, destkey, key, *keys,
413                    with_weights=False, aggregate=None):
414        """Add multiple sorted sets and store result in a new key."""
415        keys = (key,) + keys
416        numkeys = len(keys)
417        args = []
418        if with_weights:
419            assert all(isinstance(val, (list, tuple)) for val in keys), (
420                "All key arguments must be (key, weight) tuples")
421            weights = ['WEIGHTS']
422            for key, weight in keys:
423                args.append(key)
424                weights.append(weight)
425            args.extend(weights)
426        else:
427            args.extend(keys)
428
429        if aggregate is self.ZSET_AGGREGATE_SUM:
430            args.extend(('AGGREGATE', 'SUM'))
431        elif aggregate is self.ZSET_AGGREGATE_MAX:
432            args.extend(('AGGREGATE', 'MAX'))
433        elif aggregate is self.ZSET_AGGREGATE_MIN:
434            args.extend(('AGGREGATE', 'MIN'))
435        fut = self.execute(b'ZUNIONSTORE', destkey, numkeys, *args)
436        return fut
437
438    def zscan(self, key, cursor=0, match=None, count=None):
439        """Incrementally iterate sorted sets elements and associated scores."""
440        args = []
441        if match is not None:
442            args += [b'MATCH', match]
443        if count is not None:
444            args += [b'COUNT', count]
445        fut = self.execute(b'ZSCAN', key, cursor, *args)
446
447        def _converter(obj):
448            return (int(obj[0]), pairs_int_or_float(obj[1]))
449
450        return wait_convert(fut, _converter)
451
452    def izscan(self, key, *, match=None, count=None):
453        """Incrementally iterate sorted set items using async for.
454
455        Usage example:
456
457        >>> async for val, score in redis.izscan(key, match='something*'):
458        ...     print('Matched:', val, ':', score)
459
460        """
461        return _ScanIter(lambda cur: self.zscan(key, cur,
462                                                match=match,
463                                                count=count))
464
465    def zpopmin(self, key, count=None, *, encoding=_NOTSET):
466        """Removes and returns up to count members with the lowest scores
467        in the sorted set stored at key.
468
469        :raises TypeError: if count is not int
470        """
471        if count is not None and not isinstance(count, int):
472            raise TypeError("count argument must be int")
473
474        args = []
475        if count is not None:
476            args.extend([count])
477
478        fut = self.execute(b'ZPOPMIN', key, *args, encoding=encoding)
479        return fut
480
481    def zpopmax(self, key, count=None, *, encoding=_NOTSET):
482        """Removes and returns up to count members with the highest scores
483        in the sorted set stored at key.
484
485        :raises TypeError: if count is not int
486        """
487        if count is not None and not isinstance(count, int):
488            raise TypeError("count argument must be int")
489
490        args = []
491        if count is not None:
492            args.extend([count])
493
494        fut = self.execute(b'ZPOPMAX', key, *args, encoding=encoding)
495        return fut
496
497
498def _encode_min_max(flag, min, max):
499    if flag is SortedSetCommandsMixin.ZSET_EXCLUDE_MIN:
500        return '({}'.format(min), max
501    elif flag is SortedSetCommandsMixin.ZSET_EXCLUDE_MAX:
502        return min, '({}'.format(max)
503    elif flag is SortedSetCommandsMixin.ZSET_EXCLUDE_BOTH:
504        return '({}'.format(min), '({}'.format(max)
505    return min, max
506
507
508def int_or_float(value):
509    assert isinstance(value, (str, bytes)), 'raw_value must be bytes'
510    try:
511        return int(value)
512    except ValueError:
513        return float(value)
514
515
516def optional_int_or_float(value):
517    if value is None:
518        return value
519    return int_or_float(value)
520
521
522def pairs_int_or_float(value):
523    it = iter(value)
524    return [(val, int_or_float(score))
525            for val, score in zip(it, it)]
526