1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package pipeline
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"go/ast"
12	"go/constant"
13	"go/format"
14	"go/token"
15	"go/types"
16	"path/filepath"
17	"strings"
18	"unicode"
19	"unicode/utf8"
20
21	fmtparser "golang.org/x/text/internal/format"
22	"golang.org/x/tools/go/callgraph"
23	"golang.org/x/tools/go/callgraph/cha"
24	"golang.org/x/tools/go/loader"
25	"golang.org/x/tools/go/ssa"
26	"golang.org/x/tools/go/ssa/ssautil"
27)
28
29const debug = false
30
31// TODO:
32// - merge information into existing files
33// - handle different file formats (PO, XLIFF)
34// - handle features (gender, plural)
35// - message rewriting
36
37// - `msg:"etc"` tags
38
39// Extract extracts all strings form the package defined in Config.
40func Extract(c *Config) (*State, error) {
41	x, err := newExtracter(c)
42	if err != nil {
43		return nil, wrap(err, "")
44	}
45
46	if err := x.seedEndpoints(); err != nil {
47		return nil, err
48	}
49	x.extractMessages()
50
51	return &State{
52		Config:  *c,
53		program: x.iprog,
54		Extracted: Messages{
55			Language: c.SourceLanguage,
56			Messages: x.messages,
57		},
58	}, nil
59}
60
61type extracter struct {
62	conf      loader.Config
63	iprog     *loader.Program
64	prog      *ssa.Program
65	callGraph *callgraph.Graph
66
67	// Calls and other expressions to collect.
68	globals  map[token.Pos]*constData
69	funcs    map[token.Pos]*callData
70	messages []Message
71}
72
73func newExtracter(c *Config) (x *extracter, err error) {
74	x = &extracter{
75		conf:    loader.Config{},
76		globals: map[token.Pos]*constData{},
77		funcs:   map[token.Pos]*callData{},
78	}
79
80	x.iprog, err = loadPackages(&x.conf, c.Packages)
81	if err != nil {
82		return nil, wrap(err, "")
83	}
84
85	x.prog = ssautil.CreateProgram(x.iprog, ssa.GlobalDebug|ssa.BareInits)
86	x.prog.Build()
87
88	x.callGraph = cha.CallGraph(x.prog)
89
90	return x, nil
91}
92
93func (x *extracter) globalData(pos token.Pos) *constData {
94	cd := x.globals[pos]
95	if cd == nil {
96		cd = &constData{}
97		x.globals[pos] = cd
98	}
99	return cd
100}
101
102func (x *extracter) seedEndpoints() error {
103	pkgInfo := x.iprog.Package("golang.org/x/text/message")
104	if pkgInfo == nil {
105		return errors.New("pipeline: golang.org/x/text/message is not imported")
106	}
107	pkg := x.prog.Package(pkgInfo.Pkg)
108	typ := types.NewPointer(pkg.Type("Printer").Type())
109
110	x.processGlobalVars()
111
112	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Printf"), &callData{
113		formatPos: 1,
114		argPos:    2,
115		isMethod:  true,
116	})
117	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Sprintf"), &callData{
118		formatPos: 1,
119		argPos:    2,
120		isMethod:  true,
121	})
122	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Fprintf"), &callData{
123		formatPos: 2,
124		argPos:    3,
125		isMethod:  true,
126	})
127	return nil
128}
129
130// processGlobalVars finds string constants that are assigned to global
131// variables.
132func (x *extracter) processGlobalVars() {
133	for _, p := range x.prog.AllPackages() {
134		m, ok := p.Members["init"]
135		if !ok {
136			continue
137		}
138		for _, b := range m.(*ssa.Function).Blocks {
139			for _, i := range b.Instrs {
140				s, ok := i.(*ssa.Store)
141				if !ok {
142					continue
143				}
144				a, ok := s.Addr.(*ssa.Global)
145				if !ok {
146					continue
147				}
148				t := a.Type()
149				for {
150					p, ok := t.(*types.Pointer)
151					if !ok {
152						break
153					}
154					t = p.Elem()
155				}
156				if b, ok := t.(*types.Basic); !ok || b.Kind() != types.String {
157					continue
158				}
159				x.visitInit(a, s.Val)
160			}
161		}
162	}
163}
164
165type constData struct {
166	call   *callData // to provide a signature for the constants
167	values []constVal
168	others []token.Pos // Assigned to other global data.
169}
170
171func (d *constData) visit(x *extracter, f func(c constant.Value)) {
172	for _, v := range d.values {
173		f(v.value)
174	}
175	for _, p := range d.others {
176		if od, ok := x.globals[p]; ok {
177			od.visit(x, f)
178		}
179	}
180}
181
182type constVal struct {
183	value constant.Value
184	pos   token.Pos
185}
186
187type callData struct {
188	call    ssa.CallInstruction
189	expr    *ast.CallExpr
190	formats []constant.Value
191
192	callee    *callData
193	isMethod  bool
194	formatPos int
195	argPos    int   // varargs at this position in the call
196	argTypes  []int // arguments extractable from this position
197}
198
199func (c *callData) callFormatPos() int {
200	c = c.callee
201	if c.isMethod {
202		return c.formatPos - 1
203	}
204	return c.formatPos
205}
206
207func (c *callData) callArgsStart() int {
208	c = c.callee
209	if c.isMethod {
210		return c.argPos - 1
211	}
212	return c.argPos
213}
214
215func (c *callData) Pos() token.Pos      { return c.call.Pos() }
216func (c *callData) Pkg() *types.Package { return c.call.Parent().Pkg.Pkg }
217
218func (x *extracter) handleFunc(f *ssa.Function, fd *callData) {
219	for _, e := range x.callGraph.Nodes[f].In {
220		if e.Pos() == 0 {
221			continue
222		}
223
224		call := e.Site
225		caller := x.funcs[call.Pos()]
226		if caller != nil {
227			// TODO: theoretically a format string could be passed to multiple
228			// arguments of a function. Support this eventually.
229			continue
230		}
231		x.debug(call, "CALL", f.String())
232
233		caller = &callData{
234			call:      call,
235			callee:    fd,
236			formatPos: -1,
237			argPos:    -1,
238		}
239		// Offset by one if we are invoking an interface method.
240		offset := 0
241		if call.Common().IsInvoke() {
242			offset = -1
243		}
244		x.funcs[call.Pos()] = caller
245		if fd.argPos >= 0 {
246			x.visitArgs(caller, call.Common().Args[fd.argPos+offset])
247		}
248		x.visitFormats(caller, call.Common().Args[fd.formatPos+offset])
249	}
250}
251
252type posser interface {
253	Pos() token.Pos
254	Parent() *ssa.Function
255}
256
257func (x *extracter) debug(v posser, header string, args ...interface{}) {
258	if debug {
259		pos := ""
260		if p := v.Parent(); p != nil {
261			pos = posString(&x.conf, p.Package().Pkg, v.Pos())
262		}
263		if header != "CALL" && header != "INSERT" {
264			header = "  " + header
265		}
266		fmt.Printf("%-32s%-10s%-15T ", pos+fmt.Sprintf("@%d", v.Pos()), header, v)
267		for _, a := range args {
268			fmt.Printf(" %v", a)
269		}
270		fmt.Println()
271	}
272}
273
274// visitInit evaluates and collects values assigned to global variables in an
275// init function.
276func (x *extracter) visitInit(global *ssa.Global, v ssa.Value) {
277	if v == nil {
278		return
279	}
280	x.debug(v, "GLOBAL", v)
281
282	switch v := v.(type) {
283	case *ssa.Phi:
284		for _, e := range v.Edges {
285			x.visitInit(global, e)
286		}
287
288	case *ssa.Const:
289		// Only record strings with letters.
290		if str := constant.StringVal(v.Value); isMsg(str) {
291			cd := x.globalData(global.Pos())
292			cd.values = append(cd.values, constVal{v.Value, v.Pos()})
293		}
294		// TODO: handle %m-directive.
295
296	case *ssa.Global:
297		cd := x.globalData(global.Pos())
298		cd.others = append(cd.others, v.Pos())
299
300	case *ssa.FieldAddr, *ssa.Field:
301		// TODO: mark field index v.Field of v.X.Type() for extraction. extract
302		// an example args as to give parameters for the translator.
303
304	case *ssa.Slice:
305		if v.Low == nil && v.High == nil && v.Max == nil {
306			x.visitInit(global, v.X)
307		}
308
309	case *ssa.Alloc:
310		if ref := v.Referrers(); ref == nil {
311			for _, r := range *ref {
312				values := []ssa.Value{}
313				for _, o := range r.Operands(nil) {
314					if o == nil || *o == v {
315						continue
316					}
317					values = append(values, *o)
318				}
319				// TODO: return something different if we care about multiple
320				// values as well.
321				if len(values) == 1 {
322					x.visitInit(global, values[0])
323				}
324			}
325		}
326
327	case ssa.Instruction:
328		rands := v.Operands(nil)
329		if len(rands) == 1 && rands[0] != nil {
330			x.visitInit(global, *rands[0])
331		}
332	}
333	return
334}
335
336// visitFormats finds the original source of the value. The returned index is
337// position of the argument if originated from a function argument or -1
338// otherwise.
339func (x *extracter) visitFormats(call *callData, v ssa.Value) {
340	if v == nil {
341		return
342	}
343	x.debug(v, "VALUE", v)
344
345	switch v := v.(type) {
346	case *ssa.Phi:
347		for _, e := range v.Edges {
348			x.visitFormats(call, e)
349		}
350
351	case *ssa.Const:
352		// Only record strings with letters.
353		if isMsg(constant.StringVal(v.Value)) {
354			x.debug(call.call, "FORMAT", v.Value.ExactString())
355			call.formats = append(call.formats, v.Value)
356		}
357		// TODO: handle %m-directive.
358
359	case *ssa.Global:
360		x.globalData(v.Pos()).call = call
361
362	case *ssa.FieldAddr, *ssa.Field:
363		// TODO: mark field index v.Field of v.X.Type() for extraction. extract
364		// an example args as to give parameters for the translator.
365
366	case *ssa.Slice:
367		if v.Low == nil && v.High == nil && v.Max == nil {
368			x.visitFormats(call, v.X)
369		}
370
371	case *ssa.Parameter:
372		// TODO: handle the function for the index parameter.
373		f := v.Parent()
374		for i, p := range f.Params {
375			if p == v {
376				if call.formatPos < 0 {
377					call.formatPos = i
378					// TODO: is there a better way to detect this is calling
379					// a method rather than a function?
380					call.isMethod = len(f.Params) > f.Signature.Params().Len()
381					x.handleFunc(v.Parent(), call)
382				} else if debug && i != call.formatPos {
383					// TODO: support this.
384					fmt.Printf("WARNING:%s: format string passed to arg %d and %d\n",
385						posString(&x.conf, call.Pkg(), call.Pos()),
386						call.formatPos, i)
387				}
388			}
389		}
390
391	case *ssa.Alloc:
392		if ref := v.Referrers(); ref == nil {
393			for _, r := range *ref {
394				values := []ssa.Value{}
395				for _, o := range r.Operands(nil) {
396					if o == nil || *o == v {
397						continue
398					}
399					values = append(values, *o)
400				}
401				// TODO: return something different if we care about multiple
402				// values as well.
403				if len(values) == 1 {
404					x.visitFormats(call, values[0])
405				}
406			}
407		}
408
409		// TODO:
410	// case *ssa.Index:
411	// 	// Get all values in the array if applicable
412	// case *ssa.IndexAddr:
413	// 	// Get all values in the slice or *array if applicable.
414	// case *ssa.Lookup:
415	// 	// Get all values in the map if applicable.
416
417	case *ssa.FreeVar:
418		// TODO: find the link between free variables and parameters:
419		//
420		// func freeVar(p *message.Printer, str string) {
421		// 	fn := func(p *message.Printer) {
422		// 		p.Printf(str)
423		// 	}
424		// 	fn(p)
425		// }
426
427	case *ssa.Call:
428
429	case ssa.Instruction:
430		rands := v.Operands(nil)
431		if len(rands) == 1 && rands[0] != nil {
432			x.visitFormats(call, *rands[0])
433		}
434	}
435}
436
437// Note: a function may have an argument marked as both format and passthrough.
438
439// visitArgs collects information on arguments. For wrapped functions it will
440// just determine the position of the variable args slice.
441func (x *extracter) visitArgs(fd *callData, v ssa.Value) {
442	if v == nil {
443		return
444	}
445	x.debug(v, "ARGV", v)
446	switch v := v.(type) {
447
448	case *ssa.Slice:
449		if v.Low == nil && v.High == nil && v.Max == nil {
450			x.visitArgs(fd, v.X)
451		}
452
453	case *ssa.Parameter:
454		// TODO: handle the function for the index parameter.
455		f := v.Parent()
456		for i, p := range f.Params {
457			if p == v {
458				fd.argPos = i
459			}
460		}
461
462	case *ssa.Alloc:
463		if ref := v.Referrers(); ref == nil {
464			for _, r := range *ref {
465				values := []ssa.Value{}
466				for _, o := range r.Operands(nil) {
467					if o == nil || *o == v {
468						continue
469					}
470					values = append(values, *o)
471				}
472				// TODO: return something different if we care about
473				// multiple values as well.
474				if len(values) == 1 {
475					x.visitArgs(fd, values[0])
476				}
477			}
478		}
479
480	case ssa.Instruction:
481		rands := v.Operands(nil)
482		if len(rands) == 1 && rands[0] != nil {
483			x.visitArgs(fd, *rands[0])
484		}
485	}
486}
487
488// print returns Go syntax for the specified node.
489func (x *extracter) print(n ast.Node) string {
490	var buf bytes.Buffer
491	format.Node(&buf, x.conf.Fset, n)
492	return buf.String()
493}
494
495type packageExtracter struct {
496	f    *ast.File
497	x    *extracter
498	info *loader.PackageInfo
499	cmap ast.CommentMap
500}
501
502func (px packageExtracter) getComment(n ast.Node) string {
503	cs := px.cmap.Filter(n).Comments()
504	if len(cs) > 0 {
505		return strings.TrimSpace(cs[0].Text())
506	}
507	return ""
508}
509
510func (x *extracter) extractMessages() {
511	prog := x.iprog
512	files := []packageExtracter{}
513	for _, info := range x.iprog.AllPackages {
514		for _, f := range info.Files {
515			// Associate comments with nodes.
516			px := packageExtracter{
517				f, x, info,
518				ast.NewCommentMap(prog.Fset, f, f.Comments),
519			}
520			files = append(files, px)
521		}
522	}
523	for _, px := range files {
524		ast.Inspect(px.f, func(n ast.Node) bool {
525			switch v := n.(type) {
526			case *ast.CallExpr:
527				if d := x.funcs[v.Lparen]; d != nil {
528					d.expr = v
529				}
530			}
531			return true
532		})
533	}
534	for _, px := range files {
535		ast.Inspect(px.f, func(n ast.Node) bool {
536			switch v := n.(type) {
537			case *ast.CallExpr:
538				return px.handleCall(v)
539			case *ast.ValueSpec:
540				return px.handleGlobal(v)
541			}
542			return true
543		})
544	}
545}
546
547func (px packageExtracter) handleGlobal(spec *ast.ValueSpec) bool {
548	comment := px.getComment(spec)
549
550	for _, ident := range spec.Names {
551		data, ok := px.x.globals[ident.Pos()]
552		if !ok {
553			continue
554		}
555		name := ident.Name
556		var arguments []argument
557		if data.call != nil {
558			arguments = px.getArguments(data.call)
559		} else if !strings.HasPrefix(name, "msg") && !strings.HasPrefix(name, "Msg") {
560			continue
561		}
562		data.visit(px.x, func(c constant.Value) {
563			px.addMessage(spec.Pos(), []string{name}, c, comment, arguments)
564		})
565	}
566
567	return true
568}
569
570func (px packageExtracter) handleCall(call *ast.CallExpr) bool {
571	x := px.x
572	data := x.funcs[call.Lparen]
573	if data == nil || len(data.formats) == 0 {
574		return true
575	}
576	if data.expr != call {
577		panic("invariant `data.call != call` failed")
578	}
579	x.debug(data.call, "INSERT", data.formats)
580
581	argn := data.callFormatPos()
582	if argn >= len(call.Args) {
583		return true
584	}
585	format := call.Args[argn]
586
587	arguments := px.getArguments(data)
588
589	comment := ""
590	key := []string{}
591	if ident, ok := format.(*ast.Ident); ok {
592		key = append(key, ident.Name)
593		if v, ok := ident.Obj.Decl.(*ast.ValueSpec); ok && v.Comment != nil {
594			// TODO: get comment above ValueSpec as well
595			comment = v.Comment.Text()
596		}
597	}
598	if c := px.getComment(call.Args[0]); c != "" {
599		comment = c
600	}
601
602	formats := data.formats
603	for _, c := range formats {
604		px.addMessage(call.Lparen, key, c, comment, arguments)
605	}
606	return true
607}
608
609func (px packageExtracter) getArguments(data *callData) []argument {
610	arguments := []argument{}
611	x := px.x
612	info := px.info
613	if data.callArgsStart() >= 0 {
614		args := data.expr.Args[data.callArgsStart():]
615		for i, arg := range args {
616			expr := x.print(arg)
617			val := ""
618			if v := info.Types[arg].Value; v != nil {
619				val = v.ExactString()
620				switch arg.(type) {
621				case *ast.BinaryExpr, *ast.UnaryExpr:
622					expr = val
623				}
624			}
625			arguments = append(arguments, argument{
626				ArgNum:         i + 1,
627				Type:           info.Types[arg].Type.String(),
628				UnderlyingType: info.Types[arg].Type.Underlying().String(),
629				Expr:           expr,
630				Value:          val,
631				Comment:        px.getComment(arg),
632				Position:       posString(&x.conf, info.Pkg, arg.Pos()),
633				// TODO report whether it implements
634				// interfaces plural.Interface,
635				// gender.Interface.
636			})
637		}
638	}
639	return arguments
640}
641
642func (px packageExtracter) addMessage(
643	pos token.Pos,
644	key []string,
645	c constant.Value,
646	comment string,
647	arguments []argument) {
648	x := px.x
649	fmtMsg := constant.StringVal(c)
650
651	ph := placeholders{index: map[string]string{}}
652
653	trimmed, _, _ := trimWS(fmtMsg)
654
655	p := fmtparser.Parser{}
656	simArgs := make([]interface{}, len(arguments))
657	for i, v := range arguments {
658		simArgs[i] = v
659	}
660	msg := ""
661	p.Reset(simArgs)
662	for p.SetFormat(trimmed); p.Scan(); {
663		name := ""
664		var arg *argument
665		switch p.Status {
666		case fmtparser.StatusText:
667			msg += p.Text()
668			continue
669		case fmtparser.StatusSubstitution,
670			fmtparser.StatusBadWidthSubstitution,
671			fmtparser.StatusBadPrecSubstitution:
672			arguments[p.ArgNum-1].used = true
673			arg = &arguments[p.ArgNum-1]
674			name = getID(arg)
675		case fmtparser.StatusBadArgNum, fmtparser.StatusMissingArg:
676			arg = &argument{
677				ArgNum:   p.ArgNum,
678				Position: posString(&x.conf, px.info.Pkg, pos),
679			}
680			name, arg.UnderlyingType = verbToPlaceholder(p.Text(), p.ArgNum)
681		}
682		sub := p.Text()
683		if !p.HasIndex {
684			r, sz := utf8.DecodeLastRuneInString(sub)
685			sub = fmt.Sprintf("%s[%d]%c", sub[:len(sub)-sz], p.ArgNum, r)
686		}
687		msg += fmt.Sprintf("{%s}", ph.addArg(arg, name, sub))
688	}
689	key = append(key, msg)
690
691	// Add additional Placeholders that can be used in translations
692	// that are not present in the string.
693	for _, arg := range arguments {
694		if arg.used {
695			continue
696		}
697		ph.addArg(&arg, getID(&arg), fmt.Sprintf("%%[%d]v", arg.ArgNum))
698	}
699
700	x.messages = append(x.messages, Message{
701		ID:      key,
702		Key:     fmtMsg,
703		Message: Text{Msg: msg},
704		// TODO(fix): this doesn't get the before comment.
705		Comment:      comment,
706		Placeholders: ph.slice,
707		Position:     posString(&x.conf, px.info.Pkg, pos),
708	})
709}
710
711func posString(conf *loader.Config, pkg *types.Package, pos token.Pos) string {
712	p := conf.Fset.Position(pos)
713	file := fmt.Sprintf("%s:%d:%d", filepath.Base(p.Filename), p.Line, p.Column)
714	return filepath.Join(pkg.Path(), file)
715}
716
717func getID(arg *argument) string {
718	s := getLastComponent(arg.Expr)
719	s = strip(s)
720	s = strings.Replace(s, " ", "", -1)
721	// For small variable names, use user-defined types for more info.
722	if len(s) <= 2 && arg.UnderlyingType != arg.Type {
723		s = getLastComponent(arg.Type)
724	}
725	return strings.Title(s)
726}
727
728// strip is a dirty hack to convert function calls to placeholder IDs.
729func strip(s string) string {
730	s = strings.Map(func(r rune) rune {
731		if unicode.IsSpace(r) || r == '-' {
732			return '_'
733		}
734		if !unicode.In(r, unicode.Letter, unicode.Mark, unicode.Number) {
735			return -1
736		}
737		return r
738	}, s)
739	// Strip "Get" from getter functions.
740	if strings.HasPrefix(s, "Get") || strings.HasPrefix(s, "get") {
741		if len(s) > len("get") {
742			r, _ := utf8.DecodeRuneInString(s)
743			if !unicode.In(r, unicode.Ll, unicode.M) { // not lower or mark
744				s = s[len("get"):]
745			}
746		}
747	}
748	return s
749}
750
751// verbToPlaceholder gives a name for a placeholder based on the substitution
752// verb. This is only to be used if there is otherwise no other type information
753// available.
754func verbToPlaceholder(sub string, pos int) (name, underlying string) {
755	r, _ := utf8.DecodeLastRuneInString(sub)
756	name = fmt.Sprintf("Arg_%d", pos)
757	switch r {
758	case 's', 'q':
759		underlying = "string"
760	case 'd':
761		name = "Integer"
762		underlying = "int"
763	case 'e', 'f', 'g':
764		name = "Number"
765		underlying = "float64"
766	case 'm':
767		name = "Message"
768		underlying = "string"
769	default:
770		underlying = "interface{}"
771	}
772	return name, underlying
773}
774
775type placeholders struct {
776	index map[string]string
777	slice []Placeholder
778}
779
780func (p *placeholders) addArg(arg *argument, name, sub string) (id string) {
781	id = name
782	alt, ok := p.index[id]
783	for i := 1; ok && alt != sub; i++ {
784		id = fmt.Sprintf("%s_%d", name, i)
785		alt, ok = p.index[id]
786	}
787	p.index[id] = sub
788	p.slice = append(p.slice, Placeholder{
789		ID:             id,
790		String:         sub,
791		Type:           arg.Type,
792		UnderlyingType: arg.UnderlyingType,
793		ArgNum:         arg.ArgNum,
794		Expr:           arg.Expr,
795		Comment:        arg.Comment,
796	})
797	return id
798}
799
800func getLastComponent(s string) string {
801	return s[1+strings.LastIndexByte(s, '.'):]
802}
803
804// isMsg returns whether s should be translated.
805func isMsg(s string) bool {
806	// TODO: parse as format string and omit strings that contain letters
807	// coming from format verbs.
808	for _, r := range s {
809		if unicode.In(r, unicode.L) {
810			return true
811		}
812	}
813	return false
814}
815