1// Copyright 2020 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package regtest
6
7import (
8	"context"
9	"fmt"
10	"strings"
11	"sync"
12	"testing"
13
14	"golang.org/x/tools/internal/jsonrpc2/servertest"
15	"golang.org/x/tools/internal/lsp/fake"
16	"golang.org/x/tools/internal/lsp/protocol"
17)
18
19// Env holds an initialized fake Editor, Workspace, and Server, which may be
20// used for writing tests. It also provides adapter methods that call t.Fatal
21// on any error, so that tests for the happy path may be written without
22// checking errors.
23type Env struct {
24	T   testing.TB
25	Ctx context.Context
26
27	// Most tests should not need to access the scratch area, editor, server, or
28	// connection, but they are available if needed.
29	Sandbox *fake.Sandbox
30	Editor  *fake.Editor
31	Server  servertest.Connector
32
33	// mu guards the fields below, for the purpose of checking conditions on
34	// every change to diagnostics.
35	mu sync.Mutex
36	// For simplicity, each waiter gets a unique ID.
37	nextWaiterID int
38	state        State
39	waiters      map[int]*condition
40}
41
42// State encapsulates the server state TODO: explain more
43type State struct {
44	// diagnostics are a map of relative path->diagnostics params
45	diagnostics        map[string]*protocol.PublishDiagnosticsParams
46	logs               []*protocol.LogMessageParams
47	showMessage        []*protocol.ShowMessageParams
48	showMessageRequest []*protocol.ShowMessageRequestParams
49
50	registrations   []*protocol.RegistrationParams
51	unregistrations []*protocol.UnregistrationParams
52
53	// outstandingWork is a map of token->work summary. All tokens are assumed to
54	// be string, though the spec allows for numeric tokens as well.  When work
55	// completes, it is deleted from this map.
56	outstandingWork map[protocol.ProgressToken]*workProgress
57	startedWork     map[string]uint64
58	completedWork   map[string]uint64
59}
60
61type workProgress struct {
62	title, msg string
63	percent    float64
64}
65
66func (s State) String() string {
67	var b strings.Builder
68	b.WriteString("#### log messages (see RPC logs for full text):\n")
69	for _, msg := range s.logs {
70		summary := fmt.Sprintf("%v: %q", msg.Type, msg.Message)
71		if len(summary) > 60 {
72			summary = summary[:57] + "..."
73		}
74		// Some logs are quite long, and since they should be reproduced in the RPC
75		// logs on any failure we include here just a short summary.
76		fmt.Fprint(&b, "\t"+summary+"\n")
77	}
78	b.WriteString("\n")
79	b.WriteString("#### diagnostics:\n")
80	for name, params := range s.diagnostics {
81		fmt.Fprintf(&b, "\t%s (version %d):\n", name, int(params.Version))
82		for _, d := range params.Diagnostics {
83			fmt.Fprintf(&b, "\t\t(%d, %d): %s\n", int(d.Range.Start.Line), int(d.Range.Start.Character), d.Message)
84		}
85	}
86	b.WriteString("\n")
87	b.WriteString("#### outstanding work:\n")
88	for token, state := range s.outstandingWork {
89		name := state.title
90		if name == "" {
91			name = fmt.Sprintf("!NO NAME(token: %s)", token)
92		}
93		fmt.Fprintf(&b, "\t%s: %.2f\n", name, state.percent)
94	}
95	b.WriteString("#### completed work:\n")
96	for name, count := range s.completedWork {
97		fmt.Fprintf(&b, "\t%s: %d\n", name, count)
98	}
99	return b.String()
100}
101
102// A condition is satisfied when all expectations are simultaneously
103// met. At that point, the 'met' channel is closed. On any failure, err is set
104// and the failed channel is closed.
105type condition struct {
106	expectations []Expectation
107	verdict      chan Verdict
108}
109
110// NewEnv creates a new test environment using the given scratch environment
111// and gopls server.
112func NewEnv(ctx context.Context, tb testing.TB, sandbox *fake.Sandbox, ts servertest.Connector, editorConfig fake.EditorConfig, withHooks bool) *Env {
113	tb.Helper()
114	conn := ts.Connect(ctx)
115	env := &Env{
116		T:       tb,
117		Ctx:     ctx,
118		Sandbox: sandbox,
119		Server:  ts,
120		state: State{
121			diagnostics:     make(map[string]*protocol.PublishDiagnosticsParams),
122			outstandingWork: make(map[protocol.ProgressToken]*workProgress),
123			startedWork:     make(map[string]uint64),
124			completedWork:   make(map[string]uint64),
125		},
126		waiters: make(map[int]*condition),
127	}
128	var hooks fake.ClientHooks
129	if withHooks {
130		hooks = fake.ClientHooks{
131			OnDiagnostics:            env.onDiagnostics,
132			OnLogMessage:             env.onLogMessage,
133			OnWorkDoneProgressCreate: env.onWorkDoneProgressCreate,
134			OnProgress:               env.onProgress,
135			OnShowMessage:            env.onShowMessage,
136			OnShowMessageRequest:     env.onShowMessageRequest,
137			OnRegistration:           env.onRegistration,
138			OnUnregistration:         env.onUnregistration,
139		}
140	}
141	editor, err := fake.NewEditor(sandbox, editorConfig).Connect(ctx, conn, hooks)
142	if err != nil {
143		tb.Fatal(err)
144	}
145	env.Editor = editor
146	return env
147}
148
149func (e *Env) onDiagnostics(_ context.Context, d *protocol.PublishDiagnosticsParams) error {
150	e.mu.Lock()
151	defer e.mu.Unlock()
152
153	pth := e.Sandbox.Workdir.URIToPath(d.URI)
154	e.state.diagnostics[pth] = d
155	e.checkConditionsLocked()
156	return nil
157}
158
159func (e *Env) onShowMessage(_ context.Context, m *protocol.ShowMessageParams) error {
160	e.mu.Lock()
161	defer e.mu.Unlock()
162
163	e.state.showMessage = append(e.state.showMessage, m)
164	e.checkConditionsLocked()
165	return nil
166}
167
168func (e *Env) onShowMessageRequest(_ context.Context, m *protocol.ShowMessageRequestParams) error {
169	e.mu.Lock()
170	defer e.mu.Unlock()
171
172	e.state.showMessageRequest = append(e.state.showMessageRequest, m)
173	e.checkConditionsLocked()
174	return nil
175}
176
177func (e *Env) onLogMessage(_ context.Context, m *protocol.LogMessageParams) error {
178	e.mu.Lock()
179	defer e.mu.Unlock()
180
181	e.state.logs = append(e.state.logs, m)
182	e.checkConditionsLocked()
183	return nil
184}
185
186func (e *Env) onWorkDoneProgressCreate(_ context.Context, m *protocol.WorkDoneProgressCreateParams) error {
187	e.mu.Lock()
188	defer e.mu.Unlock()
189
190	e.state.outstandingWork[m.Token] = &workProgress{}
191	return nil
192}
193
194func (e *Env) onProgress(_ context.Context, m *protocol.ProgressParams) error {
195	e.mu.Lock()
196	defer e.mu.Unlock()
197	work, ok := e.state.outstandingWork[m.Token]
198	if !ok {
199		panic(fmt.Sprintf("got progress report for unknown report %v: %v", m.Token, m))
200	}
201	v := m.Value.(map[string]interface{})
202	switch kind := v["kind"]; kind {
203	case "begin":
204		work.title = v["title"].(string)
205		e.state.startedWork[work.title] = e.state.startedWork[work.title] + 1
206		if msg, ok := v["message"]; ok {
207			work.msg = msg.(string)
208		}
209	case "report":
210		if pct, ok := v["percentage"]; ok {
211			work.percent = pct.(float64)
212		}
213		if msg, ok := v["message"]; ok {
214			work.msg = msg.(string)
215		}
216	case "end":
217		title := e.state.outstandingWork[m.Token].title
218		e.state.completedWork[title] = e.state.completedWork[title] + 1
219		delete(e.state.outstandingWork, m.Token)
220	}
221	e.checkConditionsLocked()
222	return nil
223}
224
225func (e *Env) onRegistration(_ context.Context, m *protocol.RegistrationParams) error {
226	e.mu.Lock()
227	defer e.mu.Unlock()
228
229	e.state.registrations = append(e.state.registrations, m)
230	e.checkConditionsLocked()
231	return nil
232}
233
234func (e *Env) onUnregistration(_ context.Context, m *protocol.UnregistrationParams) error {
235	e.mu.Lock()
236	defer e.mu.Unlock()
237
238	e.state.unregistrations = append(e.state.unregistrations, m)
239	e.checkConditionsLocked()
240	return nil
241}
242
243func (e *Env) checkConditionsLocked() {
244	for id, condition := range e.waiters {
245		if v, _ := checkExpectations(e.state, condition.expectations); v != Unmet {
246			delete(e.waiters, id)
247			condition.verdict <- v
248		}
249	}
250}
251
252// checkExpectations reports whether s meets all expectations.
253func checkExpectations(s State, expectations []Expectation) (Verdict, string) {
254	finalVerdict := Met
255	var summary strings.Builder
256	for _, e := range expectations {
257		v := e.Check(s)
258		if v > finalVerdict {
259			finalVerdict = v
260		}
261		summary.WriteString(fmt.Sprintf("\t%v: %s\n", v, e.Description()))
262	}
263	return finalVerdict, summary.String()
264}
265
266// DiagnosticsFor returns the current diagnostics for the file. It is useful
267// after waiting on AnyDiagnosticAtCurrentVersion, when the desired diagnostic
268// is not simply described by DiagnosticAt.
269func (e *Env) DiagnosticsFor(name string) *protocol.PublishDiagnosticsParams {
270	e.mu.Lock()
271	defer e.mu.Unlock()
272	return e.state.diagnostics[name]
273}
274
275// Await waits for all expectations to simultaneously be met. It should only be
276// called from the main test goroutine.
277func (e *Env) Await(expectations ...Expectation) {
278	e.T.Helper()
279	e.mu.Lock()
280	// Before adding the waiter, we check if the condition is currently met or
281	// failed to avoid a race where the condition was realized before Await was
282	// called.
283	switch verdict, summary := checkExpectations(e.state, expectations); verdict {
284	case Met:
285		e.mu.Unlock()
286		return
287	case Unmeetable:
288		failure := fmt.Sprintf("unmeetable expectations:\n%s\nstate:\n%v", summary, e.state)
289		e.mu.Unlock()
290		e.T.Fatal(failure)
291	}
292	cond := &condition{
293		expectations: expectations,
294		verdict:      make(chan Verdict),
295	}
296	e.waiters[e.nextWaiterID] = cond
297	e.nextWaiterID++
298	e.mu.Unlock()
299
300	var err error
301	select {
302	case <-e.Ctx.Done():
303		err = e.Ctx.Err()
304	case v := <-cond.verdict:
305		if v != Met {
306			err = fmt.Errorf("condition has final verdict %v", v)
307		}
308	}
309	e.mu.Lock()
310	defer e.mu.Unlock()
311	_, summary := checkExpectations(e.state, expectations)
312
313	// Debugging an unmet expectation can be tricky, so we put some effort into
314	// nicely formatting the failure.
315	if err != nil {
316		e.T.Fatalf("waiting on:\n%s\nerr:%v\n\nstate:\n%v", summary, err, e.state)
317	}
318}
319