1package state
2
3import (
4	"fmt"
5	"time"
6
7	"github.com/hashicorp/consul/agent/structs"
8	"github.com/hashicorp/consul/api"
9	"github.com/hashicorp/go-memdb"
10)
11
12// sessionsTableSchema returns a new table schema used for storing session
13// information.
14func sessionsTableSchema() *memdb.TableSchema {
15	return &memdb.TableSchema{
16		Name: "sessions",
17		Indexes: map[string]*memdb.IndexSchema{
18			"id": &memdb.IndexSchema{
19				Name:         "id",
20				AllowMissing: false,
21				Unique:       true,
22				Indexer: &memdb.UUIDFieldIndex{
23					Field: "ID",
24				},
25			},
26			"node": &memdb.IndexSchema{
27				Name:         "node",
28				AllowMissing: false,
29				Unique:       false,
30				Indexer: &memdb.StringFieldIndex{
31					Field:     "Node",
32					Lowercase: true,
33				},
34			},
35		},
36	}
37}
38
39// sessionChecksTableSchema returns a new table schema used for storing session
40// checks.
41func sessionChecksTableSchema() *memdb.TableSchema {
42	return &memdb.TableSchema{
43		Name: "session_checks",
44		Indexes: map[string]*memdb.IndexSchema{
45			"id": &memdb.IndexSchema{
46				Name:         "id",
47				AllowMissing: false,
48				Unique:       true,
49				Indexer: &memdb.CompoundIndex{
50					Indexes: []memdb.Indexer{
51						&memdb.StringFieldIndex{
52							Field:     "Node",
53							Lowercase: true,
54						},
55						&memdb.StringFieldIndex{
56							Field:     "CheckID",
57							Lowercase: true,
58						},
59						&memdb.UUIDFieldIndex{
60							Field: "Session",
61						},
62					},
63				},
64			},
65			"node_check": &memdb.IndexSchema{
66				Name:         "node_check",
67				AllowMissing: false,
68				Unique:       false,
69				Indexer: &memdb.CompoundIndex{
70					Indexes: []memdb.Indexer{
71						&memdb.StringFieldIndex{
72							Field:     "Node",
73							Lowercase: true,
74						},
75						&memdb.StringFieldIndex{
76							Field:     "CheckID",
77							Lowercase: true,
78						},
79					},
80				},
81			},
82			"session": &memdb.IndexSchema{
83				Name:         "session",
84				AllowMissing: false,
85				Unique:       false,
86				Indexer: &memdb.UUIDFieldIndex{
87					Field: "Session",
88				},
89			},
90		},
91	}
92}
93
94func init() {
95	registerSchema(sessionsTableSchema)
96	registerSchema(sessionChecksTableSchema)
97}
98
99// Sessions is used to pull the full list of sessions for use during snapshots.
100func (s *Snapshot) Sessions() (memdb.ResultIterator, error) {
101	iter, err := s.tx.Get("sessions", "id")
102	if err != nil {
103		return nil, err
104	}
105	return iter, nil
106}
107
108// Session is used when restoring from a snapshot. For general inserts, use
109// SessionCreate.
110func (s *Restore) Session(sess *structs.Session) error {
111	// Insert the session.
112	if err := s.tx.Insert("sessions", sess); err != nil {
113		return fmt.Errorf("failed inserting session: %s", err)
114	}
115
116	// Insert the check mappings.
117	for _, checkID := range sess.Checks {
118		mapping := &sessionCheck{
119			Node:    sess.Node,
120			CheckID: checkID,
121			Session: sess.ID,
122		}
123		if err := s.tx.Insert("session_checks", mapping); err != nil {
124			return fmt.Errorf("failed inserting session check mapping: %s", err)
125		}
126	}
127
128	// Update the index.
129	if err := indexUpdateMaxTxn(s.tx, sess.ModifyIndex, "sessions"); err != nil {
130		return fmt.Errorf("failed updating index: %s", err)
131	}
132
133	return nil
134}
135
136// SessionCreate is used to register a new session in the state store.
137func (s *Store) SessionCreate(idx uint64, sess *structs.Session) error {
138	tx := s.db.Txn(true)
139	defer tx.Abort()
140
141	// This code is technically able to (incorrectly) update an existing
142	// session but we never do that in practice. The upstream endpoint code
143	// always adds a unique ID when doing a create operation so we never hit
144	// an existing session again. It isn't worth the overhead to verify
145	// that here, but it's worth noting that we should never do this in the
146	// future.
147
148	// Call the session creation
149	if err := s.sessionCreateTxn(tx, idx, sess); err != nil {
150		return err
151	}
152
153	tx.Commit()
154	return nil
155}
156
157// sessionCreateTxn is the inner method used for creating session entries in
158// an open transaction. Any health checks registered with the session will be
159// checked for failing status. Returns any error encountered.
160func (s *Store) sessionCreateTxn(tx *memdb.Txn, idx uint64, sess *structs.Session) error {
161	// Check that we have a session ID
162	if sess.ID == "" {
163		return ErrMissingSessionID
164	}
165
166	// Verify the session behavior is valid
167	switch sess.Behavior {
168	case "":
169		// Release by default to preserve backwards compatibility
170		sess.Behavior = structs.SessionKeysRelease
171	case structs.SessionKeysRelease:
172	case structs.SessionKeysDelete:
173	default:
174		return fmt.Errorf("Invalid session behavior: %s", sess.Behavior)
175	}
176
177	// Assign the indexes. ModifyIndex likely will not be used but
178	// we set it here anyways for sanity.
179	sess.CreateIndex = idx
180	sess.ModifyIndex = idx
181
182	// Check that the node exists
183	node, err := tx.First("nodes", "id", sess.Node)
184	if err != nil {
185		return fmt.Errorf("failed node lookup: %s", err)
186	}
187	if node == nil {
188		return ErrMissingNode
189	}
190
191	// Go over the session checks and ensure they exist.
192	for _, checkID := range sess.Checks {
193		check, err := tx.First("checks", "id", sess.Node, string(checkID))
194		if err != nil {
195			return fmt.Errorf("failed check lookup: %s", err)
196		}
197		if check == nil {
198			return fmt.Errorf("Missing check '%s' registration", checkID)
199		}
200
201		// Check that the check is not in critical state
202		status := check.(*structs.HealthCheck).Status
203		if status == api.HealthCritical {
204			return fmt.Errorf("Check '%s' is in %s state", checkID, status)
205		}
206	}
207
208	// Insert the session
209	if err := tx.Insert("sessions", sess); err != nil {
210		return fmt.Errorf("failed inserting session: %s", err)
211	}
212
213	// Insert the check mappings
214	for _, checkID := range sess.Checks {
215		mapping := &sessionCheck{
216			Node:    sess.Node,
217			CheckID: checkID,
218			Session: sess.ID,
219		}
220		if err := tx.Insert("session_checks", mapping); err != nil {
221			return fmt.Errorf("failed inserting session check mapping: %s", err)
222		}
223	}
224
225	// Update the index
226	if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil {
227		return fmt.Errorf("failed updating index: %s", err)
228	}
229
230	return nil
231}
232
233// SessionGet is used to retrieve an active session from the state store.
234func (s *Store) SessionGet(ws memdb.WatchSet, sessionID string) (uint64, *structs.Session, error) {
235	tx := s.db.Txn(false)
236	defer tx.Abort()
237
238	// Get the table index.
239	idx := maxIndexTxn(tx, "sessions")
240
241	// Look up the session by its ID
242	watchCh, session, err := tx.FirstWatch("sessions", "id", sessionID)
243	if err != nil {
244		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
245	}
246	ws.Add(watchCh)
247	if session != nil {
248		return idx, session.(*structs.Session), nil
249	}
250	return idx, nil, nil
251}
252
253// SessionList returns a slice containing all of the active sessions.
254func (s *Store) SessionList(ws memdb.WatchSet) (uint64, structs.Sessions, error) {
255	tx := s.db.Txn(false)
256	defer tx.Abort()
257
258	// Get the table index.
259	idx := maxIndexTxn(tx, "sessions")
260
261	// Query all of the active sessions.
262	sessions, err := tx.Get("sessions", "id")
263	if err != nil {
264		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
265	}
266	ws.Add(sessions.WatchCh())
267
268	// Go over the sessions and create a slice of them.
269	var result structs.Sessions
270	for session := sessions.Next(); session != nil; session = sessions.Next() {
271		result = append(result, session.(*structs.Session))
272	}
273	return idx, result, nil
274}
275
276// NodeSessions returns a set of active sessions associated
277// with the given node ID. The returned index is the highest
278// index seen from the result set.
279func (s *Store) NodeSessions(ws memdb.WatchSet, nodeID string) (uint64, structs.Sessions, error) {
280	tx := s.db.Txn(false)
281	defer tx.Abort()
282
283	// Get the table index.
284	idx := maxIndexTxn(tx, "sessions")
285
286	// Get all of the sessions which belong to the node
287	sessions, err := tx.Get("sessions", "node", nodeID)
288	if err != nil {
289		return 0, nil, fmt.Errorf("failed session lookup: %s", err)
290	}
291	ws.Add(sessions.WatchCh())
292
293	// Go over all of the sessions and return them as a slice
294	var result structs.Sessions
295	for session := sessions.Next(); session != nil; session = sessions.Next() {
296		result = append(result, session.(*structs.Session))
297	}
298	return idx, result, nil
299}
300
301// SessionDestroy is used to remove an active session. This will
302// implicitly invalidate the session and invoke the specified
303// session destroy behavior.
304func (s *Store) SessionDestroy(idx uint64, sessionID string) error {
305	tx := s.db.Txn(true)
306	defer tx.Abort()
307
308	// Call the session deletion.
309	if err := s.deleteSessionTxn(tx, idx, sessionID); err != nil {
310		return err
311	}
312
313	tx.Commit()
314	return nil
315}
316
317// deleteSessionTxn is the inner method, which is used to do the actual
318// session deletion and handle session invalidation, etc.
319func (s *Store) deleteSessionTxn(tx *memdb.Txn, idx uint64, sessionID string) error {
320	// Look up the session.
321	sess, err := tx.First("sessions", "id", sessionID)
322	if err != nil {
323		return fmt.Errorf("failed session lookup: %s", err)
324	}
325	if sess == nil {
326		return nil
327	}
328
329	// Delete the session and write the new index.
330	if err := tx.Delete("sessions", sess); err != nil {
331		return fmt.Errorf("failed deleting session: %s", err)
332	}
333	if err := tx.Insert("index", &IndexEntry{"sessions", idx}); err != nil {
334		return fmt.Errorf("failed updating index: %s", err)
335	}
336
337	// Enforce the max lock delay.
338	session := sess.(*structs.Session)
339	delay := session.LockDelay
340	if delay > structs.MaxLockDelay {
341		delay = structs.MaxLockDelay
342	}
343
344	// Snag the current now time so that all the expirations get calculated
345	// the same way.
346	now := time.Now()
347
348	// Get an iterator over all of the keys with the given session.
349	entries, err := tx.Get("kvs", "session", sessionID)
350	if err != nil {
351		return fmt.Errorf("failed kvs lookup: %s", err)
352	}
353	var kvs []interface{}
354	for entry := entries.Next(); entry != nil; entry = entries.Next() {
355		kvs = append(kvs, entry)
356	}
357
358	// Invalidate any held locks.
359	switch session.Behavior {
360	case structs.SessionKeysRelease:
361		for _, obj := range kvs {
362			// Note that we clone here since we are modifying the
363			// returned object and want to make sure our set op
364			// respects the transaction we are in.
365			e := obj.(*structs.DirEntry).Clone()
366			e.Session = ""
367			if err := s.kvsSetTxn(tx, idx, e, true); err != nil {
368				return fmt.Errorf("failed kvs update: %s", err)
369			}
370
371			// Apply the lock delay if present.
372			if delay > 0 {
373				s.lockDelay.SetExpiration(e.Key, now, delay)
374			}
375		}
376	case structs.SessionKeysDelete:
377		for _, obj := range kvs {
378			e := obj.(*structs.DirEntry)
379			if err := s.kvsDeleteTxn(tx, idx, e.Key); err != nil {
380				return fmt.Errorf("failed kvs delete: %s", err)
381			}
382
383			// Apply the lock delay if present.
384			if delay > 0 {
385				s.lockDelay.SetExpiration(e.Key, now, delay)
386			}
387		}
388	default:
389		return fmt.Errorf("unknown session behavior %#v", session.Behavior)
390	}
391
392	// Delete any check mappings.
393	mappings, err := tx.Get("session_checks", "session", sessionID)
394	if err != nil {
395		return fmt.Errorf("failed session checks lookup: %s", err)
396	}
397	{
398		var objs []interface{}
399		for mapping := mappings.Next(); mapping != nil; mapping = mappings.Next() {
400			objs = append(objs, mapping)
401		}
402
403		// Do the delete in a separate loop so we don't trash the iterator.
404		for _, obj := range objs {
405			if err := tx.Delete("session_checks", obj); err != nil {
406				return fmt.Errorf("failed deleting session check: %s", err)
407			}
408		}
409	}
410
411	// Delete any prepared queries.
412	queries, err := tx.Get("prepared-queries", "session", sessionID)
413	if err != nil {
414		return fmt.Errorf("failed prepared query lookup: %s", err)
415	}
416	{
417		var ids []string
418		for wrapped := queries.Next(); wrapped != nil; wrapped = queries.Next() {
419			ids = append(ids, toPreparedQuery(wrapped).ID)
420		}
421
422		// Do the delete in a separate loop so we don't trash the iterator.
423		for _, id := range ids {
424			if err := s.preparedQueryDeleteTxn(tx, idx, id); err != nil {
425				return fmt.Errorf("failed prepared query delete: %s", err)
426			}
427		}
428	}
429
430	return nil
431}
432