1// Copyright (C) 2019 Google Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package parser implements a SPIR-V assembly parser.
16package parser
17
18import (
19	"fmt"
20	"io"
21	"log"
22	"strings"
23	"unicode"
24	"unicode/utf8"
25
26	"github.com/KhronosGroup/SPIRV-Tools/utils/vscode/src/schema"
27)
28
29// Type is an enumerator of token types.
30type Type int
31
32// Type enumerators
33const (
34	Ident  Type = iota // Foo
35	PIdent             // %32, %foo
36	Integer
37	Float
38	String
39	Operator
40	Comment
41	Newline
42)
43
44func (t Type) String() string {
45	switch t {
46	case Ident:
47		return "Ident"
48	case PIdent:
49		return "PIdent"
50	case Integer:
51		return "Integer"
52	case Float:
53		return "Float"
54	case String:
55		return "String"
56	case Operator:
57		return "Operator"
58	case Comment:
59		return "Comment"
60	default:
61		return "<unknown>"
62	}
63}
64
65// Token represents a single lexed token.
66type Token struct {
67	Type  Type
68	Range Range
69}
70
71func (t Token) String() string { return fmt.Sprintf("{%v %v}", t.Type, t.Range) }
72
73// Text returns the tokens text from the source.
74func (t Token) Text(lines []string) string { return t.Range.Text(lines) }
75
76// Range represents an interval in a text file.
77type Range struct {
78	Start Position
79	End   Position
80}
81
82func (r Range) String() string { return fmt.Sprintf("[%v %v]", r.Start, r.End) }
83
84// Text returns the text for the given Range in the provided lines.
85func (r Range) Text(lines []string) string {
86	sl, sc := r.Start.Line-1, r.Start.Column-1
87	if sl < 0 || sc < 0 || sl > len(lines) || sc > len(lines[sl]) {
88		return fmt.Sprintf("<invalid start position %v>", r.Start)
89	}
90	el, ec := r.End.Line-1, r.End.Column-1
91	if el < 0 || ec < 0 || el > len(lines) || ec > len(lines[sl]) {
92		return fmt.Sprintf("<invalid end position %v>", r.End)
93	}
94
95	sb := strings.Builder{}
96	if sl != el {
97		sb.WriteString(lines[sl][sc:])
98		for l := sl + 1; l < el; l++ {
99			sb.WriteString(lines[l])
100		}
101		sb.WriteString(lines[el][:ec])
102	} else {
103		sb.WriteString(lines[sl][sc:ec])
104	}
105	return sb.String()
106}
107
108// Contains returns true if p is in r.
109func (r Range) Contains(p Position) bool {
110	return !(p.LessThan(r.Start) || p.GreaterThan(r.End))
111}
112
113func (r *Range) grow(o Range) {
114	if !r.Start.IsValid() || o.Start.LessThan(r.Start) {
115		r.Start = o.Start
116	}
117	if !r.End.IsValid() || o.End.GreaterThan(r.End) {
118		r.End = o.End
119	}
120}
121
122// Position holds a line and column position in a text file.
123type Position struct {
124	Line, Column int
125}
126
127func (p Position) String() string { return fmt.Sprintf("%v:%v", p.Line, p.Column) }
128
129// IsValid returns true if the position has a line and column greater than 1.
130func (p Position) IsValid() bool { return p.Line > 0 && p.Column > 0 }
131
132// LessThan returns true iff o is before p.
133func (p Position) LessThan(o Position) bool {
134	switch {
135	case !p.IsValid() || !o.IsValid():
136		return false
137	case p.Line < o.Line:
138		return true
139	case p.Line > o.Line:
140		return false
141	case p.Column < o.Column:
142		return true
143	default:
144		return false
145	}
146}
147
148// GreaterThan returns true iff o is greater than p.
149func (p Position) GreaterThan(o Position) bool {
150	switch {
151	case !p.IsValid() || !o.IsValid():
152		return false
153	case p.Line > o.Line:
154		return true
155	case p.Line < o.Line:
156		return false
157	case p.Column > o.Column:
158		return true
159	default:
160		return false
161	}
162}
163
164type lexer struct {
165	source string
166	lexerState
167	diags []Diagnostic
168	e     error
169}
170
171type lexerState struct {
172	offset int      // byte offset in source
173	toks   []*Token // all the lexed tokens
174	pos    Position // current position
175}
176
177// err appends an fmt.Printf style error into l.diags for the given token.
178func (l *lexer) err(tok *Token, msg string, args ...interface{}) {
179	rng := Range{}
180	if tok != nil {
181		rng = tok.Range
182	}
183	l.diags = append(l.diags, Diagnostic{
184		Range:    rng,
185		Severity: SeverityError,
186		Message:  fmt.Sprintf(msg, args...),
187	})
188}
189
190// next returns the next rune, or io.EOF if the last rune has already been
191// consumed.
192func (l *lexer) next() rune {
193	if l.offset >= len(l.source) {
194		l.e = io.EOF
195		return 0
196	}
197	r, n := utf8.DecodeRuneInString(l.source[l.offset:])
198	l.offset += n
199	if n == 0 {
200		l.e = io.EOF
201		return 0
202	}
203	if r == '\n' {
204		l.pos.Line++
205		l.pos.Column = 1
206	} else {
207		l.pos.Column++
208	}
209	return r
210}
211
212// save returns the current lexerState.
213func (l *lexer) save() lexerState {
214	return l.lexerState
215}
216
217// restore restores the current lexer state with s.
218func (l *lexer) restore(s lexerState) {
219	l.lexerState = s
220}
221
222// pident processes the PIdent token at the current position.
223// The lexer *must* know the next token is a PIdent before calling.
224func (l *lexer) pident() {
225	tok := &Token{Type: PIdent, Range: Range{Start: l.pos, End: l.pos}}
226	if r := l.next(); r != '%' {
227		log.Fatalf("lexer expected '%%', got '%v'", r)
228		return
229	}
230	for l.e == nil {
231		s := l.save()
232		r := l.next()
233		if !isAlphaNumeric(r) && r != '_' {
234			l.restore(s)
235			break
236		}
237	}
238	tok.Range.End = l.pos
239	l.toks = append(l.toks, tok)
240}
241
242// numberOrIdent processes the Ident, Float or Integer token at the current
243// position.
244func (l *lexer) numberOrIdent() {
245	const Unknown Type = -1
246	tok := &Token{Type: Unknown, Range: Range{Start: l.pos, End: l.pos}}
247loop:
248	for l.e == nil {
249		s := l.save()
250		r := l.next()
251		switch {
252		case r == '-', r == '+', isNumeric(r):
253			continue
254		case isAlpha(r), r == '_':
255			switch tok.Type {
256			case Unknown:
257				tok.Type = Ident
258			case Float, Integer:
259				l.err(tok, "invalid number")
260				return
261			}
262		case r == '.':
263			switch tok.Type {
264			case Unknown:
265				tok.Type = Float
266			default:
267				l.restore(s)
268				break loop
269			}
270		default:
271			if tok.Type == Unknown {
272				tok.Type = Integer
273			}
274			l.restore(s)
275			break loop
276		}
277	}
278	tok.Range.End = l.pos
279	l.toks = append(l.toks, tok)
280}
281
282// string processes the String token at the current position.
283// The lexer *must* know the next token is a String before calling.
284func (l *lexer) string() {
285	tok := &Token{Type: String, Range: Range{Start: l.pos, End: l.pos}}
286	if r := l.next(); r != '"' {
287		log.Fatalf("lexer expected '\"', got '%v'", r)
288		return
289	}
290	escape := false
291	for l.e == nil {
292		switch l.next() {
293		case '"':
294			if !escape {
295				tok.Range.End = l.pos
296				l.toks = append(l.toks, tok)
297				return
298			}
299		case '\\':
300			escape = !escape
301		default:
302			escape = false
303		}
304	}
305}
306
307// operator processes the Operator token at the current position.
308// The lexer *must* know the next token is a Operator before calling.
309func (l *lexer) operator() {
310	tok := &Token{Type: Operator, Range: Range{Start: l.pos, End: l.pos}}
311	for l.e == nil {
312		switch l.next() {
313		case '=', '|':
314			tok.Range.End = l.pos
315			l.toks = append(l.toks, tok)
316			return
317		}
318	}
319}
320
321// lineComment processes the Comment token at the current position.
322// The lexer *must* know the next token is a Comment before calling.
323func (l *lexer) lineComment() {
324	tok := &Token{Type: Comment, Range: Range{Start: l.pos, End: l.pos}}
325	if r := l.next(); r != ';' {
326		log.Fatalf("lexer expected ';', got '%v'", r)
327		return
328	}
329	for l.e == nil {
330		s := l.save()
331		switch l.next() {
332		case '\n':
333			l.restore(s)
334			tok.Range.End = l.pos
335			l.toks = append(l.toks, tok)
336			return
337		}
338	}
339}
340
341// newline processes the Newline token at the current position.
342// The lexer *must* know the next token is a Newline before calling.
343func (l *lexer) newline() {
344	tok := &Token{Type: Newline, Range: Range{Start: l.pos, End: l.pos}}
345	if r := l.next(); r != '\n' {
346		log.Fatalf("lexer expected '\n', got '%v'", r)
347		return
348	}
349	tok.Range.End = l.pos
350	l.toks = append(l.toks, tok)
351}
352
353// lex returns all the tokens and diagnostics after lexing source.
354func lex(source string) ([]*Token, []Diagnostic, error) {
355	l := lexer{source: source, lexerState: lexerState{pos: Position{1, 1}}}
356
357	lastPos := Position{}
358	for l.e == nil {
359		// Integrity check that the parser is making progress
360		if l.pos == lastPos {
361			log.Panicf("Parsing stuck at %v", l.pos)
362		}
363		lastPos = l.pos
364
365		s := l.save()
366		r := l.next()
367		switch {
368		case r == '%':
369			l.restore(s)
370			l.pident()
371		case r == '+' || r == '-' || r == '_' || isAlphaNumeric(r):
372			l.restore(s)
373			l.numberOrIdent()
374		case r == '"':
375			l.restore(s)
376			l.string()
377		case r == '=', r == '|':
378			l.restore(s)
379			l.operator()
380		case r == ';':
381			l.restore(s)
382			l.lineComment()
383		case r == '\n':
384			l.restore(s)
385			l.newline()
386		}
387	}
388	if l.e != nil && l.e != io.EOF {
389		return nil, nil, l.e
390	}
391	return l.toks, l.diags, nil
392}
393
394func isNumeric(r rune) bool      { return unicode.IsDigit(r) }
395func isAlpha(r rune) bool        { return unicode.IsLetter(r) }
396func isAlphaNumeric(r rune) bool { return isAlpha(r) || isNumeric(r) }
397
398type parser struct {
399	lines          []string                    // all source lines
400	toks           []*Token                    // all tokens
401	diags          []Diagnostic                // parser emitted diagnostics
402	idents         map[string]*Identifier      // identifiers by name
403	mappings       map[*Token]interface{}      // tokens to semantic map
404	extInstImports map[string]schema.OpcodeMap // extension imports by identifier
405	insts          []*Instruction              // all instructions
406}
407
408func (p *parser) parse() error {
409	for i := 0; i < len(p.toks); {
410		if p.newline(i) || p.comment(i) {
411			i++
412			continue
413		}
414		if n := p.instruction(i); n > 0 {
415			i += n
416		} else {
417			p.unexpected(i)
418			i++
419		}
420	}
421	return nil
422}
423
424// instruction parses the instruction starting at the i'th token.
425func (p *parser) instruction(i int) (n int) {
426	inst := &Instruction{}
427
428	switch {
429	case p.opcode(i) != nil:
430		inst.Opcode = p.opcode(i)
431		inst.Tokens = []*Token{p.tok(i)}
432		p.mappings[p.tok(i)] = inst
433		n++
434	case p.opcode(i+2) != nil: // try '%id' '='
435		inst.Result, inst.Opcode = p.pident(i), p.opcode(i+2)
436		if inst.Result == nil || p.operator(i+1) != "=" {
437			return 0
438		}
439		n += 3
440		inst.Tokens = []*Token{p.tok(i), p.tok(i + 1), p.tok(i + 2)}
441		p.mappings[p.tok(i+2)] = inst
442	default:
443		return
444	}
445
446	expectsResult := len(inst.Opcode.Operands) > 0 && IsResult(inst.Opcode.Operands[0].Kind)
447	operands := inst.Opcode.Operands
448	switch {
449	case inst.Result != nil && !expectsResult:
450		p.err(inst.Result, "'%s' does not have a result", inst.Opcode.Opname)
451		return
452	case inst.Result == nil && expectsResult:
453		p.err(p.tok(i), "'%s' expects a result", inst.Opcode.Opname)
454		return
455	case inst.Result != nil && expectsResult:
456		// Check the result is of the correct type
457		o := inst.Opcode.Operands[0]
458		p.operand(o.Name, o.Kind, i, false)
459		operands = operands[1:]
460		p.addIdentDef(inst.Result.Text(p.lines), inst, p.tok(i))
461	}
462
463	processOperand := func(o schema.Operand) bool {
464		if p.newline(i + n) {
465			return false
466		}
467
468		switch o.Quantifier {
469		case schema.Once:
470			if op, c := p.operand(o.Name, o.Kind, i+n, false); op != nil {
471				inst.Tokens = append(inst.Tokens, op.Tokens...)
472				n += c
473			}
474		case schema.ZeroOrOnce:
475			if op, c := p.operand(o.Name, o.Kind, i+n, true); op != nil {
476				inst.Tokens = append(inst.Tokens, op.Tokens...)
477				n += c
478			}
479		case schema.ZeroOrMany:
480			for !p.newline(i + n) {
481				if op, c := p.operand(o.Name, o.Kind, i+n, true); op != nil {
482					inst.Tokens = append(inst.Tokens, op.Tokens...)
483					n += c
484				} else {
485					return false
486				}
487			}
488		}
489		return true
490	}
491
492	for _, o := range operands {
493		if !processOperand(o) {
494			break
495		}
496
497		if inst.Opcode == schema.OpExtInst && n == 4 {
498			extImportTok, extNameTok := p.tok(i+n), p.tok(i+n+1)
499			extImport := extImportTok.Text(p.lines)
500			if extOpcodes, ok := p.extInstImports[extImport]; ok {
501				extName := extNameTok.Text(p.lines)
502				if extOpcode, ok := extOpcodes[extName]; ok {
503					n += 2 // skip ext import, ext name
504					for _, o := range extOpcode.Operands {
505						if !processOperand(o) {
506							break
507						}
508					}
509				} else {
510					p.err(extNameTok, "Unknown extension opcode '%s'", extName)
511				}
512			} else {
513				p.err(extImportTok, "Expected identifier to OpExtInstImport")
514			}
515		}
516	}
517
518	for _, t := range inst.Tokens {
519		inst.Range.grow(t.Range)
520	}
521
522	p.insts = append(p.insts, inst)
523
524	if inst.Opcode == schema.OpExtInstImport && len(inst.Tokens) >= 4 {
525		// Instruction is a OpExtInstImport. Keep track of this.
526		extTok := inst.Tokens[3]
527		extName := strings.Trim(extTok.Text(p.lines), `"`)
528		extOpcodes, ok := schema.ExtOpcodes[extName]
529		if !ok {
530			p.err(extTok, "Unknown extension '%s'", extName)
531		}
532		extImport := inst.Result.Text(p.lines)
533		p.extInstImports[extImport] = extOpcodes
534	}
535
536	return
537}
538
539// operand parses the operand with the name n, kind k, starting at the i'th
540// token.
541func (p *parser) operand(n string, k *schema.OperandKind, i int, optional bool) (*Operand, int) {
542	tok := p.tok(i)
543	if tok == nil {
544		return nil, 0
545	}
546
547	op := &Operand{
548		Name:   n,
549		Kind:   k,
550		Tokens: []*Token{tok},
551	}
552	p.mappings[tok] = op
553
554	switch k.Category {
555	case schema.OperandCategoryBitEnum, schema.OperandCategoryValueEnum:
556		s := tok.Text(p.lines)
557		for _, e := range k.Enumerants {
558			if e.Enumerant == s {
559				count := 1
560				for _, param := range e.Parameters {
561					p, c := p.operand(param.Name, param.Kind, i+count, false)
562					if p != nil {
563						op.Tokens = append(op.Tokens, p.Tokens...)
564						op.Parameters = append(op.Parameters, p)
565					}
566					count += c
567				}
568
569				// Handle bitfield '|' chains
570				if p.tok(i+count).Text(p.lines) == "|" {
571					count++ // '|'
572					p, c := p.operand(n, k, i+count, false)
573					if p != nil {
574						op.Tokens = append(op.Tokens, p.Tokens...)
575						op.Parameters = append(op.Parameters, p)
576					}
577					count += c
578				}
579
580				return op, count
581			}
582		}
583		if !optional {
584			p.err(p.tok(i), "invalid operand value '%s'", s)
585		}
586
587		return nil, 0
588
589	case schema.OperandCategoryID:
590		id := p.pident(i)
591		if id != nil {
592			p.addIdentRef(p.tok(i))
593			return op, 1
594		}
595		if !optional {
596			p.err(p.tok(i), "operand requires id, got '%s'", tok.Text(p.lines))
597		}
598		return nil, 0
599
600	case schema.OperandCategoryLiteral:
601		switch tok.Type {
602		case String, Integer, Float, Ident:
603			return op, 1
604		}
605		if !optional {
606			p.err(p.tok(i), "operand requires literal, got '%s'", tok.Text(p.lines))
607		}
608		return nil, 0
609
610	case schema.OperandCategoryComposite:
611		n := 1
612		for _, b := range k.Bases {
613			o, c := p.operand(b.Kind, b, i+n, optional)
614			if o != nil {
615				op.Tokens = append(op.Tokens, o.Tokens...)
616			}
617			n += c
618		}
619		return op, n
620
621	default:
622		p.err(p.tok(i), "OperandKind '%s' has unexpected category '%s'", k.Kind, k.Category)
623		return nil, 0
624	}
625}
626
627// tok returns the i'th token, or nil if i is out of bounds.
628func (p *parser) tok(i int) *Token {
629	if i < 0 || i >= len(p.toks) {
630		return nil
631	}
632	return p.toks[i]
633}
634
635// opcode returns the schema.Opcode for the i'th token, or nil if the i'th token
636// does not represent an opcode.
637func (p *parser) opcode(i int) *schema.Opcode {
638	if tok := p.ident(i); tok != nil {
639		name := tok.Text(p.lines)
640		if inst, found := schema.Opcodes[name]; found {
641			return inst
642		}
643	}
644	return nil
645}
646
647// operator returns the operator for the i'th token, or and empty string if the
648// i'th token is not an operator.
649func (p *parser) operator(i int) string {
650	if tok := p.tok(i); tok != nil && tok.Type == Operator {
651		return tok.Text(p.lines)
652	}
653	return ""
654}
655
656// ident returns the i'th token if it is an Ident, otherwise nil.
657func (p *parser) ident(i int) *Token {
658	if tok := p.tok(i); tok != nil && tok.Type == Ident {
659		return tok
660	}
661	return nil
662}
663
664// pident returns the i'th token if it is an PIdent, otherwise nil.
665func (p *parser) pident(i int) *Token {
666	if tok := p.tok(i); tok != nil && tok.Type == PIdent {
667		return tok
668	}
669	return nil
670}
671
672// comment returns true if the i'th token is a Comment, otherwise false.
673func (p *parser) comment(i int) bool {
674	if tok := p.tok(i); tok != nil && tok.Type == Comment {
675		return true
676	}
677	return false
678}
679
680// newline returns true if the i'th token is a Newline, otherwise false.
681func (p *parser) newline(i int) bool {
682	if tok := p.tok(i); tok != nil && tok.Type == Newline {
683		return true
684	}
685	return false
686}
687
688// unexpected emits an 'unexpected token error' for the i'th token.
689func (p *parser) unexpected(i int) {
690	p.err(p.toks[i], "syntax error: unexpected '%s'", p.toks[i].Text(p.lines))
691}
692
693// addIdentDef records the token definition for the instruction inst with the
694// given id.
695func (p *parser) addIdentDef(id string, inst *Instruction, def *Token) {
696	i, existing := p.idents[id]
697	if !existing {
698		i = &Identifier{}
699		p.idents[id] = i
700	}
701	if i.Definition == nil {
702		i.Definition = inst
703	} else {
704		p.err(def, "id '%v' redeclared", id)
705	}
706}
707
708// addIdentRef adds a identifier reference for the token ref.
709func (p *parser) addIdentRef(ref *Token) {
710	id := ref.Text(p.lines)
711	i, existing := p.idents[id]
712	if !existing {
713		i = &Identifier{}
714		p.idents[id] = i
715	}
716	i.References = append(i.References, ref)
717}
718
719// err appends an fmt.Printf style error into l.diags for the given token.
720func (p *parser) err(tok *Token, msg string, args ...interface{}) {
721	rng := Range{}
722	if tok != nil {
723		rng = tok.Range
724	}
725	p.diags = append(p.diags, Diagnostic{
726		Range:    rng,
727		Severity: SeverityError,
728		Message:  fmt.Sprintf(msg, args...),
729	})
730}
731
732// Parse parses the SPIR-V assembly string source, returning the parse results.
733func Parse(source string) (Results, error) {
734	toks, diags, err := lex(source)
735	if err != nil {
736		return Results{}, err
737	}
738	lines := strings.SplitAfter(source, "\n")
739	p := parser{
740		lines:          lines,
741		toks:           toks,
742		idents:         map[string]*Identifier{},
743		mappings:       map[*Token]interface{}{},
744		extInstImports: map[string]schema.OpcodeMap{},
745	}
746	if err := p.parse(); err != nil {
747		return Results{}, err
748	}
749	diags = append(diags, p.diags...)
750	return Results{
751		Lines:       lines,
752		Tokens:      toks,
753		Diagnostics: p.diags,
754		Identifiers: p.idents,
755		Mappings:    p.mappings,
756	}, nil
757}
758
759// IsResult returns true if k is used to store the result of an instruction.
760func IsResult(k *schema.OperandKind) bool {
761	switch k {
762	case schema.OperandKindIdResult, schema.OperandKindIdResultType:
763		return true
764	default:
765		return false
766	}
767}
768
769// Results holds the output of Parse().
770type Results struct {
771	Lines       []string
772	Tokens      []*Token
773	Diagnostics []Diagnostic
774	Identifiers map[string]*Identifier // identifiers by name
775	Mappings    map[*Token]interface{} // tokens to semantic map
776}
777
778// Instruction describes a single instruction instance
779type Instruction struct {
780	Tokens   []*Token       // all the tokens that make up the instruction
781	Result   *Token         // the token that represents the result of the instruction, or nil
782	Operands []*Operand     // the operands of the instruction
783	Range    Range          // the textual range of the instruction
784	Opcode   *schema.Opcode // the opcode for the instruction
785}
786
787// Operand describes a single operand instance
788type Operand struct {
789	Name       string              // name of the operand
790	Kind       *schema.OperandKind // kind of the operand
791	Tokens     []*Token            // all the tokens that make up the operand
792	Parameters []*Operand          // all the parameters for the operand
793}
794
795// Identifier describes a single, unique SPIR-V identifier (i.e. %32)
796type Identifier struct {
797	Definition *Instruction // where the identifier was defined
798	References []*Token     // all the places the identifier was referenced
799}
800
801// Severity is an enumerator of diagnositc seeverities
802type Severity int
803
804// Severity levels
805const (
806	SeverityError Severity = iota
807	SeverityWarning
808	SeverityInformation
809	SeverityHint
810)
811
812// Diagnostic holds a single diagnostic message that was generated while
813// parsing.
814type Diagnostic struct {
815	Range    Range
816	Severity Severity
817	Message  string
818}
819