1package client
2
3import (
4	"sync"
5	"time"
6
7	hclog "github.com/hashicorp/go-hclog"
8	"github.com/hashicorp/nomad/nomad/structs"
9)
10
11type heartbeatStop struct {
12	lastOk        time.Time
13	startupGrace  time.Time
14	allocInterval map[string]time.Duration
15	allocHookCh   chan *structs.Allocation
16	getRunner     func(string) (AllocRunner, error)
17	logger        hclog.InterceptLogger
18	shutdownCh    chan struct{}
19	lock          *sync.RWMutex
20}
21
22func newHeartbeatStop(
23	getRunner func(string) (AllocRunner, error),
24	timeout time.Duration,
25	logger hclog.InterceptLogger,
26	shutdownCh chan struct{}) *heartbeatStop {
27
28	h := &heartbeatStop{
29		startupGrace:  time.Now().Add(timeout),
30		allocInterval: make(map[string]time.Duration),
31		allocHookCh:   make(chan *structs.Allocation),
32		getRunner:     getRunner,
33		logger:        logger,
34		shutdownCh:    shutdownCh,
35		lock:          &sync.RWMutex{},
36	}
37
38	return h
39}
40
41// allocHook is called after (re)storing a new AllocRunner in the client. It registers the
42// allocation to be stopped if the taskgroup is configured appropriately
43func (h *heartbeatStop) allocHook(alloc *structs.Allocation) {
44	tg := allocTaskGroup(alloc)
45	if tg.StopAfterClientDisconnect != nil {
46		h.allocHookCh <- alloc
47	}
48}
49
50// shouldStop is called on a restored alloc to determine if lastOk is sufficiently in the
51// past that it should be prevented from restarting
52func (h *heartbeatStop) shouldStop(alloc *structs.Allocation) bool {
53	tg := allocTaskGroup(alloc)
54	if tg.StopAfterClientDisconnect != nil {
55		return h.shouldStopAfter(time.Now(), *tg.StopAfterClientDisconnect)
56	}
57	return false
58}
59
60func (h *heartbeatStop) shouldStopAfter(now time.Time, interval time.Duration) bool {
61	lastOk := h.getLastOk()
62	if lastOk.IsZero() {
63		return now.After(h.startupGrace)
64	}
65	return now.After(lastOk.Add(interval))
66}
67
68// watch is a loop that checks for allocations that should be stopped. It also manages the
69// registration of allocs to be stopped in a single thread.
70func (h *heartbeatStop) watch() {
71	// If we never manage to successfully contact the server, we want to stop our allocs
72	// after duration + start time
73	h.lastOk = time.Now()
74	stop := make(chan string, 1)
75	var now time.Time
76	var interval time.Duration
77	checkAllocs := false
78
79	for {
80		// minimize the interval
81		interval = 5 * time.Second
82		for _, t := range h.allocInterval {
83			if t < interval {
84				interval = t
85			}
86		}
87
88		checkAllocs = false
89		timeout := time.After(interval)
90
91		select {
92		case allocID := <-stop:
93			if err := h.stopAlloc(allocID); err != nil {
94				h.logger.Warn("error stopping on heartbeat timeout", "alloc", allocID, "error", err)
95				continue
96			}
97			delete(h.allocInterval, allocID)
98
99		case alloc := <-h.allocHookCh:
100			tg := allocTaskGroup(alloc)
101			if tg.StopAfterClientDisconnect != nil {
102				h.allocInterval[alloc.ID] = *tg.StopAfterClientDisconnect
103			}
104
105		case <-timeout:
106			checkAllocs = true
107
108		case <-h.shutdownCh:
109			return
110		}
111
112		if !checkAllocs {
113			continue
114		}
115
116		now = time.Now()
117		for allocID, d := range h.allocInterval {
118			if h.shouldStopAfter(now, d) {
119				stop <- allocID
120			}
121		}
122	}
123}
124
125// setLastOk sets the last known good heartbeat time to the current time, and persists that time to disk
126func (h *heartbeatStop) setLastOk(t time.Time) {
127	h.lock.Lock()
128	defer h.lock.Unlock()
129	h.lastOk = t
130}
131
132func (h *heartbeatStop) getLastOk() time.Time {
133	h.lock.RLock()
134	defer h.lock.RUnlock()
135	return h.lastOk
136}
137
138// stopAlloc actually stops the allocation
139func (h *heartbeatStop) stopAlloc(allocID string) error {
140	runner, err := h.getRunner(allocID)
141	if err != nil {
142		return err
143	}
144
145	h.logger.Debug("stopping alloc for stop_after_client_disconnect", "alloc", allocID)
146
147	runner.Destroy()
148	return nil
149}
150
151func allocTaskGroup(alloc *structs.Allocation) *structs.TaskGroup {
152	for _, tg := range alloc.Job.TaskGroups {
153		if tg.Name == alloc.TaskGroup {
154			return tg
155		}
156	}
157	return nil
158}
159