1package naffka
2
3import (
4	"fmt"
5	"log"
6	"sync"
7	"time"
8
9	sarama "github.com/Shopify/sarama"
10	"github.com/matrix-org/naffka/storage"
11	"github.com/matrix-org/naffka/types"
12)
13
14// Naffka is an implementation of the sarama kafka API designed to run within a
15// single go process. It implements both the sarama.SyncProducer and the
16// sarama.Consumer interfaces. This means it can act as a drop in replacement
17// for kafka for testing or single instance deployment.
18// Does not support multiple partitions.
19type Naffka struct {
20	db          storage.Database
21	topicsMutex sync.Mutex
22	topics      map[string]*topic
23}
24
25// New creates a new Naffka instance.
26func New(db storage.Database) (*Naffka, error) {
27	n := &Naffka{db: db, topics: map[string]*topic{}}
28	maxOffsets, err := db.MaxOffsets()
29	if err != nil {
30		return nil, err
31	}
32	for topicName, offset := range maxOffsets {
33		n.topics[topicName] = &topic{
34			db:         db,
35			topicName:  topicName,
36			nextOffset: offset + 1,
37		}
38	}
39	return n, nil
40}
41
42// SendMessage implements sarama.SyncProducer
43func (n *Naffka) SendMessage(msg *sarama.ProducerMessage) (partition int32, offset int64, err error) {
44	err = n.SendMessages([]*sarama.ProducerMessage{msg})
45	return msg.Partition, msg.Offset, err
46}
47
48// SendMessages implements sarama.SyncProducer
49func (n *Naffka) SendMessages(msgs []*sarama.ProducerMessage) error {
50	byTopic := map[string][]*sarama.ProducerMessage{}
51	for _, msg := range msgs {
52		byTopic[msg.Topic] = append(byTopic[msg.Topic], msg)
53	}
54	var topicNames []string
55	for topicName := range byTopic {
56		topicNames = append(topicNames, topicName)
57	}
58
59	now := time.Now()
60	topics := n.getTopics(topicNames)
61	for topicName := range byTopic {
62		if err := topics[topicName].send(now, byTopic[topicName]); err != nil {
63			return err
64		}
65	}
66	return nil
67}
68
69func (n *Naffka) getTopics(topicNames []string) map[string]*topic {
70	n.topicsMutex.Lock()
71	defer n.topicsMutex.Unlock()
72	result := map[string]*topic{}
73	for _, topicName := range topicNames {
74		t := n.topics[topicName]
75		if t == nil {
76			// If the topic doesn't already exist then create it.
77			t = &topic{db: n.db, topicName: topicName}
78			n.topics[topicName] = t
79		}
80		result[topicName] = t
81	}
82	return result
83}
84
85// Topics implements sarama.Consumer
86func (n *Naffka) Topics() ([]string, error) {
87	n.topicsMutex.Lock()
88	defer n.topicsMutex.Unlock()
89	var result []string
90	for topic := range n.topics {
91		result = append(result, topic)
92	}
93	return result, nil
94}
95
96// Partitions implements sarama.Consumer
97func (n *Naffka) Partitions(topic string) ([]int32, error) {
98	// Naffka stores a single partition per topic, so this always returns a single partition ID.
99	return []int32{0}, nil
100}
101
102// ConsumePartition implements sarama.Consumer
103// Note: offset is *inclusive*, i.e. it will include the message with that offset.
104func (n *Naffka) ConsumePartition(topic string, partition int32, offset int64) (sarama.PartitionConsumer, error) {
105	if partition != 0 {
106		return nil, fmt.Errorf("Unknown partition ID %d", partition)
107	}
108	topics := n.getTopics([]string{topic})
109	return topics[topic].consume(offset), nil
110}
111
112// HighWaterMarks implements sarama.Consumer
113func (n *Naffka) HighWaterMarks() map[string]map[int32]int64 {
114	n.topicsMutex.Lock()
115	defer n.topicsMutex.Unlock()
116	result := map[string]map[int32]int64{}
117	for topicName, topic := range n.topics {
118		result[topicName] = map[int32]int64{
119			0: topic.highwaterMark(),
120		}
121	}
122	return result
123}
124
125// Close implements sarama.SyncProducer and sarama.Consumer
126func (n *Naffka) Close() error {
127	return nil
128}
129
130const channelSize = 1024
131
132// partitionConsumer ensures that all messages written to a particular
133// topic, from an offset, get sent in order to a channel.
134// Implements sarama.PartitionConsumer
135type partitionConsumer struct {
136	topic    *topic
137	messages chan *sarama.ConsumerMessage
138	// Whether the consumer is in "catchup" mode or not.
139	// See "catchup" function for details.
140	// Reads and writes to this field are proctected by the topic mutex.
141	catchingUp bool
142}
143
144// AsyncClose implements sarama.PartitionConsumer
145func (c *partitionConsumer) AsyncClose() {
146}
147
148// Close implements sarama.PartitionConsumer
149func (c *partitionConsumer) Close() error {
150	// TODO: Add support for performing a clean shutdown of the consumer.
151	return nil
152}
153
154// Messages implements sarama.PartitionConsumer
155func (c *partitionConsumer) Messages() <-chan *sarama.ConsumerMessage {
156	return c.messages
157}
158
159// Errors implements sarama.PartitionConsumer
160func (c *partitionConsumer) Errors() <-chan *sarama.ConsumerError {
161	// TODO: Add option to pass consumer errors to an errors channel.
162	return nil
163}
164
165// HighWaterMarkOffset implements sarama.PartitionConsumer
166func (c *partitionConsumer) HighWaterMarkOffset() int64 {
167	return c.topic.highwaterMark()
168}
169
170// catchup makes the consumer go into "catchup" mode, where messages are read
171// from the database instead of directly from producers.
172// Once the consumer is up to date, i.e. no new messages in the database, then
173// the consumer will go back into normal mode where new messages are written
174// directly to the channel.
175// Must be called with the c.topic.mutex lock
176func (c *partitionConsumer) catchup(fromOffset int64) {
177	// If we're already in catchup mode or up to date, noop
178	if c.catchingUp || fromOffset == c.topic.nextOffset {
179		return
180	}
181
182	c.catchingUp = true
183
184	// Due to the checks above there can only be one of these goroutines
185	// running at a time
186	go func() {
187		for {
188			// Check if we're up to date yet. If we are we exit catchup mode.
189			c.topic.mutex.Lock()
190			nextOffset := c.topic.nextOffset
191			if fromOffset == nextOffset {
192				c.catchingUp = false
193				c.topic.mutex.Unlock()
194				return
195			}
196			c.topic.mutex.Unlock()
197
198			// Limit the number of messages we request from the database to be the
199			// capacity of the channel.
200			if nextOffset > fromOffset+int64(cap(c.messages)) {
201				nextOffset = fromOffset + int64(cap(c.messages))
202			}
203			// Fetch the messages from the database.
204			msgs, err := c.topic.db.FetchMessages(c.topic.topicName, fromOffset, nextOffset)
205			if err != nil {
206				// TODO: Add option to write consumer errors to an errors channel
207				// as an alternative to logging the errors.
208				log.Print("Error reading messages: ", err)
209				// Wait before retrying.
210				// TODO: Maybe use an exponentional backoff scheme here.
211				// TODO: This timeout should take account of all the other goroutines
212				// that might be doing the same thing. (If there are a 10000 consumers
213				// then we don't want to end up retrying every millisecond)
214				time.Sleep(10 * time.Second)
215				continue
216			}
217			if len(msgs) == 0 {
218				// This should only happen if the database is corrupted and has lost the
219				// messages between the requested offsets.
220				log.Fatalf("Corrupt database returned no messages between %d and %d", fromOffset, nextOffset)
221			}
222
223			// Pass the messages into the consumer channel.
224			// Blocking each write until the channel has enough space for the message.
225			for i := range msgs {
226				c.messages <- msgs[i].ConsumerMessage(c.topic.topicName)
227			}
228			// Update our the offset for the next loop iteration.
229			fromOffset = msgs[len(msgs)-1].Offset + 1
230		}
231	}()
232}
233
234// notifyNewMessage tells the consumer about a new message
235// Must be called with the c.topic.mutex lock
236func (c *partitionConsumer) notifyNewMessage(cmsg *sarama.ConsumerMessage) {
237	// If we're in "catchup" mode then the catchup routine will send the
238	// message later, since cmsg has already been written to the database
239	if c.catchingUp {
240		return
241	}
242
243	// Otherwise, lets try writing the message directly to the channel
244	select {
245	case c.messages <- cmsg:
246	default:
247		// The messages channel has filled up, so lets go into catchup
248		// mode. Once the channel starts being read from again messages
249		// will be read from the database
250		c.catchup(cmsg.Offset)
251	}
252}
253
254type topic struct {
255	db        storage.Database
256	topicName string
257	mutex     sync.Mutex
258	consumers []*partitionConsumer
259	// nextOffset is the offset that will be assigned to the next message in
260	// this topic, i.e. one greater than the last message offset.
261	nextOffset int64
262}
263
264// send writes messages to a topic.
265func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
266	var err error
267	// Encode the message keys and values.
268	msgs := make([]types.Message, len(pmsgs))
269	for i := range msgs {
270		if pmsgs[i].Key != nil {
271			msgs[i].Key, err = pmsgs[i].Key.Encode()
272			if err != nil {
273				return err
274			}
275		}
276		if pmsgs[i].Value != nil {
277			msgs[i].Value, err = pmsgs[i].Value.Encode()
278			if err != nil {
279				return err
280			}
281		}
282		pmsgs[i].Timestamp = now
283		msgs[i].Timestamp = now
284
285		msgs[i].Headers = pmsgs[i].Headers
286	}
287	// Take the lock before assigning the offsets.
288	t.mutex.Lock()
289	defer t.mutex.Unlock()
290	offset := t.nextOffset
291	for i := range msgs {
292		pmsgs[i].Offset = offset
293		msgs[i].Offset = offset
294		offset++
295	}
296	// Store the messages while we hold the lock.
297	err = t.db.StoreMessages(t.topicName, msgs)
298	if err != nil {
299		return err
300	}
301	t.nextOffset = offset
302
303	// Now notify the consumers about the messages.
304	for _, msg := range msgs {
305		cmsg := msg.ConsumerMessage(t.topicName)
306		for _, c := range t.consumers {
307			c.notifyNewMessage(cmsg)
308		}
309	}
310
311	return nil
312}
313
314func (t *topic) consume(offset int64) *partitionConsumer {
315	t.mutex.Lock()
316	defer t.mutex.Unlock()
317	c := &partitionConsumer{
318		topic: t,
319	}
320	// Handle special offsets.
321	if offset == sarama.OffsetNewest {
322		offset = t.nextOffset
323	}
324	if offset == sarama.OffsetOldest {
325		offset = 0
326	}
327	c.messages = make(chan *sarama.ConsumerMessage, channelSize)
328	t.consumers = append(t.consumers, c)
329
330	// If we're not streaming from the latest offset we need to go into
331	// "catchup" mode
332	if offset != t.nextOffset {
333		c.catchup(offset)
334	}
335	return c
336}
337
338func (t *topic) highwaterMark() int64 {
339	t.mutex.Lock()
340	defer t.mutex.Unlock()
341	return t.nextOffset
342}
343