1 /*
2  * Copyright (c) 2009-2012, Salvatore Sanfilippo <antirez at gmail dot com>
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  *   * Redistributions of source code must retain the above copyright notice,
9  *     this list of conditions and the following disclaimer.
10  *   * Redistributions in binary form must reproduce the above copyright
11  *     notice, this list of conditions and the following disclaimer in the
12  *     documentation and/or other materials provided with the distribution.
13  *   * Neither the name of Redis nor the names of its contributors may be used
14  *     to endorse or promote products derived from this software without
15  *     specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
20  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
21  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
22  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
23  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
24  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
25  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
26  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
27  * POSSIBILITY OF SUCH DAMAGE.
28  */
29 
30 #include "server.h"
31 
32 /*-----------------------------------------------------------------------------
33  * Set Commands
34  *----------------------------------------------------------------------------*/
35 
36 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
37                               robj *dstkey, int op);
38 
39 /* Factory method to return a set that *can* hold "value". When the object has
40  * an integer-encodable value, an intset will be returned. Otherwise a regular
41  * hash table. */
setTypeCreate(sds value)42 robj *setTypeCreate(sds value) {
43     if (isSdsRepresentableAsLongLong(value,NULL) == C_OK)
44         return createIntsetObject();
45     return createSetObject();
46 }
47 
48 /* Add the specified value into a set.
49  *
50  * If the value was already member of the set, nothing is done and 0 is
51  * returned, otherwise the new element is added and 1 is returned. */
setTypeAdd(robj * subject,sds value)52 int setTypeAdd(robj *subject, sds value) {
53     long long llval;
54     if (subject->encoding == OBJ_ENCODING_HT) {
55         dict *ht = subject->ptr;
56         dictEntry *de = dictAddRaw(ht,value,NULL);
57         if (de) {
58             dictSetKey(ht,de,sdsdup(value));
59             dictSetVal(ht,de,NULL);
60             return 1;
61         }
62     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
63         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
64             uint8_t success = 0;
65             subject->ptr = intsetAdd(subject->ptr,llval,&success);
66             if (success) {
67                 /* Convert to regular set when the intset contains
68                  * too many entries. */
69                 size_t max_entries = server.set_max_intset_entries;
70                 /* limit to 1G entries due to intset internals. */
71                 if (max_entries >= 1<<30) max_entries = 1<<30;
72                 if (intsetLen(subject->ptr) > max_entries)
73                     setTypeConvert(subject,OBJ_ENCODING_HT);
74                 return 1;
75             }
76         } else {
77             /* Failed to get integer from object, convert to regular set. */
78             setTypeConvert(subject,OBJ_ENCODING_HT);
79 
80             /* The set *was* an intset and this value is not integer
81              * encodable, so dictAdd should always work. */
82             serverAssert(dictAdd(subject->ptr,sdsdup(value),NULL) == DICT_OK);
83             return 1;
84         }
85     } else {
86         serverPanic("Unknown set encoding");
87     }
88     return 0;
89 }
90 
setTypeRemove(robj * setobj,sds value)91 int setTypeRemove(robj *setobj, sds value) {
92     long long llval;
93     if (setobj->encoding == OBJ_ENCODING_HT) {
94         if (dictDelete(setobj->ptr,value) == DICT_OK) {
95             if (htNeedsResize(setobj->ptr)) dictResize(setobj->ptr);
96             return 1;
97         }
98     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
99         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
100             int success;
101             setobj->ptr = intsetRemove(setobj->ptr,llval,&success);
102             if (success) return 1;
103         }
104     } else {
105         serverPanic("Unknown set encoding");
106     }
107     return 0;
108 }
109 
setTypeIsMember(robj * subject,sds value)110 int setTypeIsMember(robj *subject, sds value) {
111     long long llval;
112     if (subject->encoding == OBJ_ENCODING_HT) {
113         return dictFind((dict*)subject->ptr,value) != NULL;
114     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
115         if (isSdsRepresentableAsLongLong(value,&llval) == C_OK) {
116             return intsetFind((intset*)subject->ptr,llval);
117         }
118     } else {
119         serverPanic("Unknown set encoding");
120     }
121     return 0;
122 }
123 
setTypeInitIterator(robj * subject)124 setTypeIterator *setTypeInitIterator(robj *subject) {
125     setTypeIterator *si = zmalloc(sizeof(setTypeIterator));
126     si->subject = subject;
127     si->encoding = subject->encoding;
128     if (si->encoding == OBJ_ENCODING_HT) {
129         si->di = dictGetIterator(subject->ptr);
130     } else if (si->encoding == OBJ_ENCODING_INTSET) {
131         si->ii = 0;
132     } else {
133         serverPanic("Unknown set encoding");
134     }
135     return si;
136 }
137 
setTypeReleaseIterator(setTypeIterator * si)138 void setTypeReleaseIterator(setTypeIterator *si) {
139     if (si->encoding == OBJ_ENCODING_HT)
140         dictReleaseIterator(si->di);
141     zfree(si);
142 }
143 
144 /* Move to the next entry in the set. Returns the object at the current
145  * position.
146  *
147  * Since set elements can be internally be stored as SDS strings or
148  * simple arrays of integers, setTypeNext returns the encoding of the
149  * set object you are iterating, and will populate the appropriate pointer
150  * (sdsele) or (llele) accordingly.
151  *
152  * Note that both the sdsele and llele pointers should be passed and cannot
153  * be NULL since the function will try to defensively populate the non
154  * used field with values which are easy to trap if misused.
155  *
156  * When there are no longer elements -1 is returned. */
setTypeNext(setTypeIterator * si,sds * sdsele,int64_t * llele)157 int setTypeNext(setTypeIterator *si, sds *sdsele, int64_t *llele) {
158     if (si->encoding == OBJ_ENCODING_HT) {
159         dictEntry *de = dictNext(si->di);
160         if (de == NULL) return -1;
161         *sdsele = dictGetKey(de);
162         *llele = -123456789; /* Not needed. Defensive. */
163     } else if (si->encoding == OBJ_ENCODING_INTSET) {
164         if (!intsetGet(si->subject->ptr,si->ii++,llele))
165             return -1;
166         *sdsele = NULL; /* Not needed. Defensive. */
167     } else {
168         serverPanic("Wrong set encoding in setTypeNext");
169     }
170     return si->encoding;
171 }
172 
173 /* The not copy on write friendly version but easy to use version
174  * of setTypeNext() is setTypeNextObject(), returning new SDS
175  * strings. So if you don't retain a pointer to this object you should call
176  * sdsfree() against it.
177  *
178  * This function is the way to go for write operations where COW is not
179  * an issue. */
setTypeNextObject(setTypeIterator * si)180 sds setTypeNextObject(setTypeIterator *si) {
181     int64_t intele;
182     sds sdsele;
183     int encoding;
184 
185     encoding = setTypeNext(si,&sdsele,&intele);
186     switch(encoding) {
187         case -1:    return NULL;
188         case OBJ_ENCODING_INTSET:
189             return sdsfromlonglong(intele);
190         case OBJ_ENCODING_HT:
191             return sdsdup(sdsele);
192         default:
193             serverPanic("Unsupported encoding");
194     }
195     return NULL; /* just to suppress warnings */
196 }
197 
198 /* Return random element from a non empty set.
199  * The returned element can be an int64_t value if the set is encoded
200  * as an "intset" blob of integers, or an SDS string if the set
201  * is a regular set.
202  *
203  * The caller provides both pointers to be populated with the right
204  * object. The return value of the function is the object->encoding
205  * field of the object and is used by the caller to check if the
206  * int64_t pointer or the redis object pointer was populated.
207  *
208  * Note that both the sdsele and llele pointers should be passed and cannot
209  * be NULL since the function will try to defensively populate the non
210  * used field with values which are easy to trap if misused. */
setTypeRandomElement(robj * setobj,sds * sdsele,int64_t * llele)211 int setTypeRandomElement(robj *setobj, sds *sdsele, int64_t *llele) {
212     if (setobj->encoding == OBJ_ENCODING_HT) {
213         dictEntry *de = dictGetFairRandomKey(setobj->ptr);
214         *sdsele = dictGetKey(de);
215         *llele = -123456789; /* Not needed. Defensive. */
216     } else if (setobj->encoding == OBJ_ENCODING_INTSET) {
217         *llele = intsetRandom(setobj->ptr);
218         *sdsele = NULL; /* Not needed. Defensive. */
219     } else {
220         serverPanic("Unknown set encoding");
221     }
222     return setobj->encoding;
223 }
224 
setTypeSize(const robj * subject)225 unsigned long setTypeSize(const robj *subject) {
226     if (subject->encoding == OBJ_ENCODING_HT) {
227         return dictSize((const dict*)subject->ptr);
228     } else if (subject->encoding == OBJ_ENCODING_INTSET) {
229         return intsetLen((const intset*)subject->ptr);
230     } else {
231         serverPanic("Unknown set encoding");
232     }
233 }
234 
235 /* Convert the set to specified encoding. The resulting dict (when converting
236  * to a hash table) is presized to hold the number of elements in the original
237  * set. */
setTypeConvert(robj * setobj,int enc)238 void setTypeConvert(robj *setobj, int enc) {
239     setTypeIterator *si;
240     serverAssertWithInfo(NULL,setobj,setobj->type == OBJ_SET &&
241                              setobj->encoding == OBJ_ENCODING_INTSET);
242 
243     if (enc == OBJ_ENCODING_HT) {
244         int64_t intele;
245         dict *d = dictCreate(&setDictType,NULL);
246         sds element;
247 
248         /* Presize the dict to avoid rehashing */
249         dictExpand(d,intsetLen(setobj->ptr));
250 
251         /* To add the elements we extract integers and create redis objects */
252         si = setTypeInitIterator(setobj);
253         while (setTypeNext(si,&element,&intele) != -1) {
254             element = sdsfromlonglong(intele);
255             serverAssert(dictAdd(d,element,NULL) == DICT_OK);
256         }
257         setTypeReleaseIterator(si);
258 
259         setobj->encoding = OBJ_ENCODING_HT;
260         zfree(setobj->ptr);
261         setobj->ptr = d;
262     } else {
263         serverPanic("Unsupported set conversion");
264     }
265 }
266 
saddCommand(client * c)267 void saddCommand(client *c) {
268     robj *set;
269     int j, added = 0;
270 
271     set = lookupKeyWrite(c->db,c->argv[1]);
272     if (set == NULL) {
273         set = setTypeCreate(c->argv[2]->ptr);
274         dbAdd(c->db,c->argv[1],set);
275     } else {
276         if (set->type != OBJ_SET) {
277             addReply(c,shared.wrongtypeerr);
278             return;
279         }
280     }
281 
282     for (j = 2; j < c->argc; j++) {
283         if (setTypeAdd(set,c->argv[j]->ptr)) added++;
284     }
285     if (added) {
286         signalModifiedKey(c,c->db,c->argv[1]);
287         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[1],c->db->id);
288     }
289     server.dirty += added;
290     addReplyLongLong(c,added);
291 }
292 
sremCommand(client * c)293 void sremCommand(client *c) {
294     robj *set;
295     int j, deleted = 0, keyremoved = 0;
296 
297     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.czero)) == NULL ||
298         checkType(c,set,OBJ_SET)) return;
299 
300     for (j = 2; j < c->argc; j++) {
301         if (setTypeRemove(set,c->argv[j]->ptr)) {
302             deleted++;
303             if (setTypeSize(set) == 0) {
304                 dbDelete(c->db,c->argv[1]);
305                 keyremoved = 1;
306                 break;
307             }
308         }
309     }
310     if (deleted) {
311         signalModifiedKey(c,c->db,c->argv[1]);
312         notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
313         if (keyremoved)
314             notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],
315                                 c->db->id);
316         server.dirty += deleted;
317     }
318     addReplyLongLong(c,deleted);
319 }
320 
smoveCommand(client * c)321 void smoveCommand(client *c) {
322     robj *srcset, *dstset, *ele;
323     srcset = lookupKeyWrite(c->db,c->argv[1]);
324     dstset = lookupKeyWrite(c->db,c->argv[2]);
325     ele = c->argv[3];
326 
327     /* If the source key does not exist return 0 */
328     if (srcset == NULL) {
329         addReply(c,shared.czero);
330         return;
331     }
332 
333     /* If the source key has the wrong type, or the destination key
334      * is set and has the wrong type, return with an error. */
335     if (checkType(c,srcset,OBJ_SET) ||
336         (dstset && checkType(c,dstset,OBJ_SET))) return;
337 
338     /* If srcset and dstset are equal, SMOVE is a no-op */
339     if (srcset == dstset) {
340         addReply(c,setTypeIsMember(srcset,ele->ptr) ?
341             shared.cone : shared.czero);
342         return;
343     }
344 
345     /* If the element cannot be removed from the src set, return 0. */
346     if (!setTypeRemove(srcset,ele->ptr)) {
347         addReply(c,shared.czero);
348         return;
349     }
350     notifyKeyspaceEvent(NOTIFY_SET,"srem",c->argv[1],c->db->id);
351 
352     /* Remove the src set from the database when empty */
353     if (setTypeSize(srcset) == 0) {
354         dbDelete(c->db,c->argv[1]);
355         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
356     }
357 
358     /* Create the destination set when it doesn't exist */
359     if (!dstset) {
360         dstset = setTypeCreate(ele->ptr);
361         dbAdd(c->db,c->argv[2],dstset);
362     }
363 
364     signalModifiedKey(c,c->db,c->argv[1]);
365     server.dirty++;
366 
367     /* An extra key has changed when ele was successfully added to dstset */
368     if (setTypeAdd(dstset,ele->ptr)) {
369         server.dirty++;
370         signalModifiedKey(c,c->db,c->argv[2]);
371         notifyKeyspaceEvent(NOTIFY_SET,"sadd",c->argv[2],c->db->id);
372     }
373     addReply(c,shared.cone);
374 }
375 
sismemberCommand(client * c)376 void sismemberCommand(client *c) {
377     robj *set;
378 
379     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
380         checkType(c,set,OBJ_SET)) return;
381 
382     if (setTypeIsMember(set,c->argv[2]->ptr))
383         addReply(c,shared.cone);
384     else
385         addReply(c,shared.czero);
386 }
387 
scardCommand(client * c)388 void scardCommand(client *c) {
389     robj *o;
390 
391     if ((o = lookupKeyReadOrReply(c,c->argv[1],shared.czero)) == NULL ||
392         checkType(c,o,OBJ_SET)) return;
393 
394     addReplyLongLong(c,setTypeSize(o));
395 }
396 
397 /* Handle the "SPOP key <count>" variant. The normal version of the
398  * command is handled by the spopCommand() function itself. */
399 
400 /* How many times bigger should be the set compared to the remaining size
401  * for us to use the "create new set" strategy? Read later in the
402  * implementation for more info. */
403 #define SPOP_MOVE_STRATEGY_MUL 5
404 
spopWithCountCommand(client * c)405 void spopWithCountCommand(client *c) {
406     long l;
407     unsigned long count, size;
408     robj *set;
409 
410     /* Get the count argument */
411     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
412     if (l >= 0) {
413         count = (unsigned long) l;
414     } else {
415         addReply(c,shared.outofrangeerr);
416         return;
417     }
418 
419     /* Make sure a key with the name inputted exists, and that it's type is
420      * indeed a set. Otherwise, return nil */
421     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.emptyset[c->resp]))
422         == NULL || checkType(c,set,OBJ_SET)) return;
423 
424     /* If count is zero, serve an empty set ASAP to avoid special
425      * cases later. */
426     if (count == 0) {
427         addReply(c,shared.emptyset[c->resp]);
428         return;
429     }
430 
431     size = setTypeSize(set);
432 
433     /* Generate an SPOP keyspace notification */
434     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
435     server.dirty += count;
436 
437     /* CASE 1:
438      * The number of requested elements is greater than or equal to
439      * the number of elements inside the set: simply return the whole set. */
440     if (count >= size) {
441         /* We just return the entire set */
442         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
443 
444         /* Delete the set as it is now empty */
445         dbDelete(c->db,c->argv[1]);
446         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
447 
448         /* Propagate this command as a DEL operation */
449         rewriteClientCommandVector(c,2,shared.del,c->argv[1]);
450         signalModifiedKey(c,c->db,c->argv[1]);
451         server.dirty++;
452         return;
453     }
454 
455     /* Case 2 and 3 require to replicate SPOP as a set of SREM commands.
456      * Prepare our replication argument vector. Also send the array length
457      * which is common to both the code paths. */
458     robj *propargv[3];
459     propargv[0] = createStringObject("SREM",4);
460     propargv[1] = c->argv[1];
461     addReplySetLen(c,count);
462 
463     /* Common iteration vars. */
464     sds sdsele;
465     robj *objele;
466     int encoding;
467     int64_t llele;
468     unsigned long remaining = size-count; /* Elements left after SPOP. */
469 
470     /* If we are here, the number of requested elements is less than the
471      * number of elements inside the set. Also we are sure that count < size.
472      * Use two different strategies.
473      *
474      * CASE 2: The number of elements to return is small compared to the
475      * set size. We can just extract random elements and return them to
476      * the set. */
477     if (remaining*SPOP_MOVE_STRATEGY_MUL > count) {
478         while(count--) {
479             /* Emit and remove. */
480             encoding = setTypeRandomElement(set,&sdsele,&llele);
481             if (encoding == OBJ_ENCODING_INTSET) {
482                 addReplyBulkLongLong(c,llele);
483                 objele = createStringObjectFromLongLong(llele);
484                 set->ptr = intsetRemove(set->ptr,llele,NULL);
485             } else {
486                 addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
487                 objele = createStringObject(sdsele,sdslen(sdsele));
488                 setTypeRemove(set,sdsele);
489             }
490 
491             /* Replicate/AOF this command as an SREM operation */
492             propargv[2] = objele;
493             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
494                 PROPAGATE_AOF|PROPAGATE_REPL);
495             decrRefCount(objele);
496         }
497     } else {
498     /* CASE 3: The number of elements to return is very big, approaching
499      * the size of the set itself. After some time extracting random elements
500      * from such a set becomes computationally expensive, so we use
501      * a different strategy, we extract random elements that we don't
502      * want to return (the elements that will remain part of the set),
503      * creating a new set as we do this (that will be stored as the original
504      * set). Then we return the elements left in the original set and
505      * release it. */
506         robj *newset = NULL;
507 
508         /* Create a new set with just the remaining elements. */
509         while(remaining--) {
510             encoding = setTypeRandomElement(set,&sdsele,&llele);
511             if (encoding == OBJ_ENCODING_INTSET) {
512                 sdsele = sdsfromlonglong(llele);
513             } else {
514                 sdsele = sdsdup(sdsele);
515             }
516             if (!newset) newset = setTypeCreate(sdsele);
517             setTypeAdd(newset,sdsele);
518             setTypeRemove(set,sdsele);
519             sdsfree(sdsele);
520         }
521 
522         /* Transfer the old set to the client. */
523         setTypeIterator *si;
524         si = setTypeInitIterator(set);
525         while((encoding = setTypeNext(si,&sdsele,&llele)) != -1) {
526             if (encoding == OBJ_ENCODING_INTSET) {
527                 addReplyBulkLongLong(c,llele);
528                 objele = createStringObjectFromLongLong(llele);
529             } else {
530                 addReplyBulkCBuffer(c,sdsele,sdslen(sdsele));
531                 objele = createStringObject(sdsele,sdslen(sdsele));
532             }
533 
534             /* Replicate/AOF this command as an SREM operation */
535             propargv[2] = objele;
536             alsoPropagate(server.sremCommand,c->db->id,propargv,3,
537                 PROPAGATE_AOF|PROPAGATE_REPL);
538             decrRefCount(objele);
539         }
540         setTypeReleaseIterator(si);
541 
542         /* Assign the new set as the key value. */
543         dbOverwrite(c->db,c->argv[1],newset);
544     }
545 
546     /* Don't propagate the command itself even if we incremented the
547      * dirty counter. We don't want to propagate an SPOP command since
548      * we propagated the command as a set of SREMs operations using
549      * the alsoPropagate() API. */
550     decrRefCount(propargv[0]);
551     preventCommandPropagation(c);
552     signalModifiedKey(c,c->db,c->argv[1]);
553     server.dirty++;
554 }
555 
spopCommand(client * c)556 void spopCommand(client *c) {
557     robj *set, *ele, *aux;
558     sds sdsele;
559     int64_t llele;
560     int encoding;
561 
562     if (c->argc == 3) {
563         spopWithCountCommand(c);
564         return;
565     } else if (c->argc > 3) {
566         addReply(c,shared.syntaxerr);
567         return;
568     }
569 
570     /* Make sure a key with the name inputted exists, and that it's type is
571      * indeed a set */
572     if ((set = lookupKeyWriteOrReply(c,c->argv[1],shared.null[c->resp]))
573          == NULL || checkType(c,set,OBJ_SET)) return;
574 
575     /* Get a random element from the set */
576     encoding = setTypeRandomElement(set,&sdsele,&llele);
577 
578     /* Remove the element from the set */
579     if (encoding == OBJ_ENCODING_INTSET) {
580         ele = createStringObjectFromLongLong(llele);
581         set->ptr = intsetRemove(set->ptr,llele,NULL);
582     } else {
583         ele = createStringObject(sdsele,sdslen(sdsele));
584         setTypeRemove(set,ele->ptr);
585     }
586 
587     notifyKeyspaceEvent(NOTIFY_SET,"spop",c->argv[1],c->db->id);
588 
589     /* Replicate/AOF this command as an SREM operation */
590     aux = createStringObject("SREM",4);
591     rewriteClientCommandVector(c,3,aux,c->argv[1],ele);
592     decrRefCount(aux);
593 
594     /* Add the element to the reply */
595     addReplyBulk(c,ele);
596     decrRefCount(ele);
597 
598     /* Delete the set if it's empty */
599     if (setTypeSize(set) == 0) {
600         dbDelete(c->db,c->argv[1]);
601         notifyKeyspaceEvent(NOTIFY_GENERIC,"del",c->argv[1],c->db->id);
602     }
603 
604     /* Set has been modified */
605     signalModifiedKey(c,c->db,c->argv[1]);
606     server.dirty++;
607 }
608 
609 /* handle the "SRANDMEMBER key <count>" variant. The normal version of the
610  * command is handled by the srandmemberCommand() function itself. */
611 
612 /* How many times bigger should be the set compared to the requested size
613  * for us to don't use the "remove elements" strategy? Read later in the
614  * implementation for more info. */
615 #define SRANDMEMBER_SUB_STRATEGY_MUL 3
616 
srandmemberWithCountCommand(client * c)617 void srandmemberWithCountCommand(client *c) {
618     long l;
619     unsigned long count, size;
620     int uniq = 1;
621     robj *set;
622     sds ele;
623     int64_t llele;
624     int encoding;
625 
626     dict *d;
627 
628     if (getLongFromObjectOrReply(c,c->argv[2],&l,NULL) != C_OK) return;
629     if (l >= 0) {
630         count = (unsigned long) l;
631     } else {
632         /* A negative count means: return the same elements multiple times
633          * (i.e. don't remove the extracted element after every extraction). */
634         count = -l;
635         uniq = 0;
636     }
637 
638     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptyset[c->resp]))
639         == NULL || checkType(c,set,OBJ_SET)) return;
640     size = setTypeSize(set);
641 
642     /* If count is zero, serve it ASAP to avoid special cases later. */
643     if (count == 0) {
644         addReply(c,shared.emptyset[c->resp]);
645         return;
646     }
647 
648     /* CASE 1: The count was negative, so the extraction method is just:
649      * "return N random elements" sampling the whole set every time.
650      * This case is trivial and can be served without auxiliary data
651      * structures. */
652     if (!uniq) {
653         addReplySetLen(c,count);
654         while(count--) {
655             encoding = setTypeRandomElement(set,&ele,&llele);
656             if (encoding == OBJ_ENCODING_INTSET) {
657                 addReplyBulkLongLong(c,llele);
658             } else {
659                 addReplyBulkCBuffer(c,ele,sdslen(ele));
660             }
661         }
662         return;
663     }
664 
665     /* CASE 2:
666      * The number of requested elements is greater than the number of
667      * elements inside the set: simply return the whole set. */
668     if (count >= size) {
669         sunionDiffGenericCommand(c,c->argv+1,1,NULL,SET_OP_UNION);
670         return;
671     }
672 
673     /* For CASE 3 and CASE 4 we need an auxiliary dictionary. */
674     d = dictCreate(&objectKeyPointerValueDictType,NULL);
675 
676     /* CASE 3:
677      * The number of elements inside the set is not greater than
678      * SRANDMEMBER_SUB_STRATEGY_MUL times the number of requested elements.
679      * In this case we create a set from scratch with all the elements, and
680      * subtract random elements to reach the requested number of elements.
681      *
682      * This is done because if the number of requested elements is just
683      * a bit less than the number of elements in the set, the natural approach
684      * used into CASE 3 is highly inefficient. */
685     if (count*SRANDMEMBER_SUB_STRATEGY_MUL > size) {
686         setTypeIterator *si;
687 
688         /* Add all the elements into the temporary dictionary. */
689         si = setTypeInitIterator(set);
690         while((encoding = setTypeNext(si,&ele,&llele)) != -1) {
691             int retval = DICT_ERR;
692 
693             if (encoding == OBJ_ENCODING_INTSET) {
694                 retval = dictAdd(d,createStringObjectFromLongLong(llele),NULL);
695             } else {
696                 retval = dictAdd(d,createStringObject(ele,sdslen(ele)),NULL);
697             }
698             serverAssert(retval == DICT_OK);
699         }
700         setTypeReleaseIterator(si);
701         serverAssert(dictSize(d) == size);
702 
703         /* Remove random elements to reach the right count. */
704         while(size > count) {
705             dictEntry *de;
706 
707             de = dictGetRandomKey(d);
708             dictDelete(d,dictGetKey(de));
709             size--;
710         }
711     }
712 
713     /* CASE 4: We have a big set compared to the requested number of elements.
714      * In this case we can simply get random elements from the set and add
715      * to the temporary set, trying to eventually get enough unique elements
716      * to reach the specified count. */
717     else {
718         unsigned long added = 0;
719         robj *objele;
720 
721         while(added < count) {
722             encoding = setTypeRandomElement(set,&ele,&llele);
723             if (encoding == OBJ_ENCODING_INTSET) {
724                 objele = createStringObjectFromLongLong(llele);
725             } else {
726                 objele = createStringObject(ele,sdslen(ele));
727             }
728             /* Try to add the object to the dictionary. If it already exists
729              * free it, otherwise increment the number of objects we have
730              * in the result dictionary. */
731             if (dictAdd(d,objele,NULL) == DICT_OK)
732                 added++;
733             else
734                 decrRefCount(objele);
735         }
736     }
737 
738     /* CASE 3 & 4: send the result to the user. */
739     {
740         dictIterator *di;
741         dictEntry *de;
742 
743         addReplySetLen(c,count);
744         di = dictGetIterator(d);
745         while((de = dictNext(di)) != NULL)
746             addReplyBulk(c,dictGetKey(de));
747         dictReleaseIterator(di);
748         dictRelease(d);
749     }
750 }
751 
srandmemberCommand(client * c)752 void srandmemberCommand(client *c) {
753     robj *set;
754     sds ele;
755     int64_t llele;
756     int encoding;
757 
758     if (c->argc == 3) {
759         srandmemberWithCountCommand(c);
760         return;
761     } else if (c->argc > 3) {
762         addReply(c,shared.syntaxerr);
763         return;
764     }
765 
766     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.null[c->resp]))
767         == NULL || checkType(c,set,OBJ_SET)) return;
768 
769     encoding = setTypeRandomElement(set,&ele,&llele);
770     if (encoding == OBJ_ENCODING_INTSET) {
771         addReplyBulkLongLong(c,llele);
772     } else {
773         addReplyBulkCBuffer(c,ele,sdslen(ele));
774     }
775 }
776 
qsortCompareSetsByCardinality(const void * s1,const void * s2)777 int qsortCompareSetsByCardinality(const void *s1, const void *s2) {
778     if (setTypeSize(*(robj**)s1) > setTypeSize(*(robj**)s2)) return 1;
779     if (setTypeSize(*(robj**)s1) < setTypeSize(*(robj**)s2)) return -1;
780     return 0;
781 }
782 
783 /* This is used by SDIFF and in this case we can receive NULL that should
784  * be handled as empty sets. */
qsortCompareSetsByRevCardinality(const void * s1,const void * s2)785 int qsortCompareSetsByRevCardinality(const void *s1, const void *s2) {
786     robj *o1 = *(robj**)s1, *o2 = *(robj**)s2;
787     unsigned long first = o1 ? setTypeSize(o1) : 0;
788     unsigned long second = o2 ? setTypeSize(o2) : 0;
789 
790     if (first < second) return 1;
791     if (first > second) return -1;
792     return 0;
793 }
794 
sinterGenericCommand(client * c,robj ** setkeys,unsigned long setnum,robj * dstkey)795 void sinterGenericCommand(client *c, robj **setkeys,
796                           unsigned long setnum, robj *dstkey) {
797     robj **sets = zmalloc(sizeof(robj*)*setnum);
798     setTypeIterator *si;
799     robj *dstset = NULL;
800     sds elesds;
801     int64_t intobj;
802     void *replylen = NULL;
803     unsigned long j, cardinality = 0;
804     int encoding, empty = 0;
805 
806     for (j = 0; j < setnum; j++) {
807         robj *setobj = dstkey ?
808             lookupKeyWrite(c->db,setkeys[j]) :
809             lookupKeyRead(c->db,setkeys[j]);
810         if (!setobj) {
811             /* A NULL is considered an empty set */
812             empty += 1;
813             sets[j] = NULL;
814             continue;
815         }
816         if (checkType(c,setobj,OBJ_SET)) {
817             zfree(sets);
818             return;
819         }
820         sets[j] = setobj;
821     }
822 
823     /* Set intersection with an empty set always results in an empty set.
824      * Return ASAP if there is an empty set. */
825     if (empty > 0) {
826         zfree(sets);
827         if (dstkey) {
828             if (dbDelete(c->db,dstkey)) {
829                 signalModifiedKey(c,c->db,dstkey);
830                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",dstkey,c->db->id);
831                 server.dirty++;
832             }
833             addReply(c,shared.czero);
834         } else {
835             addReply(c,shared.emptyset[c->resp]);
836         }
837         return;
838     }
839 
840     /* Sort sets from the smallest to largest, this will improve our
841      * algorithm's performance */
842     qsort(sets,setnum,sizeof(robj*),qsortCompareSetsByCardinality);
843 
844     /* The first thing we should output is the total number of elements...
845      * since this is a multi-bulk write, but at this stage we don't know
846      * the intersection set size, so we use a trick, append an empty object
847      * to the output list and save the pointer to later modify it with the
848      * right length */
849     if (!dstkey) {
850         replylen = addReplyDeferredLen(c);
851     } else {
852         /* If we have a target key where to store the resulting set
853          * create this key with an empty set inside */
854         dstset = createIntsetObject();
855     }
856 
857     /* Iterate all the elements of the first (smallest) set, and test
858      * the element against all the other sets, if at least one set does
859      * not include the element it is discarded */
860     si = setTypeInitIterator(sets[0]);
861     while((encoding = setTypeNext(si,&elesds,&intobj)) != -1) {
862         for (j = 1; j < setnum; j++) {
863             if (sets[j] == sets[0]) continue;
864             if (encoding == OBJ_ENCODING_INTSET) {
865                 /* intset with intset is simple... and fast */
866                 if (sets[j]->encoding == OBJ_ENCODING_INTSET &&
867                     !intsetFind((intset*)sets[j]->ptr,intobj))
868                 {
869                     break;
870                 /* in order to compare an integer with an object we
871                  * have to use the generic function, creating an object
872                  * for this */
873                 } else if (sets[j]->encoding == OBJ_ENCODING_HT) {
874                     elesds = sdsfromlonglong(intobj);
875                     if (!setTypeIsMember(sets[j],elesds)) {
876                         sdsfree(elesds);
877                         break;
878                     }
879                     sdsfree(elesds);
880                 }
881             } else if (encoding == OBJ_ENCODING_HT) {
882                 if (!setTypeIsMember(sets[j],elesds)) {
883                     break;
884                 }
885             }
886         }
887 
888         /* Only take action when all sets contain the member */
889         if (j == setnum) {
890             if (!dstkey) {
891                 if (encoding == OBJ_ENCODING_HT)
892                     addReplyBulkCBuffer(c,elesds,sdslen(elesds));
893                 else
894                     addReplyBulkLongLong(c,intobj);
895                 cardinality++;
896             } else {
897                 if (encoding == OBJ_ENCODING_INTSET) {
898                     elesds = sdsfromlonglong(intobj);
899                     setTypeAdd(dstset,elesds);
900                     sdsfree(elesds);
901                 } else {
902                     setTypeAdd(dstset,elesds);
903                 }
904             }
905         }
906     }
907     setTypeReleaseIterator(si);
908 
909     if (dstkey) {
910         /* Store the resulting set into the target, if the intersection
911          * is not an empty set. */
912         int deleted = dbDelete(c->db,dstkey);
913         if (setTypeSize(dstset) > 0) {
914             dbAdd(c->db,dstkey,dstset);
915             addReplyLongLong(c,setTypeSize(dstset));
916             notifyKeyspaceEvent(NOTIFY_SET,"sinterstore",
917                 dstkey,c->db->id);
918         } else {
919             decrRefCount(dstset);
920             addReply(c,shared.czero);
921             if (deleted)
922                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
923                     dstkey,c->db->id);
924         }
925         signalModifiedKey(c,c->db,dstkey);
926         server.dirty++;
927     } else {
928         setDeferredSetLen(c,replylen,cardinality);
929     }
930     zfree(sets);
931 }
932 
933 /* SINTER key [key ...] */
sinterCommand(client * c)934 void sinterCommand(client *c) {
935     sinterGenericCommand(c,c->argv+1,c->argc-1,NULL);
936 }
937 
938 /* SINTERSTORE destination key [key ...] */
sinterstoreCommand(client * c)939 void sinterstoreCommand(client *c) {
940     sinterGenericCommand(c,c->argv+2,c->argc-2,c->argv[1]);
941 }
942 
943 #define SET_OP_UNION 0
944 #define SET_OP_DIFF 1
945 #define SET_OP_INTER 2
946 
sunionDiffGenericCommand(client * c,robj ** setkeys,int setnum,robj * dstkey,int op)947 void sunionDiffGenericCommand(client *c, robj **setkeys, int setnum,
948                               robj *dstkey, int op) {
949     robj **sets = zmalloc(sizeof(robj*)*setnum);
950     setTypeIterator *si;
951     robj *dstset = NULL;
952     sds ele;
953     int j, cardinality = 0;
954     int diff_algo = 1;
955 
956     for (j = 0; j < setnum; j++) {
957         robj *setobj = dstkey ?
958             lookupKeyWrite(c->db,setkeys[j]) :
959             lookupKeyRead(c->db,setkeys[j]);
960         if (!setobj) {
961             sets[j] = NULL;
962             continue;
963         }
964         if (checkType(c,setobj,OBJ_SET)) {
965             zfree(sets);
966             return;
967         }
968         sets[j] = setobj;
969     }
970 
971     /* Select what DIFF algorithm to use.
972      *
973      * Algorithm 1 is O(N*M) where N is the size of the element first set
974      * and M the total number of sets.
975      *
976      * Algorithm 2 is O(N) where N is the total number of elements in all
977      * the sets.
978      *
979      * We compute what is the best bet with the current input here. */
980     if (op == SET_OP_DIFF && sets[0]) {
981         long long algo_one_work = 0, algo_two_work = 0;
982 
983         for (j = 0; j < setnum; j++) {
984             if (sets[j] == NULL) continue;
985 
986             algo_one_work += setTypeSize(sets[0]);
987             algo_two_work += setTypeSize(sets[j]);
988         }
989 
990         /* Algorithm 1 has better constant times and performs less operations
991          * if there are elements in common. Give it some advantage. */
992         algo_one_work /= 2;
993         diff_algo = (algo_one_work <= algo_two_work) ? 1 : 2;
994 
995         if (diff_algo == 1 && setnum > 1) {
996             /* With algorithm 1 it is better to order the sets to subtract
997              * by decreasing size, so that we are more likely to find
998              * duplicated elements ASAP. */
999             qsort(sets+1,setnum-1,sizeof(robj*),
1000                 qsortCompareSetsByRevCardinality);
1001         }
1002     }
1003 
1004     /* We need a temp set object to store our union. If the dstkey
1005      * is not NULL (that is, we are inside an SUNIONSTORE operation) then
1006      * this set object will be the resulting object to set into the target key*/
1007     dstset = createIntsetObject();
1008 
1009     if (op == SET_OP_UNION) {
1010         /* Union is trivial, just add every element of every set to the
1011          * temporary set. */
1012         for (j = 0; j < setnum; j++) {
1013             if (!sets[j]) continue; /* non existing keys are like empty sets */
1014 
1015             si = setTypeInitIterator(sets[j]);
1016             while((ele = setTypeNextObject(si)) != NULL) {
1017                 if (setTypeAdd(dstset,ele)) cardinality++;
1018                 sdsfree(ele);
1019             }
1020             setTypeReleaseIterator(si);
1021         }
1022     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 1) {
1023         /* DIFF Algorithm 1:
1024          *
1025          * We perform the diff by iterating all the elements of the first set,
1026          * and only adding it to the target set if the element does not exist
1027          * into all the other sets.
1028          *
1029          * This way we perform at max N*M operations, where N is the size of
1030          * the first set, and M the number of sets. */
1031         si = setTypeInitIterator(sets[0]);
1032         while((ele = setTypeNextObject(si)) != NULL) {
1033             for (j = 1; j < setnum; j++) {
1034                 if (!sets[j]) continue; /* no key is an empty set. */
1035                 if (sets[j] == sets[0]) break; /* same set! */
1036                 if (setTypeIsMember(sets[j],ele)) break;
1037             }
1038             if (j == setnum) {
1039                 /* There is no other set with this element. Add it. */
1040                 setTypeAdd(dstset,ele);
1041                 cardinality++;
1042             }
1043             sdsfree(ele);
1044         }
1045         setTypeReleaseIterator(si);
1046     } else if (op == SET_OP_DIFF && sets[0] && diff_algo == 2) {
1047         /* DIFF Algorithm 2:
1048          *
1049          * Add all the elements of the first set to the auxiliary set.
1050          * Then remove all the elements of all the next sets from it.
1051          *
1052          * This is O(N) where N is the sum of all the elements in every
1053          * set. */
1054         for (j = 0; j < setnum; j++) {
1055             if (!sets[j]) continue; /* non existing keys are like empty sets */
1056 
1057             si = setTypeInitIterator(sets[j]);
1058             while((ele = setTypeNextObject(si)) != NULL) {
1059                 if (j == 0) {
1060                     if (setTypeAdd(dstset,ele)) cardinality++;
1061                 } else {
1062                     if (setTypeRemove(dstset,ele)) cardinality--;
1063                 }
1064                 sdsfree(ele);
1065             }
1066             setTypeReleaseIterator(si);
1067 
1068             /* Exit if result set is empty as any additional removal
1069              * of elements will have no effect. */
1070             if (cardinality == 0) break;
1071         }
1072     }
1073 
1074     /* Output the content of the resulting set, if not in STORE mode */
1075     if (!dstkey) {
1076         addReplySetLen(c,cardinality);
1077         si = setTypeInitIterator(dstset);
1078         while((ele = setTypeNextObject(si)) != NULL) {
1079             addReplyBulkCBuffer(c,ele,sdslen(ele));
1080             sdsfree(ele);
1081         }
1082         setTypeReleaseIterator(si);
1083         server.lazyfree_lazy_server_del ? freeObjAsync(dstset) :
1084                                           decrRefCount(dstset);
1085     } else {
1086         /* If we have a target key where to store the resulting set
1087          * create this key with the result set inside */
1088         int deleted = dbDelete(c->db,dstkey);
1089         if (setTypeSize(dstset) > 0) {
1090             dbAdd(c->db,dstkey,dstset);
1091             addReplyLongLong(c,setTypeSize(dstset));
1092             notifyKeyspaceEvent(NOTIFY_SET,
1093                 op == SET_OP_UNION ? "sunionstore" : "sdiffstore",
1094                 dstkey,c->db->id);
1095         } else {
1096             decrRefCount(dstset);
1097             addReply(c,shared.czero);
1098             if (deleted)
1099                 notifyKeyspaceEvent(NOTIFY_GENERIC,"del",
1100                     dstkey,c->db->id);
1101         }
1102         signalModifiedKey(c,c->db,dstkey);
1103         server.dirty++;
1104     }
1105     zfree(sets);
1106 }
1107 
1108 /* SUNION key [key ...] */
sunionCommand(client * c)1109 void sunionCommand(client *c) {
1110     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_UNION);
1111 }
1112 
1113 /* SUNIONSTORE destination key [key ...] */
sunionstoreCommand(client * c)1114 void sunionstoreCommand(client *c) {
1115     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_UNION);
1116 }
1117 
1118 /* SDIFF key [key ...] */
sdiffCommand(client * c)1119 void sdiffCommand(client *c) {
1120     sunionDiffGenericCommand(c,c->argv+1,c->argc-1,NULL,SET_OP_DIFF);
1121 }
1122 
1123 /* SDIFFSTORE destination key [key ...] */
sdiffstoreCommand(client * c)1124 void sdiffstoreCommand(client *c) {
1125     sunionDiffGenericCommand(c,c->argv+2,c->argc-2,c->argv[1],SET_OP_DIFF);
1126 }
1127 
sscanCommand(client * c)1128 void sscanCommand(client *c) {
1129     robj *set;
1130     unsigned long cursor;
1131 
1132     if (parseScanCursorOrReply(c,c->argv[2],&cursor) == C_ERR) return;
1133     if ((set = lookupKeyReadOrReply(c,c->argv[1],shared.emptyscan)) == NULL ||
1134         checkType(c,set,OBJ_SET)) return;
1135     scanGenericCommand(c,set,cursor);
1136 }
1137