1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18// +build go1.9
19
20package apmgocql_test
21
22import (
23	"context"
24	"errors"
25	"os"
26	"testing"
27	"time"
28
29	"github.com/gocql/gocql"
30	"github.com/stretchr/testify/assert"
31	"github.com/stretchr/testify/require"
32
33	"go.elastic.co/apm/apmtest"
34	"go.elastic.co/apm/model"
35	"go.elastic.co/apm/module/apmgocql"
36)
37
38const (
39	createKeyspaceStatement = `
40CREATE KEYSPACE IF NOT EXISTS foo
41WITH REPLICATION = {
42	'class' : 'SimpleStrategy',
43	'replication_factor' : 1
44};`
45)
46
47var cassandraHost = os.Getenv("CASSANDRA_HOST")
48
49func TestQueryObserver(t *testing.T) {
50	var start time.Time
51	observer := apmgocql.NewObserver()
52	_, spans, errors := apmtest.WithTransaction(func(ctx context.Context) {
53		start = time.Now()
54		observer.ObserveQuery(ctx, gocql.ObservedQuery{
55			Start:     start,
56			End:       start.Add(3 * time.Second),
57			Keyspace:  "quay ",
58			Statement: "SELECT * FROM foo.bar",
59			Err:       errors.New("baz"),
60		})
61	})
62
63	require.Len(t, spans, 1)
64	assert.Equal(t, "db", spans[0].Type)
65	assert.Equal(t, "cassandra", spans[0].Subtype)
66	assert.Equal(t, "query", spans[0].Action)
67	assert.Equal(t, "SELECT FROM foo.bar", spans[0].Name)
68	assert.WithinDuration(t,
69		time.Time(spans[0].Timestamp).Add(time.Duration(spans[0].Duration*1000000)),
70		start.Add(3*time.Second),
71		100*time.Millisecond, // allow some leeway for slow systems
72	)
73	assert.Equal(t, &model.SpanContext{
74		Database: &model.DatabaseSpanContext{
75			Type:      "cassandra",
76			Instance:  "quay ",
77			Statement: "SELECT * FROM foo.bar",
78		},
79	}, spans[0].Context)
80
81	require.Len(t, errors, 1)
82	assert.Equal(t, "TestQueryObserver.func1", errors[0].Culprit)
83}
84
85func TestBatchObserver(t *testing.T) {
86	var start time.Time
87	observer := apmgocql.NewObserver()
88	_, spans, errors := apmtest.WithTransaction(func(ctx context.Context) {
89		start = time.Now()
90		observer.ObserveBatch(ctx, gocql.ObservedBatch{
91			Start:    start,
92			End:      start.Add(3 * time.Second),
93			Keyspace: "quay ",
94			Statements: []string{
95				"INSERT INTO foo.bar(id) VALUES(1)",
96				"UPDATE foo.bar SET id=2",
97			},
98			Err: errors.New("baz"),
99		})
100	})
101
102	require.Len(t, spans, 3)
103	assert.Equal(t, "db", spans[2].Type)
104	assert.Equal(t, "cassandra", spans[2].Subtype)
105	assert.Equal(t, "batch", spans[2].Action) // sent last
106	for _, span := range spans[:2] {
107		assert.Equal(t, spans[2].ID, span.ParentID)
108		assert.Equal(t, spans[2].TraceID, span.TraceID)
109		assert.Equal(t, "db", span.Type)
110		assert.Equal(t, "cassandra", span.Subtype)
111		assert.Equal(t, "query", span.Action)
112	}
113
114	assert.Equal(t, "INSERT INTO foo.bar", spans[0].Name)
115	assert.Equal(t, "UPDATE foo.bar", spans[1].Name)
116	assert.Equal(t, "BATCH", spans[2].Name)
117
118	assert.Equal(t, &model.SpanContext{
119		Database: &model.DatabaseSpanContext{
120			Type:     "cassandra",
121			Instance: "quay ",
122		},
123	}, spans[2].Context)
124
125	assert.Equal(t, &model.SpanContext{
126		Database: &model.DatabaseSpanContext{
127			Type:      "cassandra",
128			Instance:  "quay ",
129			Statement: "INSERT INTO foo.bar(id) VALUES(1)",
130		},
131	}, spans[0].Context)
132
133	require.Len(t, errors, 1)
134	assert.Equal(t, "TestBatchObserver.func1", errors[0].Culprit)
135}
136
137func TestQueryObserverIntegration(t *testing.T) {
138	session := newSession(t)
139	defer session.Close()
140
141	_, spans, _ := apmtest.WithTransaction(func(ctx context.Context) {
142		err := execQuery(ctx, session, createKeyspaceStatement)
143		assert.NoError(t, err)
144
145		err = execQuery(ctx, session, `CREATE TABLE IF NOT EXISTS foo.bar (id int, PRIMARY KEY(id));`)
146		assert.NoError(t, err)
147
148		err = execQuery(ctx, session, "INSERT INTO foo.bar(id) VALUES(1)")
149		assert.NoError(t, err)
150	})
151
152	require.Len(t, spans, 3)
153	for _, span := range spans {
154		assert.Equal(t, "db", span.Type)
155		assert.Equal(t, "cassandra", span.Subtype)
156		assert.Equal(t, "query", span.Action)
157	}
158	assert.Equal(t, "CREATE", spans[0].Name)
159	assert.Equal(t, &model.SpanContext{
160		Database: &model.DatabaseSpanContext{
161			Type:      "cassandra",
162			Statement: createKeyspaceStatement,
163		},
164	}, spans[0].Context)
165	assert.Equal(t, "CREATE", spans[1].Name)
166	assert.Equal(t, "INSERT INTO foo.bar", spans[2].Name)
167}
168
169func TestBatchObserverIntegration(t *testing.T) {
170	session := newSession(t)
171	defer session.Close()
172
173	err := execQuery(context.Background(), session, createKeyspaceStatement)
174	assert.NoError(t, err)
175
176	err = execQuery(context.Background(), session, `CREATE TABLE IF NOT EXISTS foo.bar (id int, PRIMARY KEY(id));`)
177	assert.NoError(t, err)
178
179	tx, spans, _ := apmtest.WithTransaction(func(ctx context.Context) {
180		batch := session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
181		batch.Query("INSERT INTO foo.bar(id) VALUES(1)")
182		batch.Query("INSERT INTO foo.bar(id) VALUES(2)")
183		err := session.ExecuteBatch(batch)
184		assert.NoError(t, err)
185	})
186
187	require.Len(t, spans, 3)
188	assert.Equal(t, tx.ID, spans[2].ParentID)
189	assert.Equal(t, tx.TraceID, spans[2].TraceID)
190	for _, span := range spans[:2] {
191		assert.Equal(t, spans[2].ID, span.ParentID)
192		assert.Equal(t, spans[2].TraceID, span.TraceID)
193	}
194
195	assert.Equal(t, "INSERT INTO foo.bar", spans[0].Name)
196	assert.Equal(t, "INSERT INTO foo.bar", spans[1].Name)
197	assert.Equal(t, "BATCH", spans[2].Name)
198
199	assert.Equal(t, &model.SpanContext{
200		Database: &model.DatabaseSpanContext{
201			Type: "cassandra",
202		},
203	}, spans[2].Context)
204
205	assert.Equal(t, &model.SpanContext{
206		Database: &model.DatabaseSpanContext{
207			Type:      "cassandra",
208			Statement: "INSERT INTO foo.bar(id) VALUES(1)",
209		},
210	}, spans[0].Context)
211}
212
213func TestQueryObserverErrorIntegration(t *testing.T) {
214	session := newSession(t)
215	defer session.Close()
216
217	var queryError error
218	_, spans, errors := apmtest.WithTransaction(func(ctx context.Context) {
219		queryError = execQuery(ctx, session, "ZINGA")
220	})
221	require.Len(t, errors, 1)
222	require.Len(t, spans, 1)
223
224	assert.Equal(t, errors[0].Culprit, "execQuery")
225	assert.EqualError(t, queryError, errors[0].Exception.Message)
226}
227
228func execQuery(ctx context.Context, session *gocql.Session, query string) error {
229	return session.Query(query).WithContext(ctx).Exec()
230}
231
232func newSession(t *testing.T) *gocql.Session {
233	if cassandraHost == "" {
234		t.Skipf("CASSANDRA_HOST not specified")
235	}
236	observer := apmgocql.NewObserver()
237	config := gocql.NewCluster(cassandraHost)
238	config.QueryObserver = observer
239	config.BatchObserver = observer
240	session, err := config.CreateSession()
241	require.NoError(t, err)
242	return session
243}
244