1package cassandra
2
3import (
4	"context"
5	"fmt"
6	"os"
7	"strconv"
8	"testing"
9	"time"
10
11	"github.com/gocql/gocql"
12	"github.com/hashicorp/errwrap"
13	"github.com/hashicorp/vault/helper/testhelpers/docker"
14	"github.com/hashicorp/vault/sdk/database/dbplugin"
15	"github.com/ory/dockertest"
16)
17
18func prepareCassandraTestContainer(t *testing.T) (func(), string, int) {
19	if os.Getenv("CASSANDRA_HOST") != "" {
20		return func() {}, os.Getenv("CASSANDRA_HOST"), 0
21	}
22
23	pool, err := dockertest.NewPool("")
24	if err != nil {
25		t.Fatalf("Failed to connect to docker: %s", err)
26	}
27
28	cwd, _ := os.Getwd()
29	cassandraMountPath := fmt.Sprintf("%s/test-fixtures/:/etc/cassandra/", cwd)
30
31	ro := &dockertest.RunOptions{
32		Repository: "cassandra",
33		Tag:        "latest",
34		Env:        []string{"CASSANDRA_BROADCAST_ADDRESS=127.0.0.1"},
35		Mounts:     []string{cassandraMountPath},
36	}
37	resource, err := pool.RunWithOptions(ro)
38	if err != nil {
39		t.Fatalf("Could not start local cassandra docker container: %s", err)
40	}
41
42	cleanup := func() {
43		docker.CleanupResource(t, pool, resource)
44	}
45
46	port, _ := strconv.Atoi(resource.GetPort("9042/tcp"))
47	address := fmt.Sprintf("127.0.0.1:%d", port)
48
49	// exponential backoff-retry
50	if err = pool.Retry(func() error {
51		clusterConfig := gocql.NewCluster(address)
52		clusterConfig.Authenticator = gocql.PasswordAuthenticator{
53			Username: "cassandra",
54			Password: "cassandra",
55		}
56		clusterConfig.ProtoVersion = 4
57		clusterConfig.Port = port
58
59		session, err := clusterConfig.CreateSession()
60		if err != nil {
61			return errwrap.Wrapf("error creating session: {{err}}", err)
62		}
63		defer session.Close()
64		return nil
65	}); err != nil {
66		cleanup()
67		t.Fatalf("Could not connect to cassandra docker container: %s", err)
68	}
69	return cleanup, address, port
70}
71
72func TestCassandra_Initialize(t *testing.T) {
73	if os.Getenv("VAULT_ACC") == "" {
74		t.SkipNow()
75	}
76	cleanup, address, port := prepareCassandraTestContainer(t)
77	defer cleanup()
78
79	connectionDetails := map[string]interface{}{
80		"hosts":            address,
81		"port":             port,
82		"username":         "cassandra",
83		"password":         "cassandra",
84		"protocol_version": 4,
85	}
86
87	db := new()
88	_, err := db.Init(context.Background(), connectionDetails, true)
89	if err != nil {
90		t.Fatalf("err: %s", err)
91	}
92
93	if !db.Initialized {
94		t.Fatal("Database should be initialized")
95	}
96
97	err = db.Close()
98	if err != nil {
99		t.Fatalf("err: %s", err)
100	}
101
102	// test a string protocol
103	connectionDetails = map[string]interface{}{
104		"hosts":            address,
105		"port":             strconv.Itoa(port),
106		"username":         "cassandra",
107		"password":         "cassandra",
108		"protocol_version": "4",
109	}
110
111	_, err = db.Init(context.Background(), connectionDetails, true)
112	if err != nil {
113		t.Fatalf("err: %s", err)
114	}
115}
116
117func TestCassandra_CreateUser(t *testing.T) {
118	if os.Getenv("VAULT_ACC") == "" {
119		t.SkipNow()
120	}
121	cleanup, address, port := prepareCassandraTestContainer(t)
122	defer cleanup()
123
124	connectionDetails := map[string]interface{}{
125		"hosts":            address,
126		"port":             port,
127		"username":         "cassandra",
128		"password":         "cassandra",
129		"protocol_version": 4,
130	}
131
132	db := new()
133	_, err := db.Init(context.Background(), connectionDetails, true)
134	if err != nil {
135		t.Fatalf("err: %s", err)
136	}
137
138	statements := dbplugin.Statements{
139		Creation: []string{testCassandraRole},
140	}
141
142	usernameConfig := dbplugin.UsernameConfig{
143		DisplayName: "test",
144		RoleName:    "test",
145	}
146
147	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
148	if err != nil {
149		t.Fatalf("err: %s", err)
150	}
151
152	if err := testCredsExist(t, address, port, username, password); err != nil {
153		t.Fatalf("Could not connect with new credentials: %s", err)
154	}
155}
156
157func TestMyCassandra_RenewUser(t *testing.T) {
158	if os.Getenv("VAULT_ACC") == "" {
159		t.SkipNow()
160	}
161	cleanup, address, port := prepareCassandraTestContainer(t)
162	defer cleanup()
163
164	connectionDetails := map[string]interface{}{
165		"hosts":            address,
166		"port":             port,
167		"username":         "cassandra",
168		"password":         "cassandra",
169		"protocol_version": 4,
170	}
171
172	db := new()
173	_, err := db.Init(context.Background(), connectionDetails, true)
174	if err != nil {
175		t.Fatalf("err: %s", err)
176	}
177
178	statements := dbplugin.Statements{
179		Creation: []string{testCassandraRole},
180	}
181
182	usernameConfig := dbplugin.UsernameConfig{
183		DisplayName: "test",
184		RoleName:    "test",
185	}
186
187	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
188	if err != nil {
189		t.Fatalf("err: %s", err)
190	}
191
192	if err := testCredsExist(t, address, port, username, password); err != nil {
193		t.Fatalf("Could not connect with new credentials: %s", err)
194	}
195
196	err = db.RenewUser(context.Background(), statements, username, time.Now().Add(time.Minute))
197	if err != nil {
198		t.Fatalf("err: %s", err)
199	}
200}
201
202func TestCassandra_RevokeUser(t *testing.T) {
203	if os.Getenv("VAULT_ACC") == "" {
204		t.SkipNow()
205	}
206	cleanup, address, port := prepareCassandraTestContainer(t)
207	defer cleanup()
208
209	connectionDetails := map[string]interface{}{
210		"hosts":            address,
211		"port":             port,
212		"username":         "cassandra",
213		"password":         "cassandra",
214		"protocol_version": 4,
215	}
216
217	db := new()
218	_, err := db.Init(context.Background(), connectionDetails, true)
219	if err != nil {
220		t.Fatalf("err: %s", err)
221	}
222
223	statements := dbplugin.Statements{
224		Creation: []string{testCassandraRole},
225	}
226
227	usernameConfig := dbplugin.UsernameConfig{
228		DisplayName: "test",
229		RoleName:    "test",
230	}
231
232	username, password, err := db.CreateUser(context.Background(), statements, usernameConfig, time.Now().Add(time.Minute))
233	if err != nil {
234		t.Fatalf("err: %s", err)
235	}
236
237	if err = testCredsExist(t, address, port, username, password); err != nil {
238		t.Fatalf("Could not connect with new credentials: %s", err)
239	}
240
241	// Test default revoke statements
242	err = db.RevokeUser(context.Background(), statements, username)
243	if err != nil {
244		t.Fatalf("err: %s", err)
245	}
246
247	if err = testCredsExist(t, address, port, username, password); err == nil {
248		t.Fatal("Credentials were not revoked")
249	}
250}
251
252func TestCassandra_RotateRootCredentials(t *testing.T) {
253	if os.Getenv("VAULT_ACC") == "" {
254		t.SkipNow()
255	}
256	cleanup, address, port := prepareCassandraTestContainer(t)
257	defer cleanup()
258
259	connectionDetails := map[string]interface{}{
260		"hosts":            address,
261		"port":             port,
262		"username":         "cassandra",
263		"password":         "cassandra",
264		"protocol_version": 4,
265	}
266
267	db := new()
268
269	connProducer := db.cassandraConnectionProducer
270
271	_, err := db.Init(context.Background(), connectionDetails, true)
272	if err != nil {
273		t.Fatalf("err: %s", err)
274	}
275
276	if !connProducer.Initialized {
277		t.Fatal("Database should be initialized")
278	}
279
280	newConf, err := db.RotateRootCredentials(context.Background(), nil)
281	if err != nil {
282		t.Fatalf("err: %v", err)
283	}
284	if newConf["password"] == "cassandra" {
285		t.Fatal("password was not updated")
286	}
287
288	err = db.Close()
289	if err != nil {
290		t.Fatalf("err: %s", err)
291	}
292}
293
294func testCredsExist(t testing.TB, address string, port int, username, password string) error {
295	clusterConfig := gocql.NewCluster(address)
296	clusterConfig.Authenticator = gocql.PasswordAuthenticator{
297		Username: username,
298		Password: password,
299	}
300	clusterConfig.ProtoVersion = 4
301	clusterConfig.Port = port
302
303	session, err := clusterConfig.CreateSession()
304	if err != nil {
305		return errwrap.Wrapf("error creating session: {{err}}", err)
306	}
307	defer session.Close()
308	return nil
309}
310
311const testCassandraRole = `CREATE USER '{{username}}' WITH PASSWORD '{{password}}' NOSUPERUSER;
312GRANT ALL PERMISSIONS ON ALL KEYSPACES TO {{username}};`
313