1// Copyright (c) 2016, Daniel Martí <mvdan@mvdan.cc>
2// See LICENSE for licensing information
3
4package syntax
5
6import (
7	"bufio"
8	"bytes"
9	"fmt"
10	"io"
11	"strings"
12	"unicode"
13)
14
15// Indent sets the number of spaces used for indentation. If set to 0,
16// tabs will be used instead.
17func Indent(spaces uint) func(*Printer) {
18	return func(p *Printer) { p.indentSpaces = spaces }
19}
20
21// BinaryNextLine will make binary operators appear on the next line
22// when a binary command, such as a pipe, spans multiple lines. A
23// backslash will be used.
24func BinaryNextLine(p *Printer) { p.binNextLine = true }
25
26// SwitchCaseIndent will make switch cases be indented. As such, switch
27// case bodies will be two levels deeper than the switch itself.
28func SwitchCaseIndent(p *Printer) { p.swtCaseIndent = true }
29
30// SpaceRedirects will put a space after most redirection operators. The
31// exceptions are '>&', '<&', '>(', and '<('.
32func SpaceRedirects(p *Printer) { p.spaceRedirects = true }
33
34// KeepPadding will keep most nodes and tokens in the same column that
35// they were in the original source. This allows the user to decide how
36// to align and pad their code with spaces.
37//
38// Note that this feature is best-effort and will only keep the
39// alignment stable, so it may need some human help the first time it is
40// run.
41func KeepPadding(p *Printer) {
42	p.keepPadding = true
43	p.cols.Writer = p.bufWriter.(*bufio.Writer)
44	p.bufWriter = &p.cols
45}
46
47// Minify will print programs in a way to save the most bytes possible.
48// For example, indentation and comments are skipped, and extra
49// whitespace is avoided when possible.
50func Minify(p *Printer) { p.minify = true }
51
52// NewPrinter allocates a new Printer and applies any number of options.
53func NewPrinter(options ...func(*Printer)) *Printer {
54	p := &Printer{
55		bufWriter:   bufio.NewWriter(nil),
56		lenPrinter:  new(Printer),
57		tabsPrinter: new(Printer),
58	}
59	for _, opt := range options {
60		opt(p)
61	}
62	return p
63}
64
65// Print "pretty-prints" the given syntax tree node to the given writer. Writes
66// to w are buffered.
67//
68// The node types supported at the moment are *File, *Stmt, *Word, any Command
69// node, and any WordPart node. A trailing newline will only be printed when a
70// *File is used.
71func (p *Printer) Print(w io.Writer, node Node) error {
72	p.reset()
73	p.bufWriter.Reset(w)
74	switch x := node.(type) {
75	case *File:
76		p.stmtList(x.StmtList)
77		p.newline(x.End())
78	case *Stmt:
79		p.stmtList(StmtList{Stmts: []*Stmt{x}})
80	case Command:
81		p.line = x.Pos().Line()
82		p.command(x, nil)
83	case *Word:
84		p.word(x)
85	case WordPart:
86		p.wordPart(x, nil)
87	default:
88		return fmt.Errorf("unsupported node type: %T", x)
89	}
90	p.flushHeredocs()
91	p.flushComments()
92	return p.bufWriter.Flush()
93}
94
95type bufWriter interface {
96	Write([]byte) (int, error)
97	WriteString(string) (int, error)
98	WriteByte(byte) error
99	Reset(io.Writer)
100	Flush() error
101}
102
103type colCounter struct {
104	*bufio.Writer
105	column    int
106	lineStart bool
107}
108
109func (c *colCounter) WriteByte(b byte) error {
110	switch b {
111	case '\n':
112		c.column = 0
113		c.lineStart = true
114	case '\t', ' ':
115	default:
116		c.lineStart = false
117	}
118	c.column++
119	return c.Writer.WriteByte(b)
120}
121
122func (c *colCounter) WriteString(s string) (int, error) {
123	c.lineStart = false
124	for _, r := range s {
125		if r == '\n' {
126			c.column = 0
127		}
128		c.column++
129	}
130	return c.Writer.WriteString(s)
131}
132
133func (c *colCounter) Reset(w io.Writer) {
134	c.column = 1
135	c.lineStart = true
136	c.Writer.Reset(w)
137}
138
139// Printer holds the internal state of the printing mechanism of a
140// program.
141type Printer struct {
142	bufWriter
143	cols colCounter
144
145	indentSpaces   uint
146	binNextLine    bool
147	swtCaseIndent  bool
148	spaceRedirects bool
149	keepPadding    bool
150	minify         bool
151
152	wantSpace   bool
153	wantNewline bool
154	wroteSemi   bool
155
156	commentPadding uint
157
158	// pendingComments are any comments in the current line or statement
159	// that we have yet to print. This is useful because that way, we can
160	// ensure that all comments are written immediately before a newline.
161	// Otherwise, in some edge cases we might wrongly place words after a
162	// comment in the same line, breaking programs.
163	pendingComments []Comment
164
165	// firstLine means we are still writing the first line
166	firstLine bool
167	// line is the current line number
168	line uint
169
170	// lastLevel is the last level of indentation that was used.
171	lastLevel uint
172	// level is the current level of indentation.
173	level uint
174	// levelIncs records which indentation level increments actually
175	// took place, to revert them once their section ends.
176	levelIncs []bool
177
178	nestedBinary bool
179
180	// pendingHdocs is the list of pending heredocs to write.
181	pendingHdocs []*Redirect
182
183	// used in stmtCols to align comments
184	lenPrinter *Printer
185	lenCounter byteCounter
186
187	// used when printing <<- heredocs with tab indentation
188	tabsPrinter *Printer
189}
190
191func (p *Printer) reset() {
192	p.wantSpace, p.wantNewline = false, false
193	p.commentPadding = 0
194	p.pendingComments = p.pendingComments[:0]
195
196	// minification uses its own newline logic
197	p.firstLine = !p.minify
198	p.line = 0
199
200	p.lastLevel, p.level = 0, 0
201	p.levelIncs = p.levelIncs[:0]
202	p.nestedBinary = false
203	p.pendingHdocs = p.pendingHdocs[:0]
204}
205
206func (p *Printer) spaces(n uint) {
207	for i := uint(0); i < n; i++ {
208		p.WriteByte(' ')
209	}
210}
211
212func (p *Printer) space() {
213	p.WriteByte(' ')
214	p.wantSpace = false
215}
216
217func (p *Printer) spacePad(pos Pos) {
218	if p.wantSpace {
219		p.WriteByte(' ')
220		p.wantSpace = false
221	}
222	if p.cols.lineStart {
223		// Never add padding at the start of a line, since this may
224		// result in broken indentation or mixing of spaces and tabs.
225		return
226	}
227	for !p.cols.lineStart && p.cols.column > 0 && p.cols.column < int(pos.col) {
228		p.WriteByte(' ')
229	}
230}
231
232func (p *Printer) bslashNewl() {
233	if p.wantSpace {
234		p.space()
235	}
236	p.WriteString("\\\n")
237	p.line++
238	p.indent()
239}
240
241func (p *Printer) spacedString(s string, pos Pos) {
242	p.spacePad(pos)
243	p.WriteString(s)
244	p.wantSpace = true
245}
246
247func (p *Printer) spacedToken(s string, pos Pos) {
248	if p.minify {
249		p.WriteString(s)
250		p.wantSpace = false
251		return
252	}
253	p.spacePad(pos)
254	p.WriteString(s)
255	p.wantSpace = true
256}
257
258func (p *Printer) semiOrNewl(s string, pos Pos) {
259	if p.wantNewline {
260		p.newline(pos)
261		p.indent()
262	} else {
263		if !p.wroteSemi {
264			p.WriteByte(';')
265		}
266		if !p.minify {
267			p.space()
268		}
269		p.line = pos.Line()
270	}
271	p.WriteString(s)
272	p.wantSpace = true
273}
274
275func (p *Printer) incLevel() {
276	inc := false
277	if p.level <= p.lastLevel || len(p.levelIncs) == 0 {
278		p.level++
279		inc = true
280	} else if last := &p.levelIncs[len(p.levelIncs)-1]; *last {
281		*last = false
282		inc = true
283	}
284	p.levelIncs = append(p.levelIncs, inc)
285}
286
287func (p *Printer) decLevel() {
288	if p.levelIncs[len(p.levelIncs)-1] {
289		p.level--
290	}
291	p.levelIncs = p.levelIncs[:len(p.levelIncs)-1]
292}
293
294func (p *Printer) indent() {
295	if p.minify {
296		return
297	}
298	p.lastLevel = p.level
299	switch {
300	case p.level == 0:
301	case p.indentSpaces == 0:
302		for i := uint(0); i < p.level; i++ {
303			p.WriteByte('\t')
304		}
305	default:
306		p.spaces(p.indentSpaces * p.level)
307	}
308}
309
310func (p *Printer) newline(pos Pos) {
311	p.flushHeredocs()
312	p.flushComments()
313	p.WriteByte('\n')
314	p.wantNewline, p.wantSpace = false, false
315	if p.line < pos.Line() {
316		p.line++
317	}
318}
319
320func (p *Printer) flushHeredocs() {
321	if len(p.pendingHdocs) == 0 {
322		return
323	}
324	hdocs := p.pendingHdocs
325	p.pendingHdocs = p.pendingHdocs[:0]
326	coms := p.pendingComments
327	p.pendingComments = nil
328	if len(coms) > 0 {
329		c := coms[0]
330		if c.Pos().Line() == p.line {
331			p.pendingComments = append(p.pendingComments, c)
332			p.flushComments()
333			coms = coms[1:]
334		}
335	}
336
337	// Reuse the last indentation level, as
338	// indentation levels are usually changed before
339	// newlines are printed along with their
340	// subsequent indentation characters.
341	newLevel := p.level
342	p.level = p.lastLevel
343
344	for _, r := range hdocs {
345		p.line++
346		p.WriteByte('\n')
347		p.wantNewline, p.wantSpace = false, false
348		if r.Op == DashHdoc && p.indentSpaces == 0 &&
349			!p.minify && p.tabsPrinter != nil {
350			if r.Hdoc != nil {
351				extra := extraIndenter{
352					bufWriter:   p.bufWriter,
353					baseIndent:  int(p.level + 1),
354					firstIndent: -1,
355				}
356				*p.tabsPrinter = Printer{
357					bufWriter: &extra,
358				}
359				p.tabsPrinter.line = r.Hdoc.Pos().Line()
360				p.tabsPrinter.word(r.Hdoc)
361				p.indent()
362				p.line = r.Hdoc.End().Line()
363			} else {
364				p.indent()
365			}
366		} else if r.Hdoc != nil {
367			p.word(r.Hdoc)
368			p.line = r.Hdoc.End().Line()
369		}
370		p.unquotedWord(r.Word)
371		p.wantSpace = false
372	}
373	p.level = newLevel
374	p.pendingComments = coms
375}
376
377func (p *Printer) newlines(pos Pos) {
378	if p.firstLine && len(p.pendingComments) == 0 {
379		p.firstLine = false
380		return // no empty lines at the top
381	}
382	if !p.wantNewline && pos.Line() <= p.line {
383		return
384	}
385	p.newline(pos)
386	if pos.Line() > p.line {
387		if !p.minify {
388			// preserve single empty lines
389			p.WriteByte('\n')
390		}
391		p.line++
392	}
393	p.indent()
394}
395
396func (p *Printer) rightParen(pos Pos) {
397	if !p.minify {
398		p.newlines(pos)
399	}
400	p.WriteByte(')')
401	p.wantSpace = true
402}
403
404func (p *Printer) semiRsrv(s string, pos Pos) {
405	if p.wantNewline || pos.Line() > p.line {
406		p.newlines(pos)
407	} else {
408		if !p.wroteSemi {
409			p.WriteByte(';')
410		}
411		if !p.minify {
412			p.spacePad(pos)
413		}
414	}
415	p.WriteString(s)
416	p.wantSpace = true
417}
418
419func (p *Printer) flushComments() {
420	for i, c := range p.pendingComments {
421		p.firstLine = false
422		// We can't call any of the newline methods, as they call this
423		// function and we'd recurse forever.
424		cline := c.Hash.Line()
425		switch {
426		case i > 0, cline > p.line && p.line > 0:
427			p.WriteByte('\n')
428			if cline > p.line+1 {
429				p.WriteByte('\n')
430			}
431			p.indent()
432		case p.wantSpace:
433			if p.keepPadding {
434				p.spacePad(c.Pos())
435			} else {
436				p.spaces(p.commentPadding + 1)
437			}
438		}
439		// don't go back one line, which may happen in some edge cases
440		if p.line < cline {
441			p.line = cline
442		}
443		p.WriteByte('#')
444		p.WriteString(strings.TrimRightFunc(c.Text, unicode.IsSpace))
445		p.wantNewline = true
446	}
447	p.pendingComments = nil
448}
449
450func (p *Printer) comments(comments ...Comment) {
451	if p.minify {
452		return
453	}
454	p.pendingComments = append(p.pendingComments, comments...)
455}
456
457func (p *Printer) wordParts(wps []WordPart) {
458	for i, n := range wps {
459		var next WordPart
460		if i+1 < len(wps) {
461			next = wps[i+1]
462		}
463		p.wordPart(n, next)
464	}
465}
466
467func (p *Printer) wordPart(wp, next WordPart) {
468	switch x := wp.(type) {
469	case *Lit:
470		p.WriteString(x.Value)
471	case *SglQuoted:
472		if x.Dollar {
473			p.WriteByte('$')
474		}
475		p.WriteByte('\'')
476		p.WriteString(x.Value)
477		p.WriteByte('\'')
478		p.line = x.End().Line()
479	case *DblQuoted:
480		p.dblQuoted(x)
481	case *CmdSubst:
482		p.line = x.Pos().Line()
483		switch {
484		case x.TempFile:
485			p.WriteString("${")
486			p.wantSpace = true
487			p.nestedStmts(x.StmtList, x.Right)
488			p.wantSpace = false
489			p.semiRsrv("}", x.Right)
490		case x.ReplyVar:
491			p.WriteString("${|")
492			p.nestedStmts(x.StmtList, x.Right)
493			p.wantSpace = false
494			p.semiRsrv("}", x.Right)
495		default:
496			p.WriteString("$(")
497			p.wantSpace = len(x.Stmts) > 0 && startsWithLparen(x.Stmts[0])
498			p.nestedStmts(x.StmtList, x.Right)
499			p.rightParen(x.Right)
500		}
501	case *ParamExp:
502		litCont := ";"
503		if nextLit, ok := next.(*Lit); ok && nextLit.Value != "" {
504			litCont = nextLit.Value[:1]
505		}
506		name := x.Param.Value
507		switch {
508		case !p.minify:
509		case x.Excl, x.Length, x.Width:
510		case x.Index != nil, x.Slice != nil:
511		case x.Repl != nil, x.Exp != nil:
512		case len(name) > 1 && !ValidName(name): // ${10}
513		case ValidName(name + litCont): // ${var}cont
514		default:
515			x2 := *x
516			x2.Short = true
517			p.paramExp(&x2)
518			return
519		}
520		p.paramExp(x)
521	case *ArithmExp:
522		p.WriteString("$((")
523		if x.Unsigned {
524			p.WriteString("# ")
525		}
526		p.arithmExpr(x.X, false, false)
527		p.WriteString("))")
528	case *ExtGlob:
529		p.WriteString(x.Op.String())
530		p.WriteString(x.Pattern.Value)
531		p.WriteByte(')')
532	case *ProcSubst:
533		// avoid conflict with << and others
534		if p.wantSpace {
535			p.space()
536		}
537		p.WriteString(x.Op.String())
538		p.nestedStmts(x.StmtList, x.Rparen)
539		p.rightParen(x.Rparen)
540	}
541}
542
543func (p *Printer) dblQuoted(dq *DblQuoted) {
544	if dq.Dollar {
545		p.WriteByte('$')
546	}
547	p.WriteByte('"')
548	if len(dq.Parts) > 0 {
549		p.wordParts(dq.Parts)
550		p.line = dq.Parts[len(dq.Parts)-1].End().Line()
551	}
552	p.WriteByte('"')
553}
554
555func (p *Printer) wroteIndex(index ArithmExpr) bool {
556	if index == nil {
557		return false
558	}
559	p.WriteByte('[')
560	p.arithmExpr(index, false, false)
561	p.WriteByte(']')
562	return true
563}
564
565func (p *Printer) paramExp(pe *ParamExp) {
566	if pe.nakedIndex() { // arr[x]
567		p.WriteString(pe.Param.Value)
568		p.wroteIndex(pe.Index)
569		return
570	}
571	if pe.Short { // $var
572		p.WriteByte('$')
573		p.WriteString(pe.Param.Value)
574		return
575	}
576	// ${var...}
577	p.WriteString("${")
578	switch {
579	case pe.Length:
580		p.WriteByte('#')
581	case pe.Width:
582		p.WriteByte('%')
583	case pe.Excl:
584		p.WriteByte('!')
585	}
586	p.WriteString(pe.Param.Value)
587	p.wroteIndex(pe.Index)
588	switch {
589	case pe.Slice != nil:
590		p.WriteByte(':')
591		p.arithmExpr(pe.Slice.Offset, true, true)
592		if pe.Slice.Length != nil {
593			p.WriteByte(':')
594			p.arithmExpr(pe.Slice.Length, true, false)
595		}
596	case pe.Repl != nil:
597		if pe.Repl.All {
598			p.WriteByte('/')
599		}
600		p.WriteByte('/')
601		if pe.Repl.Orig != nil {
602			p.word(pe.Repl.Orig)
603		}
604		p.WriteByte('/')
605		if pe.Repl.With != nil {
606			p.word(pe.Repl.With)
607		}
608	case pe.Names != 0:
609		p.WriteString(pe.Names.String())
610	case pe.Exp != nil:
611		p.WriteString(pe.Exp.Op.String())
612		if pe.Exp.Word != nil {
613			p.word(pe.Exp.Word)
614		}
615	}
616	p.WriteByte('}')
617}
618
619func (p *Printer) loop(loop Loop) {
620	switch x := loop.(type) {
621	case *WordIter:
622		p.WriteString(x.Name.Value)
623		if len(x.Items) > 0 {
624			p.spacedString(" in", Pos{})
625			p.wordJoin(x.Items)
626		}
627	case *CStyleLoop:
628		p.WriteString("((")
629		if x.Init == nil {
630			p.space()
631		}
632		p.arithmExpr(x.Init, false, false)
633		p.WriteString("; ")
634		p.arithmExpr(x.Cond, false, false)
635		p.WriteString("; ")
636		p.arithmExpr(x.Post, false, false)
637		p.WriteString("))")
638	}
639}
640
641func (p *Printer) arithmExpr(expr ArithmExpr, compact, spacePlusMinus bool) {
642	if p.minify {
643		compact = true
644	}
645	switch x := expr.(type) {
646	case *Word:
647		p.word(x)
648	case *BinaryArithm:
649		if compact {
650			p.arithmExpr(x.X, compact, spacePlusMinus)
651			p.WriteString(x.Op.String())
652			p.arithmExpr(x.Y, compact, false)
653		} else {
654			p.arithmExpr(x.X, compact, spacePlusMinus)
655			if x.Op != Comma {
656				p.space()
657			}
658			p.WriteString(x.Op.String())
659			p.space()
660			p.arithmExpr(x.Y, compact, false)
661		}
662	case *UnaryArithm:
663		if x.Post {
664			p.arithmExpr(x.X, compact, spacePlusMinus)
665			p.WriteString(x.Op.String())
666		} else {
667			if spacePlusMinus {
668				switch x.Op {
669				case Plus, Minus:
670					p.space()
671				}
672			}
673			p.WriteString(x.Op.String())
674			p.arithmExpr(x.X, compact, false)
675		}
676	case *ParenArithm:
677		p.WriteByte('(')
678		p.arithmExpr(x.X, false, false)
679		p.WriteByte(')')
680	}
681}
682
683func (p *Printer) testExpr(expr TestExpr) {
684	switch x := expr.(type) {
685	case *Word:
686		p.word(x)
687	case *BinaryTest:
688		p.testExpr(x.X)
689		p.space()
690		p.WriteString(x.Op.String())
691		p.space()
692		p.testExpr(x.Y)
693	case *UnaryTest:
694		p.WriteString(x.Op.String())
695		p.space()
696		p.testExpr(x.X)
697	case *ParenTest:
698		p.WriteByte('(')
699		p.testExpr(x.X)
700		p.WriteByte(')')
701	}
702}
703
704func (p *Printer) word(w *Word) {
705	p.wordParts(w.Parts)
706	p.wantSpace = true
707}
708
709func (p *Printer) unquotedWord(w *Word) {
710	for _, wp := range w.Parts {
711		switch x := wp.(type) {
712		case *SglQuoted:
713			p.WriteString(x.Value)
714		case *DblQuoted:
715			p.wordParts(x.Parts)
716		case *Lit:
717			for i := 0; i < len(x.Value); i++ {
718				if b := x.Value[i]; b == '\\' {
719					if i++; i < len(x.Value) {
720						p.WriteByte(x.Value[i])
721					}
722				} else {
723					p.WriteByte(b)
724				}
725			}
726		}
727	}
728}
729
730func (p *Printer) wordJoin(ws []*Word) {
731	anyNewline := false
732	for _, w := range ws {
733		if pos := w.Pos(); pos.Line() > p.line {
734			if !anyNewline {
735				p.incLevel()
736				anyNewline = true
737			}
738			p.bslashNewl()
739		} else {
740			p.spacePad(w.Pos())
741		}
742		p.word(w)
743	}
744	if anyNewline {
745		p.decLevel()
746	}
747}
748
749func (p *Printer) casePatternJoin(pats []*Word) {
750	anyNewline := false
751	for i, w := range pats {
752		if i > 0 {
753			p.spacedToken("|", Pos{})
754		}
755		if pos := w.Pos(); pos.Line() > p.line {
756			if !anyNewline {
757				p.incLevel()
758				anyNewline = true
759			}
760			p.bslashNewl()
761		} else {
762			p.spacePad(w.Pos())
763		}
764		p.word(w)
765	}
766	if anyNewline {
767		p.decLevel()
768	}
769}
770
771func (p *Printer) elemJoin(elems []*ArrayElem, last []Comment) {
772	p.incLevel()
773	for _, el := range elems {
774		var left []Comment
775		for _, c := range el.Comments {
776			if c.Pos().After(el.Pos()) {
777				left = append(left, c)
778				break
779			}
780			p.comments(c)
781		}
782		if el.Pos().Line() > p.line {
783			p.newline(el.Pos())
784			p.indent()
785		} else if p.wantSpace {
786			p.space()
787		}
788		if p.wroteIndex(el.Index) {
789			p.WriteByte('=')
790		}
791		p.word(el.Value)
792		p.comments(left...)
793	}
794	if len(last) > 0 {
795		p.comments(last...)
796		p.flushComments()
797	}
798	p.decLevel()
799}
800
801func (p *Printer) stmt(s *Stmt) {
802	p.wroteSemi = false
803	if s.Negated {
804		p.spacedString("!", s.Pos())
805	}
806	var startRedirs int
807	if s.Cmd != nil {
808		startRedirs = p.command(s.Cmd, s.Redirs)
809	}
810	p.incLevel()
811	for _, r := range s.Redirs[startRedirs:] {
812		if r.OpPos.Line() > p.line {
813			p.bslashNewl()
814		}
815		if p.wantSpace {
816			p.spacePad(r.Pos())
817		}
818		if r.N != nil {
819			p.WriteString(r.N.Value)
820		}
821		p.WriteString(r.Op.String())
822		if p.spaceRedirects && (r.Op != DplIn && r.Op != DplOut) {
823			p.space()
824		} else {
825			p.wantSpace = true
826		}
827		p.word(r.Word)
828		if r.Op == Hdoc || r.Op == DashHdoc {
829			p.pendingHdocs = append(p.pendingHdocs, r)
830		}
831	}
832	switch {
833	case s.Semicolon.IsValid() && s.Semicolon.Line() > p.line:
834		p.bslashNewl()
835		p.WriteByte(';')
836		p.wroteSemi = true
837	case s.Background:
838		if !p.minify {
839			p.space()
840		}
841		p.WriteString("&")
842	case s.Coprocess:
843		if !p.minify {
844			p.space()
845		}
846		p.WriteString("|&")
847	}
848	p.decLevel()
849}
850
851func (p *Printer) command(cmd Command, redirs []*Redirect) (startRedirs int) {
852	p.spacePad(cmd.Pos())
853	switch x := cmd.(type) {
854	case *CallExpr:
855		p.assigns(x.Assigns)
856		if len(x.Args) <= 1 {
857			p.wordJoin(x.Args)
858			return 0
859		}
860		p.wordJoin(x.Args[:1])
861		for _, r := range redirs {
862			if r.Pos().After(x.Args[1].Pos()) || r.Op == Hdoc || r.Op == DashHdoc {
863				break
864			}
865			if p.wantSpace {
866				p.spacePad(r.Pos())
867			}
868			if r.N != nil {
869				p.WriteString(r.N.Value)
870			}
871			p.WriteString(r.Op.String())
872			if p.spaceRedirects && (r.Op != DplIn && r.Op != DplOut) {
873				p.space()
874			} else {
875				p.wantSpace = true
876			}
877			p.word(r.Word)
878			startRedirs++
879		}
880		p.wordJoin(x.Args[1:])
881	case *Block:
882		p.WriteByte('{')
883		p.wantSpace = true
884		p.nestedStmts(x.StmtList, x.Rbrace)
885		p.semiRsrv("}", x.Rbrace)
886	case *IfClause:
887		p.ifClause(x, false)
888	case *Subshell:
889		p.WriteByte('(')
890		p.wantSpace = len(x.Stmts) > 0 && startsWithLparen(x.Stmts[0])
891		p.spacePad(x.StmtList.pos())
892		p.nestedStmts(x.StmtList, x.Rparen)
893		p.wantSpace = false
894		p.spacePad(x.Rparen)
895		p.rightParen(x.Rparen)
896	case *WhileClause:
897		if x.Until {
898			p.spacedString("until", x.Pos())
899		} else {
900			p.spacedString("while", x.Pos())
901		}
902		p.nestedStmts(x.Cond, Pos{})
903		p.semiOrNewl("do", x.DoPos)
904		p.nestedStmts(x.Do, x.DonePos)
905		p.semiRsrv("done", x.DonePos)
906	case *ForClause:
907		if x.Select {
908			p.WriteString("select ")
909		} else {
910			p.WriteString("for ")
911		}
912		p.loop(x.Loop)
913		p.semiOrNewl("do", x.DoPos)
914		p.nestedStmts(x.Do, x.DonePos)
915		p.semiRsrv("done", x.DonePos)
916	case *BinaryCmd:
917		p.stmt(x.X)
918		if p.minify || x.Y.Pos().Line() <= p.line {
919			// leave p.nestedBinary untouched
920			p.spacedToken(x.Op.String(), x.OpPos)
921			p.line = x.Y.Pos().Line()
922			p.stmt(x.Y)
923			break
924		}
925		indent := !p.nestedBinary
926		if indent {
927			p.incLevel()
928		}
929		if p.binNextLine {
930			if len(p.pendingHdocs) == 0 {
931				p.bslashNewl()
932			}
933			p.spacedToken(x.Op.String(), x.OpPos)
934			if len(x.Y.Comments) > 0 {
935				p.wantSpace = false
936				p.newline(Pos{})
937				p.indent()
938				p.comments(x.Y.Comments...)
939				p.newline(Pos{})
940				p.indent()
941			}
942		} else {
943			p.spacedToken(x.Op.String(), x.OpPos)
944			p.line = x.OpPos.Line()
945			p.comments(x.Y.Comments...)
946			p.newline(Pos{})
947			p.indent()
948		}
949		p.line = x.Y.Pos().Line()
950		_, p.nestedBinary = x.Y.Cmd.(*BinaryCmd)
951		p.stmt(x.Y)
952		if indent {
953			p.decLevel()
954		}
955		p.nestedBinary = false
956	case *FuncDecl:
957		if x.RsrvWord {
958			p.WriteString("function ")
959		}
960		p.WriteString(x.Name.Value)
961		p.WriteString("()")
962		if !p.minify {
963			p.space()
964		}
965		p.line = x.Body.Pos().Line()
966		p.comments(x.Body.Comments...)
967		p.stmt(x.Body)
968	case *CaseClause:
969		p.WriteString("case ")
970		p.word(x.Word)
971		p.WriteString(" in")
972		if p.swtCaseIndent {
973			p.incLevel()
974		}
975		for i, ci := range x.Items {
976			var last []Comment
977			for i, c := range ci.Comments {
978				if c.Pos().After(ci.Pos()) {
979					last = ci.Comments[i:]
980					break
981				}
982				p.comments(c)
983			}
984			p.newlines(ci.Pos())
985			p.casePatternJoin(ci.Patterns)
986			p.WriteByte(')')
987			p.wantSpace = !p.minify
988			sep := len(ci.Stmts) > 1 || ci.StmtList.pos().Line() > p.line ||
989				(!ci.StmtList.empty() && ci.OpPos.Line() > ci.StmtList.end().Line())
990			p.nestedStmts(ci.StmtList, ci.OpPos)
991			p.level++
992			if !p.minify || i != len(x.Items)-1 {
993				if sep {
994					p.newlines(ci.OpPos)
995					p.wantNewline = true
996				}
997				p.spacedToken(ci.Op.String(), ci.OpPos)
998				// avoid ; directly after tokens like ;;
999				p.wroteSemi = true
1000			}
1001			p.comments(last...)
1002			p.flushComments()
1003			p.level--
1004		}
1005		p.comments(x.Last...)
1006		if p.swtCaseIndent {
1007			p.flushComments()
1008			p.decLevel()
1009		}
1010		p.semiRsrv("esac", x.Esac)
1011	case *ArithmCmd:
1012		p.WriteString("((")
1013		if x.Unsigned {
1014			p.WriteString("# ")
1015		}
1016		p.arithmExpr(x.X, false, false)
1017		p.WriteString("))")
1018	case *TestClause:
1019		p.WriteString("[[ ")
1020		p.testExpr(x.X)
1021		p.spacedString("]]", x.Right)
1022	case *DeclClause:
1023		p.spacedString(x.Variant.Value, x.Pos())
1024		for _, w := range x.Opts {
1025			p.space()
1026			p.word(w)
1027		}
1028		p.assigns(x.Assigns)
1029	case *TimeClause:
1030		p.spacedString("time", x.Pos())
1031		if x.PosixFormat {
1032			p.spacedString("-p", x.Pos())
1033		}
1034		if x.Stmt != nil {
1035			p.stmt(x.Stmt)
1036		}
1037	case *CoprocClause:
1038		p.spacedString("coproc", x.Pos())
1039		if x.Name != nil {
1040			p.space()
1041			p.WriteString(x.Name.Value)
1042		}
1043		p.space()
1044		p.stmt(x.Stmt)
1045	case *LetClause:
1046		p.spacedString("let", x.Pos())
1047		for _, n := range x.Exprs {
1048			p.space()
1049			p.arithmExpr(n, true, false)
1050		}
1051	}
1052	return startRedirs
1053}
1054
1055func (p *Printer) ifClause(ic *IfClause, elif bool) {
1056	if !elif {
1057		p.spacedString("if", ic.Pos())
1058	}
1059	p.nestedStmts(ic.Cond, Pos{})
1060	p.semiOrNewl("then", ic.ThenPos)
1061	p.nestedStmts(ic.Then, ic.bodyEndPos())
1062
1063	var left []Comment
1064	for _, c := range ic.ElseComments {
1065		if c.Pos().After(ic.ElsePos) {
1066			left = append(left, c)
1067			break
1068		}
1069		p.comments(c)
1070	}
1071	if ic.FollowedByElif() {
1072		s := ic.Else.Stmts[0]
1073		p.comments(s.Comments...)
1074		p.semiRsrv("elif", ic.ElsePos)
1075		p.ifClause(s.Cmd.(*IfClause), true)
1076		return
1077	}
1078	if !ic.Else.empty() {
1079		p.semiRsrv("else", ic.ElsePos)
1080		p.comments(left...)
1081		p.nestedStmts(ic.Else, ic.FiPos)
1082	} else if ic.ElsePos.IsValid() {
1083		p.line = ic.ElsePos.Line()
1084	}
1085	p.comments(ic.FiComments...)
1086	p.semiRsrv("fi", ic.FiPos)
1087}
1088
1089func startsWithLparen(s *Stmt) bool {
1090	switch x := s.Cmd.(type) {
1091	case *Subshell:
1092		return true
1093	case *BinaryCmd:
1094		return startsWithLparen(x.X)
1095	}
1096	return false
1097}
1098
1099func (p *Printer) hasInline(s *Stmt) bool {
1100	for _, c := range s.Comments {
1101		if c.Pos().Line() == s.End().Line() {
1102			return true
1103		}
1104	}
1105	return false
1106}
1107
1108func (p *Printer) stmtList(sl StmtList) {
1109	sep := p.wantNewline ||
1110		(len(sl.Stmts) > 0 && sl.Stmts[0].Pos().Line() > p.line)
1111	inlineIndent := 0
1112	lastIndentedLine := uint(0)
1113	for i, s := range sl.Stmts {
1114		pos := s.Pos()
1115		var midComs, endComs []Comment
1116		for _, c := range s.Comments {
1117			if c.End().After(s.End()) {
1118				endComs = append(endComs, c)
1119				break
1120			}
1121			if c.Pos().After(s.Pos()) {
1122				midComs = append(midComs, c)
1123				continue
1124			}
1125			p.comments(c)
1126		}
1127		if !p.minify || p.wantSpace {
1128			p.newlines(pos)
1129		}
1130		p.line = pos.Line()
1131		if !p.hasInline(s) {
1132			inlineIndent = 0
1133			p.commentPadding = 0
1134			p.comments(midComs...)
1135			p.stmt(s)
1136			p.wantNewline = true
1137			continue
1138		}
1139		p.comments(midComs...)
1140		p.stmt(s)
1141		if s.Pos().Line() > lastIndentedLine+1 {
1142			inlineIndent = 0
1143		}
1144		if inlineIndent == 0 {
1145			for _, s2 := range sl.Stmts[i:] {
1146				if !p.hasInline(s2) {
1147					break
1148				}
1149				if l := p.stmtCols(s2); l > inlineIndent {
1150					inlineIndent = l
1151				}
1152			}
1153		}
1154		if inlineIndent > 0 {
1155			if l := p.stmtCols(s); l > 0 {
1156				p.commentPadding = uint(inlineIndent - l)
1157			}
1158			lastIndentedLine = p.line
1159		}
1160		p.comments(endComs...)
1161		p.wantNewline = true
1162	}
1163	if len(sl.Stmts) == 1 && !sep {
1164		p.wantNewline = false
1165	}
1166	p.comments(sl.Last...)
1167}
1168
1169type byteCounter int
1170
1171func (c *byteCounter) WriteByte(b byte) error {
1172	switch {
1173	case *c < 0:
1174	case b == '\n':
1175		*c = -1
1176	default:
1177		*c++
1178	}
1179	return nil
1180}
1181func (c *byteCounter) Write(p []byte) (int, error) {
1182	return c.WriteString(string(p))
1183}
1184func (c *byteCounter) WriteString(s string) (int, error) {
1185	switch {
1186	case *c < 0:
1187	case strings.Contains(s, "\n"):
1188		*c = -1
1189	default:
1190		*c += byteCounter(len(s))
1191	}
1192	return 0, nil
1193}
1194func (c *byteCounter) Reset(io.Writer) { *c = 0 }
1195func (c *byteCounter) Flush() error    { return nil }
1196
1197// extraIndenter ensures that all lines in a '<<-' heredoc body have at least
1198// baseIndent leading tabs. Those that had more tab indentation than the first
1199// heredoc line will keep that relative indentation.
1200type extraIndenter struct {
1201	bufWriter
1202	baseIndent int
1203
1204	firstIndent int
1205	firstChange int
1206	curLine     []byte
1207}
1208
1209func (e *extraIndenter) WriteByte(b byte) error {
1210	e.curLine = append(e.curLine, b)
1211	if b != '\n' {
1212		return nil
1213	}
1214	trimmed := bytes.TrimLeft(e.curLine, "\t")
1215	lineIndent := len(e.curLine) - len(trimmed)
1216	if e.firstIndent < 0 {
1217		e.firstIndent = lineIndent
1218		e.firstChange = e.baseIndent - lineIndent
1219		lineIndent = e.baseIndent
1220	} else {
1221		if lineIndent < e.firstIndent {
1222			lineIndent = e.firstIndent
1223		} else {
1224			lineIndent += e.firstChange
1225		}
1226	}
1227	for i := 0; i < lineIndent; i++ {
1228		e.bufWriter.WriteByte('\t')
1229	}
1230	e.bufWriter.Write(trimmed)
1231	e.curLine = e.curLine[:0]
1232	return nil
1233}
1234
1235func (e *extraIndenter) WriteString(s string) (int, error) {
1236	for i := 0; i < len(s); i++ {
1237		e.WriteByte(s[i])
1238	}
1239	return len(s), nil
1240}
1241
1242// stmtCols reports the length that s will take when formatted in a
1243// single line. If it will span multiple lines, stmtCols will return -1.
1244func (p *Printer) stmtCols(s *Stmt) int {
1245	if p.lenPrinter == nil {
1246		return -1 // stmtCols call within stmtCols, bail
1247	}
1248	*p.lenPrinter = Printer{
1249		bufWriter: &p.lenCounter,
1250		line:      s.Pos().Line(),
1251	}
1252	p.lenPrinter.bufWriter.Reset(nil)
1253	p.lenPrinter.stmt(s)
1254	return int(p.lenCounter)
1255}
1256
1257func (p *Printer) nestedStmts(sl StmtList, closing Pos) {
1258	p.incLevel()
1259	switch {
1260	case len(sl.Stmts) > 1:
1261		// Force a newline if we find:
1262		//     { stmt; stmt; }
1263		p.wantNewline = true
1264	case closing.Line() > p.line && len(sl.Stmts) > 0 &&
1265		sl.end().Line() < closing.Line():
1266		// Force a newline if we find:
1267		//     { stmt
1268		//     }
1269		p.wantNewline = true
1270	case len(p.pendingComments) > 0 && len(sl.Stmts) > 0:
1271		// Force a newline if we find:
1272		//     for i in a b # stmt
1273		//     do foo; done
1274		p.wantNewline = true
1275	}
1276	p.stmtList(sl)
1277	if closing.IsValid() {
1278		p.flushComments()
1279	}
1280	p.decLevel()
1281}
1282
1283func (p *Printer) assigns(assigns []*Assign) {
1284	p.incLevel()
1285	for _, a := range assigns {
1286		if a.Pos().Line() > p.line {
1287			p.bslashNewl()
1288		} else {
1289			p.spacePad(a.Pos())
1290		}
1291		if a.Name != nil {
1292			p.WriteString(a.Name.Value)
1293			p.wroteIndex(a.Index)
1294			if a.Append {
1295				p.WriteByte('+')
1296			}
1297			if !a.Naked {
1298				p.WriteByte('=')
1299			}
1300		}
1301		if a.Value != nil {
1302			p.word(a.Value)
1303		} else if a.Array != nil {
1304			p.wantSpace = false
1305			p.WriteByte('(')
1306			p.elemJoin(a.Array.Elems, a.Array.Last)
1307			p.rightParen(a.Array.Rparen)
1308		}
1309		p.wantSpace = true
1310	}
1311	p.decLevel()
1312}
1313