1// Copyright (C) 2019 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// langsvr implements a Language Server for the SPIRV assembly language.
16package main
17
18import (
19	"context"
20	"fmt"
21	"io"
22	"io/ioutil"
23	"log"
24	"os"
25	"path"
26	"sort"
27	"strings"
28	"sync"
29	"unicode/utf8"
30
31	"github.com/KhronosGroup/SPIRV-Tools/utils/vscode/src/parser"
32	"github.com/KhronosGroup/SPIRV-Tools/utils/vscode/src/schema"
33
34	"github.com/KhronosGroup/SPIRV-Tools/utils/vscode/src/lsp/jsonrpc2"
35	lsp "github.com/KhronosGroup/SPIRV-Tools/utils/vscode/src/lsp/protocol"
36)
37
38const (
39	enableDebugLogging = false
40)
41
42// rSpy is a reader 'spy' that wraps an io.Reader, and logs all data that passes
43// through it.
44type rSpy struct {
45	prefix string
46	r      io.Reader
47}
48
49func (s rSpy) Read(p []byte) (n int, err error) {
50	n, err = s.r.Read(p)
51	log.Printf("%v %v", s.prefix, string(p[:n]))
52	return n, err
53}
54
55// wSpy is a reader 'spy' that wraps an io.Writer, and logs all data that passes
56// through it.
57type wSpy struct {
58	prefix string
59	w      io.Writer
60}
61
62func (s wSpy) Write(p []byte) (n int, err error) {
63	n, err = s.w.Write(p)
64	log.Printf("%v %v", s.prefix, string(p))
65	return n, err
66}
67
68// main entry point.
69func main() {
70	log.SetOutput(ioutil.Discard)
71	if enableDebugLogging {
72		// create a log file in the executable's directory.
73		if logfile, err := os.Create(path.Join(path.Dir(os.Args[0]), "log.txt")); err == nil {
74			defer logfile.Close()
75			log.SetOutput(logfile)
76		}
77	}
78
79	log.Println("language server started")
80
81	stream := jsonrpc2.NewHeaderStream(rSpy{"IDE", os.Stdin}, wSpy{"LS", os.Stdout})
82	s := server{
83		files: map[string]*file{},
84	}
85	s.ctx, s.conn, s.client = lsp.NewServer(context.Background(), stream, &s)
86	if err := s.conn.Run(s.ctx); err != nil {
87		log.Panicln(err)
88		os.Exit(1)
89	}
90
91	log.Println("language server stopped")
92}
93
94type server struct {
95	ctx    context.Context
96	conn   *jsonrpc2.Conn
97	client lsp.Client
98
99	files      map[string]*file
100	filesMutex sync.Mutex
101}
102
103// file represents a source file
104type file struct {
105	fullRange parser.Range
106	res       parser.Results
107}
108
109// tokAt returns the parser token at the given position lp
110func (f *file) tokAt(lp lsp.Position) *parser.Token {
111	toks := f.res.Tokens
112	p := parser.Position{Line: int(lp.Line) + 1, Column: int(lp.Character) + 1}
113	i := sort.Search(len(toks), func(i int) bool { return p.LessThan(toks[i].Range.End) })
114	if i == len(toks) {
115		return nil
116	}
117	if toks[i].Range.Contains(p) {
118		return toks[i]
119	}
120	return nil
121}
122
123func (s *server) DidChangeWorkspaceFolders(ctx context.Context, p *lsp.DidChangeWorkspaceFoldersParams) error {
124	log.Println("server.DidChangeWorkspaceFolders()")
125	return nil
126}
127func (s *server) Initialized(ctx context.Context, p *lsp.InitializedParams) error {
128	log.Println("server.Initialized()")
129	return nil
130}
131func (s *server) Exit(ctx context.Context) error {
132	log.Println("server.Exit()")
133	return nil
134}
135func (s *server) DidChangeConfiguration(ctx context.Context, p *lsp.DidChangeConfigurationParams) error {
136	log.Println("server.DidChangeConfiguration()")
137	return nil
138}
139func (s *server) DidOpen(ctx context.Context, p *lsp.DidOpenTextDocumentParams) error {
140	log.Println("server.DidOpen()")
141	return s.processFile(ctx, p.TextDocument.URI, p.TextDocument.Text)
142}
143func (s *server) DidChange(ctx context.Context, p *lsp.DidChangeTextDocumentParams) error {
144	log.Println("server.DidChange()")
145	return s.processFile(ctx, p.TextDocument.URI, p.ContentChanges[0].Text)
146}
147func (s *server) DidClose(ctx context.Context, p *lsp.DidCloseTextDocumentParams) error {
148	log.Println("server.DidClose()")
149	return nil
150}
151func (s *server) DidSave(ctx context.Context, p *lsp.DidSaveTextDocumentParams) error {
152	log.Println("server.DidSave()")
153	return nil
154}
155func (s *server) WillSave(ctx context.Context, p *lsp.WillSaveTextDocumentParams) error {
156	log.Println("server.WillSave()")
157	return nil
158}
159func (s *server) DidChangeWatchedFiles(ctx context.Context, p *lsp.DidChangeWatchedFilesParams) error {
160	log.Println("server.DidChangeWatchedFiles()")
161	return nil
162}
163func (s *server) Progress(ctx context.Context, p *lsp.ProgressParams) error {
164	log.Println("server.Progress()")
165	return nil
166}
167func (s *server) SetTraceNotification(ctx context.Context, p *lsp.SetTraceParams) error {
168	log.Println("server.SetTraceNotification()")
169	return nil
170}
171func (s *server) LogTraceNotification(ctx context.Context, p *lsp.LogTraceParams) error {
172	log.Println("server.LogTraceNotification()")
173	return nil
174}
175func (s *server) Implementation(ctx context.Context, p *lsp.ImplementationParams) ([]lsp.Location, error) {
176	log.Println("server.Implementation()")
177	return nil, nil
178}
179func (s *server) TypeDefinition(ctx context.Context, p *lsp.TypeDefinitionParams) ([]lsp.Location, error) {
180	log.Println("server.TypeDefinition()")
181	return nil, nil
182}
183func (s *server) DocumentColor(ctx context.Context, p *lsp.DocumentColorParams) ([]lsp.ColorInformation, error) {
184	log.Println("server.DocumentColor()")
185	return nil, nil
186}
187func (s *server) ColorPresentation(ctx context.Context, p *lsp.ColorPresentationParams) ([]lsp.ColorPresentation, error) {
188	log.Println("server.ColorPresentation()")
189	return nil, nil
190}
191func (s *server) FoldingRange(ctx context.Context, p *lsp.FoldingRangeParams) ([]lsp.FoldingRange, error) {
192	log.Println("server.FoldingRange()")
193	return nil, nil
194}
195func (s *server) Declaration(ctx context.Context, p *lsp.DeclarationParams) ([]lsp.DeclarationLink, error) {
196	log.Println("server.Declaration()")
197	return nil, nil
198}
199func (s *server) SelectionRange(ctx context.Context, p *lsp.SelectionRangeParams) ([]lsp.SelectionRange, error) {
200	log.Println("server.SelectionRange()")
201	return nil, nil
202}
203func (s *server) Initialize(ctx context.Context, p *lsp.ParamInitia) (*lsp.InitializeResult, error) {
204	log.Println("server.Initialize()")
205	res := lsp.InitializeResult{
206		Capabilities: lsp.ServerCapabilities{
207			TextDocumentSync: lsp.TextDocumentSyncOptions{
208				OpenClose: true,
209				Change:    lsp.Full, // TODO: Implement incremental
210			},
211			HoverProvider:              true,
212			DefinitionProvider:         true,
213			ReferencesProvider:         true,
214			RenameProvider:             true,
215			DocumentFormattingProvider: true,
216		},
217	}
218	return &res, nil
219}
220func (s *server) Shutdown(ctx context.Context) error {
221	log.Println("server.Shutdown()")
222	return nil
223}
224func (s *server) WillSaveWaitUntil(ctx context.Context, p *lsp.WillSaveTextDocumentParams) ([]lsp.TextEdit, error) {
225	log.Println("server.WillSaveWaitUntil()")
226	return nil, nil
227}
228func (s *server) Completion(ctx context.Context, p *lsp.CompletionParams) (*lsp.CompletionList, error) {
229	log.Println("server.Completion()")
230	return nil, nil
231}
232func (s *server) Resolve(ctx context.Context, p *lsp.CompletionItem) (*lsp.CompletionItem, error) {
233	log.Println("server.Resolve()")
234	return nil, nil
235}
236func (s *server) Hover(ctx context.Context, p *lsp.HoverParams) (*lsp.Hover, error) {
237	log.Println("server.Hover()")
238	f := s.getFile(p.TextDocument.URI)
239	if f == nil {
240		return nil, fmt.Errorf("Unknown file")
241	}
242
243	if tok := f.tokAt(p.Position); tok != nil {
244		sb := strings.Builder{}
245		switch v := f.res.Mappings[tok].(type) {
246		default:
247			sb.WriteString(fmt.Sprintf("<Unhandled type '%T'>", v))
248		case *parser.Instruction:
249			sb.WriteString(fmt.Sprintf("```\n%v\n```", v.Opcode.Opname))
250		case *parser.Identifier:
251			sb.WriteString(fmt.Sprintf("```\n%v\n```", v.Definition.Range.Text(f.res.Lines)))
252		case *parser.Operand:
253			if v.Name != "" {
254				sb.WriteString(strings.Trim(v.Name, `'`))
255				sb.WriteString("\n\n")
256			}
257
258			switch v.Kind.Category {
259			case schema.OperandCategoryBitEnum:
260			case schema.OperandCategoryValueEnum:
261				sb.WriteString("```\n")
262				sb.WriteString(strings.Trim(v.Kind.Kind, `'`))
263				sb.WriteString("\n```")
264			case schema.OperandCategoryID:
265				if s := tok.Text(f.res.Lines); s != "" {
266					if id, ok := f.res.Identifiers[s]; ok && id.Definition != nil {
267						sb.WriteString("```\n")
268						sb.WriteString(id.Definition.Range.Text(f.res.Lines))
269						sb.WriteString("\n```")
270					}
271				}
272			case schema.OperandCategoryLiteral:
273			case schema.OperandCategoryComposite:
274			}
275		case nil:
276		}
277
278		if sb.Len() > 0 {
279			res := lsp.Hover{
280				Contents: lsp.MarkupContent{
281					Kind:  "markdown",
282					Value: sb.String(),
283				},
284			}
285			return &res, nil
286		}
287	}
288
289	return nil, nil
290}
291func (s *server) SignatureHelp(ctx context.Context, p *lsp.SignatureHelpParams) (*lsp.SignatureHelp, error) {
292	log.Println("server.SignatureHelp()")
293	return nil, nil
294}
295func (s *server) Definition(ctx context.Context, p *lsp.DefinitionParams) ([]lsp.Location, error) {
296	log.Println("server.Definition()")
297	if f := s.getFile(p.TextDocument.URI); f != nil {
298		if tok := f.tokAt(p.Position); tok != nil {
299			if s := tok.Text(f.res.Lines); s != "" {
300				if id, ok := f.res.Identifiers[s]; ok {
301					loc := lsp.Location{
302						URI:   p.TextDocument.URI,
303						Range: rangeToLSP(id.Definition.Range),
304					}
305					return []lsp.Location{loc}, nil
306				}
307			}
308		}
309	}
310	return nil, nil
311}
312func (s *server) References(ctx context.Context, p *lsp.ReferenceParams) ([]lsp.Location, error) {
313	log.Println("server.References()")
314	if f := s.getFile(p.TextDocument.URI); f != nil {
315		if tok := f.tokAt(p.Position); tok != nil {
316			if s := tok.Text(f.res.Lines); s != "" {
317				if id, ok := f.res.Identifiers[s]; ok {
318					locs := make([]lsp.Location, len(id.References))
319					for i, r := range id.References {
320						locs[i] = lsp.Location{
321							URI:   p.TextDocument.URI,
322							Range: rangeToLSP(r.Range),
323						}
324					}
325					return locs, nil
326				}
327			}
328		}
329	}
330	return nil, nil
331}
332func (s *server) DocumentHighlight(ctx context.Context, p *lsp.DocumentHighlightParams) ([]lsp.DocumentHighlight, error) {
333	log.Println("server.DocumentHighlight()")
334	return nil, nil
335}
336func (s *server) DocumentSymbol(ctx context.Context, p *lsp.DocumentSymbolParams) ([]lsp.DocumentSymbol, error) {
337	log.Println("server.DocumentSymbol()")
338	return nil, nil
339}
340func (s *server) CodeAction(ctx context.Context, p *lsp.CodeActionParams) ([]lsp.CodeAction, error) {
341	log.Println("server.CodeAction()")
342	return nil, nil
343}
344func (s *server) Symbol(ctx context.Context, p *lsp.WorkspaceSymbolParams) ([]lsp.SymbolInformation, error) {
345	log.Println("server.Symbol()")
346	return nil, nil
347}
348func (s *server) CodeLens(ctx context.Context, p *lsp.CodeLensParams) ([]lsp.CodeLens, error) {
349	log.Println("server.CodeLens()")
350	return nil, nil
351}
352func (s *server) ResolveCodeLens(ctx context.Context, p *lsp.CodeLens) (*lsp.CodeLens, error) {
353	log.Println("server.ResolveCodeLens()")
354	return nil, nil
355}
356func (s *server) DocumentLink(ctx context.Context, p *lsp.DocumentLinkParams) ([]lsp.DocumentLink, error) {
357	log.Println("server.DocumentLink()")
358	return nil, nil
359}
360func (s *server) ResolveDocumentLink(ctx context.Context, p *lsp.DocumentLink) (*lsp.DocumentLink, error) {
361	log.Println("server.ResolveDocumentLink()")
362	return nil, nil
363}
364func (s *server) Formatting(ctx context.Context, p *lsp.DocumentFormattingParams) ([]lsp.TextEdit, error) {
365	log.Println("server.Formatting()")
366	if f := s.getFile(p.TextDocument.URI); f != nil {
367		// Start by measuring the distance from the start of each line to the
368		// first opcode on that line.
369		lineInstOffsets, maxInstOffset, instOffset, curOffset := []int{}, 0, 0, -1
370		for _, t := range f.res.Tokens {
371			curOffset++ // whitespace between tokens
372			switch t.Type {
373			case parser.Ident:
374				if _, isInst := schema.Opcodes[t.Text(f.res.Lines)]; isInst && instOffset == 0 {
375					instOffset = curOffset
376					continue
377				}
378			case parser.Newline:
379				lineInstOffsets = append(lineInstOffsets, instOffset)
380				if instOffset > maxInstOffset {
381					maxInstOffset = instOffset
382				}
383				curOffset, instOffset = -1, 0
384			default:
385				curOffset += utf8.RuneCountInString(t.Text(f.res.Lines))
386			}
387		}
388		lineInstOffsets = append(lineInstOffsets, instOffset)
389
390		// Now rewrite each of the lines, adding padding at the start of the
391		// line for alignment.
392		sb, newline := strings.Builder{}, true
393		for _, t := range f.res.Tokens {
394			if newline {
395				newline = false
396				indent := maxInstOffset - lineInstOffsets[0]
397				lineInstOffsets = lineInstOffsets[1:]
398				switch t.Type {
399				case parser.Newline, parser.Comment:
400				default:
401					for s := 0; s < indent; s++ {
402						sb.WriteRune(' ')
403					}
404				}
405			} else if t.Type != parser.Newline {
406				sb.WriteString(" ")
407			}
408
409			sb.WriteString(t.Text(f.res.Lines))
410			if t.Type == parser.Newline {
411				newline = true
412			}
413		}
414
415		formatted := sb.String()
416
417		// Every good file ends with a single new line.
418		formatted = strings.TrimRight(formatted, "\n") + "\n"
419
420		return []lsp.TextEdit{
421			{
422				Range:   rangeToLSP(f.fullRange),
423				NewText: formatted,
424			},
425		}, nil
426	}
427	return nil, nil
428}
429func (s *server) RangeFormatting(ctx context.Context, p *lsp.DocumentRangeFormattingParams) ([]lsp.TextEdit, error) {
430	log.Println("server.RangeFormatting()")
431	return nil, nil
432}
433func (s *server) OnTypeFormatting(ctx context.Context, p *lsp.DocumentOnTypeFormattingParams) ([]lsp.TextEdit, error) {
434	log.Println("server.OnTypeFormatting()")
435	return nil, nil
436}
437func (s *server) Rename(ctx context.Context, p *lsp.RenameParams) (*lsp.WorkspaceEdit, error) {
438	log.Println("server.Rename()")
439	if f := s.getFile(p.TextDocument.URI); f != nil {
440		if tok := f.tokAt(p.Position); tok != nil {
441			if s := tok.Text(f.res.Lines); s != "" {
442				if id, ok := f.res.Identifiers[s]; ok {
443					changes := make([]lsp.TextEdit, len(id.References))
444					for i, r := range id.References {
445						changes[i].Range = rangeToLSP(r.Range)
446						changes[i].NewText = p.NewName
447					}
448					m := map[string][]lsp.TextEdit{}
449					m[p.TextDocument.URI] = changes
450					return &lsp.WorkspaceEdit{Changes: &m}, nil
451				}
452			}
453		}
454	}
455	return nil, nil
456}
457func (s *server) PrepareRename(ctx context.Context, p *lsp.PrepareRenameParams) (*lsp.Range, error) {
458	log.Println("server.PrepareRename()")
459	return nil, nil
460}
461func (s *server) ExecuteCommand(ctx context.Context, p *lsp.ExecuteCommandParams) (interface{}, error) {
462	log.Println("server.ExecuteCommand()")
463	return nil, nil
464}
465
466func (s *server) processFile(ctx context.Context, uri, source string) error {
467	log.Println("server.DidOpen()")
468	res, err := parser.Parse(source)
469	if err != nil {
470		return err
471	}
472	fullRange := parser.Range{
473		Start: parser.Position{Line: 1, Column: 1},
474		End:   parser.Position{Line: len(res.Lines), Column: utf8.RuneCountInString(res.Lines[len(res.Lines)-1]) + 1},
475	}
476
477	s.filesMutex.Lock()
478	s.files[uri] = &file{
479		fullRange: fullRange,
480		res:       res,
481	}
482	s.filesMutex.Unlock()
483
484	dp := lsp.PublishDiagnosticsParams{URI: uri, Diagnostics: make([]lsp.Diagnostic, len(res.Diagnostics))}
485	for i, d := range res.Diagnostics {
486		dp.Diagnostics[i] = diagnosticToLSP(d)
487	}
488	s.client.PublishDiagnostics(ctx, &dp)
489	return nil
490}
491
492func (s *server) getFile(uri string) *file {
493	s.filesMutex.Lock()
494	defer s.filesMutex.Unlock()
495	return s.files[uri]
496}
497
498func diagnosticToLSP(d parser.Diagnostic) lsp.Diagnostic {
499	return lsp.Diagnostic{
500		Range:    rangeToLSP(d.Range),
501		Severity: severityToLSP(d.Severity),
502		Message:  d.Message,
503	}
504}
505
506func severityToLSP(s parser.Severity) lsp.DiagnosticSeverity {
507	switch s {
508	case parser.SeverityError:
509		return lsp.SeverityError
510	case parser.SeverityWarning:
511		return lsp.SeverityWarning
512	case parser.SeverityInformation:
513		return lsp.SeverityInformation
514	case parser.SeverityHint:
515		return lsp.SeverityHint
516	default:
517		log.Panicf("Invalid severity '%d'", int(s))
518		return lsp.SeverityError
519	}
520}
521
522func rangeToLSP(r parser.Range) lsp.Range {
523	return lsp.Range{
524		Start: positionToLSP(r.Start),
525		End:   positionToLSP(r.End),
526	}
527}
528
529func positionToLSP(r parser.Position) lsp.Position {
530	return lsp.Position{
531		Line:      float64(r.Line - 1),
532		Character: float64(r.Column - 1),
533	}
534}
535