1package pubsub
2
3import (
4	"fmt"
5
6	pb "github.com/libp2p/go-libp2p-pubsub/pb"
7
8	"github.com/libp2p/go-libp2p-core/peer"
9)
10
11// NewMessageCache creates a sliding window cache that remembers messages for as
12// long as `history` slots.
13//
14// When queried for messages to advertise, the cache only returns messages in
15// the last `gossip` slots.
16//
17// The `gossip` parameter must be smaller or equal to `history`, or this
18// function will panic.
19//
20// The slack between `gossip` and `history` accounts for the reaction time
21// between when a message is advertised via IHAVE gossip, and the peer pulls it
22// via an IWANT command.
23func NewMessageCache(gossip, history int) *MessageCache {
24	if gossip > history {
25		err := fmt.Errorf("invalid parameters for message cache; gossip slots (%d) cannot be larger than history slots (%d)",
26			gossip, history)
27		panic(err)
28	}
29	return &MessageCache{
30		msgs:    make(map[string]*pb.Message),
31		peertx:  make(map[string]map[peer.ID]int),
32		history: make([][]CacheEntry, history),
33		gossip:  gossip,
34		msgID:   DefaultMsgIdFn,
35	}
36}
37
38type MessageCache struct {
39	msgs    map[string]*pb.Message
40	peertx  map[string]map[peer.ID]int
41	history [][]CacheEntry
42	gossip  int
43	msgID   MsgIdFunction
44}
45
46func (mc *MessageCache) SetMsgIdFn(msgID MsgIdFunction) {
47	mc.msgID = msgID
48}
49
50type CacheEntry struct {
51	mid   string
52	topic string
53}
54
55func (mc *MessageCache) Put(msg *pb.Message) {
56	mid := mc.msgID(msg)
57	mc.msgs[mid] = msg
58	mc.history[0] = append(mc.history[0], CacheEntry{mid: mid, topic: msg.GetTopic()})
59}
60
61func (mc *MessageCache) Get(mid string) (*pb.Message, bool) {
62	m, ok := mc.msgs[mid]
63	return m, ok
64}
65
66func (mc *MessageCache) GetForPeer(mid string, p peer.ID) (*pb.Message, int, bool) {
67	m, ok := mc.msgs[mid]
68	if !ok {
69		return nil, 0, false
70	}
71
72	tx, ok := mc.peertx[mid]
73	if !ok {
74		tx = make(map[peer.ID]int)
75		mc.peertx[mid] = tx
76	}
77	tx[p]++
78
79	return m, tx[p], true
80}
81
82func (mc *MessageCache) GetGossipIDs(topic string) []string {
83	var mids []string
84	for _, entries := range mc.history[:mc.gossip] {
85		for _, entry := range entries {
86			if entry.topic == topic {
87				mids = append(mids, entry.mid)
88			}
89		}
90	}
91	return mids
92}
93
94func (mc *MessageCache) Shift() {
95	last := mc.history[len(mc.history)-1]
96	for _, entry := range last {
97		delete(mc.msgs, entry.mid)
98		delete(mc.peertx, entry.mid)
99	}
100	for i := len(mc.history) - 2; i >= 0; i-- {
101		mc.history[i+1] = mc.history[i]
102	}
103	mc.history[0] = nil
104}
105