1package nsq
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"log"
8	"math"
9	"math/rand"
10	"net"
11	"net/url"
12	"os"
13	"strconv"
14	"strings"
15	"sync"
16	"sync/atomic"
17	"time"
18)
19
20// Handler is the message processing interface for Consumer
21//
22// Implement this interface for handlers that return whether or not message
23// processing completed successfully.
24//
25// When the return value is nil Consumer will automatically handle FINishing.
26//
27// When the returned value is non-nil Consumer will automatically handle REQueing.
28type Handler interface {
29	HandleMessage(message *Message) error
30}
31
32// HandlerFunc is a convenience type to avoid having to declare a struct
33// to implement the Handler interface, it can be used like this:
34//
35// 	consumer.AddHandler(nsq.HandlerFunc(func(m *Message) error {
36// 		// handle the message
37// 	}))
38type HandlerFunc func(message *Message) error
39
40// HandleMessage implements the Handler interface
41func (h HandlerFunc) HandleMessage(m *Message) error {
42	return h(m)
43}
44
45// DiscoveryFilter is an interface accepted by `SetBehaviorDelegate()`
46// for filtering the nsqds returned from discovery via nsqlookupd
47type DiscoveryFilter interface {
48	Filter([]string) []string
49}
50
51// FailedMessageLogger is an interface that can be implemented by handlers that wish
52// to receive a callback when a message is deemed "failed" (i.e. the number of attempts
53// exceeded the Consumer specified MaxAttemptCount)
54type FailedMessageLogger interface {
55	LogFailedMessage(message *Message)
56}
57
58// ConsumerStats represents a snapshot of the state of a Consumer's connections and the messages
59// it has seen
60type ConsumerStats struct {
61	MessagesReceived uint64
62	MessagesFinished uint64
63	MessagesRequeued uint64
64	Connections      int
65}
66
67var instCount int64
68
69type backoffSignal int
70
71const (
72	backoffFlag backoffSignal = iota
73	continueFlag
74	resumeFlag
75)
76
77// Consumer is a high-level type to consume from NSQ.
78//
79// A Consumer instance is supplied a Handler that will be executed
80// concurrently via goroutines to handle processing the stream of messages
81// consumed from the specified topic/channel. See: Handler/HandlerFunc
82// for details on implementing the interface to create handlers.
83//
84// If configured, it will poll nsqlookupd instances and handle connection (and
85// reconnection) to any discovered nsqds.
86type Consumer struct {
87	// 64bit atomic vars need to be first for proper alignment on 32bit platforms
88	messagesReceived uint64
89	messagesFinished uint64
90	messagesRequeued uint64
91	totalRdyCount    int64
92	backoffDuration  int64
93	backoffCounter   int32
94	maxInFlight      int32
95
96	mtx sync.RWMutex
97
98	logger   logger
99	logLvl   LogLevel
100	logGuard sync.RWMutex
101
102	behaviorDelegate interface{}
103
104	id      int64
105	topic   string
106	channel string
107	config  Config
108
109	rngMtx sync.Mutex
110	rng    *rand.Rand
111
112	needRDYRedistributed int32
113
114	backoffMtx sync.Mutex
115
116	incomingMessages chan *Message
117
118	rdyRetryMtx    sync.Mutex
119	rdyRetryTimers map[string]*time.Timer
120
121	pendingConnections map[string]*Conn
122	connections        map[string]*Conn
123
124	nsqdTCPAddrs []string
125
126	// used at connection close to force a possible reconnect
127	lookupdRecheckChan chan int
128	lookupdHTTPAddrs   []string
129	lookupdQueryIndex  int
130
131	wg              sync.WaitGroup
132	runningHandlers int32
133	stopFlag        int32
134	connectedFlag   int32
135	stopHandler     sync.Once
136	exitHandler     sync.Once
137
138	// read from this channel to block until consumer is cleanly stopped
139	StopChan chan int
140	exitChan chan int
141}
142
143// NewConsumer creates a new instance of Consumer for the specified topic/channel
144//
145// The only valid way to create a Config is via NewConfig, using a struct literal will panic.
146// After Config is passed into NewConsumer the values are no longer mutable (they are copied).
147func NewConsumer(topic string, channel string, config *Config) (*Consumer, error) {
148	config.assertInitialized()
149
150	if err := config.Validate(); err != nil {
151		return nil, err
152	}
153
154	if !IsValidTopicName(topic) {
155		return nil, errors.New("invalid topic name")
156	}
157
158	if !IsValidChannelName(channel) {
159		return nil, errors.New("invalid channel name")
160	}
161
162	r := &Consumer{
163		id: atomic.AddInt64(&instCount, 1),
164
165		topic:   topic,
166		channel: channel,
167		config:  *config,
168
169		logger:      log.New(os.Stderr, "", log.Flags()),
170		logLvl:      LogLevelInfo,
171		maxInFlight: int32(config.MaxInFlight),
172
173		incomingMessages: make(chan *Message),
174
175		rdyRetryTimers:     make(map[string]*time.Timer),
176		pendingConnections: make(map[string]*Conn),
177		connections:        make(map[string]*Conn),
178
179		lookupdRecheckChan: make(chan int, 1),
180
181		rng: rand.New(rand.NewSource(time.Now().UnixNano())),
182
183		StopChan: make(chan int),
184		exitChan: make(chan int),
185	}
186	r.wg.Add(1)
187	go r.rdyLoop()
188	return r, nil
189}
190
191// Stats retrieves the current connection and message statistics for a Consumer
192func (r *Consumer) Stats() *ConsumerStats {
193	return &ConsumerStats{
194		MessagesReceived: atomic.LoadUint64(&r.messagesReceived),
195		MessagesFinished: atomic.LoadUint64(&r.messagesFinished),
196		MessagesRequeued: atomic.LoadUint64(&r.messagesRequeued),
197		Connections:      len(r.conns()),
198	}
199}
200
201func (r *Consumer) conns() []*Conn {
202	r.mtx.RLock()
203	conns := make([]*Conn, 0, len(r.connections))
204	for _, c := range r.connections {
205		conns = append(conns, c)
206	}
207	r.mtx.RUnlock()
208	return conns
209}
210
211// SetLogger assigns the logger to use as well as a level
212//
213// The logger parameter is an interface that requires the following
214// method to be implemented (such as the the stdlib log.Logger):
215//
216//    Output(calldepth int, s string)
217//
218func (r *Consumer) SetLogger(l logger, lvl LogLevel) {
219	r.logGuard.Lock()
220	defer r.logGuard.Unlock()
221
222	r.logger = l
223	r.logLvl = lvl
224}
225
226func (r *Consumer) SetLoggerLevel(lvl LogLevel) {
227	r.logGuard.Lock()
228	defer r.logGuard.Unlock()
229
230	r.logLvl = lvl
231}
232
233func (r *Consumer) getLogger() (logger, LogLevel) {
234	r.logGuard.RLock()
235	defer r.logGuard.RUnlock()
236
237	return r.logger, r.logLvl
238}
239
240// SetBehaviorDelegate takes a type implementing one or more
241// of the following interfaces that modify the behavior
242// of the `Consumer`:
243//
244//    DiscoveryFilter
245//
246func (r *Consumer) SetBehaviorDelegate(cb interface{}) {
247	matched := false
248
249	if _, ok := cb.(DiscoveryFilter); ok {
250		matched = true
251	}
252
253	if !matched {
254		panic("behavior delegate does not have any recognized methods")
255	}
256
257	r.behaviorDelegate = cb
258}
259
260// perConnMaxInFlight calculates the per-connection max-in-flight count.
261//
262// This may change dynamically based on the number of connections to nsqd the Consumer
263// is responsible for.
264func (r *Consumer) perConnMaxInFlight() int64 {
265	b := float64(r.getMaxInFlight())
266	s := b / float64(len(r.conns()))
267	return int64(math.Min(math.Max(1, s), b))
268}
269
270// IsStarved indicates whether any connections for this consumer are blocked on processing
271// before being able to receive more messages (ie. RDY count of 0 and not exiting)
272func (r *Consumer) IsStarved() bool {
273	for _, conn := range r.conns() {
274		threshold := int64(float64(conn.RDY()) * 0.85)
275		inFlight := atomic.LoadInt64(&conn.messagesInFlight)
276		if inFlight >= threshold && inFlight > 0 && !conn.IsClosing() {
277			return true
278		}
279	}
280	return false
281}
282
283func (r *Consumer) getMaxInFlight() int32 {
284	return atomic.LoadInt32(&r.maxInFlight)
285}
286
287// ChangeMaxInFlight sets a new maximum number of messages this comsumer instance
288// will allow in-flight, and updates all existing connections as appropriate.
289//
290// For example, ChangeMaxInFlight(0) would pause message flow
291//
292// If already connected, it updates the reader RDY state for each connection.
293func (r *Consumer) ChangeMaxInFlight(maxInFlight int) {
294	if r.getMaxInFlight() == int32(maxInFlight) {
295		return
296	}
297
298	atomic.StoreInt32(&r.maxInFlight, int32(maxInFlight))
299
300	for _, c := range r.conns() {
301		r.maybeUpdateRDY(c)
302	}
303}
304
305// ConnectToNSQLookupd adds an nsqlookupd address to the list for this Consumer instance.
306//
307// If it is the first to be added, it initiates an HTTP request to discover nsqd
308// producers for the configured topic.
309//
310// A goroutine is spawned to handle continual polling.
311func (r *Consumer) ConnectToNSQLookupd(addr string) error {
312	if atomic.LoadInt32(&r.stopFlag) == 1 {
313		return errors.New("consumer stopped")
314	}
315	if atomic.LoadInt32(&r.runningHandlers) == 0 {
316		return errors.New("no handlers")
317	}
318
319	if err := validatedLookupAddr(addr); err != nil {
320		return err
321	}
322
323	atomic.StoreInt32(&r.connectedFlag, 1)
324
325	r.mtx.Lock()
326	for _, x := range r.lookupdHTTPAddrs {
327		if x == addr {
328			r.mtx.Unlock()
329			return nil
330		}
331	}
332	r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs, addr)
333	numLookupd := len(r.lookupdHTTPAddrs)
334	r.mtx.Unlock()
335
336	// if this is the first one, kick off the go loop
337	if numLookupd == 1 {
338		r.queryLookupd()
339		r.wg.Add(1)
340		go r.lookupdLoop()
341	}
342
343	return nil
344}
345
346// ConnectToNSQLookupds adds multiple nsqlookupd address to the list for this Consumer instance.
347//
348// If adding the first address it initiates an HTTP request to discover nsqd
349// producers for the configured topic.
350//
351// A goroutine is spawned to handle continual polling.
352func (r *Consumer) ConnectToNSQLookupds(addresses []string) error {
353	for _, addr := range addresses {
354		err := r.ConnectToNSQLookupd(addr)
355		if err != nil {
356			return err
357		}
358	}
359	return nil
360}
361
362func validatedLookupAddr(addr string) error {
363	if strings.Contains(addr, "/") {
364		_, err := url.Parse(addr)
365		if err != nil {
366			return err
367		}
368		return nil
369	}
370	if !strings.Contains(addr, ":") {
371		return errors.New("missing port")
372	}
373	return nil
374}
375
376// poll all known lookup servers every LookupdPollInterval
377func (r *Consumer) lookupdLoop() {
378	// add some jitter so that multiple consumers discovering the same topic,
379	// when restarted at the same time, dont all connect at once.
380	r.rngMtx.Lock()
381	jitter := time.Duration(int64(r.rng.Float64() *
382		r.config.LookupdPollJitter * float64(r.config.LookupdPollInterval)))
383	r.rngMtx.Unlock()
384	var ticker *time.Ticker
385
386	select {
387	case <-time.After(jitter):
388	case <-r.exitChan:
389		goto exit
390	}
391
392	ticker = time.NewTicker(r.config.LookupdPollInterval)
393
394	for {
395		select {
396		case <-ticker.C:
397			r.queryLookupd()
398		case <-r.lookupdRecheckChan:
399			r.queryLookupd()
400		case <-r.exitChan:
401			goto exit
402		}
403	}
404
405exit:
406	if ticker != nil {
407		ticker.Stop()
408	}
409	r.log(LogLevelInfo, "exiting lookupdLoop")
410	r.wg.Done()
411}
412
413// return the next lookupd endpoint to query
414// keeping track of which one was last used
415func (r *Consumer) nextLookupdEndpoint() string {
416	r.mtx.RLock()
417	if r.lookupdQueryIndex >= len(r.lookupdHTTPAddrs) {
418		r.lookupdQueryIndex = 0
419	}
420	addr := r.lookupdHTTPAddrs[r.lookupdQueryIndex]
421	num := len(r.lookupdHTTPAddrs)
422	r.mtx.RUnlock()
423	r.lookupdQueryIndex = (r.lookupdQueryIndex + 1) % num
424
425	urlString := addr
426	if !strings.Contains(urlString, "://") {
427		urlString = "http://" + addr
428	}
429
430	u, err := url.Parse(urlString)
431	if err != nil {
432		panic(err)
433	}
434	if u.Path == "/" || u.Path == "" {
435		u.Path = "/lookup"
436	}
437
438	v, err := url.ParseQuery(u.RawQuery)
439	v.Add("topic", r.topic)
440	u.RawQuery = v.Encode()
441	return u.String()
442}
443
444type lookupResp struct {
445	Channels  []string    `json:"channels"`
446	Producers []*peerInfo `json:"producers"`
447	Timestamp int64       `json:"timestamp"`
448}
449
450type peerInfo struct {
451	RemoteAddress    string `json:"remote_address"`
452	Hostname         string `json:"hostname"`
453	BroadcastAddress string `json:"broadcast_address"`
454	TCPPort          int    `json:"tcp_port"`
455	HTTPPort         int    `json:"http_port"`
456	Version          string `json:"version"`
457}
458
459// make an HTTP req to one of the configured nsqlookupd instances to discover
460// which nsqd's provide the topic we are consuming.
461//
462// initiate a connection to any new producers that are identified.
463func (r *Consumer) queryLookupd() {
464	retries := 0
465
466retry:
467	endpoint := r.nextLookupdEndpoint()
468
469	r.log(LogLevelInfo, "querying nsqlookupd %s", endpoint)
470
471	var data lookupResp
472	err := apiRequestNegotiateV1("GET", endpoint, nil, &data)
473	if err != nil {
474		r.log(LogLevelError, "error querying nsqlookupd (%s) - %s", endpoint, err)
475		retries++
476		if retries < 3 {
477			r.log(LogLevelInfo, "retrying with next nsqlookupd")
478			goto retry
479		}
480		return
481	}
482
483	var nsqdAddrs []string
484	for _, producer := range data.Producers {
485		broadcastAddress := producer.BroadcastAddress
486		port := producer.TCPPort
487		joined := net.JoinHostPort(broadcastAddress, strconv.Itoa(port))
488		nsqdAddrs = append(nsqdAddrs, joined)
489	}
490	// apply filter
491	if discoveryFilter, ok := r.behaviorDelegate.(DiscoveryFilter); ok {
492		nsqdAddrs = discoveryFilter.Filter(nsqdAddrs)
493	}
494	for _, addr := range nsqdAddrs {
495		err = r.ConnectToNSQD(addr)
496		if err != nil && err != ErrAlreadyConnected {
497			r.log(LogLevelError, "(%s) error connecting to nsqd - %s", addr, err)
498			continue
499		}
500	}
501}
502
503// ConnectToNSQDs takes multiple nsqd addresses to connect directly to.
504//
505// It is recommended to use ConnectToNSQLookupd so that topics are discovered
506// automatically.  This method is useful when you want to connect to local instance.
507func (r *Consumer) ConnectToNSQDs(addresses []string) error {
508	for _, addr := range addresses {
509		err := r.ConnectToNSQD(addr)
510		if err != nil {
511			return err
512		}
513	}
514	return nil
515}
516
517// ConnectToNSQD takes a nsqd address to connect directly to.
518//
519// It is recommended to use ConnectToNSQLookupd so that topics are discovered
520// automatically.  This method is useful when you want to connect to a single, local,
521// instance.
522func (r *Consumer) ConnectToNSQD(addr string) error {
523	if atomic.LoadInt32(&r.stopFlag) == 1 {
524		return errors.New("consumer stopped")
525	}
526
527	if atomic.LoadInt32(&r.runningHandlers) == 0 {
528		return errors.New("no handlers")
529	}
530
531	atomic.StoreInt32(&r.connectedFlag, 1)
532
533	logger, logLvl := r.getLogger()
534
535	conn := NewConn(addr, &r.config, &consumerConnDelegate{r})
536	conn.SetLogger(logger, logLvl,
537		fmt.Sprintf("%3d [%s/%s] (%%s)", r.id, r.topic, r.channel))
538
539	r.mtx.Lock()
540	_, pendingOk := r.pendingConnections[addr]
541	_, ok := r.connections[addr]
542	if ok || pendingOk {
543		r.mtx.Unlock()
544		return ErrAlreadyConnected
545	}
546	r.pendingConnections[addr] = conn
547	if idx := indexOf(addr, r.nsqdTCPAddrs); idx == -1 {
548		r.nsqdTCPAddrs = append(r.nsqdTCPAddrs, addr)
549	}
550	r.mtx.Unlock()
551
552	r.log(LogLevelInfo, "(%s) connecting to nsqd", addr)
553
554	cleanupConnection := func() {
555		r.mtx.Lock()
556		delete(r.pendingConnections, addr)
557		r.mtx.Unlock()
558		conn.Close()
559	}
560
561	resp, err := conn.Connect()
562	if err != nil {
563		cleanupConnection()
564		return err
565	}
566
567	if resp != nil {
568		if resp.MaxRdyCount < int64(r.getMaxInFlight()) {
569			r.log(LogLevelWarning,
570				"(%s) max RDY count %d < consumer max in flight %d, truncation possible",
571				conn.String(), resp.MaxRdyCount, r.getMaxInFlight())
572		}
573	}
574
575	cmd := Subscribe(r.topic, r.channel)
576	err = conn.WriteCommand(cmd)
577	if err != nil {
578		cleanupConnection()
579		return fmt.Errorf("[%s] failed to subscribe to %s:%s - %s",
580			conn, r.topic, r.channel, err.Error())
581	}
582
583	r.mtx.Lock()
584	delete(r.pendingConnections, addr)
585	r.connections[addr] = conn
586	r.mtx.Unlock()
587
588	// pre-emptive signal to existing connections to lower their RDY count
589	for _, c := range r.conns() {
590		r.maybeUpdateRDY(c)
591	}
592
593	return nil
594}
595
596func indexOf(n string, h []string) int {
597	for i, a := range h {
598		if n == a {
599			return i
600		}
601	}
602	return -1
603}
604
605// DisconnectFromNSQD closes the connection to and removes the specified
606// `nsqd` address from the list
607func (r *Consumer) DisconnectFromNSQD(addr string) error {
608	r.mtx.Lock()
609	defer r.mtx.Unlock()
610
611	idx := indexOf(addr, r.nsqdTCPAddrs)
612	if idx == -1 {
613		return ErrNotConnected
614	}
615
616	// slice delete
617	r.nsqdTCPAddrs = append(r.nsqdTCPAddrs[:idx], r.nsqdTCPAddrs[idx+1:]...)
618
619	pendingConn, pendingOk := r.pendingConnections[addr]
620	conn, ok := r.connections[addr]
621
622	if ok {
623		conn.Close()
624	} else if pendingOk {
625		pendingConn.Close()
626	}
627
628	return nil
629}
630
631// DisconnectFromNSQLookupd removes the specified `nsqlookupd` address
632// from the list used for periodic discovery.
633func (r *Consumer) DisconnectFromNSQLookupd(addr string) error {
634	r.mtx.Lock()
635	defer r.mtx.Unlock()
636
637	idx := indexOf(addr, r.lookupdHTTPAddrs)
638	if idx == -1 {
639		return ErrNotConnected
640	}
641
642	if len(r.lookupdHTTPAddrs) == 1 {
643		return fmt.Errorf("cannot disconnect from only remaining nsqlookupd HTTP address %s", addr)
644	}
645
646	r.lookupdHTTPAddrs = append(r.lookupdHTTPAddrs[:idx], r.lookupdHTTPAddrs[idx+1:]...)
647
648	return nil
649}
650
651func (r *Consumer) onConnMessage(c *Conn, msg *Message) {
652	atomic.AddUint64(&r.messagesReceived, 1)
653	r.incomingMessages <- msg
654}
655
656func (r *Consumer) onConnMessageFinished(c *Conn, msg *Message) {
657	atomic.AddUint64(&r.messagesFinished, 1)
658}
659
660func (r *Consumer) onConnMessageRequeued(c *Conn, msg *Message) {
661	atomic.AddUint64(&r.messagesRequeued, 1)
662}
663
664func (r *Consumer) onConnBackoff(c *Conn) {
665	r.startStopContinueBackoff(c, backoffFlag)
666}
667
668func (r *Consumer) onConnContinue(c *Conn) {
669	r.startStopContinueBackoff(c, continueFlag)
670}
671
672func (r *Consumer) onConnResume(c *Conn) {
673	r.startStopContinueBackoff(c, resumeFlag)
674}
675
676func (r *Consumer) onConnResponse(c *Conn, data []byte) {
677	switch {
678	case bytes.Equal(data, []byte("CLOSE_WAIT")):
679		// server is ready for us to close (it ack'd our StartClose)
680		// we can assume we will not receive any more messages over this channel
681		// (but we can still write back responses)
682		r.log(LogLevelInfo, "(%s) received CLOSE_WAIT from nsqd", c.String())
683		c.Close()
684	}
685}
686
687func (r *Consumer) onConnError(c *Conn, data []byte) {}
688
689func (r *Consumer) onConnHeartbeat(c *Conn) {}
690
691func (r *Consumer) onConnIOError(c *Conn, err error) {
692	c.Close()
693}
694
695func (r *Consumer) onConnClose(c *Conn) {
696	var hasRDYRetryTimer bool
697
698	// remove this connections RDY count from the consumer's total
699	rdyCount := c.RDY()
700	atomic.AddInt64(&r.totalRdyCount, -rdyCount)
701
702	r.rdyRetryMtx.Lock()
703	if timer, ok := r.rdyRetryTimers[c.String()]; ok {
704		// stop any pending retry of an old RDY update
705		timer.Stop()
706		delete(r.rdyRetryTimers, c.String())
707		hasRDYRetryTimer = true
708	}
709	r.rdyRetryMtx.Unlock()
710
711	r.mtx.Lock()
712	delete(r.connections, c.String())
713	left := len(r.connections)
714	r.mtx.Unlock()
715
716	r.log(LogLevelWarning, "there are %d connections left alive", left)
717
718	if (hasRDYRetryTimer || rdyCount > 0) &&
719		(int32(left) == r.getMaxInFlight() || r.inBackoff()) {
720		// we're toggling out of (normal) redistribution cases and this conn
721		// had a RDY count...
722		//
723		// trigger RDY redistribution to make sure this RDY is moved
724		// to a new connection
725		atomic.StoreInt32(&r.needRDYRedistributed, 1)
726	}
727
728	// we were the last one (and stopping)
729	if atomic.LoadInt32(&r.stopFlag) == 1 {
730		if left == 0 {
731			r.stopHandlers()
732		}
733		return
734	}
735
736	r.mtx.RLock()
737	numLookupd := len(r.lookupdHTTPAddrs)
738	reconnect := indexOf(c.String(), r.nsqdTCPAddrs) >= 0
739	r.mtx.RUnlock()
740	if numLookupd > 0 {
741		// trigger a poll of the lookupd
742		select {
743		case r.lookupdRecheckChan <- 1:
744		default:
745		}
746	} else if reconnect {
747		// there are no lookupd and we still have this nsqd TCP address in our list...
748		// try to reconnect after a bit
749		go func(addr string) {
750			for {
751				r.log(LogLevelInfo, "(%s) re-connecting in %s", addr, r.config.LookupdPollInterval)
752				time.Sleep(r.config.LookupdPollInterval)
753				if atomic.LoadInt32(&r.stopFlag) == 1 {
754					break
755				}
756				r.mtx.RLock()
757				reconnect := indexOf(addr, r.nsqdTCPAddrs) >= 0
758				r.mtx.RUnlock()
759				if !reconnect {
760					r.log(LogLevelWarning, "(%s) skipped reconnect after removal...", addr)
761					return
762				}
763				err := r.ConnectToNSQD(addr)
764				if err != nil && err != ErrAlreadyConnected {
765					r.log(LogLevelError, "(%s) error connecting to nsqd - %s", addr, err)
766					continue
767				}
768				break
769			}
770		}(c.String())
771	}
772}
773
774func (r *Consumer) startStopContinueBackoff(conn *Conn, signal backoffSignal) {
775	// prevent many async failures/successes from immediately resulting in
776	// max backoff/normal rate (by ensuring that we dont continually incr/decr
777	// the counter during a backoff period)
778	r.backoffMtx.Lock()
779	defer r.backoffMtx.Unlock()
780	if r.inBackoffTimeout() {
781		return
782	}
783
784	// update backoff state
785	backoffUpdated := false
786	backoffCounter := atomic.LoadInt32(&r.backoffCounter)
787	switch signal {
788	case resumeFlag:
789		if backoffCounter > 0 {
790			backoffCounter--
791			backoffUpdated = true
792		}
793	case backoffFlag:
794		nextBackoff := r.config.BackoffStrategy.Calculate(int(backoffCounter) + 1)
795		if nextBackoff <= r.config.MaxBackoffDuration {
796			backoffCounter++
797			backoffUpdated = true
798		}
799	}
800	atomic.StoreInt32(&r.backoffCounter, backoffCounter)
801
802	if r.backoffCounter == 0 && backoffUpdated {
803		// exit backoff
804		count := r.perConnMaxInFlight()
805		r.log(LogLevelWarning, "exiting backoff, returning all to RDY %d", count)
806		for _, c := range r.conns() {
807			r.updateRDY(c, count)
808		}
809	} else if r.backoffCounter > 0 {
810		// start or continue backoff
811		backoffDuration := r.config.BackoffStrategy.Calculate(int(backoffCounter))
812
813		if backoffDuration > r.config.MaxBackoffDuration {
814			backoffDuration = r.config.MaxBackoffDuration
815		}
816
817		r.log(LogLevelWarning, "backing off for %s (backoff level %d), setting all to RDY 0",
818			backoffDuration, backoffCounter)
819
820		// send RDY 0 immediately (to *all* connections)
821		for _, c := range r.conns() {
822			r.updateRDY(c, 0)
823		}
824
825		r.backoff(backoffDuration)
826	}
827}
828
829func (r *Consumer) backoff(d time.Duration) {
830	atomic.StoreInt64(&r.backoffDuration, d.Nanoseconds())
831	time.AfterFunc(d, r.resume)
832}
833
834func (r *Consumer) resume() {
835	if atomic.LoadInt32(&r.stopFlag) == 1 {
836		atomic.StoreInt64(&r.backoffDuration, 0)
837		return
838	}
839
840	// pick a random connection to test the waters
841	conns := r.conns()
842	if len(conns) == 0 {
843		r.log(LogLevelWarning, "no connection available to resume")
844		r.log(LogLevelWarning, "backing off for %s", time.Second)
845		r.backoff(time.Second)
846		return
847	}
848	r.rngMtx.Lock()
849	idx := r.rng.Intn(len(conns))
850	r.rngMtx.Unlock()
851	choice := conns[idx]
852
853	r.log(LogLevelWarning,
854		"(%s) backoff timeout expired, sending RDY 1",
855		choice.String())
856
857	// while in backoff only ever let 1 message at a time through
858	err := r.updateRDY(choice, 1)
859	if err != nil {
860		r.log(LogLevelWarning, "(%s) error resuming RDY 1 - %s", choice.String(), err)
861		r.log(LogLevelWarning, "backing off for %s", time.Second)
862		r.backoff(time.Second)
863		return
864	}
865
866	atomic.StoreInt64(&r.backoffDuration, 0)
867}
868
869func (r *Consumer) inBackoff() bool {
870	return atomic.LoadInt32(&r.backoffCounter) > 0
871}
872
873func (r *Consumer) inBackoffTimeout() bool {
874	return atomic.LoadInt64(&r.backoffDuration) > 0
875}
876
877func (r *Consumer) maybeUpdateRDY(conn *Conn) {
878	inBackoff := r.inBackoff()
879	inBackoffTimeout := r.inBackoffTimeout()
880	if inBackoff || inBackoffTimeout {
881		r.log(LogLevelDebug, "(%s) skip sending RDY inBackoff:%v || inBackoffTimeout:%v",
882			conn, inBackoff, inBackoffTimeout)
883		return
884	}
885
886	count := r.perConnMaxInFlight()
887	r.log(LogLevelDebug, "(%s) sending RDY %d", conn, count)
888	r.updateRDY(conn, count)
889}
890
891func (r *Consumer) rdyLoop() {
892	redistributeTicker := time.NewTicker(r.config.RDYRedistributeInterval)
893
894	for {
895		select {
896		case <-redistributeTicker.C:
897			r.redistributeRDY()
898		case <-r.exitChan:
899			goto exit
900		}
901	}
902
903exit:
904	redistributeTicker.Stop()
905	r.log(LogLevelInfo, "rdyLoop exiting")
906	r.wg.Done()
907}
908
909func (r *Consumer) updateRDY(c *Conn, count int64) error {
910	if c.IsClosing() {
911		return ErrClosing
912	}
913
914	// never exceed the nsqd's configured max RDY count
915	if count > c.MaxRDY() {
916		count = c.MaxRDY()
917	}
918
919	// stop any pending retry of an old RDY update
920	r.rdyRetryMtx.Lock()
921	if timer, ok := r.rdyRetryTimers[c.String()]; ok {
922		timer.Stop()
923		delete(r.rdyRetryTimers, c.String())
924	}
925	r.rdyRetryMtx.Unlock()
926
927	// never exceed our global max in flight. truncate if possible.
928	// this could help a new connection get partial max-in-flight
929	rdyCount := c.RDY()
930	maxPossibleRdy := int64(r.getMaxInFlight()) - atomic.LoadInt64(&r.totalRdyCount) + rdyCount
931	if maxPossibleRdy > 0 && maxPossibleRdy < count {
932		count = maxPossibleRdy
933	}
934	if maxPossibleRdy <= 0 && count > 0 {
935		if rdyCount == 0 {
936			// we wanted to exit a zero RDY count but we couldn't send it...
937			// in order to prevent eternal starvation we reschedule this attempt
938			// (if any other RDY update succeeds this timer will be stopped)
939			r.rdyRetryMtx.Lock()
940			r.rdyRetryTimers[c.String()] = time.AfterFunc(5*time.Second,
941				func() {
942					r.updateRDY(c, count)
943				})
944			r.rdyRetryMtx.Unlock()
945		}
946		return ErrOverMaxInFlight
947	}
948
949	return r.sendRDY(c, count)
950}
951
952func (r *Consumer) sendRDY(c *Conn, count int64) error {
953	if count == 0 && c.LastRDY() == 0 {
954		// no need to send. It's already that RDY count
955		return nil
956	}
957
958	atomic.AddInt64(&r.totalRdyCount, count-c.RDY())
959	c.SetRDY(count)
960	err := c.WriteCommand(Ready(int(count)))
961	if err != nil {
962		r.log(LogLevelError, "(%s) error sending RDY %d - %s", c.String(), count, err)
963		return err
964	}
965	return nil
966}
967
968func (r *Consumer) redistributeRDY() {
969	if r.inBackoffTimeout() {
970		return
971	}
972
973	// if an external heuristic set needRDYRedistributed we want to wait
974	// until we can actually redistribute to proceed
975	conns := r.conns()
976	if len(conns) == 0 {
977		return
978	}
979
980	maxInFlight := r.getMaxInFlight()
981	if len(conns) > int(maxInFlight) {
982		r.log(LogLevelDebug, "redistributing RDY state (%d conns > %d max_in_flight)",
983			len(conns), maxInFlight)
984		atomic.StoreInt32(&r.needRDYRedistributed, 1)
985	}
986
987	if r.inBackoff() && len(conns) > 1 {
988		r.log(LogLevelDebug, "redistributing RDY state (in backoff and %d conns > 1)", len(conns))
989		atomic.StoreInt32(&r.needRDYRedistributed, 1)
990	}
991
992	if !atomic.CompareAndSwapInt32(&r.needRDYRedistributed, 1, 0) {
993		return
994	}
995
996	possibleConns := make([]*Conn, 0, len(conns))
997	for _, c := range conns {
998		lastMsgDuration := time.Now().Sub(c.LastMessageTime())
999		lastRdyDuration := time.Now().Sub(c.LastRdyTime())
1000		rdyCount := c.RDY()
1001		r.log(LogLevelDebug, "(%s) rdy: %d (last message received %s)",
1002			c.String(), rdyCount, lastMsgDuration)
1003		if rdyCount > 0 {
1004			if lastMsgDuration > r.config.LowRdyIdleTimeout {
1005				r.log(LogLevelDebug, "(%s) idle connection, giving up RDY", c.String())
1006				r.updateRDY(c, 0)
1007			} else if lastRdyDuration > r.config.LowRdyTimeout {
1008				r.log(LogLevelDebug, "(%s) RDY timeout, giving up RDY", c.String())
1009				r.updateRDY(c, 0)
1010			}
1011		}
1012		possibleConns = append(possibleConns, c)
1013	}
1014
1015	availableMaxInFlight := int64(maxInFlight) - atomic.LoadInt64(&r.totalRdyCount)
1016	if r.inBackoff() {
1017		availableMaxInFlight = 1 - atomic.LoadInt64(&r.totalRdyCount)
1018	}
1019
1020	for len(possibleConns) > 0 && availableMaxInFlight > 0 {
1021		availableMaxInFlight--
1022		r.rngMtx.Lock()
1023		i := r.rng.Int() % len(possibleConns)
1024		r.rngMtx.Unlock()
1025		c := possibleConns[i]
1026		// delete
1027		possibleConns = append(possibleConns[:i], possibleConns[i+1:]...)
1028		r.log(LogLevelDebug, "(%s) redistributing RDY", c.String())
1029		r.updateRDY(c, 1)
1030	}
1031}
1032
1033// Stop will initiate a graceful stop of the Consumer (permanent)
1034//
1035// NOTE: receive on StopChan to block until this process completes
1036func (r *Consumer) Stop() {
1037	if !atomic.CompareAndSwapInt32(&r.stopFlag, 0, 1) {
1038		return
1039	}
1040
1041	r.log(LogLevelInfo, "stopping...")
1042
1043	if len(r.conns()) == 0 {
1044		r.stopHandlers()
1045	} else {
1046		for _, c := range r.conns() {
1047			err := c.WriteCommand(StartClose())
1048			if err != nil {
1049				r.log(LogLevelError, "(%s) error sending CLS - %s", c.String(), err)
1050			}
1051		}
1052
1053		time.AfterFunc(time.Second*30, func() {
1054			// if we've waited this long handlers are blocked on processing messages
1055			// so we can't just stopHandlers (if any adtl. messages were pending processing
1056			// we would cause a panic on channel close)
1057			//
1058			// instead, we just bypass handler closing and skip to the final exit
1059			r.exit()
1060		})
1061	}
1062}
1063
1064func (r *Consumer) stopHandlers() {
1065	r.stopHandler.Do(func() {
1066		r.log(LogLevelInfo, "stopping handlers")
1067		close(r.incomingMessages)
1068	})
1069}
1070
1071// AddHandler sets the Handler for messages received by this Consumer. This can be called
1072// multiple times to add additional handlers. Handler will have a 1:1 ratio to message handling goroutines.
1073//
1074// This panics if called after connecting to NSQD or NSQ Lookupd
1075//
1076// (see Handler or HandlerFunc for details on implementing this interface)
1077func (r *Consumer) AddHandler(handler Handler) {
1078	r.AddConcurrentHandlers(handler, 1)
1079}
1080
1081// AddConcurrentHandlers sets the Handler for messages received by this Consumer.  It
1082// takes a second argument which indicates the number of goroutines to spawn for
1083// message handling.
1084//
1085// This panics if called after connecting to NSQD or NSQ Lookupd
1086//
1087// (see Handler or HandlerFunc for details on implementing this interface)
1088func (r *Consumer) AddConcurrentHandlers(handler Handler, concurrency int) {
1089	if atomic.LoadInt32(&r.connectedFlag) == 1 {
1090		panic("already connected")
1091	}
1092
1093	atomic.AddInt32(&r.runningHandlers, int32(concurrency))
1094	for i := 0; i < concurrency; i++ {
1095		go r.handlerLoop(handler)
1096	}
1097}
1098
1099func (r *Consumer) handlerLoop(handler Handler) {
1100	r.log(LogLevelDebug, "starting Handler")
1101
1102	for {
1103		message, ok := <-r.incomingMessages
1104		if !ok {
1105			goto exit
1106		}
1107
1108		if r.shouldFailMessage(message, handler) {
1109			message.Finish()
1110			continue
1111		}
1112
1113		err := handler.HandleMessage(message)
1114		if err != nil {
1115			r.log(LogLevelError, "Handler returned error (%s) for msg %s", err, message.ID)
1116			if !message.IsAutoResponseDisabled() {
1117				message.Requeue(-1)
1118			}
1119			continue
1120		}
1121
1122		if !message.IsAutoResponseDisabled() {
1123			message.Finish()
1124		}
1125	}
1126
1127exit:
1128	r.log(LogLevelDebug, "stopping Handler")
1129	if atomic.AddInt32(&r.runningHandlers, -1) == 0 {
1130		r.exit()
1131	}
1132}
1133
1134func (r *Consumer) shouldFailMessage(message *Message, handler interface{}) bool {
1135	// message passed the max number of attempts
1136	if r.config.MaxAttempts > 0 && message.Attempts > r.config.MaxAttempts {
1137		r.log(LogLevelWarning, "msg %s attempted %d times, giving up",
1138			message.ID, message.Attempts)
1139
1140		logger, ok := handler.(FailedMessageLogger)
1141		if ok {
1142			logger.LogFailedMessage(message)
1143		}
1144
1145		return true
1146	}
1147	return false
1148}
1149
1150func (r *Consumer) exit() {
1151	r.exitHandler.Do(func() {
1152		close(r.exitChan)
1153		r.wg.Wait()
1154		close(r.StopChan)
1155	})
1156}
1157
1158func (r *Consumer) log(lvl LogLevel, line string, args ...interface{}) {
1159	logger, logLvl := r.getLogger()
1160
1161	if logger == nil {
1162		return
1163	}
1164
1165	if logLvl > lvl {
1166		return
1167	}
1168
1169	logger.Output(2, fmt.Sprintf("%-4s %3d [%s/%s] %s",
1170		lvl, r.id, r.topic, r.channel,
1171		fmt.Sprintf(line, args...)))
1172}
1173