1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package lsp
6
7import (
8	"context"
9	"fmt"
10	"sync"
11	"testing"
12
13	"golang.org/x/tools/internal/lsp/protocol"
14)
15
16type fakeClient struct {
17	protocol.Client
18
19	token protocol.ProgressToken
20
21	mu                                        sync.Mutex
22	created, begun, reported, messages, ended int
23}
24
25func (c *fakeClient) checkToken(token protocol.ProgressToken) {
26	if token == nil {
27		panic("nil token in progress message")
28	}
29	if c.token != nil && c.token != token {
30		panic(fmt.Errorf("invalid token in progress message: got %v, want %v", token, c.token))
31	}
32}
33
34func (c *fakeClient) WorkDoneProgressCreate(ctx context.Context, params *protocol.WorkDoneProgressCreateParams) error {
35	c.mu.Lock()
36	defer c.mu.Unlock()
37	c.checkToken(params.Token)
38	c.created++
39	return nil
40}
41
42func (c *fakeClient) Progress(ctx context.Context, params *protocol.ProgressParams) error {
43	c.mu.Lock()
44	defer c.mu.Unlock()
45	c.checkToken(params.Token)
46	switch params.Value.(type) {
47	case *protocol.WorkDoneProgressBegin:
48		c.begun++
49	case *protocol.WorkDoneProgressReport:
50		c.reported++
51	case *protocol.WorkDoneProgressEnd:
52		c.ended++
53	default:
54		panic(fmt.Errorf("unknown progress value %T", params.Value))
55	}
56	return nil
57}
58
59func (c *fakeClient) ShowMessage(context.Context, *protocol.ShowMessageParams) error {
60	c.mu.Lock()
61	defer c.mu.Unlock()
62	c.messages++
63	return nil
64}
65
66func setup(token protocol.ProgressToken) (context.Context, *progressTracker, *fakeClient) {
67	c := &fakeClient{}
68	tracker := newProgressTracker(c)
69	tracker.supportsWorkDoneProgress = true
70	return context.Background(), tracker, c
71}
72
73func TestProgressTracker_Reporting(t *testing.T) {
74	for _, test := range []struct {
75		name                                            string
76		supported                                       bool
77		token                                           protocol.ProgressToken
78		wantReported, wantCreated, wantBegun, wantEnded int
79		wantMessages                                    int
80	}{
81		{
82			name:         "unsupported",
83			wantMessages: 2,
84		},
85		{
86			name:         "random token",
87			supported:    true,
88			wantCreated:  1,
89			wantBegun:    1,
90			wantReported: 1,
91			wantEnded:    1,
92		},
93		{
94			name:         "string token",
95			supported:    true,
96			token:        "token",
97			wantBegun:    1,
98			wantReported: 1,
99			wantEnded:    1,
100		},
101		{
102			name:         "numeric token",
103			supported:    true,
104			token:        1,
105			wantReported: 1,
106			wantBegun:    1,
107			wantEnded:    1,
108		},
109	} {
110		test := test
111		t.Run(test.name, func(t *testing.T) {
112			ctx, tracker, client := setup(test.token)
113			ctx, cancel := context.WithCancel(ctx)
114			defer cancel()
115			tracker.supportsWorkDoneProgress = test.supported
116			work := tracker.start(ctx, "work", "message", test.token, nil)
117			client.mu.Lock()
118			gotCreated, gotBegun := client.created, client.begun
119			client.mu.Unlock()
120			if gotCreated != test.wantCreated {
121				t.Errorf("got %d created tokens, want %d", gotCreated, test.wantCreated)
122			}
123			if gotBegun != test.wantBegun {
124				t.Errorf("got %d work begun, want %d", gotBegun, test.wantBegun)
125			}
126			// Ignore errors: this is just testing the reporting behavior.
127			work.report("report", 50)
128			client.mu.Lock()
129			gotReported := client.reported
130			client.mu.Unlock()
131			if gotReported != test.wantReported {
132				t.Errorf("got %d progress reports, want %d", gotReported, test.wantCreated)
133			}
134			work.end("done")
135			client.mu.Lock()
136			gotEnded, gotMessages := client.ended, client.messages
137			client.mu.Unlock()
138			if gotEnded != test.wantEnded {
139				t.Errorf("got %d ended reports, want %d", gotEnded, test.wantEnded)
140			}
141			if gotMessages != test.wantMessages {
142				t.Errorf("got %d messages, want %d", gotMessages, test.wantMessages)
143			}
144		})
145	}
146}
147
148func TestProgressTracker_Cancellation(t *testing.T) {
149	for _, token := range []protocol.ProgressToken{nil, 1, "a"} {
150		ctx, tracker, _ := setup(token)
151		var canceled bool
152		cancel := func() { canceled = true }
153		work := tracker.start(ctx, "work", "message", token, cancel)
154		if err := tracker.cancel(ctx, work.token); err != nil {
155			t.Fatal(err)
156		}
157		if !canceled {
158			t.Errorf("tracker.cancel(...): cancel not called")
159		}
160	}
161}
162