1package state
2
3import (
4	"sync"
5
6	hclog "github.com/hashicorp/go-hclog"
7	"github.com/hashicorp/nomad/client/allocrunner/taskrunner/state"
8	dmstate "github.com/hashicorp/nomad/client/devicemanager/state"
9	driverstate "github.com/hashicorp/nomad/client/pluginmanager/drivermanager/state"
10	"github.com/hashicorp/nomad/nomad/structs"
11)
12
13// MemDB implements a StateDB that stores data in memory and should only be
14// used for testing. All methods are safe for concurrent use.
15type MemDB struct {
16	// alloc_id -> value
17	allocs map[string]*structs.Allocation
18
19	// alloc_id -> value
20	deployStatus map[string]*structs.AllocDeploymentStatus
21
22	// alloc_id -> task_name -> value
23	localTaskState map[string]map[string]*state.LocalState
24	taskState      map[string]map[string]*structs.TaskState
25
26	// devicemanager -> plugin-state
27	devManagerPs *dmstate.PluginState
28
29	// drivermanager -> plugin-state
30	driverManagerPs *driverstate.PluginState
31
32	logger hclog.Logger
33
34	mu sync.RWMutex
35}
36
37func NewMemDB(logger hclog.Logger) *MemDB {
38	logger = logger.Named("memdb")
39	return &MemDB{
40		allocs:         make(map[string]*structs.Allocation),
41		deployStatus:   make(map[string]*structs.AllocDeploymentStatus),
42		localTaskState: make(map[string]map[string]*state.LocalState),
43		taskState:      make(map[string]map[string]*structs.TaskState),
44		logger:         logger,
45	}
46}
47
48func (m *MemDB) Name() string {
49	return "memdb"
50}
51
52func (m *MemDB) Upgrade() error {
53	return nil
54}
55
56func (m *MemDB) GetAllAllocations() ([]*structs.Allocation, map[string]error, error) {
57	m.mu.RLock()
58	defer m.mu.RUnlock()
59
60	allocs := make([]*structs.Allocation, 0, len(m.allocs))
61	for _, v := range m.allocs {
62		allocs = append(allocs, v)
63	}
64
65	return allocs, map[string]error{}, nil
66}
67
68func (m *MemDB) PutAllocation(alloc *structs.Allocation) error {
69	m.mu.Lock()
70	defer m.mu.Unlock()
71	m.allocs[alloc.ID] = alloc
72	return nil
73}
74
75func (m *MemDB) GetDeploymentStatus(allocID string) (*structs.AllocDeploymentStatus, error) {
76	m.mu.Lock()
77	defer m.mu.Unlock()
78	return m.deployStatus[allocID], nil
79}
80
81func (m *MemDB) PutDeploymentStatus(allocID string, ds *structs.AllocDeploymentStatus) error {
82	m.mu.Lock()
83	m.deployStatus[allocID] = ds
84	defer m.mu.Unlock()
85	return nil
86}
87
88func (m *MemDB) GetTaskRunnerState(allocID string, taskName string) (*state.LocalState, *structs.TaskState, error) {
89	m.mu.RLock()
90	defer m.mu.RUnlock()
91
92	var ls *state.LocalState
93	var ts *structs.TaskState
94
95	// Local Task State
96	allocLocalTS := m.localTaskState[allocID]
97	if len(allocLocalTS) != 0 {
98		ls = allocLocalTS[taskName]
99	}
100
101	// Task State
102	allocTS := m.taskState[allocID]
103	if len(allocTS) != 0 {
104		ts = allocTS[taskName]
105	}
106
107	return ls, ts, nil
108}
109
110func (m *MemDB) PutTaskRunnerLocalState(allocID string, taskName string, val *state.LocalState) error {
111	m.mu.Lock()
112	defer m.mu.Unlock()
113
114	if alts, ok := m.localTaskState[allocID]; ok {
115		alts[taskName] = val.Copy()
116		return nil
117	}
118
119	m.localTaskState[allocID] = map[string]*state.LocalState{
120		taskName: val.Copy(),
121	}
122
123	return nil
124}
125
126func (m *MemDB) PutTaskState(allocID string, taskName string, state *structs.TaskState) error {
127	m.mu.Lock()
128	defer m.mu.Unlock()
129
130	if ats, ok := m.taskState[allocID]; ok {
131		ats[taskName] = state.Copy()
132		return nil
133	}
134
135	m.taskState[allocID] = map[string]*structs.TaskState{
136		taskName: state.Copy(),
137	}
138
139	return nil
140}
141
142func (m *MemDB) DeleteTaskBucket(allocID, taskName string) error {
143	m.mu.Lock()
144	defer m.mu.Unlock()
145
146	if ats, ok := m.taskState[allocID]; ok {
147		delete(ats, taskName)
148	}
149
150	if alts, ok := m.localTaskState[allocID]; ok {
151		delete(alts, taskName)
152	}
153
154	return nil
155}
156
157func (m *MemDB) DeleteAllocationBucket(allocID string) error {
158	m.mu.Lock()
159	defer m.mu.Unlock()
160
161	delete(m.allocs, allocID)
162	delete(m.taskState, allocID)
163	delete(m.localTaskState, allocID)
164
165	return nil
166}
167
168func (m *MemDB) PutDevicePluginState(ps *dmstate.PluginState) error {
169	m.mu.Lock()
170	defer m.mu.Unlock()
171	m.devManagerPs = ps
172	return nil
173}
174
175// GetDevicePluginState stores the device manager's plugin state or returns an
176// error.
177func (m *MemDB) GetDevicePluginState() (*dmstate.PluginState, error) {
178	m.mu.Lock()
179	defer m.mu.Unlock()
180	return m.devManagerPs, nil
181}
182
183func (m *MemDB) GetDriverPluginState() (*driverstate.PluginState, error) {
184	m.mu.Lock()
185	defer m.mu.Unlock()
186	return m.driverManagerPs, nil
187}
188
189func (m *MemDB) PutDriverPluginState(ps *driverstate.PluginState) error {
190	m.mu.Lock()
191	defer m.mu.Unlock()
192	m.driverManagerPs = ps
193	return nil
194}
195
196func (m *MemDB) Close() error {
197	m.mu.Lock()
198	defer m.mu.Unlock()
199
200	// Set everything to nil to blow up on further use
201	m.allocs = nil
202	m.taskState = nil
203	m.localTaskState = nil
204
205	return nil
206}
207