1package stats
2
3import (
4	"context"
5	"io"
6	"log"
7	"net"
8	"testing"
9
10	"github.com/stretchr/testify/require"
11	"google.golang.org/grpc"
12	"google.golang.org/grpc/test/bufconn"
13
14	"github.com/grafana/loki/pkg/logproto"
15)
16
17const bufSize = 1024 * 1024
18
19var lis *bufconn.Listener
20var server *grpc.Server
21
22func init() {
23	lis = bufconn.Listen(bufSize)
24	server = grpc.NewServer()
25}
26
27func bufDialer(context.Context, string) (net.Conn, error) {
28	return lis.Dial()
29}
30
31func TestCollectTrailer(t *testing.T) {
32	ctx := context.Background()
33	conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure())
34	if err != nil {
35		t.Fatalf("Failed to dial bufnet: %v", err)
36	}
37	defer conn.Close()
38	ing := ingesterFn(func(s grpc.ServerStream) error {
39		ingCtx := NewContext(s.Context())
40		defer SendAsTrailer(ingCtx, s)
41		GetIngesterData(ingCtx).TotalChunksMatched++
42		GetIngesterData(ingCtx).TotalBatches = +2
43		GetIngesterData(ingCtx).TotalLinesSent = +3
44		GetChunkData(ingCtx).HeadChunkBytes++
45		GetChunkData(ingCtx).HeadChunkLines++
46		GetChunkData(ingCtx).DecompressedBytes++
47		GetChunkData(ingCtx).DecompressedLines++
48		GetChunkData(ingCtx).CompressedBytes++
49		GetChunkData(ingCtx).TotalDuplicates++
50		return nil
51	})
52	logproto.RegisterQuerierServer(server, ing)
53	go func() {
54		if err := server.Serve(lis); err != nil {
55			log.Fatalf("Server exited with error: %v", err)
56		}
57	}()
58
59	ingClient := logproto.NewQuerierClient(conn)
60
61	ctx = NewContext(ctx)
62
63	// query the ingester twice once for logs , once for samples.
64	clientStream, err := ingClient.Query(ctx, &logproto.QueryRequest{}, CollectTrailer(ctx))
65	if err != nil {
66		t.Fatal(err)
67	}
68	_, err = clientStream.Recv()
69	if err != nil && err != io.EOF {
70		t.Fatal(err)
71	}
72	clientSamples, err := ingClient.QuerySample(ctx, &logproto.SampleQueryRequest{}, CollectTrailer(ctx))
73	if err != nil {
74		t.Fatal(err)
75	}
76	_, err = clientSamples.Recv()
77	if err != nil && err != io.EOF {
78		t.Fatal(err)
79	}
80	err = clientSamples.CloseSend()
81	if err != nil {
82		t.Fatal(err)
83	}
84	res := decodeTrailers(ctx)
85	require.Equal(t, int32(2), res.Ingester.TotalReached)
86	require.Equal(t, int64(2), res.Ingester.TotalChunksMatched)
87	require.Equal(t, int64(4), res.Ingester.TotalBatches)
88	require.Equal(t, int64(6), res.Ingester.TotalLinesSent)
89	require.Equal(t, int64(2), res.Ingester.HeadChunkBytes)
90	require.Equal(t, int64(2), res.Ingester.HeadChunkLines)
91	require.Equal(t, int64(2), res.Ingester.DecompressedBytes)
92	require.Equal(t, int64(2), res.Ingester.DecompressedLines)
93	require.Equal(t, int64(2), res.Ingester.CompressedBytes)
94	require.Equal(t, int64(2), res.Ingester.TotalDuplicates)
95}
96
97type ingesterFn func(grpc.ServerStream) error
98
99func (i ingesterFn) Query(_ *logproto.QueryRequest, s logproto.Querier_QueryServer) error {
100	return i(s)
101}
102
103func (i ingesterFn) QuerySample(_ *logproto.SampleQueryRequest, s logproto.Querier_QuerySampleServer) error {
104	return i(s)
105}
106func (ingesterFn) Label(context.Context, *logproto.LabelRequest) (*logproto.LabelResponse, error) {
107	return nil, nil
108}
109func (ingesterFn) Tail(*logproto.TailRequest, logproto.Querier_TailServer) error { return nil }
110func (ingesterFn) Series(context.Context, *logproto.SeriesRequest) (*logproto.SeriesResponse, error) {
111	return nil, nil
112}
113func (ingesterFn) TailersCount(context.Context, *logproto.TailersCountRequest) (*logproto.TailersCountResponse, error) {
114	return nil, nil
115}
116
117func (i ingesterFn) GetChunkIDs(ctx context.Context, request *logproto.GetChunkIDsRequest) (*logproto.GetChunkIDsResponse, error) {
118	return nil, nil
119}
120