1package miniredis
2
3import (
4	"errors"
5	"fmt"
6	"math/big"
7	"sort"
8	"strconv"
9	"time"
10)
11
12var (
13	errInvalidEntryID = errors.New("stream ID is invalid")
14)
15
16func (db *RedisDB) exists(k string) bool {
17	_, ok := db.keys[k]
18	return ok
19}
20
21// t gives the type of a key, or ""
22func (db *RedisDB) t(k string) string {
23	return db.keys[k]
24}
25
26// allKeys returns all keys. Sorted.
27func (db *RedisDB) allKeys() []string {
28	res := make([]string, 0, len(db.keys))
29	for k := range db.keys {
30		res = append(res, k)
31	}
32	sort.Strings(res) // To make things deterministic.
33	return res
34}
35
36// flush removes all keys and values.
37func (db *RedisDB) flush() {
38	db.keys = map[string]string{}
39	db.stringKeys = map[string]string{}
40	db.hashKeys = map[string]hashKey{}
41	db.listKeys = map[string]listKey{}
42	db.setKeys = map[string]setKey{}
43	db.sortedsetKeys = map[string]sortedSet{}
44	db.ttl = map[string]time.Duration{}
45}
46
47// move something to another db. Will return ok. Or not.
48func (db *RedisDB) move(key string, to *RedisDB) bool {
49	if _, ok := to.keys[key]; ok {
50		return false
51	}
52
53	t, ok := db.keys[key]
54	if !ok {
55		return false
56	}
57	to.keys[key] = db.keys[key]
58	switch t {
59	case "string":
60		to.stringKeys[key] = db.stringKeys[key]
61	case "hash":
62		to.hashKeys[key] = db.hashKeys[key]
63	case "list":
64		to.listKeys[key] = db.listKeys[key]
65	case "set":
66		to.setKeys[key] = db.setKeys[key]
67	case "zset":
68		to.sortedsetKeys[key] = db.sortedsetKeys[key]
69	case "stream":
70		to.streamKeys[key] = db.streamKeys[key]
71	default:
72		panic("unhandled key type")
73	}
74	to.keyVersion[key]++
75	if v, ok := db.ttl[key]; ok {
76		to.ttl[key] = v
77	}
78	db.del(key, true)
79	return true
80}
81
82func (db *RedisDB) rename(from, to string) {
83	db.del(to, true)
84	switch db.t(from) {
85	case "string":
86		db.stringKeys[to] = db.stringKeys[from]
87	case "hash":
88		db.hashKeys[to] = db.hashKeys[from]
89	case "list":
90		db.listKeys[to] = db.listKeys[from]
91	case "set":
92		db.setKeys[to] = db.setKeys[from]
93	case "zset":
94		db.sortedsetKeys[to] = db.sortedsetKeys[from]
95	case "stream":
96		db.streamKeys[to] = db.streamKeys[from]
97	default:
98		panic("missing case")
99	}
100	db.keys[to] = db.keys[from]
101	db.keyVersion[to]++
102	if v, ok := db.ttl[from]; ok {
103		db.ttl[to] = v
104	}
105
106	db.del(from, true)
107}
108
109func (db *RedisDB) del(k string, delTTL bool) {
110	if !db.exists(k) {
111		return
112	}
113	t := db.t(k)
114	delete(db.keys, k)
115	db.keyVersion[k]++
116	if delTTL {
117		delete(db.ttl, k)
118	}
119	switch t {
120	case "string":
121		delete(db.stringKeys, k)
122	case "hash":
123		delete(db.hashKeys, k)
124	case "list":
125		delete(db.listKeys, k)
126	case "set":
127		delete(db.setKeys, k)
128	case "zset":
129		delete(db.sortedsetKeys, k)
130	case "stream":
131		delete(db.streamKeys, k)
132	default:
133		panic("Unknown key type: " + t)
134	}
135}
136
137// stringGet returns the string key or "" on error/nonexists.
138func (db *RedisDB) stringGet(k string) string {
139	if t, ok := db.keys[k]; !ok || t != "string" {
140		return ""
141	}
142	return db.stringKeys[k]
143}
144
145// stringSet force set()s a key. Does not touch expire.
146func (db *RedisDB) stringSet(k, v string) {
147	db.del(k, false)
148	db.keys[k] = "string"
149	db.stringKeys[k] = v
150	db.keyVersion[k]++
151}
152
153// change int key value
154func (db *RedisDB) stringIncr(k string, delta int) (int, error) {
155	v := 0
156	if sv, ok := db.stringKeys[k]; ok {
157		var err error
158		v, err = strconv.Atoi(sv)
159		if err != nil {
160			return 0, ErrIntValueError
161		}
162	}
163	v += delta
164	db.stringSet(k, strconv.Itoa(v))
165	return v, nil
166}
167
168// change float key value
169func (db *RedisDB) stringIncrfloat(k string, delta *big.Float) (*big.Float, error) {
170	v := big.NewFloat(0.0)
171	v.SetPrec(128)
172	if sv, ok := db.stringKeys[k]; ok {
173		var err error
174		v, _, err = big.ParseFloat(sv, 10, 128, 0)
175		if err != nil {
176			return nil, ErrFloatValueError
177		}
178	}
179	v.Add(v, delta)
180	db.stringSet(k, formatBig(v))
181	return v, nil
182}
183
184// listLpush is 'left push', aka unshift. Returns the new length.
185func (db *RedisDB) listLpush(k, v string) int {
186	l, ok := db.listKeys[k]
187	if !ok {
188		db.keys[k] = "list"
189	}
190	l = append([]string{v}, l...)
191	db.listKeys[k] = l
192	db.keyVersion[k]++
193	return len(l)
194}
195
196// 'left pop', aka shift.
197func (db *RedisDB) listLpop(k string) string {
198	l := db.listKeys[k]
199	el := l[0]
200	l = l[1:]
201	if len(l) == 0 {
202		db.del(k, true)
203	} else {
204		db.listKeys[k] = l
205	}
206	db.keyVersion[k]++
207	return el
208}
209
210func (db *RedisDB) listPush(k string, v ...string) int {
211	l, ok := db.listKeys[k]
212	if !ok {
213		db.keys[k] = "list"
214	}
215	l = append(l, v...)
216	db.listKeys[k] = l
217	db.keyVersion[k]++
218	return len(l)
219}
220
221func (db *RedisDB) listPop(k string) string {
222	l := db.listKeys[k]
223	el := l[len(l)-1]
224	l = l[:len(l)-1]
225	if len(l) == 0 {
226		db.del(k, true)
227	} else {
228		db.listKeys[k] = l
229		db.keyVersion[k]++
230	}
231	return el
232}
233
234// setset replaces a whole set.
235func (db *RedisDB) setSet(k string, set setKey) {
236	db.keys[k] = "set"
237	db.setKeys[k] = set
238	db.keyVersion[k]++
239}
240
241// setadd adds members to a set. Returns nr of new keys.
242func (db *RedisDB) setAdd(k string, elems ...string) int {
243	s, ok := db.setKeys[k]
244	if !ok {
245		s = setKey{}
246		db.keys[k] = "set"
247	}
248	added := 0
249	for _, e := range elems {
250		if _, ok := s[e]; !ok {
251			added++
252		}
253		s[e] = struct{}{}
254	}
255	db.setKeys[k] = s
256	db.keyVersion[k]++
257	return added
258}
259
260// setrem removes members from a set. Returns nr of deleted keys.
261func (db *RedisDB) setRem(k string, fields ...string) int {
262	s, ok := db.setKeys[k]
263	if !ok {
264		return 0
265	}
266	removed := 0
267	for _, f := range fields {
268		if _, ok := s[f]; ok {
269			removed++
270			delete(s, f)
271		}
272	}
273	if len(s) == 0 {
274		db.del(k, true)
275	} else {
276		db.setKeys[k] = s
277	}
278	db.keyVersion[k]++
279	return removed
280}
281
282// All members of a set.
283func (db *RedisDB) setMembers(k string) []string {
284	set := db.setKeys[k]
285	members := make([]string, 0, len(set))
286	for k := range set {
287		members = append(members, k)
288	}
289	sort.Strings(members)
290	return members
291}
292
293// Is a SET value present?
294func (db *RedisDB) setIsMember(k, v string) bool {
295	set, ok := db.setKeys[k]
296	if !ok {
297		return false
298	}
299	_, ok = set[v]
300	return ok
301}
302
303// hashFields returns all (sorted) keys ('fields') for a hash key.
304func (db *RedisDB) hashFields(k string) []string {
305	v := db.hashKeys[k]
306	r := make([]string, 0, len(v))
307	for k := range v {
308		r = append(r, k)
309	}
310	sort.Strings(r)
311	return r
312}
313
314// hashGet a value
315func (db *RedisDB) hashGet(key, field string) string {
316	return db.hashKeys[key][field]
317}
318
319// hashSet returns the number of new keys
320func (db *RedisDB) hashSet(k string, fv ...string) int {
321	if t, ok := db.keys[k]; ok && t != "hash" {
322		db.del(k, true)
323	}
324	db.keys[k] = "hash"
325	if _, ok := db.hashKeys[k]; !ok {
326		db.hashKeys[k] = map[string]string{}
327	}
328	new := 0
329	for idx := 0; idx < len(fv)-1; idx = idx + 2 {
330		f, v := fv[idx], fv[idx+1]
331		_, ok := db.hashKeys[k][f]
332		db.hashKeys[k][f] = v
333		db.keyVersion[k]++
334		if !ok {
335			new++
336		}
337	}
338	return new
339}
340
341// hashIncr changes int key value
342func (db *RedisDB) hashIncr(key, field string, delta int) (int, error) {
343	v := 0
344	if h, ok := db.hashKeys[key]; ok {
345		if f, ok := h[field]; ok {
346			var err error
347			v, err = strconv.Atoi(f)
348			if err != nil {
349				return 0, ErrIntValueError
350			}
351		}
352	}
353	v += delta
354	db.hashSet(key, field, strconv.Itoa(v))
355	return v, nil
356}
357
358// hashIncrfloat changes float key value
359func (db *RedisDB) hashIncrfloat(key, field string, delta *big.Float) (*big.Float, error) {
360	v := big.NewFloat(0.0)
361	v.SetPrec(128)
362	if h, ok := db.hashKeys[key]; ok {
363		if f, ok := h[field]; ok {
364			var err error
365			v, _, err = big.ParseFloat(f, 10, 128, 0)
366			if err != nil {
367				return nil, ErrFloatValueError
368			}
369		}
370	}
371	v.Add(v, delta)
372	db.hashSet(key, field, formatBig(v))
373	return v, nil
374}
375
376// sortedSet set returns a sortedSet as map
377func (db *RedisDB) sortedSet(key string) map[string]float64 {
378	ss := db.sortedsetKeys[key]
379	return map[string]float64(ss)
380}
381
382// ssetSet sets a complete sorted set.
383func (db *RedisDB) ssetSet(key string, sset sortedSet) {
384	db.keys[key] = "zset"
385	db.keyVersion[key]++
386	db.sortedsetKeys[key] = sset
387}
388
389// ssetAdd adds member to a sorted set. Returns whether this was a new member.
390func (db *RedisDB) ssetAdd(key string, score float64, member string) bool {
391	ss, ok := db.sortedsetKeys[key]
392	if !ok {
393		ss = newSortedSet()
394		db.keys[key] = "zset"
395	}
396	_, ok = ss[member]
397	ss[member] = score
398	db.sortedsetKeys[key] = ss
399	db.keyVersion[key]++
400	return !ok
401}
402
403// All members from a sorted set, ordered by score.
404func (db *RedisDB) ssetMembers(key string) []string {
405	ss, ok := db.sortedsetKeys[key]
406	if !ok {
407		return nil
408	}
409	elems := ss.byScore(asc)
410	members := make([]string, 0, len(elems))
411	for _, e := range elems {
412		members = append(members, e.member)
413	}
414	return members
415}
416
417// All members+scores from a sorted set, ordered by score.
418func (db *RedisDB) ssetElements(key string) ssElems {
419	ss, ok := db.sortedsetKeys[key]
420	if !ok {
421		return nil
422	}
423	return ss.byScore(asc)
424}
425
426// ssetCard is the sorted set cardinality.
427func (db *RedisDB) ssetCard(key string) int {
428	ss := db.sortedsetKeys[key]
429	return ss.card()
430}
431
432// ssetRank is the sorted set rank.
433func (db *RedisDB) ssetRank(key, member string, d direction) (int, bool) {
434	ss := db.sortedsetKeys[key]
435	return ss.rankByScore(member, d)
436}
437
438// ssetScore is sorted set score.
439func (db *RedisDB) ssetScore(key, member string) float64 {
440	ss := db.sortedsetKeys[key]
441	return ss[member]
442}
443
444// ssetRem is sorted set key delete.
445func (db *RedisDB) ssetRem(key, member string) bool {
446	ss := db.sortedsetKeys[key]
447	_, ok := ss[member]
448	delete(ss, member)
449	if len(ss) == 0 {
450		// Delete key on removal of last member
451		db.del(key, true)
452	}
453	return ok
454}
455
456// ssetExists tells if a member exists in a sorted set.
457func (db *RedisDB) ssetExists(key, member string) bool {
458	ss := db.sortedsetKeys[key]
459	_, ok := ss[member]
460	return ok
461}
462
463// ssetIncrby changes float sorted set score.
464func (db *RedisDB) ssetIncrby(k, m string, delta float64) float64 {
465	ss, ok := db.sortedsetKeys[k]
466	if !ok {
467		ss = newSortedSet()
468		db.keys[k] = "zset"
469		db.sortedsetKeys[k] = ss
470	}
471
472	v, _ := ss.get(m)
473	v += delta
474	ss.set(v, m)
475	db.keyVersion[k]++
476	return v
477}
478
479// setDiff implements the logic behind SDIFF*
480func (db *RedisDB) setDiff(keys []string) (setKey, error) {
481	key := keys[0]
482	keys = keys[1:]
483	if db.exists(key) && db.t(key) != "set" {
484		return nil, ErrWrongType
485	}
486	s := setKey{}
487	for k := range db.setKeys[key] {
488		s[k] = struct{}{}
489	}
490	for _, sk := range keys {
491		if !db.exists(sk) {
492			continue
493		}
494		if db.t(sk) != "set" {
495			return nil, ErrWrongType
496		}
497		for e := range db.setKeys[sk] {
498			delete(s, e)
499		}
500	}
501	return s, nil
502}
503
504// setInter implements the logic behind SINTER*
505func (db *RedisDB) setInter(keys []string) (setKey, error) {
506	key := keys[0]
507	keys = keys[1:]
508	if !db.exists(key) {
509		return setKey{}, nil
510	}
511	if db.t(key) != "set" {
512		return nil, ErrWrongType
513	}
514	s := setKey{}
515	for k := range db.setKeys[key] {
516		s[k] = struct{}{}
517	}
518	for _, sk := range keys {
519		if !db.exists(sk) {
520			return setKey{}, nil
521		}
522		if db.t(sk) != "set" {
523			return nil, ErrWrongType
524		}
525		other := db.setKeys[sk]
526		for e := range s {
527			if _, ok := other[e]; ok {
528				continue
529			}
530			delete(s, e)
531		}
532	}
533	return s, nil
534}
535
536// setUnion implements the logic behind SUNION*
537func (db *RedisDB) setUnion(keys []string) (setKey, error) {
538	key := keys[0]
539	keys = keys[1:]
540	if db.exists(key) && db.t(key) != "set" {
541		return nil, ErrWrongType
542	}
543	s := setKey{}
544	for k := range db.setKeys[key] {
545		s[k] = struct{}{}
546	}
547	for _, sk := range keys {
548		if !db.exists(sk) {
549			continue
550		}
551		if db.t(sk) != "set" {
552			return nil, ErrWrongType
553		}
554		for e := range db.setKeys[sk] {
555			s[e] = struct{}{}
556		}
557	}
558	return s, nil
559}
560
561// stream set returns a stream as a slice. Lowest ID first.
562func (db *RedisDB) stream(key string) []StreamEntry {
563	return db.streamKeys[key]
564}
565
566func (db *RedisDB) streamCreate(key string) {
567	_, ok := db.streamKeys[key]
568	if !ok {
569		db.keys[key] = "stream"
570	}
571
572	db.streamKeys[key] = make(streamKey, 0)
573	db.keyVersion[key]++
574}
575
576// streamAdd adds an entry to a stream. Returns the new entry ID.
577// If id is empty or "*" the ID will be generated automatically.
578// `values` should have an even length.
579func (db *RedisDB) streamAdd(key, entryID string, values []string) (string, error) {
580	stream, ok := db.streamKeys[key]
581	if !ok {
582		db.keys[key] = "stream"
583	}
584
585	if entryID == "" || entryID == "*" {
586		entryID = stream.generateID(db.master.effectiveNow())
587	}
588	entryID, err := formatStreamID(entryID)
589	if err != nil {
590		return "", err
591	}
592	if entryID == "0-0" {
593		return "", errZeroStreamValue
594	}
595	if streamCmp(stream.lastID(), entryID) != -1 {
596		return "", errInvalidStreamValue
597	}
598	db.streamKeys[key] = append(stream, StreamEntry{
599		ID:     entryID,
600		Values: values,
601	})
602	db.keyVersion[key]++
603	return entryID, nil
604}
605
606func (db *RedisDB) streamMaxlen(key string, n int) {
607	stream, ok := db.streamKeys[key]
608	if !ok {
609		return
610	}
611	if len(stream) > n {
612		db.streamKeys[key] = stream[len(stream)-n:]
613	}
614}
615
616func (db *RedisDB) streamLen(key string) (int, error) {
617	stream, ok := db.streamKeys[key]
618	if !ok {
619		return 0, fmt.Errorf("stream %s not exists", key)
620	}
621	return len(stream), nil
622}
623
624func (db *RedisDB) streamGroupCreate(stream, group, id string) error {
625	streamData, ok := db.streamKeys[stream]
626	if !ok {
627		return fmt.Errorf("stream %s not exists", stream)
628	}
629
630	if _, ok := db.streamGroupKeys[stream]; !ok {
631		db.streamGroupKeys[stream] = streamGroupKey{}
632	}
633
634	if _, ok := db.streamGroupKeys[stream][group]; ok {
635		return errors.New("BUSYGROUP")
636	}
637
638	entry := streamGroupEntry{
639		pending: make([]pendingEntry, 0),
640	}
641
642	if id == "$" {
643		entry.lastID = streamData.lastID()
644	} else {
645		entry.lastID = id
646	}
647
648	db.streamGroupKeys[stream][group] = entry
649
650	return nil
651}
652
653func (db *RedisDB) streamRead(stream, group, consumer, id string, count int) ([]StreamEntry, error) {
654	streamData, ok := db.streamKeys[stream]
655	if !ok {
656		return nil, fmt.Errorf("stream %s not exists", stream)
657	}
658
659	if _, ok := db.streamGroupKeys[stream]; !ok {
660		// Error for group because this is key for group
661		return nil, fmt.Errorf("group %s not exists", group)
662	}
663
664	groupData, ok := db.streamGroupKeys[stream][group]
665	if !ok {
666		return nil, fmt.Errorf("group %s not exists", group)
667	}
668
669	res := make([]StreamEntry, 0)
670
671	if id == ">" {
672		next := sort.Search(len(streamData), func(i int) bool {
673			return streamCmp(groupData.lastID, streamData[i].ID) < 0
674		})
675
676		if len(streamData[next:]) == 0 {
677			return nil, nil
678		}
679
680		if count == 0 || count > len(streamData[next:]) {
681			count = len(streamData[next:])
682		}
683
684		res = append(res, streamData[next:count]...)
685
686		for _, en := range res {
687			pending := pendingEntry{
688				consumer: consumer,
689				ID:       en.ID,
690			}
691			groupData.pending = append(groupData.pending, pending)
692		}
693
694		groupData.lastID = res[len(res)-1].ID
695		db.streamGroupKeys[stream][group] = groupData
696	} else {
697		next := sort.Search(len(groupData.pending), func(i int) bool {
698			return streamCmp(id, groupData.pending[i].ID) < 0
699		})
700
701		if len(groupData.pending[next:]) == 0 {
702			return nil, nil
703		}
704
705		for _, e := range groupData.pending[next:] {
706			if e.consumer != consumer {
707				continue
708			}
709
710			pos := sort.Search(len(streamData), func(i int) bool {
711				return streamCmp(e.ID, streamData[i].ID) == 0
712			})
713
714			// Not found
715			if pos == len(streamData) {
716				continue
717			}
718
719			res = append(res, streamData[pos])
720
721			// Truncate to allow faster next search, because next element in pending
722			// is greater, then current, so on for stream
723			streamData = streamData[pos:]
724		}
725	}
726
727	return res, nil
728}
729
730func (db *RedisDB) streamDelete(stream string, ids []string) (int, error) {
731	streamData, ok := db.streamKeys[stream]
732	if !ok {
733		return 0, fmt.Errorf("stream %s not exists", stream)
734	}
735
736	count := 0
737
738	for _, id := range ids {
739		pos := sort.Search(len(streamData), func(i int) bool {
740			return streamCmp(id, streamData[i].ID) == 0
741		})
742
743		if pos == len(streamData) {
744			continue
745		}
746
747		streamData = append(streamData[:pos], streamData[pos+1:]...)
748		count++
749	}
750
751	if count > 0 {
752		db.streamKeys[stream] = streamData
753	}
754
755	return count, nil
756}
757
758func (db *RedisDB) streamAck(stream, group string, ids []string) (int, error) {
759	if _, ok := db.streamGroupKeys[stream]; !ok {
760		// Error for group because this is key for group
761		return 0, fmt.Errorf("group %s not exists", group)
762	}
763
764	groupData, ok := db.streamGroupKeys[stream][group]
765	if !ok {
766		return 0, fmt.Errorf("group %s not exists", group)
767	}
768
769	count := 0
770
771	for _, id := range ids {
772		pos := sort.Search(len(groupData.pending), func(i int) bool {
773			return streamCmp(id, groupData.pending[i].ID) == 0
774		})
775
776		if pos == len(groupData.pending) {
777			continue
778		}
779
780		groupData.pending = append(groupData.pending[:pos], groupData.pending[pos+1:]...)
781		count++
782	}
783
784	if count > 0 {
785		db.streamGroupKeys[stream][group] = groupData
786	}
787
788	return count, nil
789}
790
791// fastForward proceeds the current timestamp with duration, works as a time machine
792func (db *RedisDB) fastForward(duration time.Duration) {
793	for _, key := range db.allKeys() {
794		if value, ok := db.ttl[key]; ok {
795			db.ttl[key] = value - duration
796			db.checkTTL(key)
797		}
798	}
799}
800
801func (db *RedisDB) checkTTL(key string) {
802	if v, ok := db.ttl[key]; ok && v <= 0 {
803		db.del(key, true)
804	}
805}
806