1package cloudwatch
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"sync"
8	"time"
9
10	"github.com/aws/aws-sdk-go/aws"
11	"github.com/aws/aws-sdk-go/aws/request"
12	"github.com/aws/aws-sdk-go/aws/session"
13	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
14	"github.com/aws/aws-sdk-go/service/servicequotas"
15	"github.com/aws/aws-sdk-go/service/servicequotas/servicequotasiface"
16	"github.com/google/uuid"
17	"github.com/grafana/grafana-plugin-sdk-go/backend"
18	"github.com/grafana/grafana-plugin-sdk-go/data"
19	"github.com/grafana/grafana/pkg/components/simplejson"
20	"github.com/grafana/grafana/pkg/models"
21	"github.com/grafana/grafana/pkg/setting"
22	"github.com/grafana/grafana/pkg/util/retryer"
23	"golang.org/x/sync/errgroup"
24)
25
26const defaultConcurrentQueries = 4
27
28type LogQueryRunnerSupplier struct {
29	Publisher models.ChannelPublisher
30	Service   *LogsService
31}
32
33type logQueryRunner struct {
34	channelName string
35	publish     models.ChannelPublisher
36	running     map[string]bool
37	runningMu   sync.Mutex
38	service     *LogsService
39}
40
41const (
42	maxAttempts   = 8
43	minRetryDelay = 500 * time.Millisecond
44	maxRetryDelay = 30 * time.Second
45)
46
47// GetHandlerForPath gets the channel handler for a certain path.
48func (s *LogQueryRunnerSupplier) GetHandlerForPath(path string) (models.ChannelHandler, error) {
49	return &logQueryRunner{
50		channelName: path,
51		publish:     s.Publisher,
52		running:     make(map[string]bool),
53		service:     s.Service,
54	}, nil
55}
56
57// OnSubscribe publishes results from the corresponding CloudWatch Logs query to the provided channel
58func (r *logQueryRunner) OnSubscribe(ctx context.Context, user *models.SignedInUser, e models.SubscribeEvent) (models.SubscribeReply, backend.SubscribeStreamStatus, error) {
59	r.runningMu.Lock()
60	defer r.runningMu.Unlock()
61
62	if _, ok := r.running[e.Channel]; ok {
63		return models.SubscribeReply{}, backend.SubscribeStreamStatusOK, nil
64	}
65
66	r.running[e.Channel] = true
67	go func() {
68		if err := r.publishResults(user.OrgId, e.Channel); err != nil {
69			plog.Error(err.Error())
70		}
71	}()
72
73	return models.SubscribeReply{}, backend.SubscribeStreamStatusOK, nil
74}
75
76// OnPublish checks if a message from the websocket can be broadcast on this channel
77func (r *logQueryRunner) OnPublish(ctx context.Context, user *models.SignedInUser, e models.PublishEvent) (models.PublishReply, backend.PublishStreamStatus, error) {
78	return models.PublishReply{}, backend.PublishStreamStatusPermissionDenied, nil
79}
80
81func (r *logQueryRunner) publishResults(orgID int64, channelName string) error {
82	defer func() {
83		r.service.DeleteResponseChannel(channelName)
84		r.runningMu.Lock()
85		delete(r.running, channelName)
86		r.runningMu.Unlock()
87	}()
88
89	responseChannel, err := r.service.GetResponseChannel(channelName)
90	if err != nil {
91		return err
92	}
93
94	for response := range responseChannel {
95		responseBytes, err := json.Marshal(response)
96		if err != nil {
97			return err
98		}
99
100		if err := r.publish(orgID, channelName, responseBytes); err != nil {
101			return err
102		}
103	}
104
105	return nil
106}
107
108// executeLiveLogQuery executes a CloudWatch Logs query with live updates over WebSocket.
109// A WebSocket channel is created, which goroutines send responses over.
110func (e *cloudWatchExecutor) executeLiveLogQuery(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) {
111	responseChannelName := uuid.New().String()
112	responseChannel := make(chan *backend.QueryDataResponse)
113	if err := e.logsService.AddResponseChannel("plugin/cloudwatch/"+responseChannelName, responseChannel); err != nil {
114		close(responseChannel)
115		return nil, err
116	}
117
118	go e.sendLiveQueriesToChannel(req, responseChannel)
119
120	response := &backend.QueryDataResponse{
121		Responses: backend.Responses{
122			"A": {
123				Frames: data.Frames{data.NewFrame("A").SetMeta(&data.FrameMeta{
124					Custom: map[string]interface{}{
125						"channelName": responseChannelName,
126					},
127				})},
128			},
129		},
130	}
131
132	return response, nil
133}
134
135func (e *cloudWatchExecutor) sendLiveQueriesToChannel(req *backend.QueryDataRequest, responseChannel chan *backend.QueryDataResponse) {
136	defer close(responseChannel)
137
138	ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
139	defer cancel()
140	eg, ectx := errgroup.WithContext(ctx)
141
142	for _, query := range req.Queries {
143		query := query
144		eg.Go(func() error {
145			return e.startLiveQuery(ectx, responseChannel, query, query.TimeRange, req.PluginContext)
146		})
147	}
148
149	if err := eg.Wait(); err != nil {
150		plog.Error(err.Error())
151	}
152}
153
154func (e *cloudWatchExecutor) getQueue(queueKey string, pluginCtx backend.PluginContext) (chan bool, error) {
155	e.logsService.queueLock.Lock()
156	defer e.logsService.queueLock.Unlock()
157
158	if queue, ok := e.logsService.queues[queueKey]; ok {
159		return queue, nil
160	}
161
162	concurrentQueriesQuota := e.fetchConcurrentQueriesQuota(queueKey, pluginCtx)
163
164	queueChannel := make(chan bool, concurrentQueriesQuota)
165	e.logsService.queues[queueKey] = queueChannel
166
167	return queueChannel, nil
168}
169
170func (e *cloudWatchExecutor) fetchConcurrentQueriesQuota(region string, pluginCtx backend.PluginContext) int {
171	sess, err := e.newSession(region, pluginCtx)
172	if err != nil {
173		plog.Warn("Could not get service quota client")
174		return defaultConcurrentQueries
175	}
176
177	client := newQuotasClient(sess)
178
179	concurrentQueriesQuota, err := client.GetServiceQuota(&servicequotas.GetServiceQuotaInput{
180		ServiceCode: aws.String("logs"),
181		QuotaCode:   aws.String("L-32C48FBB"),
182	})
183	if err != nil {
184		plog.Warn("Could not get service quota")
185		return defaultConcurrentQueries
186	}
187
188	if concurrentQueriesQuota != nil && concurrentQueriesQuota.Quota != nil && concurrentQueriesQuota.Quota.Value != nil {
189		return int(*concurrentQueriesQuota.Quota.Value)
190	}
191
192	plog.Warn("Could not get service quota")
193
194	defaultConcurrentQueriesQuota, err := client.GetAWSDefaultServiceQuota(&servicequotas.GetAWSDefaultServiceQuotaInput{
195		ServiceCode: aws.String("logs"),
196		QuotaCode:   aws.String("L-32C48FBB"),
197	})
198	if err != nil {
199		plog.Warn("Could not get default service quota")
200		return defaultConcurrentQueries
201	}
202
203	if defaultConcurrentQueriesQuota != nil && defaultConcurrentQueriesQuota.Quota != nil &&
204		defaultConcurrentQueriesQuota.Quota.Value != nil {
205		return int(*defaultConcurrentQueriesQuota.Quota.Value)
206	}
207
208	plog.Warn("Could not get default service quota")
209	return defaultConcurrentQueries
210}
211
212func (e *cloudWatchExecutor) startLiveQuery(ctx context.Context, responseChannel chan *backend.QueryDataResponse, query backend.DataQuery, timeRange backend.TimeRange, pluginCtx backend.PluginContext) error {
213	model, err := simplejson.NewJson(query.JSON)
214	if err != nil {
215		return err
216	}
217
218	dsInfo, err := e.getDSInfo(pluginCtx)
219	if err != nil {
220		return err
221	}
222
223	defaultRegion := dsInfo.region
224	region := model.Get("region").MustString(defaultRegion)
225	logsClient, err := e.getCWLogsClient(region, pluginCtx)
226	if err != nil {
227		return err
228	}
229
230	queue, err := e.getQueue(fmt.Sprintf("%s-%d", region, dsInfo.datasourceID), pluginCtx)
231	if err != nil {
232		return err
233	}
234
235	// Wait until there are no more active workers than the concurrent queries quota
236	queue <- true
237	defer func() { <-queue }()
238
239	startQueryOutput, err := e.executeStartQuery(ctx, logsClient, model, timeRange)
240	if err != nil {
241		responseChannel <- &backend.QueryDataResponse{
242			Responses: backend.Responses{
243				query.RefID: {Error: err},
244			},
245		}
246		return err
247	}
248
249	queryResultsInput := &cloudwatchlogs.GetQueryResultsInput{
250		QueryId: startQueryOutput.QueryId,
251	}
252
253	recordsMatched := 0.0
254	return retryer.Retry(func() (retryer.RetrySignal, error) {
255		getQueryResultsOutput, err := logsClient.GetQueryResultsWithContext(ctx, queryResultsInput)
256		if err != nil {
257			return retryer.FuncError, err
258		}
259
260		retryNeeded := *getQueryResultsOutput.Statistics.RecordsMatched <= recordsMatched
261		recordsMatched = *getQueryResultsOutput.Statistics.RecordsMatched
262
263		dataFrame, err := logsResultsToDataframes(getQueryResultsOutput)
264		if err != nil {
265			return retryer.FuncError, err
266		}
267
268		dataFrame.Name = query.RefID
269		dataFrame.RefID = query.RefID
270		dataFrames, err := groupResponseFrame(dataFrame, model.Get("statsGroups").MustStringArray())
271		if err != nil {
272			return retryer.FuncError, fmt.Errorf("failed to group dataframe response: %v", err)
273		}
274
275		responseChannel <- &backend.QueryDataResponse{
276			Responses: backend.Responses{
277				query.RefID: {
278					Frames: dataFrames,
279				},
280			},
281		}
282
283		if isTerminated(*getQueryResultsOutput.Status) {
284			return retryer.FuncComplete, nil
285		} else if retryNeeded {
286			return retryer.FuncFailure, nil
287		}
288
289		return retryer.FuncSuccess, nil
290	}, maxAttempts, minRetryDelay, maxRetryDelay)
291}
292
293func groupResponseFrame(frame *data.Frame, statsGroups []string) (data.Frames, error) {
294	var dataFrames data.Frames
295
296	// When a query of the form "stats ... by ..." is made, we want to return
297	// one series per group defined in the query, but due to the format
298	// the query response is in, there does not seem to be a way to tell
299	// by the response alone if/how the results should be grouped.
300	// Because of this, if the frontend sees that a "stats ... by ..." query is being made
301	// the "statsGroups" parameter is sent along with the query to the backend so that we
302	// can correctly group the CloudWatch logs response.
303	// Check if we have time field though as it makes sense to split only for time series.
304	if hasTimeField(frame) {
305		if len(statsGroups) > 0 && len(frame.Fields) > 0 {
306			groupedFrames, err := groupResults(frame, statsGroups)
307			if err != nil {
308				return nil, err
309			}
310
311			dataFrames = groupedFrames
312		} else {
313			setPreferredVisType(frame, "logs")
314			dataFrames = data.Frames{frame}
315		}
316	} else {
317		dataFrames = data.Frames{frame}
318	}
319	return dataFrames, nil
320}
321
322func hasTimeField(frame *data.Frame) bool {
323	for _, field := range frame.Fields {
324		if field.Type() == data.FieldTypeNullableTime {
325			return true
326		}
327	}
328	return false
329}
330
331func setPreferredVisType(frame *data.Frame, visType data.VisType) {
332	if frame.Meta != nil {
333		frame.Meta.PreferredVisualization = visType
334	} else {
335		frame.Meta = &data.FrameMeta{
336			PreferredVisualization: visType,
337		}
338	}
339}
340
341// Service quotas client factory.
342//
343// Stubbable by tests.
344var newQuotasClient = func(sess *session.Session) servicequotasiface.ServiceQuotasAPI {
345	client := servicequotas.New(sess)
346	client.Handlers.Send.PushFront(func(r *request.Request) {
347		r.HTTPRequest.Header.Set("User-Agent", fmt.Sprintf("Grafana/%s", setting.BuildVersion))
348	})
349
350	return client
351}
352