1// Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
2// See LICENSE.txt for license information.
3
4package localcachelayer
5
6import (
7	"os"
8	"sync"
9	"testing"
10
11	"github.com/mattermost/mattermost-server/v6/model"
12	"github.com/mattermost/mattermost-server/v6/store"
13	"github.com/mattermost/mattermost-server/v6/store/sqlstore"
14	"github.com/mattermost/mattermost-server/v6/store/storetest"
15)
16
17type storeType struct {
18	Name        string
19	SqlSettings *model.SqlSettings
20	SqlStore    *sqlstore.SqlStore
21	Store       store.Store
22}
23
24var storeTypes []*storeType
25
26func newStoreType(name, driver string) *storeType {
27	return &storeType{
28		Name:        name,
29		SqlSettings: storetest.MakeSqlSettings(driver, false),
30	}
31}
32
33func StoreTest(t *testing.T, f func(*testing.T, store.Store)) {
34	defer func() {
35		if err := recover(); err != nil {
36			tearDownStores()
37			panic(err)
38		}
39	}()
40	for _, st := range storeTypes {
41		st := st
42		t.Run(st.Name, func(t *testing.T) {
43			if testing.Short() {
44				t.SkipNow()
45			}
46			f(t, st.Store)
47		})
48	}
49}
50
51func StoreTestWithSqlStore(t *testing.T, f func(*testing.T, store.Store, storetest.SqlStore)) {
52	defer func() {
53		if err := recover(); err != nil {
54			tearDownStores()
55			panic(err)
56		}
57	}()
58	for _, st := range storeTypes {
59		st := st
60		t.Run(st.Name, func(t *testing.T) {
61			if testing.Short() {
62				t.SkipNow()
63			}
64			f(t, st.Store, st.SqlStore)
65		})
66	}
67}
68
69func initStores() {
70	if testing.Short() {
71		return
72	}
73
74	// In CI, we already run the entire test suite for both mysql and postgres in parallel.
75	// So we just run the tests for the current database set.
76	if os.Getenv("IS_CI") == "true" {
77		switch os.Getenv("MM_SQLSETTINGS_DRIVERNAME") {
78		case "mysql":
79			storeTypes = append(storeTypes, newStoreType("LocalCache+MySQL", model.DatabaseDriverMysql))
80		case "postgres":
81			storeTypes = append(storeTypes, newStoreType("LocalCache+PostgreSQL", model.DatabaseDriverPostgres))
82		}
83	} else {
84		storeTypes = append(storeTypes, newStoreType("LocalCache+MySQL", model.DatabaseDriverMysql),
85			newStoreType("LocalCache+PostgreSQL", model.DatabaseDriverPostgres))
86	}
87
88	defer func() {
89		if err := recover(); err != nil {
90			tearDownStores()
91			panic(err)
92		}
93	}()
94	var wg sync.WaitGroup
95	for _, st := range storeTypes {
96		st := st
97		wg.Add(1)
98		go func() {
99			var err error
100			defer wg.Done()
101			st.SqlStore = sqlstore.New(*st.SqlSettings, nil)
102			st.Store, err = NewLocalCacheLayer(st.SqlStore, nil, nil, getMockCacheProvider())
103			if err != nil {
104				panic(err)
105			}
106			st.Store.DropAllTables()
107			st.Store.MarkSystemRanUnitTests()
108		}()
109	}
110	wg.Wait()
111}
112
113var tearDownStoresOnce sync.Once
114
115func tearDownStores() {
116	if testing.Short() {
117		return
118	}
119	tearDownStoresOnce.Do(func() {
120		var wg sync.WaitGroup
121		wg.Add(len(storeTypes))
122		for _, st := range storeTypes {
123			st := st
124			go func() {
125				if st.Store != nil {
126					st.Store.Close()
127				}
128				if st.SqlSettings != nil {
129					storetest.CleanupSqlSettings(st.SqlSettings)
130				}
131				wg.Done()
132			}()
133		}
134		wg.Wait()
135	})
136}
137