1 /*
2  * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
3  * Copyright (c) 2009-2012, Pieter Noordhuis <pcnoordhuis at gmail dot com>
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  *   * Redistributions of source code must retain the above copyright notice,
10  *     this list of conditions and the following disclaimer.
11  *   * Redistributions in binary form must reproduce the above copyright
12  *     notice, this list of conditions and the following disclaimer in the
13  *     documentation and/or other materials provided with the distribution.
14  *   * Neither the name of Redis nor the names of its contributors may be used
15  *     to endorse or promote products derived from this software without
16  *     specific prior written permission.
17  *
18  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
22  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28  * POSSIBILITY OF SUCH DAMAGE.
29  */
30 
31 /*-----------------------------------------------------------------------------
32  * Sorted set API
33  *----------------------------------------------------------------------------*/
34 
35 /* ZSETs are ordered sets using two data structures to hold the same elements
36  * in order to get O(log(N)) INSERT and REMOVE operations into a sorted
37  * data structure.
38  *
39  * The elements are added to a hash table mapping Redis objects to scores.
40  * At the same time the elements are added to a skip list mapping scores
41  * to Redis objects (so objects are sorted by scores in this "view").
42  *
43  * Note that the SDS string representing the element is the same in both
44  * the hash table and skiplist in order to save memory. What we do in order
45  * to manage the shared SDS string more easily is to free the SDS string
46  * only in zslFreeNode(). The dictionary has no value free method set.
47  * So we should always remove an element from the dictionary, and later from
48  * the skiplist.
49  *
50  * This skiplist implementation is almost a C translation of the original
51  * algorithm described by William Pugh in "Skip Lists: A Probabilistic
52  * Alternative to Balanced Trees", modified in three ways:
53  * a) this implementation allows for repeated scores.
54  * b) the comparison is not just by key (our 'score') but by satellite data.
55  * c) there is a back pointer, so it's a doubly linked list with the back
56  * pointers being only at "level 1". This allows to traverse the list
57  * from tail to head, useful for ZREVRANGE. */
58 
59 #include "server.h"
60 #include <math.h>
61 
62 /*-----------------------------------------------------------------------------
63  * Skiplist implementation of the low level API
64  *----------------------------------------------------------------------------*/
65 
66 int zslLexValueGteMin(sds value, zlexrangespec *spec);
67 int zslLexValueLteMax(sds value, zlexrangespec *spec);
68 
69 /* Create a skiplist node with the specified number of levels.
70  * The SDS string 'ele' is referenced by the node after the call. */
zslCreateNode(int level,double score,sds ele)71 zskiplistNode *zslCreateNode(int level, double score, sds ele) {
72     zskiplistNode *zn =
73         zmalloc(sizeof(*zn)+level*sizeof(struct zskiplistLevel));
74     zn->score = score;
75     zn->ele = ele;
76     return zn;
77 }
78 
79 /* Create a new skiplist. */
zslCreate(void)80 zskiplist *zslCreate(void) {
81     int j;
82     zskiplist *zsl;
83 
84     zsl = zmalloc(sizeof(*zsl));
85     zsl->level = 1;
86     zsl->length = 0;
87     zsl->header = zslCreateNode(ZSKIPLIST_MAXLEVEL,0,NULL);
88     for (j = 0; j < ZSKIPLIST_MAXLEVEL; j++) {
89         zsl->header->level[j].forward = NULL;
90         zsl->header->level[j].span = 0;
91     }
92     zsl->header->backward = NULL;
93     zsl->tail = NULL;
94     return zsl;
95 }
96 
97 /* Free the specified skiplist node. The referenced SDS string representation
98  * of the element is freed too, unless node->ele is set to NULL before calling
99  * this function. */
zslFreeNode(zskiplistNode * node)100 void zslFreeNode(zskiplistNode *node) {
101     sdsfree(node->ele);
102     zfree(node);
103 }
104 
105 /* Free a whole skiplist. */
zslFree(zskiplist * zsl)106 void zslFree(zskiplist *zsl) {
107     zskiplistNode *node = zsl->header->level[0].forward, *next;
108 
109     zfree(zsl->header);
110     while(node) {
111         next = node->level[0].forward;
112         zslFreeNode(node);
113         node = next;
114     }
115     zfree(zsl);
116 }
117 
118 /* Returns a random level for the new skiplist node we are going to create.
119  * The return value of this function is between 1 and ZSKIPLIST_MAXLEVEL
120  * (both inclusive), with a powerlaw-alike distribution where higher
121  * levels are less likely to be returned. */
zslRandomLevel(void)122 int zslRandomLevel(void) {
123     int level = 1;
124     while ((random()&0xFFFF) < (ZSKIPLIST_P * 0xFFFF))
125         level += 1;
126     return (level<ZSKIPLIST_MAXLEVEL) ? level : ZSKIPLIST_MAXLEVEL;
127 }
128 
129 /* Insert a new node in the skiplist. Assumes the element does not already
130  * exist (up to the caller to enforce that). The skiplist takes ownership
131  * of the passed SDS string 'ele'. */
zslInsert(zskiplist * zsl,double score,sds ele)132 zskiplistNode *zslInsert(zskiplist *zsl, double score, sds ele) {
133     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
134     unsigned long rank[ZSKIPLIST_MAXLEVEL];
135     int i, level;
136 
137     serverAssert(!isnan(score));
138     x = zsl->header;
139     for (i = zsl->level-1; i >= 0; i--) {
140         /* store rank that is crossed to reach the insert position */
141         rank[i] = i == (zsl->level-1) ? 0 : rank[i+1];
142         while (x->level[i].forward &&
143                 (x->level[i].forward->score < score ||
144                     (x->level[i].forward->score == score &&
145                     sdscmp(x->level[i].forward->ele,ele) < 0)))
146         {
147             rank[i] += x->level[i].span;
148             x = x->level[i].forward;
149         }
150         update[i] = x;
151     }
152     /* we assume the element is not already inside, since we allow duplicated
153      * scores, reinserting the same element should never happen since the
154      * caller of zslInsert() should test in the hash table if the element is
155      * already inside or not. */
156     level = zslRandomLevel();
157     if (level > zsl->level) {
158         for (i = zsl->level; i < level; i++) {
159             rank[i] = 0;
160             update[i] = zsl->header;
161             update[i]->level[i].span = zsl->length;
162         }
163         zsl->level = level;
164     }
165     x = zslCreateNode(level,score,ele);
166     for (i = 0; i < level; i++) {
167         x->level[i].forward = update[i]->level[i].forward;
168         update[i]->level[i].forward = x;
169 
170         /* update span covered by update[i] as x is inserted here */
171         x->level[i].span = update[i]->level[i].span - (rank[0] - rank[i]);
172         update[i]->level[i].span = (rank[0] - rank[i]) + 1;
173     }
174 
175     /* increment span for untouched levels */
176     for (i = level; i < zsl->level; i++) {
177         update[i]->level[i].span++;
178     }
179 
180     x->backward = (update[0] == zsl->header) ? NULL : update[0];
181     if (x->level[0].forward)
182         x->level[0].forward->backward = x;
183     else
184         zsl->tail = x;
185     zsl->length++;
186     return x;
187 }
188 
189 /* Internal function used by zslDelete, zslDeleteRangeByScore and
190  * zslDeleteRangeByRank. */
zslDeleteNode(zskiplist * zsl,zskiplistNode * x,zskiplistNode ** update)191 void zslDeleteNode(zskiplist *zsl, zskiplistNode *x, zskiplistNode **update) {
192     int i;
193     for (i = 0; i < zsl->level; i++) {
194         if (update[i]->level[i].forward == x) {
195             update[i]->level[i].span += x->level[i].span - 1;
196             update[i]->level[i].forward = x->level[i].forward;
197         } else {
198             update[i]->level[i].span -= 1;
199         }
200     }
201     if (x->level[0].forward) {
202         x->level[0].forward->backward = x->backward;
203     } else {
204         zsl->tail = x->backward;
205     }
206     while(zsl->level > 1 && zsl->header->level[zsl->level-1].forward == NULL)
207         zsl->level--;
208     zsl->length--;
209 }
210 
211 /* Delete an element with matching score/element from the skiplist.
212  * The function returns 1 if the node was found and deleted, otherwise
213  * 0 is returned.
214  *
215  * If 'node' is NULL the deleted node is freed by zslFreeNode(), otherwise
216  * it is not freed (but just unlinked) and *node is set to the node pointer,
217  * so that it is possible for the caller to reuse the node (including the
218  * referenced SDS string at node->ele). */
zslDelete(zskiplist * zsl,double score,sds ele,zskiplistNode ** node)219 int zslDelete(zskiplist *zsl, double score, sds ele, zskiplistNode **node) {
220     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
221     int i;
222 
223     x = zsl->header;
224     for (i = zsl->level-1; i >= 0; i--) {
225         while (x->level[i].forward &&
226                 (x->level[i].forward->score < score ||
227                     (x->level[i].forward->score == score &&
228                      sdscmp(x->level[i].forward->ele,ele) < 0)))
229         {
230             x = x->level[i].forward;
231         }
232         update[i] = x;
233     }
234     /* We may have multiple elements with the same score, what we need
235      * is to find the element with both the right score and object. */
236     x = x->level[0].forward;
237     if (x && score == x->score && sdscmp(x->ele,ele) == 0) {
238         zslDeleteNode(zsl, x, update);
239         if (!node)
240             zslFreeNode(x);
241         else
242             *node = x;
243         return 1;
244     }
245     return 0; /* not found */
246 }
247 
248 /* Update the score of an element inside the sorted set skiplist.
249  * Note that the element must exist and must match 'score'.
250  * This function does not update the score in the hash table side, the
251  * caller should take care of it.
252  *
253  * Note that this function attempts to just update the node, in case after
254  * the score update, the node would be exactly at the same position.
255  * Otherwise the skiplist is modified by removing and re-adding a new
256  * element, which is more costly.
257  *
258  * The function returns the updated element skiplist node pointer. */
zslUpdateScore(zskiplist * zsl,double curscore,sds ele,double newscore)259 zskiplistNode *zslUpdateScore(zskiplist *zsl, double curscore, sds ele, double newscore) {
260     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
261     int i;
262 
263     /* We need to seek to element to update to start: this is useful anyway,
264      * we'll have to update or remove it. */
265     x = zsl->header;
266     for (i = zsl->level-1; i >= 0; i--) {
267         while (x->level[i].forward &&
268                 (x->level[i].forward->score < curscore ||
269                     (x->level[i].forward->score == curscore &&
270                      sdscmp(x->level[i].forward->ele,ele) < 0)))
271         {
272             x = x->level[i].forward;
273         }
274         update[i] = x;
275     }
276 
277     /* Jump to our element: note that this function assumes that the
278      * element with the matching score exists. */
279     x = x->level[0].forward;
280     serverAssert(x && curscore == x->score && sdscmp(x->ele,ele) == 0);
281 
282     /* If the node, after the score update, would be still exactly
283      * at the same position, we can just update the score without
284      * actually removing and re-inserting the element in the skiplist. */
285     if ((x->backward == NULL || x->backward->score < newscore) &&
286         (x->level[0].forward == NULL || x->level[0].forward->score > newscore))
287     {
288         x->score = newscore;
289         return x;
290     }
291 
292     /* No way to reuse the old node: we need to remove and insert a new
293      * one at a different place. */
294     zslDeleteNode(zsl, x, update);
295     zskiplistNode *newnode = zslInsert(zsl,newscore,x->ele);
296     /* We reused the old node x->ele SDS string, free the node now
297      * since zslInsert created a new one. */
298     x->ele = NULL;
299     zslFreeNode(x);
300     return newnode;
301 }
302 
zslValueGteMin(double value,zrangespec * spec)303 int zslValueGteMin(double value, zrangespec *spec) {
304     return spec->minex ? (value > spec->min) : (value >= spec->min);
305 }
306 
zslValueLteMax(double value,zrangespec * spec)307 int zslValueLteMax(double value, zrangespec *spec) {
308     return spec->maxex ? (value < spec->max) : (value <= spec->max);
309 }
310 
311 /* Returns if there is a part of the zset is in range. */
zslIsInRange(zskiplist * zsl,zrangespec * range)312 int zslIsInRange(zskiplist *zsl, zrangespec *range) {
313     zskiplistNode *x;
314 
315     /* Test for ranges that will always be empty. */
316     if (range->min > range->max ||
317             (range->min == range->max && (range->minex || range->maxex)))
318         return 0;
319     x = zsl->tail;
320     if (x == NULL || !zslValueGteMin(x->score,range))
321         return 0;
322     x = zsl->header->level[0].forward;
323     if (x == NULL || !zslValueLteMax(x->score,range))
324         return 0;
325     return 1;
326 }
327 
328 /* Find the first node that is contained in the specified range.
329  * Returns NULL when no element is contained in the range. */
zslFirstInRange(zskiplist * zsl,zrangespec * range)330 zskiplistNode *zslFirstInRange(zskiplist *zsl, zrangespec *range) {
331     zskiplistNode *x;
332     int i;
333 
334     /* If everything is out of range, return early. */
335     if (!zslIsInRange(zsl,range)) return NULL;
336 
337     x = zsl->header;
338     for (i = zsl->level-1; i >= 0; i--) {
339         /* Go forward while *OUT* of range. */
340         while (x->level[i].forward &&
341             !zslValueGteMin(x->level[i].forward->score,range))
342                 x = x->level[i].forward;
343     }
344 
345     /* This is an inner range, so the next node cannot be NULL. */
346     x = x->level[0].forward;
347     serverAssert(x != NULL);
348 
349     /* Check if score <= max. */
350     if (!zslValueLteMax(x->score,range)) return NULL;
351     return x;
352 }
353 
354 /* Find the last node that is contained in the specified range.
355  * Returns NULL when no element is contained in the range. */
zslLastInRange(zskiplist * zsl,zrangespec * range)356 zskiplistNode *zslLastInRange(zskiplist *zsl, zrangespec *range) {
357     zskiplistNode *x;
358     int i;
359 
360     /* If everything is out of range, return early. */
361     if (!zslIsInRange(zsl,range)) return NULL;
362 
363     x = zsl->header;
364     for (i = zsl->level-1; i >= 0; i--) {
365         /* Go forward while *IN* range. */
366         while (x->level[i].forward &&
367             zslValueLteMax(x->level[i].forward->score,range))
368                 x = x->level[i].forward;
369     }
370 
371     /* This is an inner range, so this node cannot be NULL. */
372     serverAssert(x != NULL);
373 
374     /* Check if score >= min. */
375     if (!zslValueGteMin(x->score,range)) return NULL;
376     return x;
377 }
378 
379 /* Delete all the elements with score between min and max from the skiplist.
380  * Both min and max can be inclusive or exclusive (see range->minex and
381  * range->maxex). When inclusive a score >= min && score <= max is deleted.
382  * Note that this function takes the reference to the hash table view of the
383  * sorted set, in order to remove the elements from the hash table too. */
zslDeleteRangeByScore(zskiplist * zsl,zrangespec * range,dict * dict)384 unsigned long zslDeleteRangeByScore(zskiplist *zsl, zrangespec *range, dict *dict) {
385     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
386     unsigned long removed = 0;
387     int i;
388 
389     x = zsl->header;
390     for (i = zsl->level-1; i >= 0; i--) {
391         while (x->level[i].forward &&
392             !zslValueGteMin(x->level[i].forward->score, range))
393                 x = x->level[i].forward;
394         update[i] = x;
395     }
396 
397     /* Current node is the last with score < or <= min. */
398     x = x->level[0].forward;
399 
400     /* Delete nodes while in range. */
401     while (x && zslValueLteMax(x->score, range)) {
402         zskiplistNode *next = x->level[0].forward;
403         zslDeleteNode(zsl,x,update);
404         dictDelete(dict,x->ele);
405         zslFreeNode(x); /* Here is where x->ele is actually released. */
406         removed++;
407         x = next;
408     }
409     return removed;
410 }
411 
zslDeleteRangeByLex(zskiplist * zsl,zlexrangespec * range,dict * dict)412 unsigned long zslDeleteRangeByLex(zskiplist *zsl, zlexrangespec *range, dict *dict) {
413     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
414     unsigned long removed = 0;
415     int i;
416 
417 
418     x = zsl->header;
419     for (i = zsl->level-1; i >= 0; i--) {
420         while (x->level[i].forward &&
421             !zslLexValueGteMin(x->level[i].forward->ele,range))
422                 x = x->level[i].forward;
423         update[i] = x;
424     }
425 
426     /* Current node is the last with score < or <= min. */
427     x = x->level[0].forward;
428 
429     /* Delete nodes while in range. */
430     while (x && zslLexValueLteMax(x->ele,range)) {
431         zskiplistNode *next = x->level[0].forward;
432         zslDeleteNode(zsl,x,update);
433         dictDelete(dict,x->ele);
434         zslFreeNode(x); /* Here is where x->ele is actually released. */
435         removed++;
436         x = next;
437     }
438     return removed;
439 }
440 
441 /* Delete all the elements with rank between start and end from the skiplist.
442  * Start and end are inclusive. Note that start and end need to be 1-based */
zslDeleteRangeByRank(zskiplist * zsl,unsigned int start,unsigned int end,dict * dict)443 unsigned long zslDeleteRangeByRank(zskiplist *zsl, unsigned int start, unsigned int end, dict *dict) {
444     zskiplistNode *update[ZSKIPLIST_MAXLEVEL], *x;
445     unsigned long traversed = 0, removed = 0;
446     int i;
447 
448     x = zsl->header;
449     for (i = zsl->level-1; i >= 0; i--) {
450         while (x->level[i].forward && (traversed + x->level[i].span) < start) {
451             traversed += x->level[i].span;
452             x = x->level[i].forward;
453         }
454         update[i] = x;
455     }
456 
457     traversed++;
458     x = x->level[0].forward;
459     while (x && traversed <= end) {
460         zskiplistNode *next = x->level[0].forward;
461         zslDeleteNode(zsl,x,update);
462         dictDelete(dict,x->ele);
463         zslFreeNode(x);
464         removed++;
465         traversed++;
466         x = next;
467     }
468     return removed;
469 }
470 
471 /* Find the rank for an element by both score and key.
472  * Returns 0 when the element cannot be found, rank otherwise.
473  * Note that the rank is 1-based due to the span of zsl->header to the
474  * first element. */
zslGetRank(zskiplist * zsl,double score,sds ele)475 unsigned long zslGetRank(zskiplist *zsl, double score, sds ele) {
476     zskiplistNode *x;
477     unsigned long rank = 0;
478     int i;
479 
480     x = zsl->header;
481     for (i = zsl->level-1; i >= 0; i--) {
482         while (x->level[i].forward &&
483             (x->level[i].forward->score < score ||
484                 (x->level[i].forward->score == score &&
485                 sdscmp(x->level[i].forward->ele,ele) <= 0))) {
486             rank += x->level[i].span;
487             x = x->level[i].forward;
488         }
489 
490         /* x might be equal to zsl->header, so test if obj is non-NULL */
491         if (x->ele && x->score == score && sdscmp(x->ele,ele) == 0) {
492             return rank;
493         }
494     }
495     return 0;
496 }
497 
498 /* Finds an element by its rank. The rank argument needs to be 1-based. */
zslGetElementByRank(zskiplist * zsl,unsigned long rank)499 zskiplistNode* zslGetElementByRank(zskiplist *zsl, unsigned long rank) {
500     zskiplistNode *x;
501     unsigned long traversed = 0;
502     int i;
503 
504     x = zsl->header;
505     for (i = zsl->level-1; i >= 0; i--) {
506         while (x->level[i].forward && (traversed + x->level[i].span) <= rank)
507         {
508             traversed += x->level[i].span;
509             x = x->level[i].forward;
510         }
511         if (traversed == rank) {
512             return x;
513         }
514     }
515     return NULL;
516 }
517 
518 /* Populate the rangespec according to the objects min and max. */
zslParseRange(robj * min,robj * max,zrangespec * spec)519 static int zslParseRange(robj *min, robj *max, zrangespec *spec) {
520     char *eptr;
521     spec->minex = spec->maxex = 0;
522 
523     /* Parse the min-max interval. If one of the values is prefixed
524      * by the "(" character, it's considered "open". For instance
525      * ZRANGEBYSCORE zset (1.5 (2.5 will match min < x < max
526      * ZRANGEBYSCORE zset 1.5 2.5 will instead match min <= x <= max */
527     if (min->encoding == OBJ_ENCODING_INT) {
528         spec->min = (long)min->ptr;
529     } else {
530         if (((char*)min->ptr)[0] == '(') {
531             spec->min = strtod((char*)min->ptr+1,&eptr);
532             if (eptr[0] != '\0' || isnan(spec->min)) return C_ERR;
533             spec->minex = 1;
534         } else {
535             spec->min = strtod((char*)min->ptr,&eptr);
536             if (eptr[0] != '\0' || isnan(spec->min)) return C_ERR;
537         }
538     }
539     if (max->encoding == OBJ_ENCODING_INT) {
540         spec->max = (long)max->ptr;
541     } else {
542         if (((char*)max->ptr)[0] == '(') {
543             spec->max = strtod((char*)max->ptr+1,&eptr);
544             if (eptr[0] != '\0' || isnan(spec->max)) return C_ERR;
545             spec->maxex = 1;
546         } else {
547             spec->max = strtod((char*)max->ptr,&eptr);
548             if (eptr[0] != '\0' || isnan(spec->max)) return C_ERR;
549         }
550     }
551 
552     return C_OK;
553 }
554 
555 /* ------------------------ Lexicographic ranges ---------------------------- */
556 
557 /* Parse max or min argument of ZRANGEBYLEX.
558   * (foo means foo (open interval)
559   * [foo means foo (closed interval)
560   * - means the min string possible
561   * + means the max string possible
562   *
563   * If the string is valid the *dest pointer is set to the redis object
564   * that will be used for the comparison, and ex will be set to 0 or 1
565   * respectively if the item is exclusive or inclusive. C_OK will be
566   * returned.
567   *
568   * If the string is not a valid range C_ERR is returned, and the value
569   * of *dest and *ex is undefined. */
zslParseLexRangeItem(robj * item,sds * dest,int * ex)570 int zslParseLexRangeItem(robj *item, sds *dest, int *ex) {
571     char *c = item->ptr;
572 
573     switch(c[0]) {
574     case '+':
575         if (c[1] != '\0') return C_ERR;
576         *ex = 1;
577         *dest = shared.maxstring;
578         return C_OK;
579     case '-':
580         if (c[1] != '\0') return C_ERR;
581         *ex = 1;
582         *dest = shared.minstring;
583         return C_OK;
584     case '(':
585         *ex = 1;
586         *dest = sdsnewlen(c+1,sdslen(c)-1);
587         return C_OK;
588     case '[':
589         *ex = 0;
590         *dest = sdsnewlen(c+1,sdslen(c)-1);
591         return C_OK;
592     default:
593         return C_ERR;
594     }
595 }
596 
597 /* Free a lex range structure, must be called only after zslParseLexRange()
598  * populated the structure with success (C_OK returned). */
zslFreeLexRange(zlexrangespec * spec)599 void zslFreeLexRange(zlexrangespec *spec) {
600     if (spec->min != shared.minstring &&
601         spec->min != shared.maxstring) sdsfree(spec->min);
602     if (spec->max != shared.minstring &&
603         spec->max != shared.maxstring) sdsfree(spec->max);
604 }
605 
606 /* Populate the lex rangespec according to the objects min and max.
607  *
608  * Return C_OK on success. On error C_ERR is returned.
609  * When OK is returned the structure must be freed with zslFreeLexRange(),
610  * otherwise no release is needed. */
zslParseLexRange(robj * min,robj * max,zlexrangespec * spec)611 int zslParseLexRange(robj *min, robj *max, zlexrangespec *spec) {
612     /* The range can't be valid if objects are integer encoded.
613      * Every item must start with ( or [. */
614     if (min->encoding == OBJ_ENCODING_INT ||
615         max->encoding == OBJ_ENCODING_INT) return C_ERR;
616 
617     spec->min = spec->max = NULL;
618     if (zslParseLexRangeItem(min, &spec->min, &spec->minex) == C_ERR ||
619         zslParseLexRangeItem(max, &spec->max, &spec->maxex) == C_ERR) {
620         zslFreeLexRange(spec);
621         return C_ERR;
622     } else {
623         return C_OK;
624     }
625 }
626 
627 /* This is just a wrapper to sdscmp() that is able to
628  * handle shared.minstring and shared.maxstring as the equivalent of
629  * -inf and +inf for strings */
sdscmplex(sds a,sds b)630 int sdscmplex(sds a, sds b) {
631     if (a == b) return 0;
632     if (a == shared.minstring || b == shared.maxstring) return -1;
633     if (a == shared.maxstring || b == shared.minstring) return 1;
634     return sdscmp(a,b);
635 }
636 
zslLexValueGteMin(sds value,zlexrangespec * spec)637 int zslLexValueGteMin(sds value, zlexrangespec *spec) {
638     return spec->minex ?
639         (sdscmplex(value,spec->min) > 0) :
640         (sdscmplex(value,spec->min) >= 0);
641 }
642 
zslLexValueLteMax(sds value,zlexrangespec * spec)643 int zslLexValueLteMax(sds value, zlexrangespec *spec) {
644     return spec->maxex ?
645         (sdscmplex(value,spec->max) < 0) :
646         (sdscmplex(value,spec->max) <= 0);
647 }
648 
649 /* Returns if there is a part of the zset is in the lex range. */
zslIsInLexRange(zskiplist * zsl,zlexrangespec * range)650 int zslIsInLexRange(zskiplist *zsl, zlexrangespec *range) {
651     zskiplistNode *x;
652 
653     /* Test for ranges that will always be empty. */
654     int cmp = sdscmplex(range->min,range->max);
655     if (cmp > 0 || (cmp == 0 && (range->minex || range->maxex)))
656         return 0;
657     x = zsl->tail;
658     if (x == NULL || !zslLexValueGteMin(x->ele,range))
659         return 0;
660     x = zsl->header->level[0].forward;
661     if (x == NULL || !zslLexValueLteMax(x->ele,range))
662         return 0;
663     return 1;
664 }
665 
666 /* Find the first node that is contained in the specified lex range.
667  * Returns NULL when no element is contained in the range. */
zslFirstInLexRange(zskiplist * zsl,zlexrangespec * range)668 zskiplistNode *zslFirstInLexRange(zskiplist *zsl, zlexrangespec *range) {
669     zskiplistNode *x;
670     int i;
671 
672     /* If everything is out of range, return early. */
673     if (!zslIsInLexRange(zsl,range)) return NULL;
674 
675     x = zsl->header;
676     for (i = zsl->level-1; i >= 0; i--) {
677         /* Go forward while *OUT* of range. */
678         while (x->level[i].forward &&
679             !zslLexValueGteMin(x->level[i].forward->ele,range))
680                 x = x->level[i].forward;
681     }
682 
683     /* This is an inner range, so the next node cannot be NULL. */
684     x = x->level[0].forward;
685     serverAssert(x != NULL);
686 
687     /* Check if score <= max. */
688     if (!zslLexValueLteMax(x->ele,range)) return NULL;
689     return x;
690 }
691 
692 /* Find the last node that is contained in the specified range.
693  * Returns NULL when no element is contained in the range. */
zslLastInLexRange(zskiplist * zsl,zlexrangespec * range)694 zskiplistNode *zslLastInLexRange(zskiplist *zsl, zlexrangespec *range) {
695     zskiplistNode *x;
696     int i;
697 
698     /* If everything is out of range, return early. */
699     if (!zslIsInLexRange(zsl,range)) return NULL;
700 
701     x = zsl->header;
702     for (i = zsl->level-1; i >= 0; i--) {
703         /* Go forward while *IN* range. */
704         while (x->level[i].forward &&
705             zslLexValueLteMax(x->level[i].forward->ele,range))
706                 x = x->level[i].forward;
707     }
708 
709     /* This is an inner range, so this node cannot be NULL. */
710     serverAssert(x != NULL);
711 
712     /* Check if score >= min. */
713     if (!zslLexValueGteMin(x->ele,range)) return NULL;
714     return x;
715 }
716 
717 /*-----------------------------------------------------------------------------
718  * Listpack-backed sorted set API
719  *----------------------------------------------------------------------------*/
720 
zzlStrtod(unsigned char * vstr,unsigned int vlen)721 double zzlStrtod(unsigned char *vstr, unsigned int vlen) {
722     char buf[128];
723     if (vlen > sizeof(buf))
724         vlen = sizeof(buf);
725     memcpy(buf,vstr,vlen);
726     buf[vlen] = '\0';
727     return strtod(buf,NULL);
728  }
729 
zzlGetScore(unsigned char * sptr)730 double zzlGetScore(unsigned char *sptr) {
731     unsigned char *vstr;
732     unsigned int vlen;
733     long long vlong;
734     double score;
735 
736     serverAssert(sptr != NULL);
737     vstr = lpGetValue(sptr,&vlen,&vlong);
738 
739     if (vstr) {
740         score = zzlStrtod(vstr,vlen);
741     } else {
742         score = vlong;
743     }
744 
745     return score;
746 }
747 
748 /* Return a listpack element as an SDS string. */
lpGetObject(unsigned char * sptr)749 sds lpGetObject(unsigned char *sptr) {
750     unsigned char *vstr;
751     unsigned int vlen;
752     long long vlong;
753 
754     serverAssert(sptr != NULL);
755     vstr = lpGetValue(sptr,&vlen,&vlong);
756 
757     if (vstr) {
758         return sdsnewlen((char*)vstr,vlen);
759     } else {
760         return sdsfromlonglong(vlong);
761     }
762 }
763 
764 /* Compare element in sorted set with given element. */
zzlCompareElements(unsigned char * eptr,unsigned char * cstr,unsigned int clen)765 int zzlCompareElements(unsigned char *eptr, unsigned char *cstr, unsigned int clen) {
766     unsigned char *vstr;
767     unsigned int vlen;
768     long long vlong;
769     unsigned char vbuf[32];
770     int minlen, cmp;
771 
772     vstr = lpGetValue(eptr,&vlen,&vlong);
773     if (vstr == NULL) {
774         /* Store string representation of long long in buf. */
775         vlen = ll2string((char*)vbuf,sizeof(vbuf),vlong);
776         vstr = vbuf;
777     }
778 
779     minlen = (vlen < clen) ? vlen : clen;
780     cmp = memcmp(vstr,cstr,minlen);
781     if (cmp == 0) return vlen-clen;
782     return cmp;
783 }
784 
zzlLength(unsigned char * zl)785 unsigned int zzlLength(unsigned char *zl) {
786     return lpLength(zl)/2;
787 }
788 
789 /* Move to next entry based on the values in eptr and sptr. Both are set to
790  * NULL when there is no next entry. */
zzlNext(unsigned char * zl,unsigned char ** eptr,unsigned char ** sptr)791 void zzlNext(unsigned char *zl, unsigned char **eptr, unsigned char **sptr) {
792     unsigned char *_eptr, *_sptr;
793     serverAssert(*eptr != NULL && *sptr != NULL);
794 
795     _eptr = lpNext(zl,*sptr);
796     if (_eptr != NULL) {
797         _sptr = lpNext(zl,_eptr);
798         serverAssert(_sptr != NULL);
799     } else {
800         /* No next entry. */
801         _sptr = NULL;
802     }
803 
804     *eptr = _eptr;
805     *sptr = _sptr;
806 }
807 
808 /* Move to the previous entry based on the values in eptr and sptr. Both are
809  * set to NULL when there is no prev entry. */
zzlPrev(unsigned char * zl,unsigned char ** eptr,unsigned char ** sptr)810 void zzlPrev(unsigned char *zl, unsigned char **eptr, unsigned char **sptr) {
811     unsigned char *_eptr, *_sptr;
812     serverAssert(*eptr != NULL && *sptr != NULL);
813 
814     _sptr = lpPrev(zl,*eptr);
815     if (_sptr != NULL) {
816         _eptr = lpPrev(zl,_sptr);
817         serverAssert(_eptr != NULL);
818     } else {
819         /* No previous entry. */
820         _eptr = NULL;
821     }
822 
823     *eptr = _eptr;
824     *sptr = _sptr;
825 }
826 
827 /* Returns if there is a part of the zset is in range. Should only be used
828  * internally by zzlFirstInRange and zzlLastInRange. */
zzlIsInRange(unsigned char * zl,zrangespec * range)829 int zzlIsInRange(unsigned char *zl, zrangespec *range) {
830     unsigned char *p;
831     double score;
832 
833     /* Test for ranges that will always be empty. */
834     if (range->min > range->max ||
835             (range->min == range->max && (range->minex || range->maxex)))
836         return 0;
837 
838     p = lpSeek(zl,-1); /* Last score. */
839     if (p == NULL) return 0; /* Empty sorted set */
840     score = zzlGetScore(p);
841     if (!zslValueGteMin(score,range))
842         return 0;
843 
844     p = lpSeek(zl,1); /* First score. */
845     serverAssert(p != NULL);
846     score = zzlGetScore(p);
847     if (!zslValueLteMax(score,range))
848         return 0;
849 
850     return 1;
851 }
852 
853 /* Find pointer to the first element contained in the specified range.
854  * Returns NULL when no element is contained in the range. */
zzlFirstInRange(unsigned char * zl,zrangespec * range)855 unsigned char *zzlFirstInRange(unsigned char *zl, zrangespec *range) {
856     unsigned char *eptr = lpSeek(zl,0), *sptr;
857     double score;
858 
859     /* If everything is out of range, return early. */
860     if (!zzlIsInRange(zl,range)) return NULL;
861 
862     while (eptr != NULL) {
863         sptr = lpNext(zl,eptr);
864         serverAssert(sptr != NULL);
865 
866         score = zzlGetScore(sptr);
867         if (zslValueGteMin(score,range)) {
868             /* Check if score <= max. */
869             if (zslValueLteMax(score,range))
870                 return eptr;
871             return NULL;
872         }
873 
874         /* Move to next element. */
875         eptr = lpNext(zl,sptr);
876     }
877 
878     return NULL;
879 }
880 
881 /* Find pointer to the last element contained in the specified range.
882  * Returns NULL when no element is contained in the range. */
zzlLastInRange(unsigned char * zl,zrangespec * range)883 unsigned char *zzlLastInRange(unsigned char *zl, zrangespec *range) {
884     unsigned char *eptr = lpSeek(zl,-2), *sptr;
885     double score;
886 
887     /* If everything is out of range, return early. */
888     if (!zzlIsInRange(zl,range)) return NULL;
889 
890     while (eptr != NULL) {
891         sptr = lpNext(zl,eptr);
892         serverAssert(sptr != NULL);
893 
894         score = zzlGetScore(sptr);
895         if (zslValueLteMax(score,range)) {
896             /* Check if score >= min. */
897             if (zslValueGteMin(score,range))
898                 return eptr;
899             return NULL;
900         }
901 
902         /* Move to previous element by moving to the score of previous element.
903          * When this returns NULL, we know there also is no element. */
904         sptr = lpPrev(zl,eptr);
905         if (sptr != NULL)
906             serverAssert((eptr = lpPrev(zl,sptr)) != NULL);
907         else
908             eptr = NULL;
909     }
910 
911     return NULL;
912 }
913 
zzlLexValueGteMin(unsigned char * p,zlexrangespec * spec)914 int zzlLexValueGteMin(unsigned char *p, zlexrangespec *spec) {
915     sds value = lpGetObject(p);
916     int res = zslLexValueGteMin(value,spec);
917     sdsfree(value);
918     return res;
919 }
920 
zzlLexValueLteMax(unsigned char * p,zlexrangespec * spec)921 int zzlLexValueLteMax(unsigned char *p, zlexrangespec *spec) {
922     sds value = lpGetObject(p);
923     int res = zslLexValueLteMax(value,spec);
924     sdsfree(value);
925     return res;
926 }
927 
928 /* Returns if there is a part of the zset is in range. Should only be used
929  * internally by zzlFirstInRange and zzlLastInRange. */
zzlIsInLexRange(unsigned char * zl,zlexrangespec * range)930 int zzlIsInLexRange(unsigned char *zl, zlexrangespec *range) {
931     unsigned char *p;
932 
933     /* Test for ranges that will always be empty. */
934     int cmp = sdscmplex(range->min,range->max);
935     if (cmp > 0 || (cmp == 0 && (range->minex || range->maxex)))
936         return 0;
937 
938     p = lpSeek(zl,-2); /* Last element. */
939     if (p == NULL) return 0;
940     if (!zzlLexValueGteMin(p,range))
941         return 0;
942 
943     p = lpSeek(zl,0); /* First element. */
944     serverAssert(p != NULL);
945     if (!zzlLexValueLteMax(p,range))
946         return 0;
947 
948     return 1;
949 }
950 
951 /* Find pointer to the first element contained in the specified lex range.
952  * Returns NULL when no element is contained in the range. */
zzlFirstInLexRange(unsigned char * zl,zlexrangespec * range)953 unsigned char *zzlFirstInLexRange(unsigned char *zl, zlexrangespec *range) {
954     unsigned char *eptr = lpSeek(zl,0), *sptr;
955 
956     /* If everything is out of range, return early. */
957     if (!zzlIsInLexRange(zl,range)) return NULL;
958 
959     while (eptr != NULL) {
960         if (zzlLexValueGteMin(eptr,range)) {
961             /* Check if score <= max. */
962             if (zzlLexValueLteMax(eptr,range))
963                 return eptr;
964             return NULL;
965         }
966 
967         /* Move to next element. */
968         sptr = lpNext(zl,eptr); /* This element score. Skip it. */
969         serverAssert(sptr != NULL);
970         eptr = lpNext(zl,sptr); /* Next element. */
971     }
972 
973     return NULL;
974 }
975 
976 /* Find pointer to the last element contained in the specified lex range.
977  * Returns NULL when no element is contained in the range. */
zzlLastInLexRange(unsigned char * zl,zlexrangespec * range)978 unsigned char *zzlLastInLexRange(unsigned char *zl, zlexrangespec *range) {
979     unsigned char *eptr = lpSeek(zl,-2), *sptr;
980 
981     /* If everything is out of range, return early. */
982     if (!zzlIsInLexRange(zl,range)) return NULL;
983 
984     while (eptr != NULL) {
985         if (zzlLexValueLteMax(eptr,range)) {
986             /* Check if score >= min. */
987             if (zzlLexValueGteMin(eptr,range))
988                 return eptr;
989             return NULL;
990         }
991 
992         /* Move to previous element by moving to the score of previous element.
993          * When this returns NULL, we know there also is no element. */
994         sptr = lpPrev(zl,eptr);
995         if (sptr != NULL)
996             serverAssert((eptr = lpPrev(zl,sptr)) != NULL);
997         else
998             eptr = NULL;
999     }
1000 
1001     return NULL;
1002 }
1003 
zzlFind(unsigned char * lp,sds ele,double * score)1004 unsigned char *zzlFind(unsigned char *lp, sds ele, double *score) {
1005     unsigned char *eptr, *sptr;
1006 
1007     if ((eptr = lpFirst(lp)) == NULL) return NULL;
1008     eptr = lpFind(lp, eptr, (unsigned char*)ele, sdslen(ele), 1);
1009     if (eptr) {
1010         sptr = lpNext(lp,eptr);
1011         serverAssert(sptr != NULL);
1012 
1013         /* Matching element, pull out score. */
1014         if (score != NULL) *score = zzlGetScore(sptr);
1015         return eptr;
1016     }
1017 
1018     return NULL;
1019 }
1020 
1021 /* Delete (element,score) pair from listpack. Use local copy of eptr because we
1022  * don't want to modify the one given as argument. */
zzlDelete(unsigned char * zl,unsigned char * eptr)1023 unsigned char *zzlDelete(unsigned char *zl, unsigned char *eptr) {
1024     return lpDeleteRangeWithEntry(zl,&eptr,2);
1025 }
1026 
zzlInsertAt(unsigned char * zl,unsigned char * eptr,sds ele,double score)1027 unsigned char *zzlInsertAt(unsigned char *zl, unsigned char *eptr, sds ele, double score) {
1028     unsigned char *sptr;
1029     char scorebuf[128];
1030     int scorelen;
1031 
1032     scorelen = d2string(scorebuf,sizeof(scorebuf),score);
1033     if (eptr == NULL) {
1034         zl = lpAppend(zl,(unsigned char*)ele,sdslen(ele));
1035         zl = lpAppend(zl,(unsigned char*)scorebuf,scorelen);
1036     } else {
1037         /* Insert member before the element 'eptr'. */
1038         zl = lpInsertString(zl,(unsigned char*)ele,sdslen(ele),eptr,LP_BEFORE,&sptr);
1039 
1040         /* Insert score after the member. */
1041         zl = lpInsertString(zl,(unsigned char*)scorebuf,scorelen,sptr,LP_AFTER,NULL);
1042     }
1043     return zl;
1044 }
1045 
1046 /* Insert (element,score) pair in listpack. This function assumes the element is
1047  * not yet present in the list. */
zzlInsert(unsigned char * zl,sds ele,double score)1048 unsigned char *zzlInsert(unsigned char *zl, sds ele, double score) {
1049     unsigned char *eptr = lpSeek(zl,0), *sptr;
1050     double s;
1051 
1052     while (eptr != NULL) {
1053         sptr = lpNext(zl,eptr);
1054         serverAssert(sptr != NULL);
1055         s = zzlGetScore(sptr);
1056 
1057         if (s > score) {
1058             /* First element with score larger than score for element to be
1059              * inserted. This means we should take its spot in the list to
1060              * maintain ordering. */
1061             zl = zzlInsertAt(zl,eptr,ele,score);
1062             break;
1063         } else if (s == score) {
1064             /* Ensure lexicographical ordering for elements. */
1065             if (zzlCompareElements(eptr,(unsigned char*)ele,sdslen(ele)) > 0) {
1066                 zl = zzlInsertAt(zl,eptr,ele,score);
1067                 break;
1068             }
1069         }
1070 
1071         /* Move to next element. */
1072         eptr = lpNext(zl,sptr);
1073     }
1074 
1075     /* Push on tail of list when it was not yet inserted. */
1076     if (eptr == NULL)
1077         zl = zzlInsertAt(zl,NULL,ele,score);
1078     return zl;
1079 }
1080 
zzlDeleteRangeByScore(unsigned char * zl,zrangespec * range,unsigned long * deleted)1081 unsigned char *zzlDeleteRangeByScore(unsigned char *zl, zrangespec *range, unsigned long *deleted) {
1082     unsigned char *eptr, *sptr;
1083     double score;
1084     unsigned long num = 0;
1085 
1086     if (deleted != NULL) *deleted = 0;
1087 
1088     eptr = zzlFirstInRange(zl,range);
1089     if (eptr == NULL) return zl;
1090 
1091     /* When the tail of the listpack is deleted, eptr will be NULL. */
1092     while (eptr && (sptr = lpNext(zl,eptr)) != NULL) {
1093         score = zzlGetScore(sptr);
1094         if (zslValueLteMax(score,range)) {
1095             /* Delete both the element and the score. */
1096             zl = lpDeleteRangeWithEntry(zl,&eptr,2);
1097             num++;
1098         } else {
1099             /* No longer in range. */
1100             break;
1101         }
1102     }
1103 
1104     if (deleted != NULL) *deleted = num;
1105     return zl;
1106 }
1107 
zzlDeleteRangeByLex(unsigned char * zl,zlexrangespec * range,unsigned long * deleted)1108 unsigned char *zzlDeleteRangeByLex(unsigned char *zl, zlexrangespec *range, unsigned long *deleted) {
1109     unsigned char *eptr, *sptr;
1110     unsigned long num = 0;
1111 
1112     if (deleted != NULL) *deleted = 0;
1113 
1114     eptr = zzlFirstInLexRange(zl,range);
1115     if (eptr == NULL) return zl;
1116 
1117     /* When the tail of the listpack is deleted, eptr will be NULL. */
1118     while (eptr && (sptr = lpNext(zl,eptr)) != NULL) {
1119         if (zzlLexValueLteMax(eptr,range)) {
1120             /* Delete both the element and the score. */
1121             zl = lpDeleteRangeWithEntry(zl,&eptr,2);
1122             num++;
1123         } else {
1124             /* No longer in range. */
1125             break;
1126         }
1127     }
1128 
1129     if (deleted != NULL) *deleted = num;
1130     return zl;
1131 }
1132 
1133 /* Delete all the elements with rank between start and end from the skiplist.
1134  * Start and end are inclusive. Note that start and end need to be 1-based */
zzlDeleteRangeByRank(unsigned char * zl,unsigned int start,unsigned int end,unsigned long * deleted)1135 unsigned char *zzlDeleteRangeByRank(unsigned char *zl, unsigned int start, unsigned int end, unsigned long *deleted) {
1136     unsigned int num = (end-start)+1;
1137     if (deleted) *deleted = num;
1138     zl = lpDeleteRange(zl,2*(start-1),2*num);
1139     return zl;
1140 }
1141 
1142 /*-----------------------------------------------------------------------------
1143  * Common sorted set API
1144  *----------------------------------------------------------------------------*/
1145 
zsetLength(const robj * zobj)1146 unsigned long zsetLength(const robj *zobj) {
1147     unsigned long length = 0;
1148     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1149         length = zzlLength(zobj->ptr);
1150     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1151         length = ((const zset*)zobj->ptr)->zsl->length;
1152     } else {
1153         serverPanic("Unknown sorted set encoding");
1154     }
1155     return length;
1156 }
1157 
zsetConvert(robj * zobj,int encoding)1158 void zsetConvert(robj *zobj, int encoding) {
1159     zset *zs;
1160     zskiplistNode *node, *next;
1161     sds ele;
1162     double score;
1163 
1164     if (zobj->encoding == encoding) return;
1165     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1166         unsigned char *zl = zobj->ptr;
1167         unsigned char *eptr, *sptr;
1168         unsigned char *vstr;
1169         unsigned int vlen;
1170         long long vlong;
1171 
1172         if (encoding != OBJ_ENCODING_SKIPLIST)
1173             serverPanic("Unknown target encoding");
1174 
1175         zs = zmalloc(sizeof(*zs));
1176         zs->dict = dictCreate(&zsetDictType);
1177         zs->zsl = zslCreate();
1178 
1179         eptr = lpSeek(zl,0);
1180         serverAssertWithInfo(NULL,zobj,eptr != NULL);
1181         sptr = lpNext(zl,eptr);
1182         serverAssertWithInfo(NULL,zobj,sptr != NULL);
1183 
1184         while (eptr != NULL) {
1185             score = zzlGetScore(sptr);
1186             vstr = lpGetValue(eptr,&vlen,&vlong);
1187             if (vstr == NULL)
1188                 ele = sdsfromlonglong(vlong);
1189             else
1190                 ele = sdsnewlen((char*)vstr,vlen);
1191 
1192             node = zslInsert(zs->zsl,score,ele);
1193             serverAssert(dictAdd(zs->dict,ele,&node->score) == DICT_OK);
1194             zzlNext(zl,&eptr,&sptr);
1195         }
1196 
1197         zfree(zobj->ptr);
1198         zobj->ptr = zs;
1199         zobj->encoding = OBJ_ENCODING_SKIPLIST;
1200     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1201         unsigned char *zl = lpNew(0);
1202 
1203         if (encoding != OBJ_ENCODING_LISTPACK)
1204             serverPanic("Unknown target encoding");
1205 
1206         /* Approach similar to zslFree(), since we want to free the skiplist at
1207          * the same time as creating the listpack. */
1208         zs = zobj->ptr;
1209         dictRelease(zs->dict);
1210         node = zs->zsl->header->level[0].forward;
1211         zfree(zs->zsl->header);
1212         zfree(zs->zsl);
1213 
1214         while (node) {
1215             zl = zzlInsertAt(zl,NULL,node->ele,node->score);
1216             next = node->level[0].forward;
1217             zslFreeNode(node);
1218             node = next;
1219         }
1220 
1221         zfree(zs);
1222         zobj->ptr = zl;
1223         zobj->encoding = OBJ_ENCODING_LISTPACK;
1224     } else {
1225         serverPanic("Unknown sorted set encoding");
1226     }
1227 }
1228 
1229 /* Convert the sorted set object into a listpack if it is not already a listpack
1230  * and if the number of elements and the maximum element size and total elements size
1231  * are within the expected ranges. */
zsetConvertToListpackIfNeeded(robj * zobj,size_t maxelelen,size_t totelelen)1232 void zsetConvertToListpackIfNeeded(robj *zobj, size_t maxelelen, size_t totelelen) {
1233     if (zobj->encoding == OBJ_ENCODING_LISTPACK) return;
1234     zset *zset = zobj->ptr;
1235 
1236     if (zset->zsl->length <= server.zset_max_listpack_entries &&
1237         maxelelen <= server.zset_max_listpack_value &&
1238         lpSafeToAdd(NULL, totelelen))
1239     {
1240         zsetConvert(zobj,OBJ_ENCODING_LISTPACK);
1241     }
1242 }
1243 
1244 /* Return (by reference) the score of the specified member of the sorted set
1245  * storing it into *score. If the element does not exist C_ERR is returned
1246  * otherwise C_OK is returned and *score is correctly populated.
1247  * If 'zobj' or 'member' is NULL, C_ERR is returned. */
zsetScore(robj * zobj,sds member,double * score)1248 int zsetScore(robj *zobj, sds member, double *score) {
1249     if (!zobj || !member) return C_ERR;
1250 
1251     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1252         if (zzlFind(zobj->ptr, member, score) == NULL) return C_ERR;
1253     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1254         zset *zs = zobj->ptr;
1255         dictEntry *de = dictFind(zs->dict, member);
1256         if (de == NULL) return C_ERR;
1257         *score = *(double*)dictGetVal(de);
1258     } else {
1259         serverPanic("Unknown sorted set encoding");
1260     }
1261     return C_OK;
1262 }
1263 
1264 /* Add a new element or update the score of an existing element in a sorted
1265  * set, regardless of its encoding.
1266  *
1267  * The set of flags change the command behavior.
1268  *
1269  * The input flags are the following:
1270  *
1271  * ZADD_INCR: Increment the current element score by 'score' instead of updating
1272  *            the current element score. If the element does not exist, we
1273  *            assume 0 as previous score.
1274  * ZADD_NX:   Perform the operation only if the element does not exist.
1275  * ZADD_XX:   Perform the operation only if the element already exist.
1276  * ZADD_GT:   Perform the operation on existing elements only if the new score is
1277  *            greater than the current score.
1278  * ZADD_LT:   Perform the operation on existing elements only if the new score is
1279  *            less than the current score.
1280  *
1281  * When ZADD_INCR is used, the new score of the element is stored in
1282  * '*newscore' if 'newscore' is not NULL.
1283  *
1284  * The returned flags are the following:
1285  *
1286  * ZADD_NAN:     The resulting score is not a number.
1287  * ZADD_ADDED:   The element was added (not present before the call).
1288  * ZADD_UPDATED: The element score was updated.
1289  * ZADD_NOP:     No operation was performed because of NX or XX.
1290  *
1291  * Return value:
1292  *
1293  * The function returns 1 on success, and sets the appropriate flags
1294  * ADDED or UPDATED to signal what happened during the operation (note that
1295  * none could be set if we re-added an element using the same score it used
1296  * to have, or in the case a zero increment is used).
1297  *
1298  * The function returns 0 on error, currently only when the increment
1299  * produces a NAN condition, or when the 'score' value is NAN since the
1300  * start.
1301  *
1302  * The command as a side effect of adding a new element may convert the sorted
1303  * set internal encoding from listpack to hashtable+skiplist.
1304  *
1305  * Memory management of 'ele':
1306  *
1307  * The function does not take ownership of the 'ele' SDS string, but copies
1308  * it if needed. */
zsetAdd(robj * zobj,double score,sds ele,int in_flags,int * out_flags,double * newscore)1309 int zsetAdd(robj *zobj, double score, sds ele, int in_flags, int *out_flags, double *newscore) {
1310     /* Turn options into simple to check vars. */
1311     int incr = (in_flags & ZADD_IN_INCR) != 0;
1312     int nx = (in_flags & ZADD_IN_NX) != 0;
1313     int xx = (in_flags & ZADD_IN_XX) != 0;
1314     int gt = (in_flags & ZADD_IN_GT) != 0;
1315     int lt = (in_flags & ZADD_IN_LT) != 0;
1316     *out_flags = 0; /* We'll return our response flags. */
1317     double curscore;
1318 
1319     /* NaN as input is an error regardless of all the other parameters. */
1320     if (isnan(score)) {
1321         *out_flags = ZADD_OUT_NAN;
1322         return 0;
1323     }
1324 
1325     /* Update the sorted set according to its encoding. */
1326     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1327         unsigned char *eptr;
1328 
1329         if ((eptr = zzlFind(zobj->ptr,ele,&curscore)) != NULL) {
1330             /* NX? Return, same element already exists. */
1331             if (nx) {
1332                 *out_flags |= ZADD_OUT_NOP;
1333                 return 1;
1334             }
1335 
1336             /* Prepare the score for the increment if needed. */
1337             if (incr) {
1338                 score += curscore;
1339                 if (isnan(score)) {
1340                     *out_flags |= ZADD_OUT_NAN;
1341                     return 0;
1342                 }
1343             }
1344 
1345             /* GT/LT? Only update if score is greater/less than current. */
1346             if ((lt && score >= curscore) || (gt && score <= curscore)) {
1347                 *out_flags |= ZADD_OUT_NOP;
1348                 return 1;
1349             }
1350 
1351             if (newscore) *newscore = score;
1352 
1353             /* Remove and re-insert when score changed. */
1354             if (score != curscore) {
1355                 zobj->ptr = zzlDelete(zobj->ptr,eptr);
1356                 zobj->ptr = zzlInsert(zobj->ptr,ele,score);
1357                 *out_flags |= ZADD_OUT_UPDATED;
1358             }
1359             return 1;
1360         } else if (!xx) {
1361             /* check if the element is too large or the list
1362              * becomes too long *before* executing zzlInsert. */
1363             if (zzlLength(zobj->ptr)+1 > server.zset_max_listpack_entries ||
1364                 sdslen(ele) > server.zset_max_listpack_value ||
1365                 !lpSafeToAdd(zobj->ptr, sdslen(ele)))
1366             {
1367                 zsetConvert(zobj,OBJ_ENCODING_SKIPLIST);
1368             } else {
1369                 zobj->ptr = zzlInsert(zobj->ptr,ele,score);
1370                 if (newscore) *newscore = score;
1371                 *out_flags |= ZADD_OUT_ADDED;
1372                 return 1;
1373             }
1374         } else {
1375             *out_flags |= ZADD_OUT_NOP;
1376             return 1;
1377         }
1378     }
1379 
1380     /* Note that the above block handling listpack would have either returned or
1381      * converted the key to skiplist. */
1382     if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1383         zset *zs = zobj->ptr;
1384         zskiplistNode *znode;
1385         dictEntry *de;
1386 
1387         de = dictFind(zs->dict,ele);
1388         if (de != NULL) {
1389             /* NX? Return, same element already exists. */
1390             if (nx) {
1391                 *out_flags |= ZADD_OUT_NOP;
1392                 return 1;
1393             }
1394 
1395             curscore = *(double*)dictGetVal(de);
1396 
1397             /* Prepare the score for the increment if needed. */
1398             if (incr) {
1399                 score += curscore;
1400                 if (isnan(score)) {
1401                     *out_flags |= ZADD_OUT_NAN;
1402                     return 0;
1403                 }
1404             }
1405 
1406             /* GT/LT? Only update if score is greater/less than current. */
1407             if ((lt && score >= curscore) || (gt && score <= curscore)) {
1408                 *out_flags |= ZADD_OUT_NOP;
1409                 return 1;
1410             }
1411 
1412             if (newscore) *newscore = score;
1413 
1414             /* Remove and re-insert when score changes. */
1415             if (score != curscore) {
1416                 znode = zslUpdateScore(zs->zsl,curscore,ele,score);
1417                 /* Note that we did not removed the original element from
1418                  * the hash table representing the sorted set, so we just
1419                  * update the score. */
1420                 dictGetVal(de) = &znode->score; /* Update score ptr. */
1421                 *out_flags |= ZADD_OUT_UPDATED;
1422             }
1423             return 1;
1424         } else if (!xx) {
1425             ele = sdsdup(ele);
1426             znode = zslInsert(zs->zsl,score,ele);
1427             serverAssert(dictAdd(zs->dict,ele,&znode->score) == DICT_OK);
1428             *out_flags |= ZADD_OUT_ADDED;
1429             if (newscore) *newscore = score;
1430             return 1;
1431         } else {
1432             *out_flags |= ZADD_OUT_NOP;
1433             return 1;
1434         }
1435     } else {
1436         serverPanic("Unknown sorted set encoding");
1437     }
1438     return 0; /* Never reached. */
1439 }
1440 
1441 /* Deletes the element 'ele' from the sorted set encoded as a skiplist+dict,
1442  * returning 1 if the element existed and was deleted, 0 otherwise (the
1443  * element was not there). It does not resize the dict after deleting the
1444  * element. */
zsetRemoveFromSkiplist(zset * zs,sds ele)1445 static int zsetRemoveFromSkiplist(zset *zs, sds ele) {
1446     dictEntry *de;
1447     double score;
1448 
1449     de = dictUnlink(zs->dict,ele);
1450     if (de != NULL) {
1451         /* Get the score in order to delete from the skiplist later. */
1452         score = *(double*)dictGetVal(de);
1453 
1454         /* Delete from the hash table and later from the skiplist.
1455          * Note that the order is important: deleting from the skiplist
1456          * actually releases the SDS string representing the element,
1457          * which is shared between the skiplist and the hash table, so
1458          * we need to delete from the skiplist as the final step. */
1459         dictFreeUnlinkedEntry(zs->dict,de);
1460 
1461         /* Delete from skiplist. */
1462         int retval = zslDelete(zs->zsl,score,ele,NULL);
1463         serverAssert(retval);
1464 
1465         return 1;
1466     }
1467 
1468     return 0;
1469 }
1470 
1471 /* Delete the element 'ele' from the sorted set, returning 1 if the element
1472  * existed and was deleted, 0 otherwise (the element was not there). */
zsetDel(robj * zobj,sds ele)1473 int zsetDel(robj *zobj, sds ele) {
1474     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1475         unsigned char *eptr;
1476 
1477         if ((eptr = zzlFind(zobj->ptr,ele,NULL)) != NULL) {
1478             zobj->ptr = zzlDelete(zobj->ptr,eptr);
1479             return 1;
1480         }
1481     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1482         zset *zs = zobj->ptr;
1483         if (zsetRemoveFromSkiplist(zs, ele)) {
1484             if (htNeedsResize(zs->dict)) dictResize(zs->dict);
1485             return 1;
1486         }
1487     } else {
1488         serverPanic("Unknown sorted set encoding");
1489     }
1490     return 0; /* No such element found. */
1491 }
1492 
1493 /* Given a sorted set object returns the 0-based rank of the object or
1494  * -1 if the object does not exist.
1495  *
1496  * For rank we mean the position of the element in the sorted collection
1497  * of elements. So the first element has rank 0, the second rank 1, and so
1498  * forth up to length-1 elements.
1499  *
1500  * If 'reverse' is false, the rank is returned considering as first element
1501  * the one with the lowest score. Otherwise if 'reverse' is non-zero
1502  * the rank is computed considering as element with rank 0 the one with
1503  * the highest score. */
zsetRank(robj * zobj,sds ele,int reverse)1504 long zsetRank(robj *zobj, sds ele, int reverse) {
1505     unsigned long llen;
1506     unsigned long rank;
1507 
1508     llen = zsetLength(zobj);
1509 
1510     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1511         unsigned char *zl = zobj->ptr;
1512         unsigned char *eptr, *sptr;
1513 
1514         eptr = lpSeek(zl,0);
1515         serverAssert(eptr != NULL);
1516         sptr = lpNext(zl,eptr);
1517         serverAssert(sptr != NULL);
1518 
1519         rank = 1;
1520         while(eptr != NULL) {
1521             if (lpCompare(eptr,(unsigned char*)ele,sdslen(ele)))
1522                 break;
1523             rank++;
1524             zzlNext(zl,&eptr,&sptr);
1525         }
1526 
1527         if (eptr != NULL) {
1528             if (reverse)
1529                 return llen-rank;
1530             else
1531                 return rank-1;
1532         } else {
1533             return -1;
1534         }
1535     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1536         zset *zs = zobj->ptr;
1537         zskiplist *zsl = zs->zsl;
1538         dictEntry *de;
1539         double score;
1540 
1541         de = dictFind(zs->dict,ele);
1542         if (de != NULL) {
1543             score = *(double*)dictGetVal(de);
1544             rank = zslGetRank(zsl,score,ele);
1545             /* Existing elements always have a rank. */
1546             serverAssert(rank != 0);
1547             if (reverse)
1548                 return llen-rank;
1549             else
1550                 return rank-1;
1551         } else {
1552             return -1;
1553         }
1554     } else {
1555         serverPanic("Unknown sorted set encoding");
1556     }
1557 }
1558 
1559 /* This is a helper function for the COPY command.
1560  * Duplicate a sorted set object, with the guarantee that the returned object
1561  * has the same encoding as the original one.
1562  *
1563  * The resulting object always has refcount set to 1 */
zsetDup(robj * o)1564 robj *zsetDup(robj *o) {
1565     robj *zobj;
1566     zset *zs;
1567     zset *new_zs;
1568 
1569     serverAssert(o->type == OBJ_ZSET);
1570 
1571     /* Create a new sorted set object that have the same encoding as the original object's encoding */
1572     if (o->encoding == OBJ_ENCODING_LISTPACK) {
1573         unsigned char *zl = o->ptr;
1574         size_t sz = lpBytes(zl);
1575         unsigned char *new_zl = zmalloc(sz);
1576         memcpy(new_zl, zl, sz);
1577         zobj = createObject(OBJ_ZSET, new_zl);
1578         zobj->encoding = OBJ_ENCODING_LISTPACK;
1579     } else if (o->encoding == OBJ_ENCODING_SKIPLIST) {
1580         zobj = createZsetObject();
1581         zs = o->ptr;
1582         new_zs = zobj->ptr;
1583         dictExpand(new_zs->dict,dictSize(zs->dict));
1584         zskiplist *zsl = zs->zsl;
1585         zskiplistNode *ln;
1586         sds ele;
1587         long llen = zsetLength(o);
1588 
1589         /* We copy the skiplist elements from the greatest to the
1590          * smallest (that's trivial since the elements are already ordered in
1591          * the skiplist): this improves the load process, since the next loaded
1592          * element will always be the smaller, so adding to the skiplist
1593          * will always immediately stop at the head, making the insertion
1594          * O(1) instead of O(log(N)). */
1595         ln = zsl->tail;
1596         while (llen--) {
1597             ele = ln->ele;
1598             sds new_ele = sdsdup(ele);
1599             zskiplistNode *znode = zslInsert(new_zs->zsl,ln->score,new_ele);
1600             dictAdd(new_zs->dict,new_ele,&znode->score);
1601             ln = ln->backward;
1602         }
1603     } else {
1604         serverPanic("Unknown sorted set encoding");
1605     }
1606     return zobj;
1607 }
1608 
1609 /* Create a new sds string from the listpack entry. */
zsetSdsFromListpackEntry(listpackEntry * e)1610 sds zsetSdsFromListpackEntry(listpackEntry *e) {
1611     return e->sval ? sdsnewlen(e->sval, e->slen) : sdsfromlonglong(e->lval);
1612 }
1613 
1614 /* Reply with bulk string from the listpack entry. */
zsetReplyFromListpackEntry(client * c,listpackEntry * e)1615 void zsetReplyFromListpackEntry(client *c, listpackEntry *e) {
1616     if (e->sval)
1617         addReplyBulkCBuffer(c, e->sval, e->slen);
1618     else
1619         addReplyBulkLongLong(c, e->lval);
1620 }
1621 
1622 
1623 /* Return random element from a non empty zset.
1624  * 'key' and 'val' will be set to hold the element.
1625  * The memory in `key` is not to be freed or modified by the caller.
1626  * 'score' can be NULL in which case it's not extracted. */
zsetTypeRandomElement(robj * zsetobj,unsigned long zsetsize,listpackEntry * key,double * score)1627 void zsetTypeRandomElement(robj *zsetobj, unsigned long zsetsize, listpackEntry *key, double *score) {
1628     if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) {
1629         zset *zs = zsetobj->ptr;
1630         dictEntry *de = dictGetFairRandomKey(zs->dict);
1631         sds s = dictGetKey(de);
1632         key->sval = (unsigned char*)s;
1633         key->slen = sdslen(s);
1634         if (score)
1635             *score = *(double*)dictGetVal(de);
1636     } else if (zsetobj->encoding == OBJ_ENCODING_LISTPACK) {
1637         listpackEntry val;
1638         lpRandomPair(zsetobj->ptr, zsetsize, key, &val);
1639         if (score) {
1640             if (val.sval) {
1641                 *score = zzlStrtod(val.sval,val.slen);
1642             } else {
1643                 *score = (double)val.lval;
1644             }
1645         }
1646     } else {
1647         serverPanic("Unknown zset encoding");
1648     }
1649 }
1650 
1651 /*-----------------------------------------------------------------------------
1652  * Sorted set commands
1653  *----------------------------------------------------------------------------*/
1654 
1655 /* This generic command implements both ZADD and ZINCRBY. */
zaddGenericCommand(client * c,int flags)1656 void zaddGenericCommand(client *c, int flags) {
1657     static char *nanerr = "resulting score is not a number (NaN)";
1658     robj *key = c->argv[1];
1659     robj *zobj;
1660     sds ele;
1661     double score = 0, *scores = NULL;
1662     int j, elements, ch = 0;
1663     int scoreidx = 0;
1664     /* The following vars are used in order to track what the command actually
1665      * did during the execution, to reply to the client and to trigger the
1666      * notification of keyspace change. */
1667     int added = 0;      /* Number of new elements added. */
1668     int updated = 0;    /* Number of elements with updated score. */
1669     int processed = 0;  /* Number of elements processed, may remain zero with
1670                            options like XX. */
1671 
1672     /* Parse options. At the end 'scoreidx' is set to the argument position
1673      * of the score of the first score-element pair. */
1674     scoreidx = 2;
1675     while(scoreidx < c->argc) {
1676         char *opt = c->argv[scoreidx]->ptr;
1677         if (!strcasecmp(opt,"nx")) flags |= ZADD_IN_NX;
1678         else if (!strcasecmp(opt,"xx")) flags |= ZADD_IN_XX;
1679         else if (!strcasecmp(opt,"ch")) ch = 1; /* Return num of elements added or updated. */
1680         else if (!strcasecmp(opt,"incr")) flags |= ZADD_IN_INCR;
1681         else if (!strcasecmp(opt,"gt")) flags |= ZADD_IN_GT;
1682         else if (!strcasecmp(opt,"lt")) flags |= ZADD_IN_LT;
1683         else break;
1684         scoreidx++;
1685     }
1686 
1687     /* Turn options into simple to check vars. */
1688     int incr = (flags & ZADD_IN_INCR) != 0;
1689     int nx = (flags & ZADD_IN_NX) != 0;
1690     int xx = (flags & ZADD_IN_XX) != 0;
1691     int gt = (flags & ZADD_IN_GT) != 0;
1692     int lt = (flags & ZADD_IN_LT) != 0;
1693 
1694     /* After the options, we expect to have an even number of args, since
1695      * we expect any number of score-element pairs. */
1696     elements = c->argc-scoreidx;
1697     if (elements % 2 || !elements) {
1698         addReplyErrorObject(c,shared.syntaxerr);
1699         return;
1700     }
1701     elements /= 2; /* Now this holds the number of score-element pairs. */
1702 
1703     /* Check for incompatible options. */
1704     if (nx && xx) {
1705         addReplyError(c,
1706             "XX and NX options at the same time are not compatible");
1707         return;
1708     }
1709 
1710     if ((gt && nx) || (lt && nx) || (gt && lt)) {
1711         addReplyError(c,
1712             "GT, LT, and/or NX options at the same time are not compatible");
1713         return;
1714     }
1715     /* Note that XX is compatible with either GT or LT */
1716 
1717     if (incr && elements > 1) {
1718         addReplyError(c,
1719             "INCR option supports a single increment-element pair");
1720         return;
1721     }
1722 
1723     /* Start parsing all the scores, we need to emit any syntax error
1724      * before executing additions to the sorted set, as the command should
1725      * either execute fully or nothing at all. */
1726     scores = zmalloc(sizeof(double)*elements);
1727     for (j = 0; j < elements; j++) {
1728         if (getDoubleFromObjectOrReply(c,c->argv[scoreidx+j*2],&scores[j],NULL)
1729             != C_OK) goto cleanup;
1730     }
1731 
1732     /* Lookup the key and create the sorted set if does not exist. */
1733     zobj = lookupKeyWrite(c->db,key);
1734     if (checkType(c,zobj,OBJ_ZSET)) goto cleanup;
1735     if (zobj == NULL) {
1736         if (xx) goto reply_to_client; /* No key + XX option: nothing to do. */
1737         if (server.zset_max_listpack_entries == 0 ||
1738             server.zset_max_listpack_value < sdslen(c->argv[scoreidx+1]->ptr))
1739         {
1740             zobj = createZsetObject();
1741         } else {
1742             zobj = createZsetListpackObject();
1743         }
1744         dbAdd(c->db,key,zobj);
1745     }
1746 
1747     for (j = 0; j < elements; j++) {
1748         double newscore;
1749         score = scores[j];
1750         int retflags = 0;
1751 
1752         ele = c->argv[scoreidx+1+j*2]->ptr;
1753         int retval = zsetAdd(zobj, score, ele, flags, &retflags, &newscore);
1754         if (retval == 0) {
1755             addReplyError(c,nanerr);
1756             goto cleanup;
1757         }
1758         if (retflags & ZADD_OUT_ADDED) added++;
1759         if (retflags & ZADD_OUT_UPDATED) updated++;
1760         if (!(retflags & ZADD_OUT_NOP)) processed++;
1761         score = newscore;
1762     }
1763     server.dirty += (added+updated);
1764 
1765 reply_to_client:
1766     if (incr) { /* ZINCRBY or INCR option. */
1767         if (processed)
1768             addReplyDouble(c,score);
1769         else
1770             addReplyNull(c);
1771     } else { /* ZADD. */
1772         addReplyLongLong(c,ch ? added+updated : added);
1773     }
1774 
1775 cleanup:
1776     zfree(scores);
1777     if (added || updated) {
1778         signalModifiedKey(c,c->db,key);
1779         notifyKeyspaceEvent(NOTIFY_ZSET,
1780             incr ? "zincr" : "zadd", key, c->db->id);
1781     }
1782 }
1783 
zaddCommand(client * c)1784 void zaddCommand(client *c) {
1785     zaddGenericCommand(c,ZADD_IN_NONE);
1786 }
1787 
zincrbyCommand(client * c)1788 void zincrbyCommand(client *c) {
1789     zaddGenericCommand(c,ZADD_IN_INCR);
1790 }
1791 
zremCommand(client * c)1792 void zremCommand(client *c) {
1793     robj *key = c->argv[1];
1794     robj *zobj;
1795     int deleted = 0, keyremoved = 0, j;
1796 
1797     if ((zobj = lookupKeyWriteOrReply(c,key,shared.czero)) == NULL ||
1798         checkType(c,zobj,OBJ_ZSET)) return;
1799 
1800     for (j = 2; j < c->argc; j++) {
1801         if (zsetDel(zobj,c->argv[j]->ptr)) deleted++;
1802         if (zsetLength(zobj) == 0) {
1803             dbDelete(c->db,key);
1804             keyremoved = 1;
1805             break;
1806         }
1807     }
1808 
1809     if (deleted) {
1810         notifyKeyspaceEvent(NOTIFY_ZSET,"zrem",key,c->db->id);
1811         if (keyremoved)
1812             notifyKeyspaceEvent(NOTIFY_GENERIC,"del",key,c->db->id);
1813         signalModifiedKey(c,c->db,key);
1814         server.dirty += deleted;
1815     }
1816     addReplyLongLong(c,deleted);
1817 }
1818 
1819 typedef enum {
1820     ZRANGE_AUTO = 0,
1821     ZRANGE_RANK,
1822     ZRANGE_SCORE,
1823     ZRANGE_LEX,
1824 } zrange_type;
1825 
1826 /* Implements ZREMRANGEBYRANK, ZREMRANGEBYSCORE, ZREMRANGEBYLEX commands. */
zremrangeGenericCommand(client * c,zrange_type rangetype)1827 void zremrangeGenericCommand(client *c, zrange_type rangetype) {
1828     robj *key = c->argv[1];
1829     robj *zobj;
1830     int keyremoved = 0;
1831     unsigned long deleted = 0;
1832     zrangespec range;
1833     zlexrangespec lexrange;
1834     long start, end, llen;
1835     char *notify_type = NULL;
1836 
1837     /* Step 1: Parse the range. */
1838     if (rangetype == ZRANGE_RANK) {
1839         notify_type = "zremrangebyrank";
1840         if ((getLongFromObjectOrReply(c,c->argv[2],&start,NULL) != C_OK) ||
1841             (getLongFromObjectOrReply(c,c->argv[3],&end,NULL) != C_OK))
1842             return;
1843     } else if (rangetype == ZRANGE_SCORE) {
1844         notify_type = "zremrangebyscore";
1845         if (zslParseRange(c->argv[2],c->argv[3],&range) != C_OK) {
1846             addReplyError(c,"min or max is not a float");
1847             return;
1848         }
1849     } else if (rangetype == ZRANGE_LEX) {
1850         notify_type = "zremrangebylex";
1851         if (zslParseLexRange(c->argv[2],c->argv[3],&lexrange) != C_OK) {
1852             addReplyError(c,"min or max not valid string range item");
1853             return;
1854         }
1855     } else {
1856         serverPanic("unknown rangetype %d", (int)rangetype);
1857     }
1858 
1859     /* Step 2: Lookup & range sanity checks if needed. */
1860     if ((zobj = lookupKeyWriteOrReply(c,key,shared.czero)) == NULL ||
1861         checkType(c,zobj,OBJ_ZSET)) goto cleanup;
1862 
1863     if (rangetype == ZRANGE_RANK) {
1864         /* Sanitize indexes. */
1865         llen = zsetLength(zobj);
1866         if (start < 0) start = llen+start;
1867         if (end < 0) end = llen+end;
1868         if (start < 0) start = 0;
1869 
1870         /* Invariant: start >= 0, so this test will be true when end < 0.
1871          * The range is empty when start > end or start >= length. */
1872         if (start > end || start >= llen) {
1873             addReply(c,shared.czero);
1874             goto cleanup;
1875         }
1876         if (end >= llen) end = llen-1;
1877     }
1878 
1879     /* Step 3: Perform the range deletion operation. */
1880     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
1881         switch(rangetype) {
1882         case ZRANGE_AUTO:
1883         case ZRANGE_RANK:
1884             zobj->ptr = zzlDeleteRangeByRank(zobj->ptr,start+1,end+1,&deleted);
1885             break;
1886         case ZRANGE_SCORE:
1887             zobj->ptr = zzlDeleteRangeByScore(zobj->ptr,&range,&deleted);
1888             break;
1889         case ZRANGE_LEX:
1890             zobj->ptr = zzlDeleteRangeByLex(zobj->ptr,&lexrange,&deleted);
1891             break;
1892         }
1893         if (zzlLength(zobj->ptr) == 0) {
1894             dbDelete(c->db,key);
1895             keyremoved = 1;
1896         }
1897     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
1898         zset *zs = zobj->ptr;
1899         switch(rangetype) {
1900         case ZRANGE_AUTO:
1901         case ZRANGE_RANK:
1902             deleted = zslDeleteRangeByRank(zs->zsl,start+1,end+1,zs->dict);
1903             break;
1904         case ZRANGE_SCORE:
1905             deleted = zslDeleteRangeByScore(zs->zsl,&range,zs->dict);
1906             break;
1907         case ZRANGE_LEX:
1908             deleted = zslDeleteRangeByLex(zs->zsl,&lexrange,zs->dict);
1909             break;
1910         }
1911         if (htNeedsResize(zs->dict)) dictResize(zs->dict);
1912         if (dictSize(zs->dict) == 0) {
1913             dbDelete(c->db,key);
1914             keyremoved = 1;
1915         }
1916     } else {
1917         serverPanic("Unknown sorted set encoding");
1918     }
1919 
1920     /* Step 4: Notifications and reply. */
1921     if (deleted) {
1922         signalModifiedKey(c,c->db,key);
1923         notifyKeyspaceEvent(NOTIFY_ZSET,notify_type,key,c->db->id);
1924         if (keyremoved)
1925             notifyKeyspaceEvent(NOTIFY_GENERIC,"del",key,c->db->id);
1926     }
1927     server.dirty += deleted;
1928     addReplyLongLong(c,deleted);
1929 
1930 cleanup:
1931     if (rangetype == ZRANGE_LEX) zslFreeLexRange(&lexrange);
1932 }
1933 
zremrangebyrankCommand(client * c)1934 void zremrangebyrankCommand(client *c) {
1935     zremrangeGenericCommand(c,ZRANGE_RANK);
1936 }
1937 
zremrangebyscoreCommand(client * c)1938 void zremrangebyscoreCommand(client *c) {
1939     zremrangeGenericCommand(c,ZRANGE_SCORE);
1940 }
1941 
zremrangebylexCommand(client * c)1942 void zremrangebylexCommand(client *c) {
1943     zremrangeGenericCommand(c,ZRANGE_LEX);
1944 }
1945 
1946 typedef struct {
1947     robj *subject;
1948     int type; /* Set, sorted set */
1949     int encoding;
1950     double weight;
1951 
1952     union {
1953         /* Set iterators. */
1954         union _iterset {
1955             struct {
1956                 intset *is;
1957                 int ii;
1958             } is;
1959             struct {
1960                 dict *dict;
1961                 dictIterator *di;
1962                 dictEntry *de;
1963             } ht;
1964         } set;
1965 
1966         /* Sorted set iterators. */
1967         union _iterzset {
1968             struct {
1969                 unsigned char *zl;
1970                 unsigned char *eptr, *sptr;
1971             } zl;
1972             struct {
1973                 zset *zs;
1974                 zskiplistNode *node;
1975             } sl;
1976         } zset;
1977     } iter;
1978 } zsetopsrc;
1979 
1980 
1981 /* Use dirty flags for pointers that need to be cleaned up in the next
1982  * iteration over the zsetopval. The dirty flag for the long long value is
1983  * special, since long long values don't need cleanup. Instead, it means that
1984  * we already checked that "ell" holds a long long, or tried to convert another
1985  * representation into a long long value. When this was successful,
1986  * OPVAL_VALID_LL is set as well. */
1987 #define OPVAL_DIRTY_SDS 1
1988 #define OPVAL_DIRTY_LL 2
1989 #define OPVAL_VALID_LL 4
1990 
1991 /* Store value retrieved from the iterator. */
1992 typedef struct {
1993     int flags;
1994     unsigned char _buf[32]; /* Private buffer. */
1995     sds ele;
1996     unsigned char *estr;
1997     unsigned int elen;
1998     long long ell;
1999     double score;
2000 } zsetopval;
2001 
2002 typedef union _iterset iterset;
2003 typedef union _iterzset iterzset;
2004 
zuiInitIterator(zsetopsrc * op)2005 void zuiInitIterator(zsetopsrc *op) {
2006     if (op->subject == NULL)
2007         return;
2008 
2009     if (op->type == OBJ_SET) {
2010         iterset *it = &op->iter.set;
2011         if (op->encoding == OBJ_ENCODING_INTSET) {
2012             it->is.is = op->subject->ptr;
2013             it->is.ii = 0;
2014         } else if (op->encoding == OBJ_ENCODING_HT) {
2015             it->ht.dict = op->subject->ptr;
2016             it->ht.di = dictGetIterator(op->subject->ptr);
2017             it->ht.de = dictNext(it->ht.di);
2018         } else {
2019             serverPanic("Unknown set encoding");
2020         }
2021     } else if (op->type == OBJ_ZSET) {
2022         /* Sorted sets are traversed in reverse order to optimize for
2023          * the insertion of the elements in a new list as in
2024          * ZDIFF/ZINTER/ZUNION */
2025         iterzset *it = &op->iter.zset;
2026         if (op->encoding == OBJ_ENCODING_LISTPACK) {
2027             it->zl.zl = op->subject->ptr;
2028             it->zl.eptr = lpSeek(it->zl.zl,-2);
2029             if (it->zl.eptr != NULL) {
2030                 it->zl.sptr = lpNext(it->zl.zl,it->zl.eptr);
2031                 serverAssert(it->zl.sptr != NULL);
2032             }
2033         } else if (op->encoding == OBJ_ENCODING_SKIPLIST) {
2034             it->sl.zs = op->subject->ptr;
2035             it->sl.node = it->sl.zs->zsl->tail;
2036         } else {
2037             serverPanic("Unknown sorted set encoding");
2038         }
2039     } else {
2040         serverPanic("Unsupported type");
2041     }
2042 }
2043 
zuiClearIterator(zsetopsrc * op)2044 void zuiClearIterator(zsetopsrc *op) {
2045     if (op->subject == NULL)
2046         return;
2047 
2048     if (op->type == OBJ_SET) {
2049         iterset *it = &op->iter.set;
2050         if (op->encoding == OBJ_ENCODING_INTSET) {
2051             UNUSED(it); /* skip */
2052         } else if (op->encoding == OBJ_ENCODING_HT) {
2053             dictReleaseIterator(it->ht.di);
2054         } else {
2055             serverPanic("Unknown set encoding");
2056         }
2057     } else if (op->type == OBJ_ZSET) {
2058         iterzset *it = &op->iter.zset;
2059         if (op->encoding == OBJ_ENCODING_LISTPACK) {
2060             UNUSED(it); /* skip */
2061         } else if (op->encoding == OBJ_ENCODING_SKIPLIST) {
2062             UNUSED(it); /* skip */
2063         } else {
2064             serverPanic("Unknown sorted set encoding");
2065         }
2066     } else {
2067         serverPanic("Unsupported type");
2068     }
2069 }
2070 
zuiDiscardDirtyValue(zsetopval * val)2071 void zuiDiscardDirtyValue(zsetopval *val) {
2072     if (val->flags & OPVAL_DIRTY_SDS) {
2073         sdsfree(val->ele);
2074         val->ele = NULL;
2075         val->flags &= ~OPVAL_DIRTY_SDS;
2076     }
2077 }
2078 
zuiLength(zsetopsrc * op)2079 unsigned long zuiLength(zsetopsrc *op) {
2080     if (op->subject == NULL)
2081         return 0;
2082 
2083     if (op->type == OBJ_SET) {
2084         if (op->encoding == OBJ_ENCODING_INTSET) {
2085             return intsetLen(op->subject->ptr);
2086         } else if (op->encoding == OBJ_ENCODING_HT) {
2087             dict *ht = op->subject->ptr;
2088             return dictSize(ht);
2089         } else {
2090             serverPanic("Unknown set encoding");
2091         }
2092     } else if (op->type == OBJ_ZSET) {
2093         if (op->encoding == OBJ_ENCODING_LISTPACK) {
2094             return zzlLength(op->subject->ptr);
2095         } else if (op->encoding == OBJ_ENCODING_SKIPLIST) {
2096             zset *zs = op->subject->ptr;
2097             return zs->zsl->length;
2098         } else {
2099             serverPanic("Unknown sorted set encoding");
2100         }
2101     } else {
2102         serverPanic("Unsupported type");
2103     }
2104 }
2105 
2106 /* Check if the current value is valid. If so, store it in the passed structure
2107  * and move to the next element. If not valid, this means we have reached the
2108  * end of the structure and can abort. */
zuiNext(zsetopsrc * op,zsetopval * val)2109 int zuiNext(zsetopsrc *op, zsetopval *val) {
2110     if (op->subject == NULL)
2111         return 0;
2112 
2113     zuiDiscardDirtyValue(val);
2114 
2115     memset(val,0,sizeof(zsetopval));
2116 
2117     if (op->type == OBJ_SET) {
2118         iterset *it = &op->iter.set;
2119         if (op->encoding == OBJ_ENCODING_INTSET) {
2120             int64_t ell;
2121 
2122             if (!intsetGet(it->is.is,it->is.ii,&ell))
2123                 return 0;
2124             val->ell = ell;
2125             val->score = 1.0;
2126 
2127             /* Move to next element. */
2128             it->is.ii++;
2129         } else if (op->encoding == OBJ_ENCODING_HT) {
2130             if (it->ht.de == NULL)
2131                 return 0;
2132             val->ele = dictGetKey(it->ht.de);
2133             val->score = 1.0;
2134 
2135             /* Move to next element. */
2136             it->ht.de = dictNext(it->ht.di);
2137         } else {
2138             serverPanic("Unknown set encoding");
2139         }
2140     } else if (op->type == OBJ_ZSET) {
2141         iterzset *it = &op->iter.zset;
2142         if (op->encoding == OBJ_ENCODING_LISTPACK) {
2143             /* No need to check both, but better be explicit. */
2144             if (it->zl.eptr == NULL || it->zl.sptr == NULL)
2145                 return 0;
2146             val->estr = lpGetValue(it->zl.eptr,&val->elen,&val->ell);
2147             val->score = zzlGetScore(it->zl.sptr);
2148 
2149             /* Move to next element (going backwards, see zuiInitIterator). */
2150             zzlPrev(it->zl.zl,&it->zl.eptr,&it->zl.sptr);
2151         } else if (op->encoding == OBJ_ENCODING_SKIPLIST) {
2152             if (it->sl.node == NULL)
2153                 return 0;
2154             val->ele = it->sl.node->ele;
2155             val->score = it->sl.node->score;
2156 
2157             /* Move to next element. (going backwards, see zuiInitIterator) */
2158             it->sl.node = it->sl.node->backward;
2159         } else {
2160             serverPanic("Unknown sorted set encoding");
2161         }
2162     } else {
2163         serverPanic("Unsupported type");
2164     }
2165     return 1;
2166 }
2167 
zuiLongLongFromValue(zsetopval * val)2168 int zuiLongLongFromValue(zsetopval *val) {
2169     if (!(val->flags & OPVAL_DIRTY_LL)) {
2170         val->flags |= OPVAL_DIRTY_LL;
2171 
2172         if (val->ele != NULL) {
2173             if (string2ll(val->ele,sdslen(val->ele),&val->ell))
2174                 val->flags |= OPVAL_VALID_LL;
2175         } else if (val->estr != NULL) {
2176             if (string2ll((char*)val->estr,val->elen,&val->ell))
2177                 val->flags |= OPVAL_VALID_LL;
2178         } else {
2179             /* The long long was already set, flag as valid. */
2180             val->flags |= OPVAL_VALID_LL;
2181         }
2182     }
2183     return val->flags & OPVAL_VALID_LL;
2184 }
2185 
zuiSdsFromValue(zsetopval * val)2186 sds zuiSdsFromValue(zsetopval *val) {
2187     if (val->ele == NULL) {
2188         if (val->estr != NULL) {
2189             val->ele = sdsnewlen((char*)val->estr,val->elen);
2190         } else {
2191             val->ele = sdsfromlonglong(val->ell);
2192         }
2193         val->flags |= OPVAL_DIRTY_SDS;
2194     }
2195     return val->ele;
2196 }
2197 
2198 /* This is different from zuiSdsFromValue since returns a new SDS string
2199  * which is up to the caller to free. */
zuiNewSdsFromValue(zsetopval * val)2200 sds zuiNewSdsFromValue(zsetopval *val) {
2201     if (val->flags & OPVAL_DIRTY_SDS) {
2202         /* We have already one to return! */
2203         sds ele = val->ele;
2204         val->flags &= ~OPVAL_DIRTY_SDS;
2205         val->ele = NULL;
2206         return ele;
2207     } else if (val->ele) {
2208         return sdsdup(val->ele);
2209     } else if (val->estr) {
2210         return sdsnewlen((char*)val->estr,val->elen);
2211     } else {
2212         return sdsfromlonglong(val->ell);
2213     }
2214 }
2215 
zuiBufferFromValue(zsetopval * val)2216 int zuiBufferFromValue(zsetopval *val) {
2217     if (val->estr == NULL) {
2218         if (val->ele != NULL) {
2219             val->elen = sdslen(val->ele);
2220             val->estr = (unsigned char*)val->ele;
2221         } else {
2222             val->elen = ll2string((char*)val->_buf,sizeof(val->_buf),val->ell);
2223             val->estr = val->_buf;
2224         }
2225     }
2226     return 1;
2227 }
2228 
2229 /* Find value pointed to by val in the source pointer to by op. When found,
2230  * return 1 and store its score in target. Return 0 otherwise. */
zuiFind(zsetopsrc * op,zsetopval * val,double * score)2231 int zuiFind(zsetopsrc *op, zsetopval *val, double *score) {
2232     if (op->subject == NULL)
2233         return 0;
2234 
2235     if (op->type == OBJ_SET) {
2236         if (op->encoding == OBJ_ENCODING_INTSET) {
2237             if (zuiLongLongFromValue(val) &&
2238                 intsetFind(op->subject->ptr,val->ell))
2239             {
2240                 *score = 1.0;
2241                 return 1;
2242             } else {
2243                 return 0;
2244             }
2245         } else if (op->encoding == OBJ_ENCODING_HT) {
2246             dict *ht = op->subject->ptr;
2247             zuiSdsFromValue(val);
2248             if (dictFind(ht,val->ele) != NULL) {
2249                 *score = 1.0;
2250                 return 1;
2251             } else {
2252                 return 0;
2253             }
2254         } else {
2255             serverPanic("Unknown set encoding");
2256         }
2257     } else if (op->type == OBJ_ZSET) {
2258         zuiSdsFromValue(val);
2259 
2260         if (op->encoding == OBJ_ENCODING_LISTPACK) {
2261             if (zzlFind(op->subject->ptr,val->ele,score) != NULL) {
2262                 /* Score is already set by zzlFind. */
2263                 return 1;
2264             } else {
2265                 return 0;
2266             }
2267         } else if (op->encoding == OBJ_ENCODING_SKIPLIST) {
2268             zset *zs = op->subject->ptr;
2269             dictEntry *de;
2270             if ((de = dictFind(zs->dict,val->ele)) != NULL) {
2271                 *score = *(double*)dictGetVal(de);
2272                 return 1;
2273             } else {
2274                 return 0;
2275             }
2276         } else {
2277             serverPanic("Unknown sorted set encoding");
2278         }
2279     } else {
2280         serverPanic("Unsupported type");
2281     }
2282 }
2283 
zuiCompareByCardinality(const void * s1,const void * s2)2284 int zuiCompareByCardinality(const void *s1, const void *s2) {
2285     unsigned long first = zuiLength((zsetopsrc*)s1);
2286     unsigned long second = zuiLength((zsetopsrc*)s2);
2287     if (first > second) return 1;
2288     if (first < second) return -1;
2289     return 0;
2290 }
2291 
zuiCompareByRevCardinality(const void * s1,const void * s2)2292 static int zuiCompareByRevCardinality(const void *s1, const void *s2) {
2293     return zuiCompareByCardinality(s1, s2) * -1;
2294 }
2295 
2296 #define REDIS_AGGR_SUM 1
2297 #define REDIS_AGGR_MIN 2
2298 #define REDIS_AGGR_MAX 3
2299 #define zunionInterDictValue(_e) (dictGetVal(_e) == NULL ? 1.0 : *(double*)dictGetVal(_e))
2300 
zunionInterAggregate(double * target,double val,int aggregate)2301 inline static void zunionInterAggregate(double *target, double val, int aggregate) {
2302     if (aggregate == REDIS_AGGR_SUM) {
2303         *target = *target + val;
2304         /* The result of adding two doubles is NaN when one variable
2305          * is +inf and the other is -inf. When these numbers are added,
2306          * we maintain the convention of the result being 0.0. */
2307         if (isnan(*target)) *target = 0.0;
2308     } else if (aggregate == REDIS_AGGR_MIN) {
2309         *target = val < *target ? val : *target;
2310     } else if (aggregate == REDIS_AGGR_MAX) {
2311         *target = val > *target ? val : *target;
2312     } else {
2313         /* safety net */
2314         serverPanic("Unknown ZUNION/INTER aggregate type");
2315     }
2316 }
2317 
zsetDictGetMaxElementLength(dict * d,size_t * totallen)2318 static size_t zsetDictGetMaxElementLength(dict *d, size_t *totallen) {
2319     dictIterator *di;
2320     dictEntry *de;
2321     size_t maxelelen = 0;
2322 
2323     di = dictGetIterator(d);
2324 
2325     while((de = dictNext(di)) != NULL) {
2326         sds ele = dictGetKey(de);
2327         if (sdslen(ele) > maxelelen) maxelelen = sdslen(ele);
2328         if (totallen)
2329             (*totallen) += sdslen(ele);
2330     }
2331 
2332     dictReleaseIterator(di);
2333 
2334     return maxelelen;
2335 }
2336 
zdiffAlgorithm1(zsetopsrc * src,long setnum,zset * dstzset,size_t * maxelelen,size_t * totelelen)2337 static void zdiffAlgorithm1(zsetopsrc *src, long setnum, zset *dstzset, size_t *maxelelen, size_t *totelelen) {
2338     /* DIFF Algorithm 1:
2339      *
2340      * We perform the diff by iterating all the elements of the first set,
2341      * and only adding it to the target set if the element does not exist
2342      * into all the other sets.
2343      *
2344      * This way we perform at max N*M operations, where N is the size of
2345      * the first set, and M the number of sets.
2346      *
2347      * There is also a O(K*log(K)) cost for adding the resulting elements
2348      * to the target set, where K is the final size of the target set.
2349      *
2350      * The final complexity of this algorithm is O(N*M + K*log(K)). */
2351     int j;
2352     zsetopval zval;
2353     zskiplistNode *znode;
2354     sds tmp;
2355 
2356     /* With algorithm 1 it is better to order the sets to subtract
2357      * by decreasing size, so that we are more likely to find
2358      * duplicated elements ASAP. */
2359     qsort(src+1,setnum-1,sizeof(zsetopsrc),zuiCompareByRevCardinality);
2360 
2361     memset(&zval, 0, sizeof(zval));
2362     zuiInitIterator(&src[0]);
2363     while (zuiNext(&src[0],&zval)) {
2364         double value;
2365         int exists = 0;
2366 
2367         for (j = 1; j < setnum; j++) {
2368             /* It is not safe to access the zset we are
2369              * iterating, so explicitly check for equal object.
2370              * This check isn't really needed anymore since we already
2371              * check for a duplicate set in the zsetChooseDiffAlgorithm
2372              * function, but we're leaving it for future-proofing. */
2373             if (src[j].subject == src[0].subject ||
2374                 zuiFind(&src[j],&zval,&value)) {
2375                 exists = 1;
2376                 break;
2377             }
2378         }
2379 
2380         if (!exists) {
2381             tmp = zuiNewSdsFromValue(&zval);
2382             znode = zslInsert(dstzset->zsl,zval.score,tmp);
2383             dictAdd(dstzset->dict,tmp,&znode->score);
2384             if (sdslen(tmp) > *maxelelen) *maxelelen = sdslen(tmp);
2385             (*totelelen) += sdslen(tmp);
2386         }
2387     }
2388     zuiClearIterator(&src[0]);
2389 }
2390 
2391 
zdiffAlgorithm2(zsetopsrc * src,long setnum,zset * dstzset,size_t * maxelelen,size_t * totelelen)2392 static void zdiffAlgorithm2(zsetopsrc *src, long setnum, zset *dstzset, size_t *maxelelen, size_t *totelelen) {
2393     /* DIFF Algorithm 2:
2394      *
2395      * Add all the elements of the first set to the auxiliary set.
2396      * Then remove all the elements of all the next sets from it.
2397      *
2398 
2399      * This is O(L + (N-K)log(N)) where L is the sum of all the elements in every
2400      * set, N is the size of the first set, and K is the size of the result set.
2401      *
2402      * Note that from the (L-N) dict searches, (N-K) got to the zsetRemoveFromSkiplist
2403      * which costs log(N)
2404      *
2405      * There is also a O(K) cost at the end for finding the largest element
2406      * size, but this doesn't change the algorithm complexity since K < L, and
2407      * O(2L) is the same as O(L). */
2408     int j;
2409     int cardinality = 0;
2410     zsetopval zval;
2411     zskiplistNode *znode;
2412     sds tmp;
2413 
2414     for (j = 0; j < setnum; j++) {
2415         if (zuiLength(&src[j]) == 0) continue;
2416 
2417         memset(&zval, 0, sizeof(zval));
2418         zuiInitIterator(&src[j]);
2419         while (zuiNext(&src[j],&zval)) {
2420             if (j == 0) {
2421                 tmp = zuiNewSdsFromValue(&zval);
2422                 znode = zslInsert(dstzset->zsl,zval.score,tmp);
2423                 dictAdd(dstzset->dict,tmp,&znode->score);
2424                 cardinality++;
2425             } else {
2426                 tmp = zuiSdsFromValue(&zval);
2427                 if (zsetRemoveFromSkiplist(dstzset, tmp)) {
2428                     cardinality--;
2429                 }
2430             }
2431 
2432             /* Exit if result set is empty as any additional removal
2433                 * of elements will have no effect. */
2434             if (cardinality == 0) break;
2435         }
2436         zuiClearIterator(&src[j]);
2437 
2438         if (cardinality == 0) break;
2439     }
2440 
2441     /* Resize dict if needed after removing multiple elements */
2442     if (htNeedsResize(dstzset->dict)) dictResize(dstzset->dict);
2443 
2444     /* Using this algorithm, we can't calculate the max element as we go,
2445      * we have to iterate through all elements to find the max one after. */
2446     *maxelelen = zsetDictGetMaxElementLength(dstzset->dict, totelelen);
2447 }
2448 
zsetChooseDiffAlgorithm(zsetopsrc * src,long setnum)2449 static int zsetChooseDiffAlgorithm(zsetopsrc *src, long setnum) {
2450     int j;
2451 
2452     /* Select what DIFF algorithm to use.
2453      *
2454      * Algorithm 1 is O(N*M + K*log(K)) where N is the size of the
2455      * first set, M the total number of sets, and K is the size of the
2456      * result set.
2457      *
2458      * Algorithm 2 is O(L + (N-K)log(N)) where L is the total number of elements
2459      * in all the sets, N is the size of the first set, and K is the size of the
2460      * result set.
2461      *
2462      * We compute what is the best bet with the current input here. */
2463     long long algo_one_work = 0;
2464     long long algo_two_work = 0;
2465 
2466     for (j = 0; j < setnum; j++) {
2467         /* If any other set is equal to the first set, there is nothing to be
2468          * done, since we would remove all elements anyway. */
2469         if (j > 0 && src[0].subject == src[j].subject) {
2470             return 0;
2471         }
2472 
2473         algo_one_work += zuiLength(&src[0]);
2474         algo_two_work += zuiLength(&src[j]);
2475     }
2476 
2477     /* Algorithm 1 has better constant times and performs less operations
2478      * if there are elements in common. Give it some advantage. */
2479     algo_one_work /= 2;
2480     return (algo_one_work <= algo_two_work) ? 1 : 2;
2481 }
2482 
zdiff(zsetopsrc * src,long setnum,zset * dstzset,size_t * maxelelen,size_t * totelelen)2483 static void zdiff(zsetopsrc *src, long setnum, zset *dstzset, size_t *maxelelen, size_t *totelelen) {
2484     /* Skip everything if the smallest input is empty. */
2485     if (zuiLength(&src[0]) > 0) {
2486         int diff_algo = zsetChooseDiffAlgorithm(src, setnum);
2487         if (diff_algo == 1) {
2488             zdiffAlgorithm1(src, setnum, dstzset, maxelelen, totelelen);
2489         } else if (diff_algo == 2) {
2490             zdiffAlgorithm2(src, setnum, dstzset, maxelelen, totelelen);
2491         } else if (diff_algo != 0) {
2492             serverPanic("Unknown algorithm");
2493         }
2494     }
2495 }
2496 
2497 dictType setAccumulatorDictType = {
2498     dictSdsHash,               /* hash function */
2499     NULL,                      /* key dup */
2500     NULL,                      /* val dup */
2501     dictSdsKeyCompare,         /* key compare */
2502     NULL,                      /* key destructor */
2503     NULL,                      /* val destructor */
2504     NULL                       /* allow to expand */
2505 };
2506 
2507 /* The zunionInterDiffGenericCommand() function is called in order to implement the
2508  * following commands: ZUNION, ZINTER, ZDIFF, ZUNIONSTORE, ZINTERSTORE, ZDIFFSTORE,
2509  * ZINTERCARD.
2510  *
2511  * 'numkeysIndex' parameter position of key number. for ZUNION/ZINTER/ZDIFF command,
2512  * this value is 1, for ZUNIONSTORE/ZINTERSTORE/ZDIFFSTORE command, this value is 2.
2513  *
2514  * 'op' SET_OP_INTER, SET_OP_UNION or SET_OP_DIFF.
2515  *
2516  * 'cardinality_only' is currently only applicable when 'op' is SET_OP_INTER.
2517  * Work for SINTERCARD, only return the cardinality with minimum processing and memory overheads.
2518  */
zunionInterDiffGenericCommand(client * c,robj * dstkey,int numkeysIndex,int op,int cardinality_only)2519 void zunionInterDiffGenericCommand(client *c, robj *dstkey, int numkeysIndex, int op,
2520                                    int cardinality_only) {
2521     int i, j;
2522     long setnum;
2523     int aggregate = REDIS_AGGR_SUM;
2524     zsetopsrc *src;
2525     zsetopval zval;
2526     sds tmp;
2527     size_t maxelelen = 0, totelelen = 0;
2528     robj *dstobj;
2529     zset *dstzset;
2530     zskiplistNode *znode;
2531     int withscores = 0;
2532     unsigned long cardinality = 0;
2533     long limit = 0; /* Stop searching after reaching the limit. 0 means unlimited. */
2534 
2535     /* expect setnum input keys to be given */
2536     if ((getLongFromObjectOrReply(c, c->argv[numkeysIndex], &setnum, NULL) != C_OK))
2537         return;
2538 
2539     if (setnum < 1) {
2540         addReplyErrorFormat(c,
2541             "at least 1 input key is needed for %s", c->cmd->name);
2542         return;
2543     }
2544 
2545     /* test if the expected number of keys would overflow */
2546     if (setnum > (c->argc-(numkeysIndex+1))) {
2547         addReplyErrorObject(c,shared.syntaxerr);
2548         return;
2549     }
2550 
2551     /* read keys to be used for input */
2552     src = zcalloc(sizeof(zsetopsrc) * setnum);
2553     for (i = 0, j = numkeysIndex+1; i < setnum; i++, j++) {
2554         robj *obj = lookupKeyRead(c->db, c->argv[j]);
2555         if (obj != NULL) {
2556             if (obj->type != OBJ_ZSET && obj->type != OBJ_SET) {
2557                 zfree(src);
2558                 addReplyErrorObject(c,shared.wrongtypeerr);
2559                 return;
2560             }
2561 
2562             src[i].subject = obj;
2563             src[i].type = obj->type;
2564             src[i].encoding = obj->encoding;
2565         } else {
2566             src[i].subject = NULL;
2567         }
2568 
2569         /* Default all weights to 1. */
2570         src[i].weight = 1.0;
2571     }
2572 
2573     /* parse optional extra arguments */
2574     if (j < c->argc) {
2575         int remaining = c->argc - j;
2576 
2577         while (remaining) {
2578             if (op != SET_OP_DIFF && !cardinality_only &&
2579                 remaining >= (setnum + 1) &&
2580                 !strcasecmp(c->argv[j]->ptr,"weights"))
2581             {
2582                 j++; remaining--;
2583                 for (i = 0; i < setnum; i++, j++, remaining--) {
2584                     if (getDoubleFromObjectOrReply(c,c->argv[j],&src[i].weight,
2585                             "weight value is not a float") != C_OK)
2586                     {
2587                         zfree(src);
2588                         return;
2589                     }
2590                 }
2591             } else if (op != SET_OP_DIFF && !cardinality_only &&
2592                        remaining >= 2 &&
2593                        !strcasecmp(c->argv[j]->ptr,"aggregate"))
2594             {
2595                 j++; remaining--;
2596                 if (!strcasecmp(c->argv[j]->ptr,"sum")) {
2597                     aggregate = REDIS_AGGR_SUM;
2598                 } else if (!strcasecmp(c->argv[j]->ptr,"min")) {
2599                     aggregate = REDIS_AGGR_MIN;
2600                 } else if (!strcasecmp(c->argv[j]->ptr,"max")) {
2601                     aggregate = REDIS_AGGR_MAX;
2602                 } else {
2603                     zfree(src);
2604                     addReplyErrorObject(c,shared.syntaxerr);
2605                     return;
2606                 }
2607                 j++; remaining--;
2608             } else if (remaining >= 1 &&
2609                        !dstkey && !cardinality_only &&
2610                        !strcasecmp(c->argv[j]->ptr,"withscores"))
2611             {
2612                 j++; remaining--;
2613                 withscores = 1;
2614             } else if (cardinality_only && remaining >= 2 &&
2615                        !strcasecmp(c->argv[j]->ptr, "limit"))
2616             {
2617                 j++; remaining--;
2618                 if (getPositiveLongFromObjectOrReply(c, c->argv[j], &limit,
2619                                                      "LIMIT can't be negative") != C_OK)
2620                 {
2621                     zfree(src);
2622                     return;
2623                 }
2624                 j++; remaining--;
2625             } else {
2626                 zfree(src);
2627                 addReplyErrorObject(c,shared.syntaxerr);
2628                 return;
2629             }
2630         }
2631     }
2632 
2633     if (op != SET_OP_DIFF) {
2634         /* sort sets from the smallest to largest, this will improve our
2635         * algorithm's performance */
2636         qsort(src,setnum,sizeof(zsetopsrc),zuiCompareByCardinality);
2637     }
2638 
2639     dstobj = createZsetObject();
2640     dstzset = dstobj->ptr;
2641     memset(&zval, 0, sizeof(zval));
2642 
2643     if (op == SET_OP_INTER) {
2644         /* Skip everything if the smallest input is empty. */
2645         if (zuiLength(&src[0]) > 0) {
2646             /* Precondition: as src[0] is non-empty and the inputs are ordered
2647              * by size, all src[i > 0] are non-empty too. */
2648             zuiInitIterator(&src[0]);
2649             while (zuiNext(&src[0],&zval)) {
2650                 double score, value;
2651 
2652                 score = src[0].weight * zval.score;
2653                 if (isnan(score)) score = 0;
2654 
2655                 for (j = 1; j < setnum; j++) {
2656                     /* It is not safe to access the zset we are
2657                      * iterating, so explicitly check for equal object. */
2658                     if (src[j].subject == src[0].subject) {
2659                         value = zval.score*src[j].weight;
2660                         zunionInterAggregate(&score,value,aggregate);
2661                     } else if (zuiFind(&src[j],&zval,&value)) {
2662                         value *= src[j].weight;
2663                         zunionInterAggregate(&score,value,aggregate);
2664                     } else {
2665                         break;
2666                     }
2667                 }
2668 
2669                 /* Only continue when present in every input. */
2670                 if (j == setnum && cardinality_only) {
2671                     cardinality++;
2672 
2673                     /* We stop the searching after reaching the limit. */
2674                     if (limit && cardinality >= (unsigned long)limit) {
2675                         /* Cleanup before we break the zuiNext loop. */
2676                         zuiDiscardDirtyValue(&zval);
2677                         break;
2678                     }
2679                 } else if (j == setnum) {
2680                     tmp = zuiNewSdsFromValue(&zval);
2681                     znode = zslInsert(dstzset->zsl,score,tmp);
2682                     dictAdd(dstzset->dict,tmp,&znode->score);
2683                     totelelen += sdslen(tmp);
2684                     if (sdslen(tmp) > maxelelen) maxelelen = sdslen(tmp);
2685                 }
2686             }
2687             zuiClearIterator(&src[0]);
2688         }
2689     } else if (op == SET_OP_UNION) {
2690         dict *accumulator = dictCreate(&setAccumulatorDictType);
2691         dictIterator *di;
2692         dictEntry *de, *existing;
2693         double score;
2694 
2695         if (setnum) {
2696             /* Our union is at least as large as the largest set.
2697              * Resize the dictionary ASAP to avoid useless rehashing. */
2698             dictExpand(accumulator,zuiLength(&src[setnum-1]));
2699         }
2700 
2701         /* Step 1: Create a dictionary of elements -> aggregated-scores
2702          * by iterating one sorted set after the other. */
2703         for (i = 0; i < setnum; i++) {
2704             if (zuiLength(&src[i]) == 0) continue;
2705 
2706             zuiInitIterator(&src[i]);
2707             while (zuiNext(&src[i],&zval)) {
2708                 /* Initialize value */
2709                 score = src[i].weight * zval.score;
2710                 if (isnan(score)) score = 0;
2711 
2712                 /* Search for this element in the accumulating dictionary. */
2713                 de = dictAddRaw(accumulator,zuiSdsFromValue(&zval),&existing);
2714                 /* If we don't have it, we need to create a new entry. */
2715                 if (!existing) {
2716                     tmp = zuiNewSdsFromValue(&zval);
2717                     /* Remember the longest single element encountered,
2718                      * to understand if it's possible to convert to listpack
2719                      * at the end. */
2720                      totelelen += sdslen(tmp);
2721                      if (sdslen(tmp) > maxelelen) maxelelen = sdslen(tmp);
2722                     /* Update the element with its initial score. */
2723                     dictSetKey(accumulator, de, tmp);
2724                     dictSetDoubleVal(de,score);
2725                 } else {
2726                     /* Update the score with the score of the new instance
2727                      * of the element found in the current sorted set.
2728                      *
2729                      * Here we access directly the dictEntry double
2730                      * value inside the union as it is a big speedup
2731                      * compared to using the getDouble/setDouble API. */
2732                     zunionInterAggregate(&existing->v.d,score,aggregate);
2733                 }
2734             }
2735             zuiClearIterator(&src[i]);
2736         }
2737 
2738         /* Step 2: convert the dictionary into the final sorted set. */
2739         di = dictGetIterator(accumulator);
2740 
2741         /* We now are aware of the final size of the resulting sorted set,
2742          * let's resize the dictionary embedded inside the sorted set to the
2743          * right size, in order to save rehashing time. */
2744         dictExpand(dstzset->dict,dictSize(accumulator));
2745 
2746         while((de = dictNext(di)) != NULL) {
2747             sds ele = dictGetKey(de);
2748             score = dictGetDoubleVal(de);
2749             znode = zslInsert(dstzset->zsl,score,ele);
2750             dictAdd(dstzset->dict,ele,&znode->score);
2751         }
2752         dictReleaseIterator(di);
2753         dictRelease(accumulator);
2754     } else if (op == SET_OP_DIFF) {
2755         zdiff(src, setnum, dstzset, &maxelelen, &totelelen);
2756     } else {
2757         serverPanic("Unknown operator");
2758     }
2759 
2760     if (dstkey) {
2761         if (dstzset->zsl->length) {
2762             zsetConvertToListpackIfNeeded(dstobj, maxelelen, totelelen);
2763             setKey(c, c->db, dstkey, dstobj, 0);
2764             addReplyLongLong(c, zsetLength(dstobj));
2765             notifyKeyspaceEvent(NOTIFY_ZSET,
2766                                 (op == SET_OP_UNION) ? "zunionstore" :
2767                                     (op == SET_OP_INTER ? "zinterstore" : "zdiffstore"),
2768                                 dstkey, c->db->id);
2769             server.dirty++;
2770         } else {
2771             addReply(c, shared.czero);
2772             if (dbDelete(c->db, dstkey)) {
2773                 signalModifiedKey(c, c->db, dstkey);
2774                 notifyKeyspaceEvent(NOTIFY_GENERIC, "del", dstkey, c->db->id);
2775                 server.dirty++;
2776             }
2777         }
2778     } else if (cardinality_only) {
2779         addReplyLongLong(c, cardinality);
2780     } else {
2781         unsigned long length = dstzset->zsl->length;
2782         zskiplist *zsl = dstzset->zsl;
2783         zskiplistNode *zn = zsl->header->level[0].forward;
2784         /* In case of WITHSCORES, respond with a single array in RESP2, and
2785          * nested arrays in RESP3. We can't use a map response type since the
2786          * client library needs to know to respect the order. */
2787         if (withscores && c->resp == 2)
2788             addReplyArrayLen(c, length*2);
2789         else
2790             addReplyArrayLen(c, length);
2791 
2792         while (zn != NULL) {
2793             if (withscores && c->resp > 2) addReplyArrayLen(c,2);
2794             addReplyBulkCBuffer(c,zn->ele,sdslen(zn->ele));
2795             if (withscores) addReplyDouble(c,zn->score);
2796             zn = zn->level[0].forward;
2797         }
2798     }
2799     decrRefCount(dstobj);
2800     zfree(src);
2801 }
2802 
2803 /* ZUNIONSTORE destination numkeys key [key ...] [WEIGHTS weight] [AGGREGATE SUM|MIN|MAX] */
zunionstoreCommand(client * c)2804 void zunionstoreCommand(client *c) {
2805     zunionInterDiffGenericCommand(c, c->argv[1], 2, SET_OP_UNION, 0);
2806 }
2807 
2808 /* ZINTERSTORE destination numkeys key [key ...] [WEIGHTS weight] [AGGREGATE SUM|MIN|MAX] */
zinterstoreCommand(client * c)2809 void zinterstoreCommand(client *c) {
2810     zunionInterDiffGenericCommand(c, c->argv[1], 2, SET_OP_INTER, 0);
2811 }
2812 
2813 /* ZDIFFSTORE destination numkeys key [key ...] */
zdiffstoreCommand(client * c)2814 void zdiffstoreCommand(client *c) {
2815     zunionInterDiffGenericCommand(c, c->argv[1], 2, SET_OP_DIFF, 0);
2816 }
2817 
2818 /* ZUNION numkeys key [key ...] [WEIGHTS weight] [AGGREGATE SUM|MIN|MAX] [WITHSCORES] */
zunionCommand(client * c)2819 void zunionCommand(client *c) {
2820     zunionInterDiffGenericCommand(c, NULL, 1, SET_OP_UNION, 0);
2821 }
2822 
2823 /* ZINTER numkeys key [key ...] [WEIGHTS weight] [AGGREGATE SUM|MIN|MAX] [WITHSCORES] */
zinterCommand(client * c)2824 void zinterCommand(client *c) {
2825     zunionInterDiffGenericCommand(c, NULL, 1, SET_OP_INTER, 0);
2826 }
2827 
2828 /* ZINTERCARD numkeys key [key ...] [LIMIT limit] */
zinterCardCommand(client * c)2829 void zinterCardCommand(client *c) {
2830     zunionInterDiffGenericCommand(c, NULL, 1, SET_OP_INTER, 1);
2831 }
2832 
2833 /* ZDIFF numkeys key [key ...] [WITHSCORES] */
zdiffCommand(client * c)2834 void zdiffCommand(client *c) {
2835     zunionInterDiffGenericCommand(c, NULL, 1, SET_OP_DIFF, 0);
2836 }
2837 
2838 typedef enum {
2839     ZRANGE_DIRECTION_AUTO = 0,
2840     ZRANGE_DIRECTION_FORWARD,
2841     ZRANGE_DIRECTION_REVERSE
2842 } zrange_direction;
2843 
2844 typedef enum {
2845     ZRANGE_CONSUMER_TYPE_CLIENT = 0,
2846     ZRANGE_CONSUMER_TYPE_INTERNAL
2847 } zrange_consumer_type;
2848 
2849 typedef struct zrange_result_handler zrange_result_handler;
2850 
2851 typedef void (*zrangeResultBeginFunction)(zrange_result_handler *c);
2852 typedef void (*zrangeResultFinalizeFunction)(
2853     zrange_result_handler *c, size_t result_count);
2854 typedef void (*zrangeResultEmitCBufferFunction)(
2855     zrange_result_handler *c, const void *p, size_t len, double score);
2856 typedef void (*zrangeResultEmitLongLongFunction)(
2857     zrange_result_handler *c, long long ll, double score);
2858 
2859 void zrangeGenericCommand (zrange_result_handler *handler, int argc_start, int store,
2860                            zrange_type rangetype, zrange_direction direction);
2861 
2862 /* Interface struct for ZRANGE/ZRANGESTORE generic implementation.
2863  * There is one implementation of this interface that sends a RESP reply to clients.
2864  * and one implementation that stores the range result into a zset object. */
2865 struct zrange_result_handler {
2866     zrange_consumer_type                 type;
2867     client                              *client;
2868     robj                                *dstkey;
2869     robj                                *dstobj;
2870     void                                *userdata;
2871     int                                  withscores;
2872     int                                  should_emit_array_length;
2873     zrangeResultBeginFunction            beginResultEmission;
2874     zrangeResultFinalizeFunction         finalizeResultEmission;
2875     zrangeResultEmitCBufferFunction      emitResultFromCBuffer;
2876     zrangeResultEmitLongLongFunction     emitResultFromLongLong;
2877 };
2878 
2879 /* Result handler methods for responding the ZRANGE to clients. */
zrangeResultBeginClient(zrange_result_handler * handler)2880 static void zrangeResultBeginClient(zrange_result_handler *handler) {
2881     handler->userdata = addReplyDeferredLen(handler->client);
2882 }
2883 
zrangeResultEmitCBufferToClient(zrange_result_handler * handler,const void * value,size_t value_length_in_bytes,double score)2884 static void zrangeResultEmitCBufferToClient(zrange_result_handler *handler,
2885     const void *value, size_t value_length_in_bytes, double score)
2886 {
2887     if (handler->should_emit_array_length) {
2888         addReplyArrayLen(handler->client, 2);
2889     }
2890 
2891     addReplyBulkCBuffer(handler->client, value, value_length_in_bytes);
2892 
2893     if (handler->withscores) {
2894         addReplyDouble(handler->client, score);
2895     }
2896 }
2897 
zrangeResultEmitLongLongToClient(zrange_result_handler * handler,long long value,double score)2898 static void zrangeResultEmitLongLongToClient(zrange_result_handler *handler,
2899     long long value, double score)
2900 {
2901     if (handler->should_emit_array_length) {
2902         addReplyArrayLen(handler->client, 2);
2903     }
2904 
2905     addReplyBulkLongLong(handler->client, value);
2906 
2907     if (handler->withscores) {
2908         addReplyDouble(handler->client, score);
2909     }
2910 }
2911 
zrangeResultFinalizeClient(zrange_result_handler * handler,size_t result_count)2912 static void zrangeResultFinalizeClient(zrange_result_handler *handler,
2913     size_t result_count)
2914 {
2915     /* In case of WITHSCORES, respond with a single array in RESP2, and
2916      * nested arrays in RESP3. We can't use a map response type since the
2917      * client library needs to know to respect the order. */
2918     if (handler->withscores && (handler->client->resp == 2)) {
2919         result_count *= 2;
2920     }
2921 
2922     setDeferredArrayLen(handler->client, handler->userdata, result_count);
2923 }
2924 
2925 /* Result handler methods for storing the ZRANGESTORE to a zset. */
zrangeResultBeginStore(zrange_result_handler * handler)2926 static void zrangeResultBeginStore(zrange_result_handler *handler)
2927 {
2928     handler->dstobj = createZsetListpackObject();
2929 }
2930 
zrangeResultEmitCBufferForStore(zrange_result_handler * handler,const void * value,size_t value_length_in_bytes,double score)2931 static void zrangeResultEmitCBufferForStore(zrange_result_handler *handler,
2932     const void *value, size_t value_length_in_bytes, double score)
2933 {
2934     double newscore;
2935     int retflags = 0;
2936     sds ele = sdsnewlen(value, value_length_in_bytes);
2937     int retval = zsetAdd(handler->dstobj, score, ele, ZADD_IN_NONE, &retflags, &newscore);
2938     sdsfree(ele);
2939     serverAssert(retval);
2940 }
2941 
zrangeResultEmitLongLongForStore(zrange_result_handler * handler,long long value,double score)2942 static void zrangeResultEmitLongLongForStore(zrange_result_handler *handler,
2943     long long value, double score)
2944 {
2945     double newscore;
2946     int retflags = 0;
2947     sds ele = sdsfromlonglong(value);
2948     int retval = zsetAdd(handler->dstobj, score, ele, ZADD_IN_NONE, &retflags, &newscore);
2949     sdsfree(ele);
2950     serverAssert(retval);
2951 }
2952 
zrangeResultFinalizeStore(zrange_result_handler * handler,size_t result_count)2953 static void zrangeResultFinalizeStore(zrange_result_handler *handler, size_t result_count)
2954 {
2955     if (result_count) {
2956         setKey(handler->client, handler->client->db, handler->dstkey, handler->dstobj, 0);
2957         addReplyLongLong(handler->client, result_count);
2958         notifyKeyspaceEvent(NOTIFY_ZSET, "zrangestore", handler->dstkey, handler->client->db->id);
2959         server.dirty++;
2960     } else {
2961         addReply(handler->client, shared.czero);
2962         if (dbDelete(handler->client->db, handler->dstkey)) {
2963             signalModifiedKey(handler->client, handler->client->db, handler->dstkey);
2964             notifyKeyspaceEvent(NOTIFY_GENERIC, "del", handler->dstkey, handler->client->db->id);
2965             server.dirty++;
2966         }
2967     }
2968     decrRefCount(handler->dstobj);
2969 }
2970 
2971 /* Initialize the consumer interface type with the requested type. */
zrangeResultHandlerInit(zrange_result_handler * handler,client * client,zrange_consumer_type type)2972 static void zrangeResultHandlerInit(zrange_result_handler *handler,
2973     client *client, zrange_consumer_type type)
2974 {
2975     memset(handler, 0, sizeof(*handler));
2976 
2977     handler->client = client;
2978 
2979     switch (type) {
2980     case ZRANGE_CONSUMER_TYPE_CLIENT:
2981         handler->beginResultEmission = zrangeResultBeginClient;
2982         handler->finalizeResultEmission = zrangeResultFinalizeClient;
2983         handler->emitResultFromCBuffer = zrangeResultEmitCBufferToClient;
2984         handler->emitResultFromLongLong = zrangeResultEmitLongLongToClient;
2985         break;
2986 
2987     case ZRANGE_CONSUMER_TYPE_INTERNAL:
2988         handler->beginResultEmission = zrangeResultBeginStore;
2989         handler->finalizeResultEmission = zrangeResultFinalizeStore;
2990         handler->emitResultFromCBuffer = zrangeResultEmitCBufferForStore;
2991         handler->emitResultFromLongLong = zrangeResultEmitLongLongForStore;
2992         break;
2993     }
2994 }
2995 
zrangeResultHandlerScoreEmissionEnable(zrange_result_handler * handler)2996 static void zrangeResultHandlerScoreEmissionEnable(zrange_result_handler *handler) {
2997     handler->withscores = 1;
2998     handler->should_emit_array_length = (handler->client->resp > 2);
2999 }
3000 
zrangeResultHandlerDestinationKeySet(zrange_result_handler * handler,robj * dstkey)3001 static void zrangeResultHandlerDestinationKeySet (zrange_result_handler *handler,
3002     robj *dstkey)
3003 {
3004     handler->dstkey = dstkey;
3005 }
3006 
3007 /* This command implements ZRANGE, ZREVRANGE. */
genericZrangebyrankCommand(zrange_result_handler * handler,robj * zobj,long start,long end,int withscores,int reverse)3008 void genericZrangebyrankCommand(zrange_result_handler *handler,
3009     robj *zobj, long start, long end, int withscores, int reverse) {
3010 
3011     client *c = handler->client;
3012     long llen;
3013     long rangelen;
3014     size_t result_cardinality;
3015 
3016     /* Sanitize indexes. */
3017     llen = zsetLength(zobj);
3018     if (start < 0) start = llen+start;
3019     if (end < 0) end = llen+end;
3020     if (start < 0) start = 0;
3021 
3022     handler->beginResultEmission(handler);
3023 
3024     /* Invariant: start >= 0, so this test will be true when end < 0.
3025      * The range is empty when start > end or start >= length. */
3026     if (start > end || start >= llen) {
3027         handler->finalizeResultEmission(handler, 0);
3028         return;
3029     }
3030     if (end >= llen) end = llen-1;
3031     rangelen = (end-start)+1;
3032     result_cardinality = rangelen;
3033 
3034     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3035         unsigned char *zl = zobj->ptr;
3036         unsigned char *eptr, *sptr;
3037         unsigned char *vstr;
3038         unsigned int vlen;
3039         long long vlong;
3040         double score = 0.0;
3041 
3042         if (reverse)
3043             eptr = lpSeek(zl,-2-(2*start));
3044         else
3045             eptr = lpSeek(zl,2*start);
3046 
3047         serverAssertWithInfo(c,zobj,eptr != NULL);
3048         sptr = lpNext(zl,eptr);
3049 
3050         while (rangelen--) {
3051             serverAssertWithInfo(c,zobj,eptr != NULL && sptr != NULL);
3052             vstr = lpGetValue(eptr,&vlen,&vlong);
3053 
3054             if (withscores) /* don't bother to extract the score if it's gonna be ignored. */
3055                 score = zzlGetScore(sptr);
3056 
3057             if (vstr == NULL) {
3058                 handler->emitResultFromLongLong(handler, vlong, score);
3059             } else {
3060                 handler->emitResultFromCBuffer(handler, vstr, vlen, score);
3061             }
3062 
3063             if (reverse)
3064                 zzlPrev(zl,&eptr,&sptr);
3065             else
3066                 zzlNext(zl,&eptr,&sptr);
3067         }
3068 
3069     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3070         zset *zs = zobj->ptr;
3071         zskiplist *zsl = zs->zsl;
3072         zskiplistNode *ln;
3073 
3074         /* Check if starting point is trivial, before doing log(N) lookup. */
3075         if (reverse) {
3076             ln = zsl->tail;
3077             if (start > 0)
3078                 ln = zslGetElementByRank(zsl,llen-start);
3079         } else {
3080             ln = zsl->header->level[0].forward;
3081             if (start > 0)
3082                 ln = zslGetElementByRank(zsl,start+1);
3083         }
3084 
3085         while(rangelen--) {
3086             serverAssertWithInfo(c,zobj,ln != NULL);
3087             sds ele = ln->ele;
3088             handler->emitResultFromCBuffer(handler, ele, sdslen(ele), ln->score);
3089             ln = reverse ? ln->backward : ln->level[0].forward;
3090         }
3091     } else {
3092         serverPanic("Unknown sorted set encoding");
3093     }
3094 
3095     handler->finalizeResultEmission(handler, result_cardinality);
3096 }
3097 
3098 /* ZRANGESTORE <dst> <src> <min> <max> [BYSCORE | BYLEX] [REV] [LIMIT offset count] */
zrangestoreCommand(client * c)3099 void zrangestoreCommand (client *c) {
3100     robj *dstkey = c->argv[1];
3101     zrange_result_handler handler;
3102     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_INTERNAL);
3103     zrangeResultHandlerDestinationKeySet(&handler, dstkey);
3104     zrangeGenericCommand(&handler, 2, 1, ZRANGE_AUTO, ZRANGE_DIRECTION_AUTO);
3105 }
3106 
3107 /* ZRANGE <key> <min> <max> [BYSCORE | BYLEX] [REV] [WITHSCORES] [LIMIT offset count] */
zrangeCommand(client * c)3108 void zrangeCommand(client *c) {
3109     zrange_result_handler handler;
3110     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3111     zrangeGenericCommand(&handler, 1, 0, ZRANGE_AUTO, ZRANGE_DIRECTION_AUTO);
3112 }
3113 
3114 /* ZREVRANGE <key> <start> <stop> [WITHSCORES] */
zrevrangeCommand(client * c)3115 void zrevrangeCommand(client *c) {
3116     zrange_result_handler handler;
3117     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3118     zrangeGenericCommand(&handler, 1, 0, ZRANGE_RANK, ZRANGE_DIRECTION_REVERSE);
3119 }
3120 
3121 /* This command implements ZRANGEBYSCORE, ZREVRANGEBYSCORE. */
genericZrangebyscoreCommand(zrange_result_handler * handler,zrangespec * range,robj * zobj,long offset,long limit,int reverse)3122 void genericZrangebyscoreCommand(zrange_result_handler *handler,
3123     zrangespec *range, robj *zobj, long offset, long limit,
3124     int reverse) {
3125     unsigned long rangelen = 0;
3126 
3127     handler->beginResultEmission(handler);
3128 
3129     /* For invalid offset, return directly. */
3130     if (offset > 0 && offset >= (long)zsetLength(zobj)) {
3131         handler->finalizeResultEmission(handler, 0);
3132         return;
3133     }
3134 
3135     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3136         unsigned char *zl = zobj->ptr;
3137         unsigned char *eptr, *sptr;
3138         unsigned char *vstr;
3139         unsigned int vlen;
3140         long long vlong;
3141 
3142         /* If reversed, get the last node in range as starting point. */
3143         if (reverse) {
3144             eptr = zzlLastInRange(zl,range);
3145         } else {
3146             eptr = zzlFirstInRange(zl,range);
3147         }
3148 
3149         /* Get score pointer for the first element. */
3150         if (eptr)
3151             sptr = lpNext(zl,eptr);
3152 
3153         /* If there is an offset, just traverse the number of elements without
3154          * checking the score because that is done in the next loop. */
3155         while (eptr && offset--) {
3156             if (reverse) {
3157                 zzlPrev(zl,&eptr,&sptr);
3158             } else {
3159                 zzlNext(zl,&eptr,&sptr);
3160             }
3161         }
3162 
3163         while (eptr && limit--) {
3164             double score = zzlGetScore(sptr);
3165 
3166             /* Abort when the node is no longer in range. */
3167             if (reverse) {
3168                 if (!zslValueGteMin(score,range)) break;
3169             } else {
3170                 if (!zslValueLteMax(score,range)) break;
3171             }
3172 
3173             vstr = lpGetValue(eptr,&vlen,&vlong);
3174             rangelen++;
3175             if (vstr == NULL) {
3176                 handler->emitResultFromLongLong(handler, vlong, score);
3177             } else {
3178                 handler->emitResultFromCBuffer(handler, vstr, vlen, score);
3179             }
3180 
3181             /* Move to next node */
3182             if (reverse) {
3183                 zzlPrev(zl,&eptr,&sptr);
3184             } else {
3185                 zzlNext(zl,&eptr,&sptr);
3186             }
3187         }
3188     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3189         zset *zs = zobj->ptr;
3190         zskiplist *zsl = zs->zsl;
3191         zskiplistNode *ln;
3192 
3193         /* If reversed, get the last node in range as starting point. */
3194         if (reverse) {
3195             ln = zslLastInRange(zsl,range);
3196         } else {
3197             ln = zslFirstInRange(zsl,range);
3198         }
3199 
3200         /* If there is an offset, just traverse the number of elements without
3201          * checking the score because that is done in the next loop. */
3202         while (ln && offset--) {
3203             if (reverse) {
3204                 ln = ln->backward;
3205             } else {
3206                 ln = ln->level[0].forward;
3207             }
3208         }
3209 
3210         while (ln && limit--) {
3211             /* Abort when the node is no longer in range. */
3212             if (reverse) {
3213                 if (!zslValueGteMin(ln->score,range)) break;
3214             } else {
3215                 if (!zslValueLteMax(ln->score,range)) break;
3216             }
3217 
3218             rangelen++;
3219             handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score);
3220 
3221             /* Move to next node */
3222             if (reverse) {
3223                 ln = ln->backward;
3224             } else {
3225                 ln = ln->level[0].forward;
3226             }
3227         }
3228     } else {
3229         serverPanic("Unknown sorted set encoding");
3230     }
3231 
3232     handler->finalizeResultEmission(handler, rangelen);
3233 }
3234 
3235 /* ZRANGEBYSCORE <key> <min> <max> [WITHSCORES] [LIMIT offset count] */
zrangebyscoreCommand(client * c)3236 void zrangebyscoreCommand(client *c) {
3237     zrange_result_handler handler;
3238     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3239     zrangeGenericCommand(&handler, 1, 0, ZRANGE_SCORE, ZRANGE_DIRECTION_FORWARD);
3240 }
3241 
3242 /* ZREVRANGEBYSCORE <key> <max> <min> [WITHSCORES] [LIMIT offset count] */
zrevrangebyscoreCommand(client * c)3243 void zrevrangebyscoreCommand(client *c) {
3244     zrange_result_handler handler;
3245     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3246     zrangeGenericCommand(&handler, 1, 0, ZRANGE_SCORE, ZRANGE_DIRECTION_REVERSE);
3247 }
3248 
zcountCommand(client * c)3249 void zcountCommand(client *c) {
3250     robj *key = c->argv[1];
3251     robj *zobj;
3252     zrangespec range;
3253     unsigned long count = 0;
3254 
3255     /* Parse the range arguments */
3256     if (zslParseRange(c->argv[2],c->argv[3],&range) != C_OK) {
3257         addReplyError(c,"min or max is not a float");
3258         return;
3259     }
3260 
3261     /* Lookup the sorted set */
3262     if ((zobj = lookupKeyReadOrReply(c, key, shared.czero)) == NULL ||
3263         checkType(c, zobj, OBJ_ZSET)) return;
3264 
3265     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3266         unsigned char *zl = zobj->ptr;
3267         unsigned char *eptr, *sptr;
3268         double score;
3269 
3270         /* Use the first element in range as the starting point */
3271         eptr = zzlFirstInRange(zl,&range);
3272 
3273         /* No "first" element */
3274         if (eptr == NULL) {
3275             addReply(c, shared.czero);
3276             return;
3277         }
3278 
3279         /* First element is in range */
3280         sptr = lpNext(zl,eptr);
3281         score = zzlGetScore(sptr);
3282         serverAssertWithInfo(c,zobj,zslValueLteMax(score,&range));
3283 
3284         /* Iterate over elements in range */
3285         while (eptr) {
3286             score = zzlGetScore(sptr);
3287 
3288             /* Abort when the node is no longer in range. */
3289             if (!zslValueLteMax(score,&range)) {
3290                 break;
3291             } else {
3292                 count++;
3293                 zzlNext(zl,&eptr,&sptr);
3294             }
3295         }
3296     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3297         zset *zs = zobj->ptr;
3298         zskiplist *zsl = zs->zsl;
3299         zskiplistNode *zn;
3300         unsigned long rank;
3301 
3302         /* Find first element in range */
3303         zn = zslFirstInRange(zsl, &range);
3304 
3305         /* Use rank of first element, if any, to determine preliminary count */
3306         if (zn != NULL) {
3307             rank = zslGetRank(zsl, zn->score, zn->ele);
3308             count = (zsl->length - (rank - 1));
3309 
3310             /* Find last element in range */
3311             zn = zslLastInRange(zsl, &range);
3312 
3313             /* Use rank of last element, if any, to determine the actual count */
3314             if (zn != NULL) {
3315                 rank = zslGetRank(zsl, zn->score, zn->ele);
3316                 count -= (zsl->length - rank);
3317             }
3318         }
3319     } else {
3320         serverPanic("Unknown sorted set encoding");
3321     }
3322 
3323     addReplyLongLong(c, count);
3324 }
3325 
zlexcountCommand(client * c)3326 void zlexcountCommand(client *c) {
3327     robj *key = c->argv[1];
3328     robj *zobj;
3329     zlexrangespec range;
3330     unsigned long count = 0;
3331 
3332     /* Parse the range arguments */
3333     if (zslParseLexRange(c->argv[2],c->argv[3],&range) != C_OK) {
3334         addReplyError(c,"min or max not valid string range item");
3335         return;
3336     }
3337 
3338     /* Lookup the sorted set */
3339     if ((zobj = lookupKeyReadOrReply(c, key, shared.czero)) == NULL ||
3340         checkType(c, zobj, OBJ_ZSET))
3341     {
3342         zslFreeLexRange(&range);
3343         return;
3344     }
3345 
3346     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3347         unsigned char *zl = zobj->ptr;
3348         unsigned char *eptr, *sptr;
3349 
3350         /* Use the first element in range as the starting point */
3351         eptr = zzlFirstInLexRange(zl,&range);
3352 
3353         /* No "first" element */
3354         if (eptr == NULL) {
3355             zslFreeLexRange(&range);
3356             addReply(c, shared.czero);
3357             return;
3358         }
3359 
3360         /* First element is in range */
3361         sptr = lpNext(zl,eptr);
3362         serverAssertWithInfo(c,zobj,zzlLexValueLteMax(eptr,&range));
3363 
3364         /* Iterate over elements in range */
3365         while (eptr) {
3366             /* Abort when the node is no longer in range. */
3367             if (!zzlLexValueLteMax(eptr,&range)) {
3368                 break;
3369             } else {
3370                 count++;
3371                 zzlNext(zl,&eptr,&sptr);
3372             }
3373         }
3374     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3375         zset *zs = zobj->ptr;
3376         zskiplist *zsl = zs->zsl;
3377         zskiplistNode *zn;
3378         unsigned long rank;
3379 
3380         /* Find first element in range */
3381         zn = zslFirstInLexRange(zsl, &range);
3382 
3383         /* Use rank of first element, if any, to determine preliminary count */
3384         if (zn != NULL) {
3385             rank = zslGetRank(zsl, zn->score, zn->ele);
3386             count = (zsl->length - (rank - 1));
3387 
3388             /* Find last element in range */
3389             zn = zslLastInLexRange(zsl, &range);
3390 
3391             /* Use rank of last element, if any, to determine the actual count */
3392             if (zn != NULL) {
3393                 rank = zslGetRank(zsl, zn->score, zn->ele);
3394                 count -= (zsl->length - rank);
3395             }
3396         }
3397     } else {
3398         serverPanic("Unknown sorted set encoding");
3399     }
3400 
3401     zslFreeLexRange(&range);
3402     addReplyLongLong(c, count);
3403 }
3404 
3405 /* This command implements ZRANGEBYLEX, ZREVRANGEBYLEX. */
genericZrangebylexCommand(zrange_result_handler * handler,zlexrangespec * range,robj * zobj,int withscores,long offset,long limit,int reverse)3406 void genericZrangebylexCommand(zrange_result_handler *handler,
3407     zlexrangespec *range, robj *zobj, int withscores, long offset, long limit,
3408     int reverse)
3409 {
3410     unsigned long rangelen = 0;
3411 
3412     handler->beginResultEmission(handler);
3413 
3414     if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3415         unsigned char *zl = zobj->ptr;
3416         unsigned char *eptr, *sptr;
3417         unsigned char *vstr;
3418         unsigned int vlen;
3419         long long vlong;
3420 
3421         /* If reversed, get the last node in range as starting point. */
3422         if (reverse) {
3423             eptr = zzlLastInLexRange(zl,range);
3424         } else {
3425             eptr = zzlFirstInLexRange(zl,range);
3426         }
3427 
3428         /* Get score pointer for the first element. */
3429         if (eptr)
3430             sptr = lpNext(zl,eptr);
3431 
3432         /* If there is an offset, just traverse the number of elements without
3433          * checking the score because that is done in the next loop. */
3434         while (eptr && offset--) {
3435             if (reverse) {
3436                 zzlPrev(zl,&eptr,&sptr);
3437             } else {
3438                 zzlNext(zl,&eptr,&sptr);
3439             }
3440         }
3441 
3442         while (eptr && limit--) {
3443             double score = 0;
3444             if (withscores) /* don't bother to extract the score if it's gonna be ignored. */
3445                 score = zzlGetScore(sptr);
3446 
3447             /* Abort when the node is no longer in range. */
3448             if (reverse) {
3449                 if (!zzlLexValueGteMin(eptr,range)) break;
3450             } else {
3451                 if (!zzlLexValueLteMax(eptr,range)) break;
3452             }
3453 
3454             vstr = lpGetValue(eptr,&vlen,&vlong);
3455             rangelen++;
3456             if (vstr == NULL) {
3457                 handler->emitResultFromLongLong(handler, vlong, score);
3458             } else {
3459                 handler->emitResultFromCBuffer(handler, vstr, vlen, score);
3460             }
3461 
3462             /* Move to next node */
3463             if (reverse) {
3464                 zzlPrev(zl,&eptr,&sptr);
3465             } else {
3466                 zzlNext(zl,&eptr,&sptr);
3467             }
3468         }
3469     } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3470         zset *zs = zobj->ptr;
3471         zskiplist *zsl = zs->zsl;
3472         zskiplistNode *ln;
3473 
3474         /* If reversed, get the last node in range as starting point. */
3475         if (reverse) {
3476             ln = zslLastInLexRange(zsl,range);
3477         } else {
3478             ln = zslFirstInLexRange(zsl,range);
3479         }
3480 
3481         /* If there is an offset, just traverse the number of elements without
3482          * checking the score because that is done in the next loop. */
3483         while (ln && offset--) {
3484             if (reverse) {
3485                 ln = ln->backward;
3486             } else {
3487                 ln = ln->level[0].forward;
3488             }
3489         }
3490 
3491         while (ln && limit--) {
3492             /* Abort when the node is no longer in range. */
3493             if (reverse) {
3494                 if (!zslLexValueGteMin(ln->ele,range)) break;
3495             } else {
3496                 if (!zslLexValueLteMax(ln->ele,range)) break;
3497             }
3498 
3499             rangelen++;
3500             handler->emitResultFromCBuffer(handler, ln->ele, sdslen(ln->ele), ln->score);
3501 
3502             /* Move to next node */
3503             if (reverse) {
3504                 ln = ln->backward;
3505             } else {
3506                 ln = ln->level[0].forward;
3507             }
3508         }
3509     } else {
3510         serverPanic("Unknown sorted set encoding");
3511     }
3512 
3513     handler->finalizeResultEmission(handler, rangelen);
3514 }
3515 
3516 /* ZRANGEBYLEX <key> <min> <max> [LIMIT offset count] */
zrangebylexCommand(client * c)3517 void zrangebylexCommand(client *c) {
3518     zrange_result_handler handler;
3519     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3520     zrangeGenericCommand(&handler, 1, 0, ZRANGE_LEX, ZRANGE_DIRECTION_FORWARD);
3521 }
3522 
3523 /* ZREVRANGEBYLEX <key> <max> <min> [LIMIT offset count] */
zrevrangebylexCommand(client * c)3524 void zrevrangebylexCommand(client *c) {
3525     zrange_result_handler handler;
3526     zrangeResultHandlerInit(&handler, c, ZRANGE_CONSUMER_TYPE_CLIENT);
3527     zrangeGenericCommand(&handler, 1, 0, ZRANGE_LEX, ZRANGE_DIRECTION_REVERSE);
3528 }
3529 
3530 /**
3531  * This function handles ZRANGE and ZRANGESTORE, and also the deprecated
3532  * Z[REV]RANGE[BYPOS|BYLEX] commands.
3533  *
3534  * The simple ZRANGE and ZRANGESTORE can take _AUTO in rangetype and direction,
3535  * other command pass explicit value.
3536  *
3537  * The argc_start points to the src key argument, so following syntax is like:
3538  * <src> <min> <max> [BYSCORE | BYLEX] [REV] [WITHSCORES] [LIMIT offset count]
3539  */
zrangeGenericCommand(zrange_result_handler * handler,int argc_start,int store,zrange_type rangetype,zrange_direction direction)3540 void zrangeGenericCommand(zrange_result_handler *handler, int argc_start, int store,
3541                           zrange_type rangetype, zrange_direction direction)
3542 {
3543     client *c = handler->client;
3544     robj *key = c->argv[argc_start];
3545     robj *zobj;
3546     zrangespec range;
3547     zlexrangespec lexrange;
3548     int minidx = argc_start + 1;
3549     int maxidx = argc_start + 2;
3550 
3551     /* Options common to all */
3552     long opt_start = 0;
3553     long opt_end = 0;
3554     int opt_withscores = 0;
3555     long opt_offset = 0;
3556     long opt_limit = -1;
3557 
3558     /* Step 1: Skip the <src> <min> <max> args and parse remaining optional arguments. */
3559     for (int j=argc_start + 3; j < c->argc; j++) {
3560         int leftargs = c->argc-j-1;
3561         if (!store && !strcasecmp(c->argv[j]->ptr,"withscores")) {
3562             opt_withscores = 1;
3563         } else if (!strcasecmp(c->argv[j]->ptr,"limit") && leftargs >= 2) {
3564             if ((getLongFromObjectOrReply(c, c->argv[j+1], &opt_offset, NULL) != C_OK) ||
3565                 (getLongFromObjectOrReply(c, c->argv[j+2], &opt_limit, NULL) != C_OK))
3566             {
3567                 return;
3568             }
3569             j += 2;
3570         } else if (direction == ZRANGE_DIRECTION_AUTO &&
3571                    !strcasecmp(c->argv[j]->ptr,"rev"))
3572         {
3573             direction = ZRANGE_DIRECTION_REVERSE;
3574         } else if (rangetype == ZRANGE_AUTO &&
3575                    !strcasecmp(c->argv[j]->ptr,"bylex"))
3576         {
3577             rangetype = ZRANGE_LEX;
3578         } else if (rangetype == ZRANGE_AUTO &&
3579                    !strcasecmp(c->argv[j]->ptr,"byscore"))
3580         {
3581             rangetype = ZRANGE_SCORE;
3582         } else {
3583             addReplyErrorObject(c,shared.syntaxerr);
3584             return;
3585         }
3586     }
3587 
3588     /* Use defaults if not overridden by arguments. */
3589     if (direction == ZRANGE_DIRECTION_AUTO)
3590         direction = ZRANGE_DIRECTION_FORWARD;
3591     if (rangetype == ZRANGE_AUTO)
3592         rangetype = ZRANGE_RANK;
3593 
3594     /* Check for conflicting arguments. */
3595     if (opt_limit != -1 && rangetype == ZRANGE_RANK) {
3596         addReplyError(c,"syntax error, LIMIT is only supported in combination with either BYSCORE or BYLEX");
3597         return;
3598     }
3599     if (opt_withscores && rangetype == ZRANGE_LEX) {
3600         addReplyError(c,"syntax error, WITHSCORES not supported in combination with BYLEX");
3601         return;
3602     }
3603 
3604     if (direction == ZRANGE_DIRECTION_REVERSE &&
3605         ((ZRANGE_SCORE == rangetype) || (ZRANGE_LEX == rangetype)))
3606     {
3607         /* Range is given as [max,min] */
3608         int tmp = maxidx;
3609         maxidx = minidx;
3610         minidx = tmp;
3611     }
3612 
3613     /* Step 2: Parse the range. */
3614     switch (rangetype) {
3615     case ZRANGE_AUTO:
3616     case ZRANGE_RANK:
3617         /* Z[REV]RANGE, ZRANGESTORE [REV]RANGE */
3618         if ((getLongFromObjectOrReply(c, c->argv[minidx], &opt_start,NULL) != C_OK) ||
3619             (getLongFromObjectOrReply(c, c->argv[maxidx], &opt_end,NULL) != C_OK))
3620         {
3621             return;
3622         }
3623         break;
3624 
3625     case ZRANGE_SCORE:
3626         /* Z[REV]RANGEBYSCORE, ZRANGESTORE [REV]RANGEBYSCORE */
3627         if (zslParseRange(c->argv[minidx], c->argv[maxidx], &range) != C_OK) {
3628             addReplyError(c, "min or max is not a float");
3629             return;
3630         }
3631         break;
3632 
3633     case ZRANGE_LEX:
3634         /* Z[REV]RANGEBYLEX, ZRANGESTORE [REV]RANGEBYLEX */
3635         if (zslParseLexRange(c->argv[minidx], c->argv[maxidx], &lexrange) != C_OK) {
3636             addReplyError(c, "min or max not valid string range item");
3637             return;
3638         }
3639         break;
3640     }
3641 
3642     if (opt_withscores || store) {
3643         zrangeResultHandlerScoreEmissionEnable(handler);
3644     }
3645 
3646     /* Step 3: Lookup the key and get the range. */
3647     zobj = lookupKeyRead(c->db, key);
3648     if (zobj == NULL) {
3649         if (store) {
3650             handler->beginResultEmission(handler);
3651             handler->finalizeResultEmission(handler, 0);
3652         } else {
3653             addReply(c, shared.emptyarray);
3654         }
3655         goto cleanup;
3656     }
3657 
3658     if (checkType(c,zobj,OBJ_ZSET)) goto cleanup;
3659 
3660     /* Step 4: Pass this to the command-specific handler. */
3661     switch (rangetype) {
3662     case ZRANGE_AUTO:
3663     case ZRANGE_RANK:
3664         genericZrangebyrankCommand(handler, zobj, opt_start, opt_end,
3665             opt_withscores || store, direction == ZRANGE_DIRECTION_REVERSE);
3666         break;
3667 
3668     case ZRANGE_SCORE:
3669         genericZrangebyscoreCommand(handler, &range, zobj, opt_offset,
3670             opt_limit, direction == ZRANGE_DIRECTION_REVERSE);
3671         break;
3672 
3673     case ZRANGE_LEX:
3674         genericZrangebylexCommand(handler, &lexrange, zobj, opt_withscores || store,
3675             opt_offset, opt_limit, direction == ZRANGE_DIRECTION_REVERSE);
3676         break;
3677     }
3678 
3679     /* Instead of returning here, we'll just fall-through the clean-up. */
3680 
3681 cleanup:
3682 
3683     if (rangetype == ZRANGE_LEX) {
3684         zslFreeLexRange(&lexrange);
3685     }
3686 }
3687 
zcardCommand(client * c)3688 void zcardCommand(client *c) {
3689     robj *key = c->argv[1];
3690     robj *zobj;
3691 
3692     if ((zobj = lookupKeyReadOrReply(c,key,shared.czero)) == NULL ||
3693         checkType(c,zobj,OBJ_ZSET)) return;
3694 
3695     addReplyLongLong(c,zsetLength(zobj));
3696 }
3697 
zscoreCommand(client * c)3698 void zscoreCommand(client *c) {
3699     robj *key = c->argv[1];
3700     robj *zobj;
3701     double score;
3702 
3703     if ((zobj = lookupKeyReadOrReply(c,key,shared.null[c->resp])) == NULL ||
3704         checkType(c,zobj,OBJ_ZSET)) return;
3705 
3706     if (zsetScore(zobj,c->argv[2]->ptr,&score) == C_ERR) {
3707         addReplyNull(c);
3708     } else {
3709         addReplyDouble(c,score);
3710     }
3711 }
3712 
zmscoreCommand(client * c)3713 void zmscoreCommand(client *c) {
3714     robj *key = c->argv[1];
3715     robj *zobj;
3716     double score;
3717     zobj = lookupKeyRead(c->db,key);
3718     if (checkType(c,zobj,OBJ_ZSET)) return;
3719 
3720     addReplyArrayLen(c,c->argc - 2);
3721     for (int j = 2; j < c->argc; j++) {
3722         /* Treat a missing set the same way as an empty set */
3723         if (zobj == NULL || zsetScore(zobj,c->argv[j]->ptr,&score) == C_ERR) {
3724             addReplyNull(c);
3725         } else {
3726             addReplyDouble(c,score);
3727         }
3728     }
3729 }
3730 
zrankGenericCommand(client * c,int reverse)3731 void zrankGenericCommand(client *c, int reverse) {
3732     robj *key = c->argv[1];
3733     robj *ele = c->argv[2];
3734     robj *zobj;
3735     long rank;
3736 
3737     if ((zobj = lookupKeyReadOrReply(c,key,shared.null[c->resp])) == NULL ||
3738         checkType(c,zobj,OBJ_ZSET)) return;
3739 
3740     serverAssertWithInfo(c,ele,sdsEncodedObject(ele));
3741     rank = zsetRank(zobj,ele->ptr,reverse);
3742     if (rank >= 0) {
3743         addReplyLongLong(c,rank);
3744     } else {
3745         addReplyNull(c);
3746     }
3747 }
3748 
zrankCommand(client * c)3749 void zrankCommand(client *c) {
3750     zrankGenericCommand(c, 0);
3751 }
3752 
zrevrankCommand(client * c)3753 void zrevrankCommand(client *c) {
3754     zrankGenericCommand(c, 1);
3755 }
3756 
zscanCommand(client * c)3757 void zscanCommand(client *c) {
3758     robj *o;
3759     unsigned long cursor;
3760 
3761     if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return;
3762     if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL ||
3763         checkType(c,o,OBJ_ZSET)) return;
3764     scanGenericCommand(c,o,cursor);
3765 }
3766 
3767 /* This command implements the generic zpop operation, used by:
3768  * ZPOPMIN, ZPOPMAX, BZPOPMIN, BZPOPMAX and ZMPOP. This function is also used
3769  * inside blocked.c in the unblocking stage of BZPOPMIN, BZPOPMAX and BZMPOP.
3770  *
3771  * If 'emitkey' is true also the key name is emitted, useful for the blocking
3772  * behavior of BZPOP[MIN|MAX], since we can block into multiple keys.
3773  * Or in ZMPOP/BZMPOP, because we also can take multiple keys.
3774  *
3775  * 'count' is the number of elements requested to pop, or -1 for plain single pop.
3776  *
3777  * 'use_nested_array' when false it generates a flat array (with or without key name).
3778  * When true, it generates a nested 2 level array of field + score pairs, or 3 level when emitkey is set.
3779  *
3780  * 'reply_nil_when_empty' when true we reply a NIL if we are not able to pop up any elements.
3781  * Like in ZMPOP/BZMPOP we reply with a structured nested array containing key name
3782  * and member + score pairs. In these commands, we reply with null when we have no result.
3783  * Otherwise in ZPOPMIN/ZPOPMAX we reply an empty array by default.
3784  *
3785  * 'deleted' is an optional output argument to get an indication
3786  * if the key got deleted by this function.
3787  * */
genericZpopCommand(client * c,robj ** keyv,int keyc,int where,int emitkey,long count,int use_nested_array,int reply_nil_when_empty,int * deleted)3788 void genericZpopCommand(client *c, robj **keyv, int keyc, int where, int emitkey,
3789                         long count, int use_nested_array, int reply_nil_when_empty, int *deleted) {
3790     int idx;
3791     robj *key = NULL;
3792     robj *zobj = NULL;
3793     sds ele;
3794     double score;
3795 
3796     if (deleted) *deleted = 0;
3797 
3798     /* Check type and break on the first error, otherwise identify candidate. */
3799     idx = 0;
3800     while (idx < keyc) {
3801         key = keyv[idx++];
3802         zobj = lookupKeyWrite(c->db,key);
3803         if (!zobj) continue;
3804         if (checkType(c,zobj,OBJ_ZSET)) return;
3805         break;
3806     }
3807 
3808     /* No candidate for zpopping, return empty. */
3809     if (!zobj) {
3810         if (reply_nil_when_empty) {
3811             addReplyNullArray(c);
3812         } else {
3813             addReply(c,shared.emptyarray);
3814         }
3815         return;
3816     }
3817 
3818     if (count == 0) {
3819         /* ZPOPMIN/ZPOPMAX with count 0. */
3820         addReply(c, shared.emptyarray);
3821         return;
3822     }
3823 
3824     long result_count = 0;
3825 
3826     /* When count is -1, we need to correct it to 1 for plain single pop. */
3827     if (count == -1) count = 1;
3828 
3829     long llen = zsetLength(zobj);
3830     long rangelen = (count > llen) ? llen : count;
3831 
3832     if (!use_nested_array && !emitkey) {
3833         /* ZPOPMIN/ZPOPMAX with or without COUNT option in RESP2. */
3834         addReplyArrayLen(c, rangelen * 2);
3835     } else if (use_nested_array && !emitkey) {
3836         /* ZPOPMIN/ZPOPMAX with COUNT option in RESP3. */
3837         addReplyArrayLen(c, rangelen);
3838     } else if (!use_nested_array && emitkey) {
3839         /* BZPOPMIN/BZPOPMAX in RESP2 and RESP3. */
3840         addReplyArrayLen(c, rangelen * 2 + 1);
3841         addReplyBulk(c, key);
3842     } else if (use_nested_array && emitkey) {
3843         /* ZMPOP/BZMPOP in RESP2 and RESP3. */
3844         addReplyArrayLen(c, 2);
3845         addReplyBulk(c, key);
3846         addReplyArrayLen(c, rangelen);
3847     }
3848 
3849     /* Remove the element. */
3850     do {
3851         if (zobj->encoding == OBJ_ENCODING_LISTPACK) {
3852             unsigned char *zl = zobj->ptr;
3853             unsigned char *eptr, *sptr;
3854             unsigned char *vstr;
3855             unsigned int vlen;
3856             long long vlong;
3857 
3858             /* Get the first or last element in the sorted set. */
3859             eptr = lpSeek(zl,where == ZSET_MAX ? -2 : 0);
3860             serverAssertWithInfo(c,zobj,eptr != NULL);
3861             vstr = lpGetValue(eptr,&vlen,&vlong);
3862             if (vstr == NULL)
3863                 ele = sdsfromlonglong(vlong);
3864             else
3865                 ele = sdsnewlen(vstr,vlen);
3866 
3867             /* Get the score. */
3868             sptr = lpNext(zl,eptr);
3869             serverAssertWithInfo(c,zobj,sptr != NULL);
3870             score = zzlGetScore(sptr);
3871         } else if (zobj->encoding == OBJ_ENCODING_SKIPLIST) {
3872             zset *zs = zobj->ptr;
3873             zskiplist *zsl = zs->zsl;
3874             zskiplistNode *zln;
3875 
3876             /* Get the first or last element in the sorted set. */
3877             zln = (where == ZSET_MAX ? zsl->tail :
3878                                        zsl->header->level[0].forward);
3879 
3880             /* There must be an element in the sorted set. */
3881             serverAssertWithInfo(c,zobj,zln != NULL);
3882             ele = sdsdup(zln->ele);
3883             score = zln->score;
3884         } else {
3885             serverPanic("Unknown sorted set encoding");
3886         }
3887 
3888         serverAssertWithInfo(c,zobj,zsetDel(zobj,ele));
3889         server.dirty++;
3890 
3891         if (result_count == 0) { /* Do this only for the first iteration. */
3892             char *events[2] = {"zpopmin","zpopmax"};
3893             notifyKeyspaceEvent(NOTIFY_ZSET,events[where],key,c->db->id);
3894             signalModifiedKey(c,c->db,key);
3895         }
3896 
3897         if (use_nested_array) {
3898             addReplyArrayLen(c,2);
3899         }
3900         addReplyBulkCBuffer(c,ele,sdslen(ele));
3901         addReplyDouble(c,score);
3902         sdsfree(ele);
3903         ++result_count;
3904     } while(--rangelen);
3905 
3906     /* Remove the key, if indeed needed. */
3907     if (zsetLength(zobj) == 0) {
3908         if (deleted) *deleted = 1;
3909 
3910         dbDelete(c->db,key);
3911         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",key,c->db->id);
3912     }
3913 
3914     if (c->cmd->proc == zmpopCommand) {
3915         /* Always replicate it as ZPOP[MIN|MAX] with COUNT option instead of ZMPOP. */
3916         robj *count_obj = createStringObjectFromLongLong((count > llen) ? llen : count);
3917         rewriteClientCommandVector(c, 3,
3918                                    (where == ZSET_MAX) ? shared.zpopmax : shared.zpopmin,
3919                                    key, count_obj);
3920         decrRefCount(count_obj);
3921     }
3922 }
3923 
3924 /* ZPOPMIN/ZPOPMAX key [<count>] */
zpopMinMaxCommand(client * c,int where)3925 void zpopMinMaxCommand(client *c, int where) {
3926     if (c->argc > 3) {
3927         addReplyErrorObject(c,shared.syntaxerr);
3928         return;
3929     }
3930 
3931     long count = -1; /* -1 for plain single pop. */
3932     if (c->argc == 3 && getPositiveLongFromObjectOrReply(c, c->argv[2], &count, NULL) != C_OK)
3933         return;
3934 
3935     /* Respond with a single (flat) array in RESP2 or if count is -1
3936      * (returning a single element). In RESP3, when count > 0 use nested array. */
3937     int use_nested_array = (c->resp > 2 && count != -1);
3938 
3939     genericZpopCommand(c, &c->argv[1], 1, where, 0, count, use_nested_array, 0, NULL);
3940 }
3941 
3942 /* ZPOPMIN key [<count>] */
zpopminCommand(client * c)3943 void zpopminCommand(client *c) {
3944     zpopMinMaxCommand(c, ZSET_MIN);
3945 }
3946 
3947 /* ZMAXPOP key [<count>] */
zpopmaxCommand(client * c)3948 void zpopmaxCommand(client *c) {
3949     zpopMinMaxCommand(c, ZSET_MAX);
3950 }
3951 
3952 /* BZPOPMIN, BZPOPMAX, BZMPOP actual implementation.
3953  *
3954  * 'numkeys' is the number of keys.
3955  *
3956  * 'timeout_idx' parameter position of block timeout.
3957  *
3958  * 'where' ZSET_MIN or ZSET_MAX.
3959  *
3960  * 'count' is the number of elements requested to pop, or -1 for plain single pop.
3961  *
3962  * 'use_nested_array' when false it generates a flat array (with or without key name).
3963  * When true, it generates a nested 3 level array of keyname, field + score pairs.
3964  * */
blockingGenericZpopCommand(client * c,robj ** keys,int numkeys,int where,int timeout_idx,long count,int use_nested_array,int reply_nil_when_empty)3965 void blockingGenericZpopCommand(client *c, robj **keys, int numkeys, int where,
3966                                 int timeout_idx, long count, int use_nested_array, int reply_nil_when_empty) {
3967     robj *o;
3968     robj *key;
3969     mstime_t timeout;
3970     int j;
3971 
3972     if (getTimeoutFromObjectOrReply(c,c->argv[timeout_idx],&timeout,UNIT_SECONDS)
3973         != C_OK) return;
3974 
3975     for (j = 0; j < numkeys; j++) {
3976         key = keys[j];
3977         o = lookupKeyWrite(c->db,key);
3978         /* Non-existing key, move to next key. */
3979         if (o == NULL) continue;
3980 
3981         if (checkType(c,o,OBJ_ZSET)) return;
3982 
3983         long llen = zsetLength(o);
3984         /* Empty zset, move to next key. */
3985         if (llen == 0) continue;
3986 
3987         /* Non empty zset, this is like a normal ZPOP[MIN|MAX]. */
3988         genericZpopCommand(c, &key, 1, where, 1, count, use_nested_array, reply_nil_when_empty, NULL);
3989 
3990         if (count == -1) {
3991             /* Replicate it as ZPOP[MIN|MAX] instead of BZPOP[MIN|MAX]. */
3992             rewriteClientCommandVector(c,2,
3993                                        (where == ZSET_MAX) ? shared.zpopmax : shared.zpopmin,
3994                                        key);
3995         } else {
3996             /* Replicate it as ZPOP[MIN|MAX] with COUNT option. */
3997             robj *count_obj = createStringObjectFromLongLong((count > llen) ? llen : count);
3998             rewriteClientCommandVector(c, 3,
3999                                        (where == ZSET_MAX) ? shared.zpopmax : shared.zpopmin,
4000                                        key, count_obj);
4001             decrRefCount(count_obj);
4002         }
4003 
4004         return;
4005     }
4006 
4007     /* If we are not allowed to block the client and the zset is empty the only thing
4008      * we can do is treating it as a timeout (even with timeout 0). */
4009     if (c->flags & CLIENT_DENY_BLOCKING) {
4010         addReplyNullArray(c);
4011         return;
4012     }
4013 
4014     /* If the keys do not exist we must block */
4015     struct blockPos pos = {where};
4016     blockForKeys(c,BLOCKED_ZSET,c->argv+1,c->argc-2,count,timeout,NULL,&pos,NULL);
4017 }
4018 
4019 // BZPOPMIN key [key ...] timeout
bzpopminCommand(client * c)4020 void bzpopminCommand(client *c) {
4021     blockingGenericZpopCommand(c, c->argv+1, c->argc-2, ZSET_MIN, c->argc-1, -1, 0, 0);
4022 }
4023 
4024 // BZPOPMAX key [key ...] timeout
bzpopmaxCommand(client * c)4025 void bzpopmaxCommand(client *c) {
4026     blockingGenericZpopCommand(c, c->argv+1, c->argc-2, ZSET_MAX, c->argc-1, -1, 0, 0);
4027 }
4028 
zarndmemberReplyWithListpack(client * c,unsigned int count,listpackEntry * keys,listpackEntry * vals)4029 static void zarndmemberReplyWithListpack(client *c, unsigned int count, listpackEntry *keys, listpackEntry *vals) {
4030     for (unsigned long i = 0; i < count; i++) {
4031         if (vals && c->resp > 2)
4032             addReplyArrayLen(c,2);
4033         if (keys[i].sval)
4034             addReplyBulkCBuffer(c, keys[i].sval, keys[i].slen);
4035         else
4036             addReplyBulkLongLong(c, keys[i].lval);
4037         if (vals) {
4038             if (vals[i].sval) {
4039                 addReplyDouble(c, zzlStrtod(vals[i].sval,vals[i].slen));
4040             } else
4041                 addReplyDouble(c, vals[i].lval);
4042         }
4043     }
4044 }
4045 
4046 /* How many times bigger should be the zset compared to the requested size
4047  * for us to not use the "remove elements" strategy? Read later in the
4048  * implementation for more info. */
4049 #define ZRANDMEMBER_SUB_STRATEGY_MUL 3
4050 
4051 /* If client is trying to ask for a very large number of random elements,
4052  * queuing may consume an unlimited amount of memory, so we want to limit
4053  * the number of randoms per time. */
4054 #define ZRANDMEMBER_RANDOM_SAMPLE_LIMIT 1000
4055 
zrandmemberWithCountCommand(client * c,long l,int withscores)4056 void zrandmemberWithCountCommand(client *c, long l, int withscores) {
4057     unsigned long count, size;
4058     int uniq = 1;
4059     robj *zsetobj;
4060 
4061     if ((zsetobj = lookupKeyReadOrReply(c, c->argv[1], shared.emptyarray))
4062         == NULL || checkType(c, zsetobj, OBJ_ZSET)) return;
4063     size = zsetLength(zsetobj);
4064 
4065     if(l >= 0) {
4066         count = (unsigned long) l;
4067     } else {
4068         count = -l;
4069         uniq = 0;
4070     }
4071 
4072     /* If count is zero, serve it ASAP to avoid special cases later. */
4073     if (count == 0) {
4074         addReply(c,shared.emptyarray);
4075         return;
4076     }
4077 
4078     /* CASE 1: The count was negative, so the extraction method is just:
4079      * "return N random elements" sampling the whole set every time.
4080      * This case is trivial and can be served without auxiliary data
4081      * structures. This case is the only one that also needs to return the
4082      * elements in random order. */
4083     if (!uniq || count == 1) {
4084         if (withscores && c->resp == 2)
4085             addReplyArrayLen(c, count*2);
4086         else
4087             addReplyArrayLen(c, count);
4088         if (zsetobj->encoding == OBJ_ENCODING_SKIPLIST) {
4089             zset *zs = zsetobj->ptr;
4090             while (count--) {
4091                 dictEntry *de = dictGetFairRandomKey(zs->dict);
4092                 sds key = dictGetKey(de);
4093                 if (withscores && c->resp > 2)
4094                     addReplyArrayLen(c,2);
4095                 addReplyBulkCBuffer(c, key, sdslen(key));
4096                 if (withscores)
4097                     addReplyDouble(c, *(double*)dictGetVal(de));
4098             }
4099         } else if (zsetobj->encoding == OBJ_ENCODING_LISTPACK) {
4100             listpackEntry *keys, *vals = NULL;
4101             unsigned long limit, sample_count;
4102             limit = count > ZRANDMEMBER_RANDOM_SAMPLE_LIMIT ? ZRANDMEMBER_RANDOM_SAMPLE_LIMIT : count;
4103             keys = zmalloc(sizeof(listpackEntry)*limit);
4104             if (withscores)
4105                 vals = zmalloc(sizeof(listpackEntry)*limit);
4106             while (count) {
4107                 sample_count = count > limit ? limit : count;
4108                 count -= sample_count;
4109                 lpRandomPairs(zsetobj->ptr, sample_count, keys, vals);
4110                 zarndmemberReplyWithListpack(c, sample_count, keys, vals);
4111             }
4112             zfree(keys);
4113             zfree(vals);
4114         }
4115         return;
4116     }
4117 
4118     zsetopsrc src;
4119     zsetopval zval;
4120     src.subject = zsetobj;
4121     src.type = zsetobj->type;
4122     src.encoding = zsetobj->encoding;
4123     zuiInitIterator(&src);
4124     memset(&zval, 0, sizeof(zval));
4125 
4126     /* Initiate reply count, RESP3 responds with nested array, RESP2 with flat one. */
4127     long reply_size = count < size ? count : size;
4128     if (withscores && c->resp == 2)
4129         addReplyArrayLen(c, reply_size*2);
4130     else
4131         addReplyArrayLen(c, reply_size);
4132 
4133     /* CASE 2:
4134     * The number of requested elements is greater than the number of
4135     * elements inside the zset: simply return the whole zset. */
4136     if (count >= size) {
4137         while (zuiNext(&src, &zval)) {
4138             if (withscores && c->resp > 2)
4139                 addReplyArrayLen(c,2);
4140             addReplyBulkSds(c, zuiNewSdsFromValue(&zval));
4141             if (withscores)
4142                 addReplyDouble(c, zval.score);
4143         }
4144         zuiClearIterator(&src);
4145         return;
4146     }
4147 
4148     /* CASE 3:
4149      * The number of elements inside the zset is not greater than
4150      * ZRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
4151      * In this case we create a dict from scratch with all the elements, and
4152      * subtract random elements to reach the requested number of elements.
4153      *
4154      * This is done because if the number of requested elements is just
4155      * a bit less than the number of elements in the set, the natural approach
4156      * used into CASE 4 is highly inefficient. */
4157     if (count*ZRANDMEMBER_SUB_STRATEGY_MUL > size) {
4158         dict *d = dictCreate(&sdsReplyDictType);
4159         dictExpand(d, size);
4160         /* Add all the elements into the temporary dictionary. */
4161         while (zuiNext(&src, &zval)) {
4162             sds key = zuiNewSdsFromValue(&zval);
4163             dictEntry *de = dictAddRaw(d, key, NULL);
4164             serverAssert(de);
4165             if (withscores)
4166                 dictSetDoubleVal(de, zval.score);
4167         }
4168         serverAssert(dictSize(d) == size);
4169 
4170         /* Remove random elements to reach the right count. */
4171         while (size > count) {
4172             dictEntry *de;
4173             de = dictGetFairRandomKey(d);
4174             dictUnlink(d,dictGetKey(de));
4175             sdsfree(dictGetKey(de));
4176             dictFreeUnlinkedEntry(d,de);
4177             size--;
4178         }
4179 
4180         /* Reply with what's in the dict and release memory */
4181         dictIterator *di;
4182         dictEntry *de;
4183         di = dictGetIterator(d);
4184         while ((de = dictNext(di)) != NULL) {
4185             if (withscores && c->resp > 2)
4186                 addReplyArrayLen(c,2);
4187             addReplyBulkSds(c, dictGetKey(de));
4188             if (withscores)
4189                 addReplyDouble(c, dictGetDoubleVal(de));
4190         }
4191 
4192         dictReleaseIterator(di);
4193         dictRelease(d);
4194     }
4195 
4196     /* CASE 4: We have a big zset compared to the requested number of elements.
4197      * In this case we can simply get random elements from the zset and add
4198      * to the temporary set, trying to eventually get enough unique elements
4199      * to reach the specified count. */
4200     else {
4201         if (zsetobj->encoding == OBJ_ENCODING_LISTPACK) {
4202             /* it is inefficient to repeatedly pick one random element from a
4203              * listpack. so we use this instead: */
4204             listpackEntry *keys, *vals = NULL;
4205             keys = zmalloc(sizeof(listpackEntry)*count);
4206             if (withscores)
4207                 vals = zmalloc(sizeof(listpackEntry)*count);
4208             serverAssert(lpRandomPairsUnique(zsetobj->ptr, count, keys, vals) == count);
4209             zarndmemberReplyWithListpack(c, count, keys, vals);
4210             zfree(keys);
4211             zfree(vals);
4212             zuiClearIterator(&src);
4213             return;
4214         }
4215 
4216         /* Hashtable encoding (generic implementation) */
4217         unsigned long added = 0;
4218         dict *d = dictCreate(&hashDictType);
4219         dictExpand(d, count);
4220 
4221         while (added < count) {
4222             listpackEntry key;
4223             double score;
4224             zsetTypeRandomElement(zsetobj, size, &key, withscores ? &score: NULL);
4225 
4226             /* Try to add the object to the dictionary. If it already exists
4227             * free it, otherwise increment the number of objects we have
4228             * in the result dictionary. */
4229             sds skey = zsetSdsFromListpackEntry(&key);
4230             if (dictAdd(d,skey,NULL) != DICT_OK) {
4231                 sdsfree(skey);
4232                 continue;
4233             }
4234             added++;
4235 
4236             if (withscores && c->resp > 2)
4237                 addReplyArrayLen(c,2);
4238             zsetReplyFromListpackEntry(c, &key);
4239             if (withscores)
4240                 addReplyDouble(c, score);
4241         }
4242 
4243         /* Release memory */
4244         dictRelease(d);
4245     }
4246     zuiClearIterator(&src);
4247 }
4248 
4249 /* ZRANDMEMBER key [<count> [WITHSCORES]] */
zrandmemberCommand(client * c)4250 void zrandmemberCommand(client *c) {
4251     long l;
4252     int withscores = 0;
4253     robj *zset;
4254     listpackEntry ele;
4255 
4256     if (c->argc >= 3) {
4257         if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
4258         if (c->argc > 4 || (c->argc == 4 && strcasecmp(c->argv[3]->ptr,"withscores"))) {
4259             addReplyErrorObject(c,shared.syntaxerr);
4260             return;
4261         } else if (c->argc == 4)
4262             withscores = 1;
4263         zrandmemberWithCountCommand(c, l, withscores);
4264         return;
4265     }
4266 
4267     /* Handle variant without <count> argument. Reply with simple bulk string */
4268     if ((zset = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))== NULL ||
4269         checkType(c,zset,OBJ_ZSET)) {
4270         return;
4271     }
4272 
4273     zsetTypeRandomElement(zset, zsetLength(zset), &ele,NULL);
4274     zsetReplyFromListpackEntry(c,&ele);
4275 }
4276 
4277 /* ZMPOP/BZMPOP
4278  *
4279  * 'numkeys_idx' parameter position of key number.
4280  * 'is_block' this indicates whether it is a blocking variant. */
zmpopGenericCommand(client * c,int numkeys_idx,int is_block)4281 void zmpopGenericCommand(client *c, int numkeys_idx, int is_block) {
4282     long j;
4283     long numkeys = 0;      /* Number of keys. */
4284     int where = 0;         /* ZSET_MIN or ZSET_MAX. */
4285     long count = -1;       /* Reply will consist of up to count elements, depending on the zset's length. */
4286 
4287     /* Parse the numkeys. */
4288     if (getRangeLongFromObjectOrReply(c, c->argv[numkeys_idx], 1, LONG_MAX,
4289                                       &numkeys, "numkeys should be greater than 0") != C_OK)
4290         return;
4291 
4292     /* Parse the where. where_idx: the index of where in the c->argv. */
4293     long where_idx = numkeys_idx + numkeys + 1;
4294     if (where_idx >= c->argc) {
4295         addReplyErrorObject(c, shared.syntaxerr);
4296         return;
4297     }
4298     if (!strcasecmp(c->argv[where_idx]->ptr, "MIN")) {
4299         where = ZSET_MIN;
4300     } else if (!strcasecmp(c->argv[where_idx]->ptr, "MAX")) {
4301         where = ZSET_MAX;
4302     } else {
4303         addReplyErrorObject(c, shared.syntaxerr);
4304         return;
4305     }
4306 
4307     /* Parse the optional arguments. */
4308     for (j = where_idx + 1; j < c->argc; j++) {
4309         char *opt = c->argv[j]->ptr;
4310         int moreargs = (c->argc - 1) - j;
4311 
4312         if (count == -1 && !strcasecmp(opt, "COUNT") && moreargs) {
4313             j++;
4314             if (getRangeLongFromObjectOrReply(c, c->argv[j], 1, LONG_MAX,
4315                                               &count,"count should be greater than 0") != C_OK)
4316                 return;
4317         } else {
4318             addReplyErrorObject(c, shared.syntaxerr);
4319             return;
4320         }
4321     }
4322 
4323     if (count == -1) count = 1;
4324 
4325     if (is_block) {
4326         /* BLOCK. We will handle CLIENT_DENY_BLOCKING flag in blockingGenericZpopCommand. */
4327         blockingGenericZpopCommand(c, c->argv+numkeys_idx+1, numkeys, where, 1, count, 1, 1);
4328     } else {
4329         /* NON-BLOCK */
4330         genericZpopCommand(c, c->argv+numkeys_idx+1, numkeys, where, 1, count, 1, 1, NULL);
4331     }
4332 }
4333 
4334 /* ZMPOP numkeys [<key> ...] MIN|MAX [COUNT count] */
zmpopCommand(client * c)4335 void zmpopCommand(client *c) {
4336     zmpopGenericCommand(c, 1, 0);
4337 }
4338 
4339 /* BZMPOP timeout numkeys [<key> ...] MIN|MAX [COUNT count] */
bzmpopCommand(client * c)4340 void bzmpopCommand(client *c) {
4341     zmpopGenericCommand(c, 2, 1);
4342 }
4343