1// Copyright 2016 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package pipeline
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"go/ast"
12	"go/constant"
13	"go/format"
14	"go/token"
15	"go/types"
16	"path/filepath"
17	"sort"
18	"strings"
19	"unicode"
20	"unicode/utf8"
21
22	fmtparser "golang.org/x/text/internal/format"
23	"golang.org/x/tools/go/callgraph"
24	"golang.org/x/tools/go/callgraph/cha"
25	"golang.org/x/tools/go/loader"
26	"golang.org/x/tools/go/ssa"
27	"golang.org/x/tools/go/ssa/ssautil"
28)
29
30const debug = false
31
32// TODO:
33// - merge information into existing files
34// - handle different file formats (PO, XLIFF)
35// - handle features (gender, plural)
36// - message rewriting
37
38// - `msg:"etc"` tags
39
40// Extract extracts all strings form the package defined in Config.
41func Extract(c *Config) (*State, error) {
42	x, err := newExtracter(c)
43	if err != nil {
44		return nil, wrap(err, "")
45	}
46
47	if err := x.seedEndpoints(); err != nil {
48		return nil, err
49	}
50	x.extractMessages()
51
52	return &State{
53		Config:  *c,
54		program: x.iprog,
55		Extracted: Messages{
56			Language: c.SourceLanguage,
57			Messages: x.messages,
58		},
59	}, nil
60}
61
62type extracter struct {
63	conf      loader.Config
64	iprog     *loader.Program
65	prog      *ssa.Program
66	callGraph *callgraph.Graph
67
68	// Calls and other expressions to collect.
69	globals  map[token.Pos]*constData
70	funcs    map[token.Pos]*callData
71	messages []Message
72}
73
74func newExtracter(c *Config) (x *extracter, err error) {
75	x = &extracter{
76		conf:    loader.Config{},
77		globals: map[token.Pos]*constData{},
78		funcs:   map[token.Pos]*callData{},
79	}
80
81	x.iprog, err = loadPackages(&x.conf, c.Packages)
82	if err != nil {
83		return nil, wrap(err, "")
84	}
85
86	x.prog = ssautil.CreateProgram(x.iprog, ssa.GlobalDebug|ssa.BareInits)
87	x.prog.Build()
88
89	x.callGraph = cha.CallGraph(x.prog)
90
91	return x, nil
92}
93
94func (x *extracter) globalData(pos token.Pos) *constData {
95	cd := x.globals[pos]
96	if cd == nil {
97		cd = &constData{}
98		x.globals[pos] = cd
99	}
100	return cd
101}
102
103func (x *extracter) seedEndpoints() error {
104	pkgInfo := x.iprog.Package("golang.org/x/text/message")
105	if pkgInfo == nil {
106		return errors.New("pipeline: golang.org/x/text/message is not imported")
107	}
108	pkg := x.prog.Package(pkgInfo.Pkg)
109	typ := types.NewPointer(pkg.Type("Printer").Type())
110
111	x.processGlobalVars()
112
113	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Printf"), &callData{
114		formatPos: 1,
115		argPos:    2,
116		isMethod:  true,
117	})
118	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Sprintf"), &callData{
119		formatPos: 1,
120		argPos:    2,
121		isMethod:  true,
122	})
123	x.handleFunc(x.prog.LookupMethod(typ, pkg.Pkg, "Fprintf"), &callData{
124		formatPos: 2,
125		argPos:    3,
126		isMethod:  true,
127	})
128	return nil
129}
130
131// processGlobalVars finds string constants that are assigned to global
132// variables.
133func (x *extracter) processGlobalVars() {
134	for _, p := range x.prog.AllPackages() {
135		m, ok := p.Members["init"]
136		if !ok {
137			continue
138		}
139		for _, b := range m.(*ssa.Function).Blocks {
140			for _, i := range b.Instrs {
141				s, ok := i.(*ssa.Store)
142				if !ok {
143					continue
144				}
145				a, ok := s.Addr.(*ssa.Global)
146				if !ok {
147					continue
148				}
149				t := a.Type()
150				for {
151					p, ok := t.(*types.Pointer)
152					if !ok {
153						break
154					}
155					t = p.Elem()
156				}
157				if b, ok := t.(*types.Basic); !ok || b.Kind() != types.String {
158					continue
159				}
160				x.visitInit(a, s.Val)
161			}
162		}
163	}
164}
165
166type constData struct {
167	call   *callData // to provide a signature for the constants
168	values []constVal
169	others []token.Pos // Assigned to other global data.
170}
171
172func (d *constData) visit(x *extracter, f func(c constant.Value)) {
173	for _, v := range d.values {
174		f(v.value)
175	}
176	for _, p := range d.others {
177		if od, ok := x.globals[p]; ok {
178			od.visit(x, f)
179		}
180	}
181}
182
183type constVal struct {
184	value constant.Value
185	pos   token.Pos
186}
187
188type callData struct {
189	call    ssa.CallInstruction
190	expr    *ast.CallExpr
191	formats []constant.Value
192
193	callee    *callData
194	isMethod  bool
195	formatPos int
196	argPos    int   // varargs at this position in the call
197	argTypes  []int // arguments extractable from this position
198}
199
200func (c *callData) callFormatPos() int {
201	c = c.callee
202	if c.isMethod {
203		return c.formatPos - 1
204	}
205	return c.formatPos
206}
207
208func (c *callData) callArgsStart() int {
209	c = c.callee
210	if c.isMethod {
211		return c.argPos - 1
212	}
213	return c.argPos
214}
215
216func (c *callData) Pos() token.Pos      { return c.call.Pos() }
217func (c *callData) Pkg() *types.Package { return c.call.Parent().Pkg.Pkg }
218
219func (x *extracter) handleFunc(f *ssa.Function, fd *callData) {
220	for _, e := range x.callGraph.Nodes[f].In {
221		if e.Pos() == 0 {
222			continue
223		}
224
225		call := e.Site
226		caller := x.funcs[call.Pos()]
227		if caller != nil {
228			// TODO: theoretically a format string could be passed to multiple
229			// arguments of a function. Support this eventually.
230			continue
231		}
232		x.debug(call, "CALL", f.String())
233
234		caller = &callData{
235			call:      call,
236			callee:    fd,
237			formatPos: -1,
238			argPos:    -1,
239		}
240		// Offset by one if we are invoking an interface method.
241		offset := 0
242		if call.Common().IsInvoke() {
243			offset = -1
244		}
245		x.funcs[call.Pos()] = caller
246		if fd.argPos >= 0 {
247			x.visitArgs(caller, call.Common().Args[fd.argPos+offset])
248		}
249		x.visitFormats(caller, call.Common().Args[fd.formatPos+offset])
250	}
251}
252
253type posser interface {
254	Pos() token.Pos
255	Parent() *ssa.Function
256}
257
258func (x *extracter) debug(v posser, header string, args ...interface{}) {
259	if debug {
260		pos := ""
261		if p := v.Parent(); p != nil {
262			pos = posString(&x.conf, p.Package().Pkg, v.Pos())
263		}
264		if header != "CALL" && header != "INSERT" {
265			header = "  " + header
266		}
267		fmt.Printf("%-32s%-10s%-15T ", pos+fmt.Sprintf("@%d", v.Pos()), header, v)
268		for _, a := range args {
269			fmt.Printf(" %v", a)
270		}
271		fmt.Println()
272	}
273}
274
275// visitInit evaluates and collects values assigned to global variables in an
276// init function.
277func (x *extracter) visitInit(global *ssa.Global, v ssa.Value) {
278	if v == nil {
279		return
280	}
281	x.debug(v, "GLOBAL", v)
282
283	switch v := v.(type) {
284	case *ssa.Phi:
285		for _, e := range v.Edges {
286			x.visitInit(global, e)
287		}
288
289	case *ssa.Const:
290		// Only record strings with letters.
291		if str := constant.StringVal(v.Value); isMsg(str) {
292			cd := x.globalData(global.Pos())
293			cd.values = append(cd.values, constVal{v.Value, v.Pos()})
294		}
295		// TODO: handle %m-directive.
296
297	case *ssa.Global:
298		cd := x.globalData(global.Pos())
299		cd.others = append(cd.others, v.Pos())
300
301	case *ssa.FieldAddr, *ssa.Field:
302		// TODO: mark field index v.Field of v.X.Type() for extraction. extract
303		// an example args as to give parameters for the translator.
304
305	case *ssa.Slice:
306		if v.Low == nil && v.High == nil && v.Max == nil {
307			x.visitInit(global, v.X)
308		}
309
310	case *ssa.Alloc:
311		if ref := v.Referrers(); ref == nil {
312			for _, r := range *ref {
313				values := []ssa.Value{}
314				for _, o := range r.Operands(nil) {
315					if o == nil || *o == v {
316						continue
317					}
318					values = append(values, *o)
319				}
320				// TODO: return something different if we care about multiple
321				// values as well.
322				if len(values) == 1 {
323					x.visitInit(global, values[0])
324				}
325			}
326		}
327
328	case ssa.Instruction:
329		rands := v.Operands(nil)
330		if len(rands) == 1 && rands[0] != nil {
331			x.visitInit(global, *rands[0])
332		}
333	}
334	return
335}
336
337// visitFormats finds the original source of the value. The returned index is
338// position of the argument if originated from a function argument or -1
339// otherwise.
340func (x *extracter) visitFormats(call *callData, v ssa.Value) {
341	if v == nil {
342		return
343	}
344	x.debug(v, "VALUE", v)
345
346	switch v := v.(type) {
347	case *ssa.Phi:
348		for _, e := range v.Edges {
349			x.visitFormats(call, e)
350		}
351
352	case *ssa.Const:
353		// Only record strings with letters.
354		if isMsg(constant.StringVal(v.Value)) {
355			x.debug(call.call, "FORMAT", v.Value.ExactString())
356			call.formats = append(call.formats, v.Value)
357		}
358		// TODO: handle %m-directive.
359
360	case *ssa.Global:
361		x.globalData(v.Pos()).call = call
362
363	case *ssa.FieldAddr, *ssa.Field:
364		// TODO: mark field index v.Field of v.X.Type() for extraction. extract
365		// an example args as to give parameters for the translator.
366
367	case *ssa.Slice:
368		if v.Low == nil && v.High == nil && v.Max == nil {
369			x.visitFormats(call, v.X)
370		}
371
372	case *ssa.Parameter:
373		// TODO: handle the function for the index parameter.
374		f := v.Parent()
375		for i, p := range f.Params {
376			if p == v {
377				if call.formatPos < 0 {
378					call.formatPos = i
379					// TODO: is there a better way to detect this is calling
380					// a method rather than a function?
381					call.isMethod = len(f.Params) > f.Signature.Params().Len()
382					x.handleFunc(v.Parent(), call)
383				} else if debug && i != call.formatPos {
384					// TODO: support this.
385					fmt.Printf("WARNING:%s: format string passed to arg %d and %d\n",
386						posString(&x.conf, call.Pkg(), call.Pos()),
387						call.formatPos, i)
388				}
389			}
390		}
391
392	case *ssa.Alloc:
393		if ref := v.Referrers(); ref == nil {
394			for _, r := range *ref {
395				values := []ssa.Value{}
396				for _, o := range r.Operands(nil) {
397					if o == nil || *o == v {
398						continue
399					}
400					values = append(values, *o)
401				}
402				// TODO: return something different if we care about multiple
403				// values as well.
404				if len(values) == 1 {
405					x.visitFormats(call, values[0])
406				}
407			}
408		}
409
410		// TODO:
411	// case *ssa.Index:
412	// 	// Get all values in the array if applicable
413	// case *ssa.IndexAddr:
414	// 	// Get all values in the slice or *array if applicable.
415	// case *ssa.Lookup:
416	// 	// Get all values in the map if applicable.
417
418	case *ssa.FreeVar:
419		// TODO: find the link between free variables and parameters:
420		//
421		// func freeVar(p *message.Printer, str string) {
422		// 	fn := func(p *message.Printer) {
423		// 		p.Printf(str)
424		// 	}
425		// 	fn(p)
426		// }
427
428	case *ssa.Call:
429
430	case ssa.Instruction:
431		rands := v.Operands(nil)
432		if len(rands) == 1 && rands[0] != nil {
433			x.visitFormats(call, *rands[0])
434		}
435	}
436}
437
438// Note: a function may have an argument marked as both format and passthrough.
439
440// visitArgs collects information on arguments. For wrapped functions it will
441// just determine the position of the variable args slice.
442func (x *extracter) visitArgs(fd *callData, v ssa.Value) {
443	if v == nil {
444		return
445	}
446	x.debug(v, "ARGV", v)
447	switch v := v.(type) {
448
449	case *ssa.Slice:
450		if v.Low == nil && v.High == nil && v.Max == nil {
451			x.visitArgs(fd, v.X)
452		}
453
454	case *ssa.Parameter:
455		// TODO: handle the function for the index parameter.
456		f := v.Parent()
457		for i, p := range f.Params {
458			if p == v {
459				fd.argPos = i
460			}
461		}
462
463	case *ssa.Alloc:
464		if ref := v.Referrers(); ref == nil {
465			for _, r := range *ref {
466				values := []ssa.Value{}
467				for _, o := range r.Operands(nil) {
468					if o == nil || *o == v {
469						continue
470					}
471					values = append(values, *o)
472				}
473				// TODO: return something different if we care about
474				// multiple values as well.
475				if len(values) == 1 {
476					x.visitArgs(fd, values[0])
477				}
478			}
479		}
480
481	case ssa.Instruction:
482		rands := v.Operands(nil)
483		if len(rands) == 1 && rands[0] != nil {
484			x.visitArgs(fd, *rands[0])
485		}
486	}
487}
488
489// print returns Go syntax for the specified node.
490func (x *extracter) print(n ast.Node) string {
491	var buf bytes.Buffer
492	format.Node(&buf, x.conf.Fset, n)
493	return buf.String()
494}
495
496type packageExtracter struct {
497	f    *ast.File
498	x    *extracter
499	info *loader.PackageInfo
500	cmap ast.CommentMap
501}
502
503func (px packageExtracter) getComment(n ast.Node) string {
504	cs := px.cmap.Filter(n).Comments()
505	if len(cs) > 0 {
506		return strings.TrimSpace(cs[0].Text())
507	}
508	return ""
509}
510
511func (x *extracter) extractMessages() {
512	prog := x.iprog
513	keys := make([]*types.Package, 0, len(x.iprog.AllPackages))
514	for k := range x.iprog.AllPackages {
515		keys = append(keys, k)
516	}
517	sort.Slice(keys, func(i, j int) bool { return keys[i].Path() < keys[j].Path() })
518	files := []packageExtracter{}
519	for _, k := range keys {
520		info := x.iprog.AllPackages[k]
521		for _, f := range info.Files {
522			// Associate comments with nodes.
523			px := packageExtracter{
524				f, x, info,
525				ast.NewCommentMap(prog.Fset, f, f.Comments),
526			}
527			files = append(files, px)
528		}
529	}
530	for _, px := range files {
531		ast.Inspect(px.f, func(n ast.Node) bool {
532			switch v := n.(type) {
533			case *ast.CallExpr:
534				if d := x.funcs[v.Lparen]; d != nil {
535					d.expr = v
536				}
537			}
538			return true
539		})
540	}
541	for _, px := range files {
542		ast.Inspect(px.f, func(n ast.Node) bool {
543			switch v := n.(type) {
544			case *ast.CallExpr:
545				return px.handleCall(v)
546			case *ast.ValueSpec:
547				return px.handleGlobal(v)
548			}
549			return true
550		})
551	}
552}
553
554func (px packageExtracter) handleGlobal(spec *ast.ValueSpec) bool {
555	comment := px.getComment(spec)
556
557	for _, ident := range spec.Names {
558		data, ok := px.x.globals[ident.Pos()]
559		if !ok {
560			continue
561		}
562		name := ident.Name
563		var arguments []argument
564		if data.call != nil {
565			arguments = px.getArguments(data.call)
566		} else if !strings.HasPrefix(name, "msg") && !strings.HasPrefix(name, "Msg") {
567			continue
568		}
569		data.visit(px.x, func(c constant.Value) {
570			px.addMessage(spec.Pos(), []string{name}, c, comment, arguments)
571		})
572	}
573
574	return true
575}
576
577func (px packageExtracter) handleCall(call *ast.CallExpr) bool {
578	x := px.x
579	data := x.funcs[call.Lparen]
580	if data == nil || len(data.formats) == 0 {
581		return true
582	}
583	if data.expr != call {
584		panic("invariant `data.call != call` failed")
585	}
586	x.debug(data.call, "INSERT", data.formats)
587
588	argn := data.callFormatPos()
589	if argn >= len(call.Args) {
590		return true
591	}
592	format := call.Args[argn]
593
594	arguments := px.getArguments(data)
595
596	comment := ""
597	key := []string{}
598	if ident, ok := format.(*ast.Ident); ok {
599		key = append(key, ident.Name)
600		if v, ok := ident.Obj.Decl.(*ast.ValueSpec); ok && v.Comment != nil {
601			// TODO: get comment above ValueSpec as well
602			comment = v.Comment.Text()
603		}
604	}
605	if c := px.getComment(call.Args[0]); c != "" {
606		comment = c
607	}
608
609	formats := data.formats
610	for _, c := range formats {
611		px.addMessage(call.Lparen, key, c, comment, arguments)
612	}
613	return true
614}
615
616func (px packageExtracter) getArguments(data *callData) []argument {
617	arguments := []argument{}
618	x := px.x
619	info := px.info
620	if data.callArgsStart() >= 0 {
621		args := data.expr.Args[data.callArgsStart():]
622		for i, arg := range args {
623			expr := x.print(arg)
624			val := ""
625			if v := info.Types[arg].Value; v != nil {
626				val = v.ExactString()
627				switch arg.(type) {
628				case *ast.BinaryExpr, *ast.UnaryExpr:
629					expr = val
630				}
631			}
632			arguments = append(arguments, argument{
633				ArgNum:         i + 1,
634				Type:           info.Types[arg].Type.String(),
635				UnderlyingType: info.Types[arg].Type.Underlying().String(),
636				Expr:           expr,
637				Value:          val,
638				Comment:        px.getComment(arg),
639				Position:       posString(&x.conf, info.Pkg, arg.Pos()),
640				// TODO report whether it implements
641				// interfaces plural.Interface,
642				// gender.Interface.
643			})
644		}
645	}
646	return arguments
647}
648
649func (px packageExtracter) addMessage(
650	pos token.Pos,
651	key []string,
652	c constant.Value,
653	comment string,
654	arguments []argument) {
655	x := px.x
656	fmtMsg := constant.StringVal(c)
657
658	ph := placeholders{index: map[string]string{}}
659
660	trimmed, _, _ := trimWS(fmtMsg)
661
662	p := fmtparser.Parser{}
663	simArgs := make([]interface{}, len(arguments))
664	for i, v := range arguments {
665		simArgs[i] = v
666	}
667	msg := ""
668	p.Reset(simArgs)
669	for p.SetFormat(trimmed); p.Scan(); {
670		name := ""
671		var arg *argument
672		switch p.Status {
673		case fmtparser.StatusText:
674			msg += p.Text()
675			continue
676		case fmtparser.StatusSubstitution,
677			fmtparser.StatusBadWidthSubstitution,
678			fmtparser.StatusBadPrecSubstitution:
679			arguments[p.ArgNum-1].used = true
680			arg = &arguments[p.ArgNum-1]
681			name = getID(arg)
682		case fmtparser.StatusBadArgNum, fmtparser.StatusMissingArg:
683			arg = &argument{
684				ArgNum:   p.ArgNum,
685				Position: posString(&x.conf, px.info.Pkg, pos),
686			}
687			name, arg.UnderlyingType = verbToPlaceholder(p.Text(), p.ArgNum)
688		}
689		sub := p.Text()
690		if !p.HasIndex {
691			r, sz := utf8.DecodeLastRuneInString(sub)
692			sub = fmt.Sprintf("%s[%d]%c", sub[:len(sub)-sz], p.ArgNum, r)
693		}
694		msg += fmt.Sprintf("{%s}", ph.addArg(arg, name, sub))
695	}
696	key = append(key, msg)
697
698	// Add additional Placeholders that can be used in translations
699	// that are not present in the string.
700	for _, arg := range arguments {
701		if arg.used {
702			continue
703		}
704		ph.addArg(&arg, getID(&arg), fmt.Sprintf("%%[%d]v", arg.ArgNum))
705	}
706
707	x.messages = append(x.messages, Message{
708		ID:      key,
709		Key:     fmtMsg,
710		Message: Text{Msg: msg},
711		// TODO(fix): this doesn't get the before comment.
712		Comment:      comment,
713		Placeholders: ph.slice,
714		Position:     posString(&x.conf, px.info.Pkg, pos),
715	})
716}
717
718func posString(conf *loader.Config, pkg *types.Package, pos token.Pos) string {
719	p := conf.Fset.Position(pos)
720	file := fmt.Sprintf("%s:%d:%d", filepath.Base(p.Filename), p.Line, p.Column)
721	return filepath.Join(pkg.Path(), file)
722}
723
724func getID(arg *argument) string {
725	s := getLastComponent(arg.Expr)
726	s = strip(s)
727	s = strings.Replace(s, " ", "", -1)
728	// For small variable names, use user-defined types for more info.
729	if len(s) <= 2 && arg.UnderlyingType != arg.Type {
730		s = getLastComponent(arg.Type)
731	}
732	return strings.Title(s)
733}
734
735// strip is a dirty hack to convert function calls to placeholder IDs.
736func strip(s string) string {
737	s = strings.Map(func(r rune) rune {
738		if unicode.IsSpace(r) || r == '-' {
739			return '_'
740		}
741		if !unicode.In(r, unicode.Letter, unicode.Mark, unicode.Number) {
742			return -1
743		}
744		return r
745	}, s)
746	// Strip "Get" from getter functions.
747	if strings.HasPrefix(s, "Get") || strings.HasPrefix(s, "get") {
748		if len(s) > len("get") {
749			r, _ := utf8.DecodeRuneInString(s)
750			if !unicode.In(r, unicode.Ll, unicode.M) { // not lower or mark
751				s = s[len("get"):]
752			}
753		}
754	}
755	return s
756}
757
758// verbToPlaceholder gives a name for a placeholder based on the substitution
759// verb. This is only to be used if there is otherwise no other type information
760// available.
761func verbToPlaceholder(sub string, pos int) (name, underlying string) {
762	r, _ := utf8.DecodeLastRuneInString(sub)
763	name = fmt.Sprintf("Arg_%d", pos)
764	switch r {
765	case 's', 'q':
766		underlying = "string"
767	case 'd':
768		name = "Integer"
769		underlying = "int"
770	case 'e', 'f', 'g':
771		name = "Number"
772		underlying = "float64"
773	case 'm':
774		name = "Message"
775		underlying = "string"
776	default:
777		underlying = "interface{}"
778	}
779	return name, underlying
780}
781
782type placeholders struct {
783	index map[string]string
784	slice []Placeholder
785}
786
787func (p *placeholders) addArg(arg *argument, name, sub string) (id string) {
788	id = name
789	alt, ok := p.index[id]
790	for i := 1; ok && alt != sub; i++ {
791		id = fmt.Sprintf("%s_%d", name, i)
792		alt, ok = p.index[id]
793	}
794	p.index[id] = sub
795	p.slice = append(p.slice, Placeholder{
796		ID:             id,
797		String:         sub,
798		Type:           arg.Type,
799		UnderlyingType: arg.UnderlyingType,
800		ArgNum:         arg.ArgNum,
801		Expr:           arg.Expr,
802		Comment:        arg.Comment,
803	})
804	return id
805}
806
807func getLastComponent(s string) string {
808	return s[1+strings.LastIndexByte(s, '.'):]
809}
810
811// isMsg returns whether s should be translated.
812func isMsg(s string) bool {
813	// TODO: parse as format string and omit strings that contain letters
814	// coming from format verbs.
815	for _, r := range s {
816		if unicode.In(r, unicode.L) {
817			return true
818		}
819	}
820	return false
821}
822