1package devicemanager
2
3import (
4	"context"
5	"fmt"
6	"sync"
7	"time"
8
9	log "github.com/hashicorp/go-hclog"
10	multierror "github.com/hashicorp/go-multierror"
11	"github.com/hashicorp/nomad/helper/pluginutils/loader"
12	"github.com/hashicorp/nomad/helper/pluginutils/singleton"
13	"github.com/hashicorp/nomad/nomad/structs"
14	"github.com/hashicorp/nomad/plugins/base"
15	bstructs "github.com/hashicorp/nomad/plugins/base/structs"
16	"github.com/hashicorp/nomad/plugins/device"
17)
18
19const (
20	// statsBackoffBaseline is the baseline time for exponential backoff while
21	// collecting device stats.
22	statsBackoffBaseline = 5 * time.Second
23
24	// statsBackoffLimit is the limit of the exponential backoff for collecting
25	// device statistics.
26	statsBackoffLimit = 30 * time.Minute
27)
28
29// instanceManagerConfig configures a device instance manager
30type instanceManagerConfig struct {
31	// Logger is the logger used by the device instance manager
32	Logger log.Logger
33
34	// Ctx is used to shutdown the device instance manager
35	Ctx context.Context
36
37	// Loader is the plugin loader
38	Loader loader.PluginCatalog
39
40	// StoreReattach is used to store a plugins reattach config
41	StoreReattach StorePluginReattachFn
42
43	// PluginConfig is the config passed to the launched plugins
44	PluginConfig *base.AgentConfig
45
46	// Id is the ID of the plugin being managed
47	Id *loader.PluginID
48
49	// FingerprintOutCh is used to emit new fingerprinted devices
50	FingerprintOutCh chan<- struct{}
51
52	// StatsInterval is the interval at which we collect statistics.
53	StatsInterval time.Duration
54}
55
56// instanceManager is used to manage a single device plugin
57type instanceManager struct {
58	// logger is the logger used by the device instance manager
59	logger log.Logger
60
61	// ctx is used to shutdown the device manager
62	ctx context.Context
63
64	// cancel is used to shutdown management of this device plugin
65	cancel context.CancelFunc
66
67	// loader is the plugin loader
68	loader loader.PluginCatalog
69
70	// storeReattach is used to store a plugins reattach config
71	storeReattach StorePluginReattachFn
72
73	// pluginConfig is the config passed to the launched plugins
74	pluginConfig *base.AgentConfig
75
76	// id is the ID of the plugin being managed
77	id *loader.PluginID
78
79	// fingerprintOutCh is used to emit new fingerprinted devices
80	fingerprintOutCh chan<- struct{}
81
82	// plugin is the plugin instance being managed
83	plugin loader.PluginInstance
84
85	// device is the device plugin being managed
86	device device.DevicePlugin
87
88	// pluginLock locks access to the device and plugin
89	pluginLock sync.Mutex
90
91	// shutdownLock is used to serialize attempts to shutdown
92	shutdownLock sync.Mutex
93
94	// devices is the set of fingerprinted devices
95	devices    []*device.DeviceGroup
96	deviceLock sync.RWMutex
97
98	// statsInterval is the interval at which we collect statistics.
99	statsInterval time.Duration
100
101	// deviceStats is the set of statistics objects per devices
102	deviceStats     []*device.DeviceGroupStats
103	deviceStatsLock sync.RWMutex
104
105	// firstFingerprintCh is used to trigger that we have successfully
106	// fingerprinted once. It is used to gate launching the stats collection.
107	firstFingerprintCh chan struct{}
108	hasFingerprinted   bool
109}
110
111// newInstanceManager returns a new device instance manager. It is expected that
112// the context passed in the configuration is cancelled in order to shutdown
113// launched goroutines.
114func newInstanceManager(c *instanceManagerConfig) *instanceManager {
115
116	ctx, cancel := context.WithCancel(c.Ctx)
117	i := &instanceManager{
118		logger:             c.Logger.With("plugin", c.Id.Name),
119		ctx:                ctx,
120		cancel:             cancel,
121		loader:             c.Loader,
122		storeReattach:      c.StoreReattach,
123		pluginConfig:       c.PluginConfig,
124		id:                 c.Id,
125		fingerprintOutCh:   c.FingerprintOutCh,
126		statsInterval:      c.StatsInterval,
127		firstFingerprintCh: make(chan struct{}),
128	}
129
130	go i.run()
131	return i
132}
133
134// HasDevices returns if the instance is managing the passed devices
135func (i *instanceManager) HasDevices(d *structs.AllocatedDeviceResource) bool {
136	i.deviceLock.RLock()
137	defer i.deviceLock.RUnlock()
138
139OUTER:
140	for _, dev := range i.devices {
141		if dev.Name != d.Name || dev.Type != d.Type || dev.Vendor != d.Vendor {
142			continue
143		}
144
145		// Check that we have all the requested devices
146		ids := make(map[string]struct{}, len(dev.Devices))
147		for _, inst := range dev.Devices {
148			ids[inst.ID] = struct{}{}
149		}
150
151		for _, reqID := range d.DeviceIDs {
152			if _, ok := ids[reqID]; !ok {
153				continue OUTER
154			}
155		}
156
157		return true
158	}
159
160	return false
161}
162
163// AllStats returns all the device statistics returned by the device plugin.
164func (i *instanceManager) AllStats() []*device.DeviceGroupStats {
165	i.deviceStatsLock.RLock()
166	defer i.deviceStatsLock.RUnlock()
167	return i.deviceStats
168}
169
170// DeviceStats returns the device statistics for the request devices.
171func (i *instanceManager) DeviceStats(d *structs.AllocatedDeviceResource) *device.DeviceGroupStats {
172	i.deviceStatsLock.RLock()
173	defer i.deviceStatsLock.RUnlock()
174
175	// Find the device in question and then gather the instance statistics we
176	// are interested in
177	for _, group := range i.deviceStats {
178		if group.Vendor != d.Vendor || group.Type != d.Type || group.Name != d.Name {
179			continue
180		}
181
182		// We found the group we want so now grab the instance stats
183		out := &device.DeviceGroupStats{
184			Vendor:        d.Vendor,
185			Type:          d.Type,
186			Name:          d.Name,
187			InstanceStats: make(map[string]*device.DeviceStats, len(d.DeviceIDs)),
188		}
189
190		for _, id := range d.DeviceIDs {
191			out.InstanceStats[id] = group.InstanceStats[id]
192		}
193
194		return out
195	}
196
197	return nil
198}
199
200// Reserve reserves the given devices
201func (i *instanceManager) Reserve(d *structs.AllocatedDeviceResource) (*device.ContainerReservation, error) {
202	// Get a device plugin
203	devicePlugin, err := i.dispense()
204	if err != nil {
205		i.logger.Error("dispensing plugin failed", "error", err)
206		return nil, err
207	}
208
209	// Send the reserve request
210	return devicePlugin.Reserve(d.DeviceIDs)
211}
212
213// Devices returns the detected devices.
214func (i *instanceManager) Devices() []*device.DeviceGroup {
215	i.deviceLock.RLock()
216	defer i.deviceLock.RUnlock()
217	return i.devices
218}
219
220// WaitForFirstFingerprint waits until either the plugin fingerprints, the
221// passed context is done, or the plugin instance manager is shutdown.
222func (i *instanceManager) WaitForFirstFingerprint(ctx context.Context) {
223	select {
224	case <-i.ctx.Done():
225	case <-ctx.Done():
226	case <-i.firstFingerprintCh:
227	}
228}
229
230// run is a long lived goroutine that starts the fingerprinting and stats
231// collection goroutine and then shutsdown the plugin on exit.
232func (i *instanceManager) run() {
233	// Dispense once to ensure we are given a valid plugin
234	if _, err := i.dispense(); err != nil {
235		i.logger.Error("dispensing initial plugin failed", "error", err)
236		return
237	}
238
239	// Create a waitgroup to block on shutdown for all created goroutines to
240	// exit
241	var wg sync.WaitGroup
242
243	// Start the fingerprinter
244	wg.Add(1)
245	go func() {
246		i.fingerprint()
247		wg.Done()
248	}()
249
250	// Wait for a valid result before starting stats collection
251	select {
252	case <-i.ctx.Done():
253		goto DONE
254	case <-i.firstFingerprintCh:
255	}
256
257	// Start stats
258	wg.Add(1)
259	go func() {
260		i.collectStats()
261		wg.Done()
262	}()
263
264	// Do a final cleanup
265DONE:
266	wg.Wait()
267	i.cleanup()
268}
269
270// dispense is used to dispense a plugin.
271func (i *instanceManager) dispense() (plugin device.DevicePlugin, err error) {
272	i.pluginLock.Lock()
273	defer i.pluginLock.Unlock()
274
275	// See if we already have a running instance
276	if i.plugin != nil && !i.plugin.Exited() {
277		return i.device, nil
278	}
279
280	// Get an instance of the plugin
281	pluginInstance, err := i.loader.Dispense(i.id.Name, i.id.PluginType, i.pluginConfig, i.logger)
282	if err != nil {
283		// Retry as the error just indicates the singleton has exited
284		if err == singleton.SingletonPluginExited {
285			pluginInstance, err = i.loader.Dispense(i.id.Name, i.id.PluginType, i.pluginConfig, i.logger)
286		}
287
288		// If we still have an error there is a real problem
289		if err != nil {
290			return nil, fmt.Errorf("failed to start plugin: %v", err)
291		}
292	}
293
294	// Convert to a fingerprint plugin
295	device, ok := pluginInstance.Plugin().(device.DevicePlugin)
296	if !ok {
297		pluginInstance.Kill()
298		return nil, fmt.Errorf("plugin loaded does not implement the driver interface")
299	}
300
301	// Store the plugin and device
302	i.plugin = pluginInstance
303	i.device = device
304
305	// Store the reattach config
306	if c, ok := pluginInstance.ReattachConfig(); ok {
307		i.storeReattach(c)
308	}
309
310	return device, nil
311}
312
313// cleanup shutsdown the plugin
314func (i *instanceManager) cleanup() {
315	i.shutdownLock.Lock()
316	i.pluginLock.Lock()
317	defer i.pluginLock.Unlock()
318	defer i.shutdownLock.Unlock()
319
320	if i.plugin != nil && !i.plugin.Exited() {
321		i.plugin.Kill()
322		i.storeReattach(nil)
323	}
324}
325
326// fingerprint is a long lived routine used to fingerprint the device
327func (i *instanceManager) fingerprint() {
328START:
329	// Get a device plugin
330	devicePlugin, err := i.dispense()
331	if err != nil {
332		i.logger.Error("dispensing plugin failed", "error", err)
333		i.cancel()
334		return
335	}
336
337	// Start fingerprinting
338	fingerprintCh, err := devicePlugin.Fingerprint(i.ctx)
339	if err == device.ErrPluginDisabled {
340		i.logger.Info("fingerprinting failed: plugin is not enabled")
341		i.handleFingerprintError()
342		return
343	} else if err != nil {
344		i.logger.Error("fingerprinting failed", "error", err)
345		i.handleFingerprintError()
346		return
347	}
348
349	var fresp *device.FingerprintResponse
350	var ok bool
351	for {
352		select {
353		case <-i.ctx.Done():
354			return
355		case fresp, ok = <-fingerprintCh:
356		}
357
358		if !ok {
359			i.logger.Trace("exiting since fingerprinting gracefully shutdown")
360			i.handleFingerprintError()
361			return
362		}
363
364		// Guard against error by the plugin
365		if fresp == nil {
366			continue
367		}
368
369		// Handle any errors
370		if fresp.Error != nil {
371			if fresp.Error == bstructs.ErrPluginShutdown {
372				i.logger.Error("plugin exited unexpectedly")
373				goto START
374			}
375
376			i.logger.Error("fingerprinting returned an error", "error", fresp.Error)
377			i.handleFingerprintError()
378			return
379		}
380
381		if err := i.handleFingerprint(fresp); err != nil {
382			// Cancel the context so we cleanup all goroutines
383			i.logger.Error("returned devices failed fingerprinting", "error", err)
384			i.handleFingerprintError()
385		}
386	}
387}
388
389// handleFingerprintError exits the manager and shutsdown the plugin.
390func (i *instanceManager) handleFingerprintError() {
391	// Clear out the devices and trigger a node update
392	i.deviceLock.Lock()
393	defer i.deviceLock.Unlock()
394
395	// If we have fingerprinted before clear it out
396	if i.hasFingerprinted {
397		// Store the new devices
398		i.devices = nil
399
400		// Trigger that the we have new devices
401		select {
402		case i.fingerprintOutCh <- struct{}{}:
403		default:
404		}
405	}
406
407	// Cancel the context so we cleanup all goroutines
408	i.cancel()
409}
410
411// handleFingerprint stores the new devices and triggers the fingerprint output
412// channel. An error is returned if the passed devices don't pass validation.
413func (i *instanceManager) handleFingerprint(f *device.FingerprintResponse) error {
414	// Validate the received devices
415	var validationErr multierror.Error
416	for i, d := range f.Devices {
417		if err := d.Validate(); err != nil {
418			multierror.Append(&validationErr, multierror.Prefix(err, fmt.Sprintf("device group %d: ", i)))
419		}
420	}
421
422	if err := validationErr.ErrorOrNil(); err != nil {
423		return err
424	}
425
426	i.deviceLock.Lock()
427	defer i.deviceLock.Unlock()
428
429	// Store the new devices
430	i.devices = f.Devices
431
432	// Mark that we have received data
433	if !i.hasFingerprinted {
434		close(i.firstFingerprintCh)
435		i.hasFingerprinted = true
436	}
437
438	// Trigger that we have data to pull
439	select {
440	case i.fingerprintOutCh <- struct{}{}:
441	default:
442	}
443
444	return nil
445}
446
447// collectStats is a long lived goroutine for collecting device statistics. It
448// handles errors by backing off exponentially and retrying.
449func (i *instanceManager) collectStats() {
450	attempt := 0
451
452START:
453	// Get a device plugin
454	devicePlugin, err := i.dispense()
455	if err != nil {
456		i.logger.Error("dispensing plugin failed", "error", err)
457		i.cancel()
458		return
459	}
460
461	// Start stats collection
462	statsCh, err := devicePlugin.Stats(i.ctx, i.statsInterval)
463	if err != nil {
464		i.logger.Error("stats collection failed", "error", err)
465		return
466	}
467
468	var sresp *device.StatsResponse
469	var ok bool
470	for {
471		select {
472		case <-i.ctx.Done():
473			return
474		case sresp, ok = <-statsCh:
475		}
476
477		if !ok {
478			i.logger.Trace("exiting since stats gracefully shutdown")
479			return
480		}
481
482		// Guard against error by the plugin
483		if sresp == nil {
484			continue
485		}
486
487		// Handle any errors
488		if sresp.Error != nil {
489			if sresp.Error == bstructs.ErrPluginShutdown {
490				i.logger.Error("plugin exited unexpectedly")
491				goto START
492			}
493
494			// Retry with an exponential backoff
495			backoff := (1 << (2 * uint64(attempt))) * statsBackoffBaseline
496			if backoff > statsBackoffLimit {
497				backoff = statsBackoffLimit
498			}
499			attempt++
500
501			i.logger.Error("stats returned an error", "error", err, "retry", backoff)
502
503			select {
504			case <-i.ctx.Done():
505				return
506			case <-time.After(backoff):
507				goto START
508			}
509		}
510
511		// Reset the attempt since we got statistics
512		attempt = 0
513
514		// Store the new stats
515		if sresp.Groups != nil {
516			i.deviceStatsLock.Lock()
517			i.deviceStats = sresp.Groups
518			i.deviceStatsLock.Unlock()
519		}
520	}
521}
522