1package vault
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"strings"
8	"sync"
9	"time"
10
11	metrics "github.com/armon/go-metrics"
12	"github.com/golang/protobuf/ptypes"
13	"github.com/hashicorp/errwrap"
14	memdb "github.com/hashicorp/go-memdb"
15	uuid "github.com/hashicorp/go-uuid"
16	"github.com/hashicorp/vault/helper/identity"
17	"github.com/hashicorp/vault/helper/identity/mfa"
18	"github.com/hashicorp/vault/helper/namespace"
19	"github.com/hashicorp/vault/helper/storagepacker"
20	"github.com/hashicorp/vault/sdk/helper/consts"
21	"github.com/hashicorp/vault/sdk/helper/strutil"
22	"github.com/hashicorp/vault/sdk/logical"
23)
24
25var errDuplicateIdentityName = errors.New("duplicate identity name")
26
27func (c *Core) SetLoadCaseSensitiveIdentityStore(caseSensitive bool) {
28	c.loadCaseSensitiveIdentityStore = caseSensitive
29}
30
31func (c *Core) loadIdentityStoreArtifacts(ctx context.Context) error {
32	if c.identityStore == nil {
33		c.logger.Warn("identity store is not setup, skipping loading")
34		return nil
35	}
36
37	loadFunc := func(context.Context) error {
38		err := c.identityStore.loadEntities(ctx)
39		if err != nil {
40			return err
41		}
42		return c.identityStore.loadGroups(ctx)
43	}
44
45	if !c.loadCaseSensitiveIdentityStore {
46		// Load everything when memdb is set to operate on lower cased names
47		err := loadFunc(ctx)
48		switch {
49		case err == nil:
50			// If it succeeds, all is well
51			return nil
52		case err != nil && !errwrap.Contains(err, errDuplicateIdentityName.Error()):
53			return err
54		}
55	}
56
57	c.identityStore.logger.Warn("enabling case sensitive identity names")
58
59	// Set identity store to operate on case sensitive identity names
60	c.identityStore.disableLowerCasedNames = true
61
62	// Swap the memdb instance by the one which operates on case sensitive
63	// names, hence obviating the need to unload anything that's already
64	// loaded.
65	if err := c.identityStore.resetDB(ctx); err != nil {
66		return err
67	}
68
69	// Attempt to load identity artifacts once more after memdb is reset to
70	// accept case sensitive names
71	return loadFunc(ctx)
72}
73
74func (i *IdentityStore) sanitizeName(name string) string {
75	if i.disableLowerCasedNames {
76		return name
77	}
78	return strings.ToLower(name)
79}
80
81func (i *IdentityStore) loadGroups(ctx context.Context) error {
82	i.logger.Debug("identity loading groups")
83	existing, err := i.groupPacker.View().List(ctx, groupBucketsPrefix)
84	if err != nil {
85		return fmt.Errorf("failed to scan for groups: %w", err)
86	}
87	i.logger.Debug("groups collected", "num_existing", len(existing))
88
89	for _, key := range existing {
90		bucket, err := i.groupPacker.GetBucket(ctx, groupBucketsPrefix+key)
91		if err != nil {
92			return err
93		}
94
95		if bucket == nil {
96			continue
97		}
98
99		for _, item := range bucket.Items {
100			group, err := i.parseGroupFromBucketItem(item)
101			if err != nil {
102				return err
103			}
104			if group == nil {
105				continue
106			}
107
108			ns, err := NamespaceByID(ctx, group.NamespaceID, i.core)
109			if err != nil {
110				return err
111			}
112			if ns == nil {
113				// Remove dangling groups
114				if !(i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.core.perfStandby) {
115					// Group's namespace doesn't exist anymore but the group
116					// from the namespace still exists.
117					i.logger.Warn("deleting group and its any existing aliases", "name", group.Name, "namespace_id", group.NamespaceID)
118					err = i.groupPacker.DeleteItem(ctx, group.ID)
119					if err != nil {
120						return err
121					}
122				}
123				continue
124			}
125			nsCtx := namespace.ContextWithNamespace(ctx, ns)
126
127			// Ensure that there are no groups with duplicate names
128			groupByName, err := i.MemDBGroupByName(nsCtx, group.Name, false)
129			if err != nil {
130				return err
131			}
132			if groupByName != nil {
133				i.logger.Warn(errDuplicateIdentityName.Error(), "group_name", group.Name, "conflicting_group_name", groupByName.Name, "action", "merge the contents of duplicated groups into one and delete the other")
134				if !i.disableLowerCasedNames {
135					return errDuplicateIdentityName
136				}
137			}
138
139			if i.logger.IsDebug() {
140				i.logger.Debug("loading group", "name", group.Name, "id", group.ID)
141			}
142
143			txn := i.db.Txn(true)
144
145			// Before pull#5786, entity memberships in groups were not getting
146			// updated when respective entities were deleted. This is here to
147			// check that the entity IDs in the group are indeed valid, and if
148			// not remove them.
149			persist := false
150			for _, memberEntityID := range group.MemberEntityIDs {
151				entity, err := i.MemDBEntityByID(memberEntityID, false)
152				if err != nil {
153					txn.Abort()
154					return err
155				}
156				if entity == nil {
157					persist = true
158					group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, memberEntityID)
159				}
160			}
161
162			err = i.UpsertGroupInTxn(ctx, txn, group, persist)
163			if err != nil {
164				txn.Abort()
165				return fmt.Errorf("failed to update group in memdb: %w", err)
166			}
167
168			txn.Commit()
169		}
170	}
171
172	if i.logger.IsInfo() {
173		i.logger.Info("groups restored")
174	}
175
176	return nil
177}
178
179func (i *IdentityStore) loadEntities(ctx context.Context) error {
180	// Accumulate existing entities
181	i.logger.Debug("loading entities")
182	existing, err := i.entityPacker.View().List(ctx, storagepacker.StoragePackerBucketsPrefix)
183	if err != nil {
184		return fmt.Errorf("failed to scan for entities: %w", err)
185	}
186	i.logger.Debug("entities collected", "num_existing", len(existing))
187
188	// Make the channels used for the worker pool
189	broker := make(chan string)
190	quit := make(chan bool)
191
192	// Buffer these channels to prevent deadlocks
193	errs := make(chan error, len(existing))
194	result := make(chan *storagepacker.Bucket, len(existing))
195
196	// Use a wait group
197	wg := &sync.WaitGroup{}
198
199	// Create 64 workers to distribute work to
200	for j := 0; j < consts.ExpirationRestoreWorkerCount; j++ {
201		wg.Add(1)
202		go func() {
203			defer wg.Done()
204
205			for {
206				select {
207				case key, ok := <-broker:
208					// broker has been closed, we are done
209					if !ok {
210						return
211					}
212
213					bucket, err := i.entityPacker.GetBucket(ctx, storagepacker.StoragePackerBucketsPrefix+key)
214					if err != nil {
215						errs <- err
216						continue
217					}
218
219					// Write results out to the result channel
220					result <- bucket
221
222				// quit early
223				case <-quit:
224					return
225				}
226			}
227		}()
228	}
229
230	// Distribute the collected keys to the workers in a go routine
231	wg.Add(1)
232	go func() {
233		defer wg.Done()
234		for j, key := range existing {
235			if j%500 == 0 {
236				i.logger.Debug("entities loading", "progress", j)
237			}
238
239			select {
240			case <-quit:
241				return
242
243			default:
244				broker <- key
245			}
246		}
247
248		// Close the broker, causing worker routines to exit
249		close(broker)
250	}()
251
252	// Restore each key by pulling from the result chan
253	for j := 0; j < len(existing); j++ {
254		select {
255		case err := <-errs:
256			// Close all go routines
257			close(quit)
258
259			return err
260
261		case bucket := <-result:
262			// If there is no entry, nothing to restore
263			if bucket == nil {
264				continue
265			}
266
267			for _, item := range bucket.Items {
268				entity, err := i.parseEntityFromBucketItem(ctx, item)
269				if err != nil {
270					return err
271				}
272				if entity == nil {
273					continue
274				}
275
276				ns, err := NamespaceByID(ctx, entity.NamespaceID, i.core)
277				if err != nil {
278					return err
279				}
280				if ns == nil {
281					// Remove dangling entities
282					if !(i.core.ReplicationState().HasState(consts.ReplicationPerformanceSecondary) || i.core.perfStandby) {
283						// Entity's namespace doesn't exist anymore but the
284						// entity from the namespace still exists.
285						i.logger.Warn("deleting entity and its any existing aliases", "name", entity.Name, "namespace_id", entity.NamespaceID)
286						err = i.entityPacker.DeleteItem(ctx, entity.ID)
287						if err != nil {
288							return err
289						}
290					}
291					continue
292				}
293				nsCtx := namespace.ContextWithNamespace(ctx, ns)
294
295				// Ensure that there are no entities with duplicate names
296				entityByName, err := i.MemDBEntityByName(nsCtx, entity.Name, false)
297				if err != nil {
298					return nil
299				}
300				if entityByName != nil {
301					i.logger.Warn(errDuplicateIdentityName.Error(), "entity_name", entity.Name, "conflicting_entity_name", entityByName.Name, "action", "merge the duplicate entities into one")
302					if !i.disableLowerCasedNames {
303						return errDuplicateIdentityName
304					}
305				}
306
307				// Only update MemDB and don't hit the storage again
308				err = i.upsertEntity(nsCtx, entity, nil, false)
309				if err != nil {
310					return fmt.Errorf("failed to update entity in MemDB: %w", err)
311				}
312			}
313		}
314	}
315
316	// Let all go routines finish
317	wg.Wait()
318
319	if i.logger.IsInfo() {
320		i.logger.Info("entities restored")
321	}
322
323	return nil
324}
325
326// upsertEntityInTxn either creates or updates an existing entity. The
327// operations will be updated in both MemDB and storage. If 'persist' is set to
328// false, then storage will not be updated. When an alias is transferred from
329// one entity to another, both the source and destination entities should get
330// updated, in which case, callers should send in both entity and
331// previousEntity.
332func (i *IdentityStore) upsertEntityInTxn(ctx context.Context, txn *memdb.Txn, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
333	defer metrics.MeasureSince([]string{"identity", "upsert_entity_txn"}, time.Now())
334	var err error
335
336	if txn == nil {
337		return errors.New("txn is nil")
338	}
339
340	if entity == nil {
341		return errors.New("entity is nil")
342	}
343
344	if entity.NamespaceID == "" {
345		entity.NamespaceID = namespace.RootNamespaceID
346	}
347
348	if previousEntity != nil && previousEntity.NamespaceID != entity.NamespaceID {
349		return errors.New("entity and previous entity are not in the same namespace")
350	}
351
352	aliasFactors := make([]string, len(entity.Aliases))
353
354	for index, alias := range entity.Aliases {
355		// Verify that alias is not associated to a different one already
356		aliasByFactors, err := i.MemDBAliasByFactors(alias.MountAccessor, alias.Name, false, false)
357		if err != nil {
358			return err
359		}
360
361		if alias.NamespaceID == "" {
362			alias.NamespaceID = namespace.RootNamespaceID
363		}
364
365		switch {
366		case aliasByFactors == nil:
367			// Not found, no merging needed, just check namespace
368			if alias.NamespaceID != entity.NamespaceID {
369				return errors.New("alias and entity are not in the same namespace")
370			}
371
372		case aliasByFactors.CanonicalID == entity.ID:
373			// Lookup found the same entity, so it's already attached to the
374			// right place
375			if aliasByFactors.NamespaceID != entity.NamespaceID {
376				return errors.New("alias from factors and entity are not in the same namespace")
377			}
378
379		case previousEntity != nil && aliasByFactors.CanonicalID == previousEntity.ID:
380			// previousEntity isn't upserted yet so may still contain the old
381			// alias reference in memdb if it was just changed; validate
382			// whether or not it's _actually_ still tied to the entity
383			var found bool
384			for _, prevEntAlias := range previousEntity.Aliases {
385				if prevEntAlias.ID == alias.ID {
386					found = true
387					break
388				}
389			}
390			// If we didn't find the alias still tied to previousEntity, we
391			// shouldn't use the merging logic and should bail
392			if !found {
393				break
394			}
395
396			// Otherwise it's still tied to previousEntity and fall through
397			// into merging. We don't need a namespace check here as existing
398			// checks when creating the aliases should ensure that all line up.
399			fallthrough
400
401		default:
402			i.logger.Warn("alias is already tied to a different entity; these entities are being merged", "alias_id", alias.ID, "other_entity_id", aliasByFactors.CanonicalID, "entity_aliases", entity.Aliases, "alias_by_factors", aliasByFactors)
403
404			respErr, intErr := i.mergeEntity(ctx, txn, entity, []string{aliasByFactors.CanonicalID}, true, false, true, persist)
405			switch {
406			case respErr != nil:
407				return respErr
408			case intErr != nil:
409				return intErr
410			}
411
412			// The entity and aliases will be loaded into memdb and persisted
413			// as a result of the merge so we are done here
414			return nil
415		}
416
417		if strutil.StrListContains(aliasFactors, i.sanitizeName(alias.Name)+alias.MountAccessor) {
418			i.logger.Warn(errDuplicateIdentityName.Error(), "alias_name", alias.Name, "mount_accessor", alias.MountAccessor, "entity_name", entity.Name, "action", "delete one of the duplicate aliases")
419			if !i.disableLowerCasedNames {
420				return errDuplicateIdentityName
421			}
422		}
423
424		// Insert or update alias in MemDB using the transaction created above
425		err = i.MemDBUpsertAliasInTxn(txn, alias, false)
426		if err != nil {
427			return err
428		}
429
430		aliasFactors[index] = i.sanitizeName(alias.Name) + alias.MountAccessor
431	}
432
433	// If previous entity is set, update it in MemDB and persist it
434	if previousEntity != nil {
435		err = i.MemDBUpsertEntityInTxn(txn, previousEntity)
436		if err != nil {
437			return err
438		}
439
440		if persist {
441			// Persist the previous entity object
442			marshaledPreviousEntity, err := ptypes.MarshalAny(previousEntity)
443			if err != nil {
444				return err
445			}
446			err = i.entityPacker.PutItem(ctx, &storagepacker.Item{
447				ID:      previousEntity.ID,
448				Message: marshaledPreviousEntity,
449			})
450			if err != nil {
451				return err
452			}
453		}
454	}
455
456	// Insert or update entity in MemDB using the transaction created above
457	err = i.MemDBUpsertEntityInTxn(txn, entity)
458	if err != nil {
459		return err
460	}
461
462	if persist {
463		entityAsAny, err := ptypes.MarshalAny(entity)
464		if err != nil {
465			return err
466		}
467		item := &storagepacker.Item{
468			ID:      entity.ID,
469			Message: entityAsAny,
470		}
471
472		// Persist the entity object
473		err = i.entityPacker.PutItem(ctx, item)
474		if err != nil {
475			return err
476		}
477	}
478
479	return nil
480}
481
482// upsertEntity either creates or updates an existing entity. The operations
483// will be updated in both MemDB and storage. If 'persist' is set to false,
484// then storage will not be updated. When an alias is transferred from one
485// entity to another, both the source and destination entities should get
486// updated, in which case, callers should send in both entity and
487// previousEntity.
488func (i *IdentityStore) upsertEntity(ctx context.Context, entity *identity.Entity, previousEntity *identity.Entity, persist bool) error {
489	defer metrics.MeasureSince([]string{"identity", "upsert_entity"}, time.Now())
490
491	// Create a MemDB transaction to update both alias and entity
492	txn := i.db.Txn(true)
493	defer txn.Abort()
494
495	err := i.upsertEntityInTxn(ctx, txn, entity, previousEntity, persist)
496	if err != nil {
497		return err
498	}
499
500	txn.Commit()
501
502	return nil
503}
504
505func (i *IdentityStore) MemDBUpsertAliasInTxn(txn *memdb.Txn, alias *identity.Alias, groupAlias bool) error {
506	if txn == nil {
507		return fmt.Errorf("nil txn")
508	}
509
510	if alias == nil {
511		return fmt.Errorf("alias is nil")
512	}
513
514	if alias.NamespaceID == "" {
515		alias.NamespaceID = namespace.RootNamespaceID
516	}
517
518	tableName := entityAliasesTable
519	if groupAlias {
520		tableName = groupAliasesTable
521	}
522
523	aliasRaw, err := txn.First(tableName, "id", alias.ID)
524	if err != nil {
525		return fmt.Errorf("failed to lookup alias from memdb using alias ID: %w", err)
526	}
527
528	if aliasRaw != nil {
529		err = txn.Delete(tableName, aliasRaw)
530		if err != nil {
531			return fmt.Errorf("failed to delete alias from memdb: %w", err)
532		}
533	}
534
535	if err := txn.Insert(tableName, alias); err != nil {
536		return fmt.Errorf("failed to update alias into memdb: %w", err)
537	}
538
539	return nil
540}
541
542func (i *IdentityStore) MemDBAliasByIDInTxn(txn *memdb.Txn, aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) {
543	if aliasID == "" {
544		return nil, fmt.Errorf("missing alias ID")
545	}
546
547	if txn == nil {
548		return nil, fmt.Errorf("txn is nil")
549	}
550
551	tableName := entityAliasesTable
552	if groupAlias {
553		tableName = groupAliasesTable
554	}
555
556	aliasRaw, err := txn.First(tableName, "id", aliasID)
557	if err != nil {
558		return nil, fmt.Errorf("failed to fetch alias from memdb using alias ID: %w", err)
559	}
560
561	if aliasRaw == nil {
562		return nil, nil
563	}
564
565	alias, ok := aliasRaw.(*identity.Alias)
566	if !ok {
567		return nil, fmt.Errorf("failed to declare the type of fetched alias")
568	}
569
570	if clone {
571		return alias.Clone()
572	}
573
574	return alias, nil
575}
576
577func (i *IdentityStore) MemDBAliasByID(aliasID string, clone bool, groupAlias bool) (*identity.Alias, error) {
578	if aliasID == "" {
579		return nil, fmt.Errorf("missing alias ID")
580	}
581
582	txn := i.db.Txn(false)
583
584	return i.MemDBAliasByIDInTxn(txn, aliasID, clone, groupAlias)
585}
586
587func (i *IdentityStore) MemDBAliasByFactors(mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
588	if aliasName == "" {
589		return nil, fmt.Errorf("missing alias name")
590	}
591
592	if mountAccessor == "" {
593		return nil, fmt.Errorf("missing mount accessor")
594	}
595
596	txn := i.db.Txn(false)
597
598	return i.MemDBAliasByFactorsInTxn(txn, mountAccessor, aliasName, clone, groupAlias)
599}
600
601func (i *IdentityStore) MemDBAliasByFactorsInTxn(txn *memdb.Txn, mountAccessor, aliasName string, clone bool, groupAlias bool) (*identity.Alias, error) {
602	if txn == nil {
603		return nil, fmt.Errorf("nil txn")
604	}
605
606	if aliasName == "" {
607		return nil, fmt.Errorf("missing alias name")
608	}
609
610	if mountAccessor == "" {
611		return nil, fmt.Errorf("missing mount accessor")
612	}
613
614	tableName := entityAliasesTable
615	if groupAlias {
616		tableName = groupAliasesTable
617	}
618
619	aliasRaw, err := txn.First(tableName, "factors", mountAccessor, aliasName)
620	if err != nil {
621		return nil, fmt.Errorf("failed to fetch alias from memdb using factors: %w", err)
622	}
623
624	if aliasRaw == nil {
625		return nil, nil
626	}
627
628	alias, ok := aliasRaw.(*identity.Alias)
629	if !ok {
630		return nil, fmt.Errorf("failed to declare the type of fetched alias")
631	}
632
633	if clone {
634		return alias.Clone()
635	}
636
637	return alias, nil
638}
639
640func (i *IdentityStore) MemDBDeleteAliasByIDInTxn(txn *memdb.Txn, aliasID string, groupAlias bool) error {
641	if aliasID == "" {
642		return nil
643	}
644
645	if txn == nil {
646		return fmt.Errorf("txn is nil")
647	}
648
649	alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, groupAlias)
650	if err != nil {
651		return err
652	}
653
654	if alias == nil {
655		return nil
656	}
657
658	tableName := entityAliasesTable
659	if groupAlias {
660		tableName = groupAliasesTable
661	}
662
663	err = txn.Delete(tableName, alias)
664	if err != nil {
665		return fmt.Errorf("failed to delete alias from memdb: %w", err)
666	}
667
668	return nil
669}
670
671func (i *IdentityStore) MemDBAliases(ws memdb.WatchSet, groupAlias bool) (memdb.ResultIterator, error) {
672	txn := i.db.Txn(false)
673
674	tableName := entityAliasesTable
675	if groupAlias {
676		tableName = groupAliasesTable
677	}
678
679	iter, err := txn.Get(tableName, "id")
680	if err != nil {
681		return nil, err
682	}
683
684	ws.Add(iter.WatchCh())
685
686	return iter, nil
687}
688
689func (i *IdentityStore) MemDBUpsertEntityInTxn(txn *memdb.Txn, entity *identity.Entity) error {
690	if txn == nil {
691		return fmt.Errorf("nil txn")
692	}
693
694	if entity == nil {
695		return fmt.Errorf("entity is nil")
696	}
697
698	if entity.NamespaceID == "" {
699		entity.NamespaceID = namespace.RootNamespaceID
700	}
701
702	entityRaw, err := txn.First(entitiesTable, "id", entity.ID)
703	if err != nil {
704		return fmt.Errorf("failed to lookup entity from memdb using entity id: %w", err)
705	}
706
707	if entityRaw != nil {
708		err = txn.Delete(entitiesTable, entityRaw)
709		if err != nil {
710			return fmt.Errorf("failed to delete entity from memdb: %w", err)
711		}
712	}
713
714	if err := txn.Insert(entitiesTable, entity); err != nil {
715		return fmt.Errorf("failed to update entity into memdb: %w", err)
716	}
717
718	return nil
719}
720
721func (i *IdentityStore) MemDBEntityByIDInTxn(txn *memdb.Txn, entityID string, clone bool) (*identity.Entity, error) {
722	if entityID == "" {
723		return nil, fmt.Errorf("missing entity id")
724	}
725
726	if txn == nil {
727		return nil, fmt.Errorf("txn is nil")
728	}
729
730	entityRaw, err := txn.First(entitiesTable, "id", entityID)
731	if err != nil {
732		return nil, fmt.Errorf("failed to fetch entity from memdb using entity id: %w", err)
733	}
734
735	if entityRaw == nil {
736		return nil, nil
737	}
738
739	entity, ok := entityRaw.(*identity.Entity)
740	if !ok {
741		return nil, fmt.Errorf("failed to declare the type of fetched entity")
742	}
743
744	if clone {
745		return entity.Clone()
746	}
747
748	return entity, nil
749}
750
751func (i *IdentityStore) MemDBEntityByID(entityID string, clone bool) (*identity.Entity, error) {
752	if entityID == "" {
753		return nil, fmt.Errorf("missing entity id")
754	}
755
756	txn := i.db.Txn(false)
757
758	return i.MemDBEntityByIDInTxn(txn, entityID, clone)
759}
760
761func (i *IdentityStore) MemDBEntityByName(ctx context.Context, entityName string, clone bool) (*identity.Entity, error) {
762	if entityName == "" {
763		return nil, fmt.Errorf("missing entity name")
764	}
765
766	txn := i.db.Txn(false)
767
768	return i.MemDBEntityByNameInTxn(ctx, txn, entityName, clone)
769}
770
771func (i *IdentityStore) MemDBEntityByNameInTxn(ctx context.Context, txn *memdb.Txn, entityName string, clone bool) (*identity.Entity, error) {
772	if entityName == "" {
773		return nil, fmt.Errorf("missing entity name")
774	}
775
776	ns, err := namespace.FromContext(ctx)
777	if err != nil {
778		return nil, err
779	}
780
781	entityRaw, err := txn.First(entitiesTable, "name", ns.ID, entityName)
782	if err != nil {
783		return nil, fmt.Errorf("failed to fetch entity from memdb using entity name: %w", err)
784	}
785
786	if entityRaw == nil {
787		return nil, nil
788	}
789
790	entity, ok := entityRaw.(*identity.Entity)
791	if !ok {
792		return nil, fmt.Errorf("failed to declare the type of fetched entity")
793	}
794
795	if clone {
796		return entity.Clone()
797	}
798
799	return entity, nil
800}
801
802func (i *IdentityStore) MemDBEntitiesByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Entity, error) {
803	if txn == nil {
804		return nil, fmt.Errorf("nil txn")
805	}
806
807	if bucketKey == "" {
808		return nil, fmt.Errorf("empty bucket key")
809	}
810
811	entitiesIter, err := txn.Get(entitiesTable, "bucket_key", bucketKey)
812	if err != nil {
813		return nil, fmt.Errorf("failed to lookup entities using bucket entry key hash: %w", err)
814	}
815
816	var entities []*identity.Entity
817	for entity := entitiesIter.Next(); entity != nil; entity = entitiesIter.Next() {
818		entities = append(entities, entity.(*identity.Entity))
819	}
820
821	return entities, nil
822}
823
824func (i *IdentityStore) MemDBEntityByMergedEntityID(mergedEntityID string, clone bool) (*identity.Entity, error) {
825	if mergedEntityID == "" {
826		return nil, fmt.Errorf("missing merged entity id")
827	}
828
829	txn := i.db.Txn(false)
830
831	entityRaw, err := txn.First(entitiesTable, "merged_entity_ids", mergedEntityID)
832	if err != nil {
833		return nil, fmt.Errorf("failed to fetch entity from memdb using merged entity id: %w", err)
834	}
835
836	if entityRaw == nil {
837		return nil, nil
838	}
839
840	entity, ok := entityRaw.(*identity.Entity)
841	if !ok {
842		return nil, fmt.Errorf("failed to declare the type of fetched entity")
843	}
844
845	if clone {
846		return entity.Clone()
847	}
848
849	return entity, nil
850}
851
852func (i *IdentityStore) MemDBEntityByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Entity, error) {
853	if aliasID == "" {
854		return nil, fmt.Errorf("missing alias ID")
855	}
856
857	if txn == nil {
858		return nil, fmt.Errorf("txn is nil")
859	}
860
861	alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, false)
862	if err != nil {
863		return nil, err
864	}
865
866	if alias == nil {
867		return nil, nil
868	}
869
870	return i.MemDBEntityByIDInTxn(txn, alias.CanonicalID, clone)
871}
872
873func (i *IdentityStore) MemDBEntityByAliasID(aliasID string, clone bool) (*identity.Entity, error) {
874	if aliasID == "" {
875		return nil, fmt.Errorf("missing alias ID")
876	}
877
878	txn := i.db.Txn(false)
879
880	return i.MemDBEntityByAliasIDInTxn(txn, aliasID, clone)
881}
882
883func (i *IdentityStore) MemDBDeleteEntityByID(entityID string) error {
884	if entityID == "" {
885		return nil
886	}
887
888	txn := i.db.Txn(true)
889	defer txn.Abort()
890
891	err := i.MemDBDeleteEntityByIDInTxn(txn, entityID)
892	if err != nil {
893		return err
894	}
895
896	txn.Commit()
897
898	return nil
899}
900
901func (i *IdentityStore) MemDBDeleteEntityByIDInTxn(txn *memdb.Txn, entityID string) error {
902	if entityID == "" {
903		return nil
904	}
905
906	if txn == nil {
907		return fmt.Errorf("txn is nil")
908	}
909
910	entity, err := i.MemDBEntityByIDInTxn(txn, entityID, false)
911	if err != nil {
912		return err
913	}
914
915	if entity == nil {
916		return nil
917	}
918
919	err = txn.Delete(entitiesTable, entity)
920	if err != nil {
921		return fmt.Errorf("failed to delete entity from memdb: %w", err)
922	}
923
924	return nil
925}
926
927func (i *IdentityStore) sanitizeAlias(ctx context.Context, alias *identity.Alias) error {
928	var err error
929
930	if alias == nil {
931		return fmt.Errorf("alias is nil")
932	}
933
934	// Alias must always be tied to a canonical object
935	if alias.CanonicalID == "" {
936		return fmt.Errorf("missing canonical ID")
937	}
938
939	// Alias must have a name
940	if alias.Name == "" {
941		return fmt.Errorf("missing alias name %q", alias.Name)
942	}
943
944	// Alias metadata should always be map[string]string
945	err = validateMetadata(alias.Metadata)
946	if err != nil {
947		return fmt.Errorf("invalid alias metadata: %w", err)
948	}
949
950	// Create an ID if there isn't one already
951	if alias.ID == "" {
952		alias.ID, err = uuid.GenerateUUID()
953		if err != nil {
954			return fmt.Errorf("failed to generate alias ID")
955		}
956	}
957
958	if alias.NamespaceID == "" {
959		ns, err := namespace.FromContext(ctx)
960		if err != nil {
961			return err
962		}
963		alias.NamespaceID = ns.ID
964	}
965
966	ns, err := namespace.FromContext(ctx)
967	if err != nil {
968		return err
969	}
970	if ns.ID != alias.NamespaceID {
971		return errors.New("alias belongs to a different namespace")
972	}
973
974	// Set the creation and last update times
975	if alias.CreationTime == nil {
976		alias.CreationTime = ptypes.TimestampNow()
977		alias.LastUpdateTime = alias.CreationTime
978	} else {
979		alias.LastUpdateTime = ptypes.TimestampNow()
980	}
981
982	return nil
983}
984
985func (i *IdentityStore) sanitizeEntity(ctx context.Context, entity *identity.Entity) error {
986	var err error
987
988	if entity == nil {
989		return fmt.Errorf("entity is nil")
990	}
991
992	// Create an ID if there isn't one already
993	if entity.ID == "" {
994		entity.ID, err = uuid.GenerateUUID()
995		if err != nil {
996			return fmt.Errorf("failed to generate entity id")
997		}
998
999		// Set the storage bucket key in entity
1000		entity.BucketKey = i.entityPacker.BucketKey(entity.ID)
1001	}
1002
1003	ns, err := namespace.FromContext(ctx)
1004	if err != nil {
1005		return err
1006	}
1007	if entity.NamespaceID == "" {
1008		entity.NamespaceID = ns.ID
1009	}
1010	if ns.ID != entity.NamespaceID {
1011		return errors.New("entity does not belong to this namespace")
1012	}
1013
1014	// Create a name if there isn't one already
1015	if entity.Name == "" {
1016		entity.Name, err = i.generateName(ctx, "entity")
1017		if err != nil {
1018			return fmt.Errorf("failed to generate entity name")
1019		}
1020	}
1021
1022	// Entity metadata should always be map[string]string
1023	err = validateMetadata(entity.Metadata)
1024	if err != nil {
1025		return fmt.Errorf("invalid entity metadata: %w", err)
1026	}
1027
1028	// Set the creation and last update times
1029	if entity.CreationTime == nil {
1030		entity.CreationTime = ptypes.TimestampNow()
1031		entity.LastUpdateTime = entity.CreationTime
1032	} else {
1033		entity.LastUpdateTime = ptypes.TimestampNow()
1034	}
1035
1036	// Ensure that MFASecrets is non-nil at any time. This is useful when MFA
1037	// secret generation procedures try to append MFA info to entity.
1038	if entity.MFASecrets == nil {
1039		entity.MFASecrets = make(map[string]*mfa.Secret)
1040	}
1041
1042	return nil
1043}
1044
1045func (i *IdentityStore) sanitizeAndUpsertGroup(ctx context.Context, group *identity.Group, previousGroup *identity.Group, memberGroupIDs []string) error {
1046	var err error
1047
1048	if group == nil {
1049		return fmt.Errorf("group is nil")
1050	}
1051
1052	// Create an ID if there isn't one already
1053	if group.ID == "" {
1054		group.ID, err = uuid.GenerateUUID()
1055		if err != nil {
1056			return fmt.Errorf("failed to generate group id")
1057		}
1058
1059		// Set the hash value of the storage bucket key in group
1060		group.BucketKey = i.groupPacker.BucketKey(group.ID)
1061	}
1062
1063	if group.NamespaceID == "" {
1064		ns, err := namespace.FromContext(ctx)
1065		if err != nil {
1066			return err
1067		}
1068		group.NamespaceID = ns.ID
1069	}
1070	ns, err := namespace.FromContext(ctx)
1071	if err != nil {
1072		return err
1073	}
1074	if ns.ID != group.NamespaceID {
1075		return errors.New("group does not belong to this namespace")
1076	}
1077
1078	// Create a name if there isn't one already
1079	if group.Name == "" {
1080		group.Name, err = i.generateName(ctx, "group")
1081		if err != nil {
1082			return fmt.Errorf("failed to generate group name")
1083		}
1084	}
1085
1086	// Entity metadata should always be map[string]string
1087	err = validateMetadata(group.Metadata)
1088	if err != nil {
1089		return fmt.Errorf("invalid group metadata: %w", err)
1090	}
1091
1092	// Set the creation and last update times
1093	if group.CreationTime == nil {
1094		group.CreationTime = ptypes.TimestampNow()
1095		group.LastUpdateTime = group.CreationTime
1096	} else {
1097		group.LastUpdateTime = ptypes.TimestampNow()
1098	}
1099
1100	// Remove duplicate entity IDs and check if all IDs are valid
1101	group.MemberEntityIDs = strutil.RemoveDuplicates(group.MemberEntityIDs, false)
1102	for _, entityID := range group.MemberEntityIDs {
1103		entity, err := i.MemDBEntityByID(entityID, false)
1104		if err != nil {
1105			return fmt.Errorf("failed to validate entity ID %q: %w", entityID, err)
1106		}
1107		if entity == nil {
1108			return fmt.Errorf("invalid entity ID %q", entityID)
1109		}
1110	}
1111
1112	txn := i.db.Txn(true)
1113	defer txn.Abort()
1114
1115	var currentMemberGroupIDs []string
1116	var currentMemberGroups []*identity.Group
1117
1118	// If there are no member group IDs supplied, then it shouldn't be
1119	// processed. If an empty set of member group IDs are supplied, then it
1120	// should be processed. Hence the nil check instead of the length check.
1121	if memberGroupIDs == nil {
1122		goto ALIAS
1123	}
1124
1125	memberGroupIDs = strutil.RemoveDuplicates(memberGroupIDs, false)
1126
1127	// For those group member IDs that are removed from the list, remove current
1128	// group ID as their respective ParentGroupID.
1129
1130	// Get the current MemberGroups IDs for this group
1131	currentMemberGroups, err = i.MemDBGroupsByParentGroupID(group.ID, false)
1132	if err != nil {
1133		return err
1134	}
1135	for _, currentMemberGroup := range currentMemberGroups {
1136		currentMemberGroupIDs = append(currentMemberGroupIDs, currentMemberGroup.ID)
1137	}
1138
1139	// Update parent group IDs in the removed members
1140	for _, currentMemberGroupID := range currentMemberGroupIDs {
1141		if strutil.StrListContains(memberGroupIDs, currentMemberGroupID) {
1142			continue
1143		}
1144
1145		currentMemberGroup, err := i.MemDBGroupByID(currentMemberGroupID, true)
1146		if err != nil {
1147			return err
1148		}
1149		if currentMemberGroup == nil {
1150			return fmt.Errorf("invalid member group ID %q", currentMemberGroupID)
1151		}
1152
1153		// Remove group ID from the parent group IDs
1154		currentMemberGroup.ParentGroupIDs = strutil.StrListDelete(currentMemberGroup.ParentGroupIDs, group.ID)
1155
1156		err = i.UpsertGroupInTxn(ctx, txn, currentMemberGroup, true)
1157		if err != nil {
1158			return err
1159		}
1160	}
1161
1162	// After the group lock is held, make membership updates to all the
1163	// relevant groups
1164	for _, memberGroupID := range memberGroupIDs {
1165		memberGroup, err := i.MemDBGroupByID(memberGroupID, true)
1166		if err != nil {
1167			return err
1168		}
1169		if memberGroup == nil {
1170			return fmt.Errorf("invalid member group ID %q", memberGroupID)
1171		}
1172
1173		// Skip if memberGroupID is already a member of group.ID
1174		if strutil.StrListContains(memberGroup.ParentGroupIDs, group.ID) {
1175			continue
1176		}
1177
1178		// Ensure that adding memberGroupID does not lead to cyclic
1179		// relationships
1180		// Detect self loop
1181		if group.ID == memberGroupID {
1182			return fmt.Errorf("member group ID %q is same as the ID of the group", group.ID)
1183		}
1184
1185		groupByID, err := i.MemDBGroupByID(group.ID, true)
1186		if err != nil {
1187			return err
1188		}
1189
1190		// If group is nil, that means that a group doesn't already exist and its
1191		// okay to add any group as its member group.
1192		if groupByID != nil {
1193			// If adding the memberGroupID to groupID creates a cycle, then groupID must
1194			// be a hop in that loop. Start a DFS traversal from memberGroupID and see if
1195			// it reaches back to groupID. If it does, then it's a loop.
1196
1197			// Created a visited set
1198			visited := make(map[string]bool)
1199			cycleDetected, err := i.detectCycleDFS(visited, groupByID.ID, memberGroupID)
1200			if err != nil {
1201				return fmt.Errorf("failed to perform cyclic relationship detection for member group ID %q", memberGroupID)
1202			}
1203			if cycleDetected {
1204				return fmt.Errorf("cyclic relationship detected for member group ID %q", memberGroupID)
1205			}
1206		}
1207
1208		memberGroup.ParentGroupIDs = append(memberGroup.ParentGroupIDs, group.ID)
1209
1210		// This technically is not upsert. It is only update, only the method
1211		// name is upsert here.
1212		err = i.UpsertGroupInTxn(ctx, txn, memberGroup, true)
1213		if err != nil {
1214			// Ideally we would want to revert the whole operation in case of
1215			// errors while persisting in member groups. But there is no
1216			// storage transaction support yet. When we do have it, this will need
1217			// an update.
1218			return err
1219		}
1220	}
1221
1222ALIAS:
1223	// Sanitize the group alias
1224	if group.Alias != nil {
1225		group.Alias.CanonicalID = group.ID
1226		err = i.sanitizeAlias(ctx, group.Alias)
1227		if err != nil {
1228			return err
1229		}
1230	}
1231
1232	// If previousGroup is not nil, we are moving the alias from the previous
1233	// group to the new one. As a result we need to upsert both in the context
1234	// of this same transaction.
1235	if previousGroup != nil {
1236		err = i.UpsertGroupInTxn(ctx, txn, previousGroup, true)
1237		if err != nil {
1238			return err
1239		}
1240	}
1241
1242	err = i.UpsertGroupInTxn(ctx, txn, group, true)
1243	if err != nil {
1244		return err
1245	}
1246
1247	txn.Commit()
1248
1249	return nil
1250}
1251
1252func (i *IdentityStore) deleteAliasesInEntityInTxn(txn *memdb.Txn, entity *identity.Entity, aliases []*identity.Alias) error {
1253	if entity == nil {
1254		return fmt.Errorf("entity is nil")
1255	}
1256
1257	if txn == nil {
1258		return fmt.Errorf("txn is nil")
1259	}
1260
1261	var remainList []*identity.Alias
1262	var removeList []*identity.Alias
1263
1264	for _, item := range aliases {
1265		for _, alias := range entity.Aliases {
1266			if alias.ID == item.ID {
1267				removeList = append(removeList, alias)
1268			} else {
1269				remainList = append(remainList, alias)
1270			}
1271		}
1272	}
1273
1274	// Remove identity indices from aliases table for those that needs to
1275	// be removed
1276	for _, alias := range removeList {
1277		err := i.MemDBDeleteAliasByIDInTxn(txn, alias.ID, false)
1278		if err != nil {
1279			return err
1280		}
1281	}
1282
1283	// Update the entity with remaining items
1284	entity.Aliases = remainList
1285
1286	return nil
1287}
1288
1289// validateMeta validates a set of key/value pairs from the agent config
1290func validateMetadata(meta map[string]string) error {
1291	if len(meta) > metaMaxKeyPairs {
1292		return fmt.Errorf("metadata cannot contain more than %d key/value pairs", metaMaxKeyPairs)
1293	}
1294
1295	for key, value := range meta {
1296		if err := validateMetaPair(key, value); err != nil {
1297			return fmt.Errorf("failed to load metadata pair (%q, %q): %w", key, value, err)
1298		}
1299	}
1300
1301	return nil
1302}
1303
1304// validateMetaPair checks that the given key/value pair is in a valid format
1305func validateMetaPair(key, value string) error {
1306	if key == "" {
1307		return fmt.Errorf("key cannot be blank")
1308	}
1309	if !metaKeyFormatRegEx(key) {
1310		return fmt.Errorf("key contains invalid characters")
1311	}
1312	if len(key) > metaKeyMaxLength {
1313		return fmt.Errorf("key is too long (limit: %d characters)", metaKeyMaxLength)
1314	}
1315	if strings.HasPrefix(key, metaKeyReservedPrefix) {
1316		return fmt.Errorf("key prefix %q is reserved for internal use", metaKeyReservedPrefix)
1317	}
1318	if len(value) > metaValueMaxLength {
1319		return fmt.Errorf("value is too long (limit: %d characters)", metaValueMaxLength)
1320	}
1321	return nil
1322}
1323
1324func (i *IdentityStore) MemDBGroupByNameInTxn(ctx context.Context, txn *memdb.Txn, groupName string, clone bool) (*identity.Group, error) {
1325	if groupName == "" {
1326		return nil, fmt.Errorf("missing group name")
1327	}
1328
1329	if txn == nil {
1330		return nil, fmt.Errorf("txn is nil")
1331	}
1332
1333	ns, err := namespace.FromContext(ctx)
1334	if err != nil {
1335		return nil, err
1336	}
1337
1338	groupRaw, err := txn.First(groupsTable, "name", ns.ID, groupName)
1339	if err != nil {
1340		return nil, fmt.Errorf("failed to fetch group from memdb using group name: %w", err)
1341	}
1342
1343	if groupRaw == nil {
1344		return nil, nil
1345	}
1346
1347	group, ok := groupRaw.(*identity.Group)
1348	if !ok {
1349		return nil, fmt.Errorf("failed to declare the type of fetched group")
1350	}
1351
1352	if clone {
1353		return group.Clone()
1354	}
1355
1356	return group, nil
1357}
1358
1359func (i *IdentityStore) MemDBGroupByName(ctx context.Context, groupName string, clone bool) (*identity.Group, error) {
1360	if groupName == "" {
1361		return nil, fmt.Errorf("missing group name")
1362	}
1363
1364	txn := i.db.Txn(false)
1365
1366	return i.MemDBGroupByNameInTxn(ctx, txn, groupName, clone)
1367}
1368
1369func (i *IdentityStore) UpsertGroup(ctx context.Context, group *identity.Group, persist bool) error {
1370	defer metrics.MeasureSince([]string{"identity", "upsert_group"}, time.Now())
1371
1372	txn := i.db.Txn(true)
1373	defer txn.Abort()
1374
1375	err := i.UpsertGroupInTxn(ctx, txn, group, true)
1376	if err != nil {
1377		return err
1378	}
1379
1380	txn.Commit()
1381
1382	return nil
1383}
1384
1385func (i *IdentityStore) UpsertGroupInTxn(ctx context.Context, txn *memdb.Txn, group *identity.Group, persist bool) error {
1386	defer metrics.MeasureSince([]string{"identity", "upsert_group_txn"}, time.Now())
1387
1388	var err error
1389
1390	if txn == nil {
1391		return fmt.Errorf("txn is nil")
1392	}
1393
1394	if group == nil {
1395		return fmt.Errorf("group is nil")
1396	}
1397
1398	// Increment the modify index of the group
1399	group.ModifyIndex++
1400
1401	// Clear the old alias from memdb
1402	groupClone, err := i.MemDBGroupByID(group.ID, true)
1403	if err != nil {
1404		return err
1405	}
1406	if groupClone != nil && groupClone.Alias != nil {
1407		err = i.MemDBDeleteAliasByIDInTxn(txn, groupClone.Alias.ID, true)
1408		if err != nil {
1409			return err
1410		}
1411	}
1412
1413	// Add the new alias to memdb
1414	if group.Alias != nil {
1415		err = i.MemDBUpsertAliasInTxn(txn, group.Alias, true)
1416		if err != nil {
1417			return err
1418		}
1419	}
1420
1421	// Insert or update group in MemDB using the transaction created above
1422	err = i.MemDBUpsertGroupInTxn(txn, group)
1423	if err != nil {
1424		return err
1425	}
1426
1427	if persist {
1428		groupAsAny, err := ptypes.MarshalAny(group)
1429		if err != nil {
1430			return err
1431		}
1432
1433		item := &storagepacker.Item{
1434			ID:      group.ID,
1435			Message: groupAsAny,
1436		}
1437
1438		sent, err := sendGroupUpgrade(ctx, i, group)
1439		if err != nil {
1440			return err
1441		}
1442		if !sent {
1443			if err := i.groupPacker.PutItem(ctx, item); err != nil {
1444				return err
1445			}
1446		}
1447	}
1448
1449	return nil
1450}
1451
1452func (i *IdentityStore) MemDBUpsertGroupInTxn(txn *memdb.Txn, group *identity.Group) error {
1453	if txn == nil {
1454		return fmt.Errorf("nil txn")
1455	}
1456
1457	if group == nil {
1458		return fmt.Errorf("group is nil")
1459	}
1460
1461	if group.NamespaceID == "" {
1462		group.NamespaceID = namespace.RootNamespaceID
1463	}
1464
1465	groupRaw, err := txn.First(groupsTable, "id", group.ID)
1466	if err != nil {
1467		return fmt.Errorf("failed to lookup group from memdb using group id: %w", err)
1468	}
1469
1470	if groupRaw != nil {
1471		err = txn.Delete(groupsTable, groupRaw)
1472		if err != nil {
1473			return fmt.Errorf("failed to delete group from memdb: %w", err)
1474		}
1475	}
1476
1477	if err := txn.Insert(groupsTable, group); err != nil {
1478		return fmt.Errorf("failed to update group into memdb: %w", err)
1479	}
1480
1481	return nil
1482}
1483
1484func (i *IdentityStore) MemDBDeleteGroupByIDInTxn(txn *memdb.Txn, groupID string) error {
1485	if groupID == "" {
1486		return nil
1487	}
1488
1489	if txn == nil {
1490		return fmt.Errorf("txn is nil")
1491	}
1492
1493	group, err := i.MemDBGroupByIDInTxn(txn, groupID, false)
1494	if err != nil {
1495		return err
1496	}
1497
1498	if group == nil {
1499		return nil
1500	}
1501
1502	err = txn.Delete("groups", group)
1503	if err != nil {
1504		return fmt.Errorf("failed to delete group from memdb: %w", err)
1505	}
1506
1507	return nil
1508}
1509
1510func (i *IdentityStore) MemDBGroupByIDInTxn(txn *memdb.Txn, groupID string, clone bool) (*identity.Group, error) {
1511	if groupID == "" {
1512		return nil, fmt.Errorf("missing group ID")
1513	}
1514
1515	if txn == nil {
1516		return nil, fmt.Errorf("txn is nil")
1517	}
1518
1519	groupRaw, err := txn.First(groupsTable, "id", groupID)
1520	if err != nil {
1521		return nil, fmt.Errorf("failed to fetch group from memdb using group ID: %w", err)
1522	}
1523
1524	if groupRaw == nil {
1525		return nil, nil
1526	}
1527
1528	group, ok := groupRaw.(*identity.Group)
1529	if !ok {
1530		return nil, fmt.Errorf("failed to declare the type of fetched group")
1531	}
1532
1533	if clone {
1534		return group.Clone()
1535	}
1536
1537	return group, nil
1538}
1539
1540func (i *IdentityStore) MemDBGroupByID(groupID string, clone bool) (*identity.Group, error) {
1541	if groupID == "" {
1542		return nil, fmt.Errorf("missing group ID")
1543	}
1544
1545	txn := i.db.Txn(false)
1546
1547	return i.MemDBGroupByIDInTxn(txn, groupID, clone)
1548}
1549
1550func (i *IdentityStore) MemDBGroupsByParentGroupIDInTxn(txn *memdb.Txn, memberGroupID string, clone bool) ([]*identity.Group, error) {
1551	if memberGroupID == "" {
1552		return nil, fmt.Errorf("missing member group ID")
1553	}
1554
1555	groupsIter, err := txn.Get(groupsTable, "parent_group_ids", memberGroupID)
1556	if err != nil {
1557		return nil, fmt.Errorf("failed to lookup groups using member group ID: %w", err)
1558	}
1559
1560	var groups []*identity.Group
1561	for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
1562		entry := group.(*identity.Group)
1563		if clone {
1564			entry, err = entry.Clone()
1565			if err != nil {
1566				return nil, err
1567			}
1568		}
1569		groups = append(groups, entry)
1570	}
1571
1572	return groups, nil
1573}
1574
1575func (i *IdentityStore) MemDBGroupsByParentGroupID(memberGroupID string, clone bool) ([]*identity.Group, error) {
1576	if memberGroupID == "" {
1577		return nil, fmt.Errorf("missing member group ID")
1578	}
1579
1580	txn := i.db.Txn(false)
1581
1582	return i.MemDBGroupsByParentGroupIDInTxn(txn, memberGroupID, clone)
1583}
1584
1585func (i *IdentityStore) MemDBGroupsByMemberEntityID(entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) {
1586	txn := i.db.Txn(false)
1587	defer txn.Abort()
1588
1589	return i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, clone, externalOnly)
1590}
1591
1592func (i *IdentityStore) MemDBGroupsByMemberEntityIDInTxn(txn *memdb.Txn, entityID string, clone bool, externalOnly bool) ([]*identity.Group, error) {
1593	if entityID == "" {
1594		return nil, fmt.Errorf("missing entity ID")
1595	}
1596
1597	groupsIter, err := txn.Get(groupsTable, "member_entity_ids", entityID)
1598	if err != nil {
1599		return nil, fmt.Errorf("failed to lookup groups using entity ID: %w", err)
1600	}
1601
1602	var groups []*identity.Group
1603	for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
1604		entry := group.(*identity.Group)
1605		if externalOnly && entry.Type == groupTypeInternal {
1606			continue
1607		}
1608		if clone {
1609			entry, err = entry.Clone()
1610			if err != nil {
1611				return nil, err
1612			}
1613		}
1614		groups = append(groups, entry)
1615	}
1616
1617	return groups, nil
1618}
1619
1620func (i *IdentityStore) groupPoliciesByEntityID(entityID string) (map[string][]string, error) {
1621	if entityID == "" {
1622		return nil, fmt.Errorf("empty entity ID")
1623	}
1624
1625	groups, err := i.MemDBGroupsByMemberEntityID(entityID, false, false)
1626	if err != nil {
1627		return nil, err
1628	}
1629
1630	visited := make(map[string]bool)
1631	policies := make(map[string][]string)
1632	for _, group := range groups {
1633		err := i.collectPoliciesReverseDFS(group, visited, policies)
1634		if err != nil {
1635			return nil, err
1636		}
1637	}
1638
1639	return policies, nil
1640}
1641
1642func (i *IdentityStore) groupsByEntityID(entityID string) ([]*identity.Group, []*identity.Group, error) {
1643	if entityID == "" {
1644		return nil, nil, fmt.Errorf("empty entity ID")
1645	}
1646
1647	groups, err := i.MemDBGroupsByMemberEntityID(entityID, true, false)
1648	if err != nil {
1649		return nil, nil, err
1650	}
1651
1652	visited := make(map[string]bool)
1653	var tGroups []*identity.Group
1654	for _, group := range groups {
1655		gGroups, err := i.collectGroupsReverseDFS(group, visited, nil)
1656		if err != nil {
1657			return nil, nil, err
1658		}
1659		tGroups = append(tGroups, gGroups...)
1660	}
1661
1662	// Remove duplicates
1663	groupMap := make(map[string]*identity.Group)
1664	for _, group := range tGroups {
1665		groupMap[group.ID] = group
1666	}
1667
1668	tGroups = make([]*identity.Group, 0, len(groupMap))
1669	for _, group := range groupMap {
1670		tGroups = append(tGroups, group)
1671	}
1672
1673	diff := diffGroups(groups, tGroups)
1674
1675	// For sanity
1676	// There should not be any group that gets deleted
1677	if len(diff.Deleted) != 0 {
1678		return nil, nil, fmt.Errorf("failed to diff group memberships")
1679	}
1680
1681	return diff.Unmodified, diff.New, nil
1682}
1683
1684func (i *IdentityStore) collectGroupsReverseDFS(group *identity.Group, visited map[string]bool, groups []*identity.Group) ([]*identity.Group, error) {
1685	if group == nil {
1686		return nil, fmt.Errorf("nil group")
1687	}
1688
1689	// If traversal for a groupID is performed before, skip it
1690	if visited[group.ID] {
1691		return groups, nil
1692	}
1693	visited[group.ID] = true
1694
1695	groups = append(groups, group)
1696
1697	// Traverse all the parent groups
1698	for _, parentGroupID := range group.ParentGroupIDs {
1699		parentGroup, err := i.MemDBGroupByID(parentGroupID, false)
1700		if err != nil {
1701			return nil, err
1702		}
1703		if parentGroup == nil {
1704			continue
1705		}
1706		groups, err = i.collectGroupsReverseDFS(parentGroup, visited, groups)
1707		if err != nil {
1708			return nil, fmt.Errorf("failed to collect group at parent group ID %q", parentGroup.ID)
1709		}
1710	}
1711
1712	return groups, nil
1713}
1714
1715func (i *IdentityStore) collectPoliciesReverseDFS(group *identity.Group, visited map[string]bool, policies map[string][]string) error {
1716	if group == nil {
1717		return fmt.Errorf("nil group")
1718	}
1719
1720	// If traversal for a groupID is performed before, skip it
1721	if visited[group.ID] {
1722		return nil
1723	}
1724	visited[group.ID] = true
1725
1726	policies[group.NamespaceID] = append(policies[group.NamespaceID], group.Policies...)
1727
1728	// Traverse all the parent groups
1729	for _, parentGroupID := range group.ParentGroupIDs {
1730		parentGroup, err := i.MemDBGroupByID(parentGroupID, false)
1731		if err != nil {
1732			return err
1733		}
1734		if parentGroup == nil {
1735			continue
1736		}
1737		err = i.collectPoliciesReverseDFS(parentGroup, visited, policies)
1738		if err != nil {
1739			return fmt.Errorf("failed to collect policies at parent group ID %q", parentGroup.ID)
1740		}
1741	}
1742
1743	return nil
1744}
1745
1746func (i *IdentityStore) detectCycleDFS(visited map[string]bool, startingGroupID, groupID string) (bool, error) {
1747	// If the traversal reaches the startingGroupID, a loop is detected
1748	if startingGroupID == groupID {
1749		return true, nil
1750	}
1751
1752	// If traversal for a groupID is performed before, skip it
1753	if visited[groupID] {
1754		return false, nil
1755	}
1756	visited[groupID] = true
1757
1758	group, err := i.MemDBGroupByID(groupID, true)
1759	if err != nil {
1760		return false, err
1761	}
1762	if group == nil {
1763		return false, nil
1764	}
1765
1766	// Fetch all groups in which groupID is present as a ParentGroupID. In
1767	// other words, find all the subgroups of groupID.
1768	memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false)
1769	if err != nil {
1770		return false, err
1771	}
1772
1773	// DFS traverse the member groups
1774	for _, memberGroup := range memberGroups {
1775		cycleDetected, err := i.detectCycleDFS(visited, startingGroupID, memberGroup.ID)
1776		if err != nil {
1777			return false, fmt.Errorf("failed to perform cycle detection at member group ID %q", memberGroup.ID)
1778		}
1779		if cycleDetected {
1780			return true, fmt.Errorf("cycle detected at member group ID %q", memberGroup.ID)
1781		}
1782	}
1783
1784	return false, nil
1785}
1786
1787func (i *IdentityStore) memberGroupIDsByID(groupID string) ([]string, error) {
1788	var memberGroupIDs []string
1789	memberGroups, err := i.MemDBGroupsByParentGroupID(groupID, false)
1790	if err != nil {
1791		return nil, err
1792	}
1793	for _, memberGroup := range memberGroups {
1794		memberGroupIDs = append(memberGroupIDs, memberGroup.ID)
1795	}
1796	return memberGroupIDs, nil
1797}
1798
1799func (i *IdentityStore) generateName(ctx context.Context, entryType string) (string, error) {
1800	var name string
1801OUTER:
1802	for {
1803		randBytes, err := uuid.GenerateRandomBytes(4)
1804		if err != nil {
1805			return "", err
1806		}
1807		name = fmt.Sprintf("%s_%s", entryType, fmt.Sprintf("%08x", randBytes[0:4]))
1808
1809		switch entryType {
1810		case "entity":
1811			entity, err := i.MemDBEntityByName(ctx, name, false)
1812			if err != nil {
1813				return "", err
1814			}
1815			if entity == nil {
1816				break OUTER
1817			}
1818		case "group":
1819			group, err := i.MemDBGroupByName(ctx, name, false)
1820			if err != nil {
1821				return "", err
1822			}
1823			if group == nil {
1824				break OUTER
1825			}
1826		default:
1827			return "", fmt.Errorf("unrecognized type %q", entryType)
1828		}
1829	}
1830
1831	return name, nil
1832}
1833
1834func (i *IdentityStore) MemDBGroupsByBucketKeyInTxn(txn *memdb.Txn, bucketKey string) ([]*identity.Group, error) {
1835	if txn == nil {
1836		return nil, fmt.Errorf("nil txn")
1837	}
1838
1839	if bucketKey == "" {
1840		return nil, fmt.Errorf("empty bucket key")
1841	}
1842
1843	groupsIter, err := txn.Get(groupsTable, "bucket_key", bucketKey)
1844	if err != nil {
1845		return nil, fmt.Errorf("failed to lookup groups using bucket entry key hash: %w", err)
1846	}
1847
1848	var groups []*identity.Group
1849	for group := groupsIter.Next(); group != nil; group = groupsIter.Next() {
1850		groups = append(groups, group.(*identity.Group))
1851	}
1852
1853	return groups, nil
1854}
1855
1856func (i *IdentityStore) MemDBGroupByAliasIDInTxn(txn *memdb.Txn, aliasID string, clone bool) (*identity.Group, error) {
1857	if aliasID == "" {
1858		return nil, fmt.Errorf("missing alias ID")
1859	}
1860
1861	if txn == nil {
1862		return nil, fmt.Errorf("txn is nil")
1863	}
1864
1865	alias, err := i.MemDBAliasByIDInTxn(txn, aliasID, false, true)
1866	if err != nil {
1867		return nil, err
1868	}
1869
1870	if alias == nil {
1871		return nil, nil
1872	}
1873
1874	return i.MemDBGroupByIDInTxn(txn, alias.CanonicalID, clone)
1875}
1876
1877func (i *IdentityStore) MemDBGroupByAliasID(aliasID string, clone bool) (*identity.Group, error) {
1878	if aliasID == "" {
1879		return nil, fmt.Errorf("missing alias ID")
1880	}
1881
1882	txn := i.db.Txn(false)
1883
1884	return i.MemDBGroupByAliasIDInTxn(txn, aliasID, clone)
1885}
1886
1887func (i *IdentityStore) refreshExternalGroupMembershipsByEntityID(ctx context.Context, entityID string, groupAliases []*logical.Alias, mountAccessor string) ([]*logical.Alias, error) {
1888	defer metrics.MeasureSince([]string{"identity", "refresh_external_groups"}, time.Now())
1889
1890	if entityID == "" {
1891		return nil, fmt.Errorf("empty entity ID")
1892	}
1893
1894	refreshFunc := func(dryRun bool) (bool, []*logical.Alias, error) {
1895		if !dryRun {
1896			i.groupLock.Lock()
1897			defer i.groupLock.Unlock()
1898		}
1899
1900		txn := i.db.Txn(!dryRun)
1901		defer txn.Abort()
1902
1903		oldGroups, err := i.MemDBGroupsByMemberEntityIDInTxn(txn, entityID, true, true)
1904		if err != nil {
1905			return false, nil, err
1906		}
1907
1908		var newGroups []*identity.Group
1909		var validAliases []*logical.Alias
1910		for _, alias := range groupAliases {
1911			aliasByFactors, err := i.MemDBAliasByFactorsInTxn(txn, alias.MountAccessor, alias.Name, true, true)
1912			if err != nil {
1913				return false, nil, err
1914			}
1915			if aliasByFactors == nil {
1916				continue
1917			}
1918			mappingGroup, err := i.MemDBGroupByAliasIDInTxn(txn, aliasByFactors.ID, true)
1919			if err != nil {
1920				return false, nil, err
1921			}
1922			if mappingGroup == nil {
1923				return false, nil, fmt.Errorf("group unavailable for a valid alias ID %q", aliasByFactors.ID)
1924			}
1925
1926			newGroups = append(newGroups, mappingGroup)
1927			validAliases = append(validAliases, alias)
1928		}
1929
1930		diff := diffGroups(oldGroups, newGroups)
1931
1932		// Add the entity ID to all the new groups
1933		for _, group := range diff.New {
1934			if group.Type != groupTypeExternal {
1935				continue
1936			}
1937
1938			// We need to update a group, if we are in a dry run we should
1939			// report back that a change needs to take place.
1940			if dryRun {
1941				return true, nil, nil
1942			}
1943
1944			i.logger.Debug("adding member entity ID to external group", "member_entity_id", entityID, "group_id", group.ID)
1945
1946			group.MemberEntityIDs = append(group.MemberEntityIDs, entityID)
1947
1948			err = i.UpsertGroupInTxn(ctx, txn, group, true)
1949			if err != nil {
1950				return false, nil, err
1951			}
1952		}
1953
1954		// Remove the entity ID from all the deleted groups
1955		for _, group := range diff.Deleted {
1956			if group.Type != groupTypeExternal {
1957				continue
1958			}
1959
1960			// If the external group is from a different mount, don't remove the
1961			// entity ID from it.
1962			if mountAccessor != "" && group.Alias != nil && group.Alias.MountAccessor != mountAccessor {
1963				continue
1964			}
1965
1966			// We need to update a group, if we are in a dry run we should
1967			// report back that a change needs to take place.
1968			if dryRun {
1969				return true, nil, nil
1970			}
1971
1972			i.logger.Debug("removing member entity ID from external group", "member_entity_id", entityID, "group_id", group.ID)
1973
1974			group.MemberEntityIDs = strutil.StrListDelete(group.MemberEntityIDs, entityID)
1975
1976			err = i.UpsertGroupInTxn(ctx, txn, group, true)
1977			if err != nil {
1978				return false, nil, err
1979			}
1980		}
1981
1982		txn.Commit()
1983		return false, validAliases, nil
1984	}
1985
1986	// dryRun
1987	needsUpdate, validAliases, err := refreshFunc(true)
1988	if err != nil {
1989		return nil, err
1990	}
1991
1992	if needsUpdate || len(groupAliases) > 0 {
1993		i.logger.Debug("refreshing external group memberships", "entity_id", entityID, "group_aliases", groupAliases)
1994	}
1995
1996	if !needsUpdate {
1997		return validAliases, nil
1998	}
1999
2000	// Run the update
2001	_, validAliases, err = refreshFunc(false)
2002	if err != nil {
2003		return nil, err
2004	}
2005
2006	return validAliases, nil
2007}
2008
2009// diffGroups is used to diff two sets of groups
2010func diffGroups(old, new []*identity.Group) *groupDiff {
2011	diff := &groupDiff{}
2012
2013	existing := make(map[string]*identity.Group)
2014	for _, group := range old {
2015		existing[group.ID] = group
2016	}
2017
2018	for _, group := range new {
2019		// Check if the entry in new is present in the old
2020		_, ok := existing[group.ID]
2021
2022		// If its not present, then its a new entry
2023		if !ok {
2024			diff.New = append(diff.New, group)
2025			continue
2026		}
2027
2028		// If its present, it means that its unmodified
2029		diff.Unmodified = append(diff.Unmodified, group)
2030
2031		// By deleting the unmodified from the old set, we could determine the
2032		// ones that are stale by looking at the remaining ones.
2033		delete(existing, group.ID)
2034	}
2035
2036	// Any remaining entries must have been deleted
2037	for _, me := range existing {
2038		diff.Deleted = append(diff.Deleted, me)
2039	}
2040
2041	return diff
2042}
2043
2044func (i *IdentityStore) handleAliasListCommon(ctx context.Context, groupAlias bool) (*logical.Response, error) {
2045	ns, err := namespace.FromContext(ctx)
2046	if err != nil {
2047		return nil, err
2048	}
2049
2050	tableName := entityAliasesTable
2051	if groupAlias {
2052		tableName = groupAliasesTable
2053	}
2054
2055	ws := memdb.NewWatchSet()
2056
2057	txn := i.db.Txn(false)
2058
2059	iter, err := txn.Get(tableName, "namespace_id", ns.ID)
2060	if err != nil {
2061		return nil, fmt.Errorf("failed to fetch iterator for aliases in memdb: %w", err)
2062	}
2063
2064	ws.Add(iter.WatchCh())
2065
2066	var aliasIDs []string
2067	aliasInfo := map[string]interface{}{}
2068
2069	type mountInfo struct {
2070		MountType string
2071		MountPath string
2072	}
2073	mountAccessorMap := map[string]mountInfo{}
2074
2075	for {
2076		raw := iter.Next()
2077		if raw == nil {
2078			break
2079		}
2080		alias := raw.(*identity.Alias)
2081		aliasIDs = append(aliasIDs, alias.ID)
2082		aliasInfoEntry := map[string]interface{}{
2083			"name":           alias.Name,
2084			"canonical_id":   alias.CanonicalID,
2085			"mount_accessor": alias.MountAccessor,
2086		}
2087
2088		mi, ok := mountAccessorMap[alias.MountAccessor]
2089		if ok {
2090			aliasInfoEntry["mount_type"] = mi.MountType
2091			aliasInfoEntry["mount_path"] = mi.MountPath
2092		} else {
2093			mi = mountInfo{}
2094			if mountValidationResp := i.core.router.validateMountByAccessor(alias.MountAccessor); mountValidationResp != nil {
2095				mi.MountType = mountValidationResp.MountType
2096				mi.MountPath = mountValidationResp.MountPath
2097				aliasInfoEntry["mount_type"] = mi.MountType
2098				aliasInfoEntry["mount_path"] = mi.MountPath
2099			}
2100			mountAccessorMap[alias.MountAccessor] = mi
2101		}
2102
2103		aliasInfo[alias.ID] = aliasInfoEntry
2104	}
2105
2106	return logical.ListResponseWithInfo(aliasIDs, aliasInfo), nil
2107}
2108
2109func (i *IdentityStore) countEntities() (int, error) {
2110	txn := i.db.Txn(false)
2111
2112	iter, err := txn.Get(entitiesTable, "id")
2113	if err != nil {
2114		return -1, err
2115	}
2116
2117	count := 0
2118	val := iter.Next()
2119	for val != nil {
2120		count++
2121		val = iter.Next()
2122	}
2123
2124	return count, nil
2125}
2126
2127// Sum up the number of entities belonging to each namespace (keyed by ID)
2128func (i *IdentityStore) countEntitiesByNamespace(ctx context.Context) (map[string]int, error) {
2129	txn := i.db.Txn(false)
2130	iter, err := txn.Get(entitiesTable, "id")
2131	if err != nil {
2132		return nil, err
2133	}
2134
2135	byNamespace := make(map[string]int)
2136	val := iter.Next()
2137	for val != nil {
2138		// Check if runtime exceeded.
2139		select {
2140		case <-ctx.Done():
2141			return byNamespace, errors.New("context cancelled")
2142		default:
2143			break
2144		}
2145
2146		// Count in the namespace attached to the entity.
2147		entity := val.(*identity.Entity)
2148		byNamespace[entity.NamespaceID] = byNamespace[entity.NamespaceID] + 1
2149		val = iter.Next()
2150	}
2151
2152	return byNamespace, nil
2153}
2154
2155// Sum up the number of entities belonging to each mount point (keyed by accessor)
2156func (i *IdentityStore) countEntitiesByMountAccessor(ctx context.Context) (map[string]int, error) {
2157	txn := i.db.Txn(false)
2158	iter, err := txn.Get(entitiesTable, "id")
2159	if err != nil {
2160		return nil, err
2161	}
2162
2163	byMountAccessor := make(map[string]int)
2164	val := iter.Next()
2165	for val != nil {
2166		// Check if runtime exceeded.
2167		select {
2168		case <-ctx.Done():
2169			return byMountAccessor, errors.New("context cancelled")
2170		default:
2171			break
2172		}
2173
2174		// Count each alias separately; will translate to mount point and type
2175		// in the caller.
2176		entity := val.(*identity.Entity)
2177		for _, alias := range entity.Aliases {
2178			byMountAccessor[alias.MountAccessor] = byMountAccessor[alias.MountAccessor] + 1
2179		}
2180		val = iter.Next()
2181	}
2182
2183	return byMountAccessor, nil
2184}
2185