1package main
2
3import (
4	"encoding/xml"
5	"flag"
6	"fmt"
7	"go/ast"
8	"go/parser"
9	"go/token"
10	"io"
11	"io/ioutil"
12	"os"
13	"path/filepath"
14	"regexp"
15	"strings"
16	"time"
17
18	"golang.org/x/tools/go/packages"
19)
20
21const coberturaDTDDecl = `<!DOCTYPE coverage SYSTEM "http://cobertura.sourceforge.net/xml/coverage-04.dtd">`
22
23var byFiles bool
24
25func fatal(format string, a ...interface{}) {
26	_, _ = fmt.Fprintf(os.Stderr, format, a...)
27	os.Exit(1)
28}
29
30func main() {
31	var ignore Ignore
32
33	flag.BoolVar(&byFiles, "by-files", false, "code coverage by file, not class")
34	flag.BoolVar(&ignore.GeneratedFiles, "ignore-gen-files", false, "ignore generated files")
35	ignoreDirsRe := flag.String("ignore-dirs", "", "ignore dirs matching this regexp")
36	ignoreFilesRe := flag.String("ignore-files", "", "ignore files matching this regexp")
37
38	flag.Parse()
39
40	var err error
41	if *ignoreDirsRe != "" {
42		ignore.Dirs, err = regexp.Compile(*ignoreDirsRe)
43		if err != nil {
44			fatal("Bad -ignore-dirs regexp: %s\n", err)
45		}
46	}
47
48	if *ignoreFilesRe != "" {
49		ignore.Files, err = regexp.Compile(*ignoreFilesRe)
50		if err != nil {
51			fatal("Bad -ignore-files regexp: %s\n", err)
52		}
53	}
54
55	if err := convert(os.Stdin, os.Stdout, &ignore); err != nil {
56		fatal("code coverage conversion failed: %s", err)
57	}
58}
59
60func convert(in io.Reader, out io.Writer, ignore *Ignore) error {
61	profiles, err := ParseProfiles(in, ignore)
62	if err != nil {
63		return err
64	}
65
66	pkgs, err := getPackages(profiles)
67	if err != nil {
68		return err
69	}
70
71	sources := make([]*Source, 0)
72	pkgMap := make(map[string]*packages.Package)
73	for _, pkg := range pkgs {
74		sources = appendIfUnique(sources, pkg.Module.Dir)
75		pkgMap[pkg.ID] = pkg
76	}
77
78	coverage := Coverage{Sources: sources, Packages: nil, Timestamp: time.Now().UnixNano() / int64(time.Millisecond)}
79	if err := coverage.parseProfiles(profiles, pkgMap, ignore); err != nil {
80		return err
81	}
82
83	_, _ = fmt.Fprint(out, xml.Header)
84	_, _ = fmt.Fprintln(out, coberturaDTDDecl)
85
86	encoder := xml.NewEncoder(out)
87	encoder.Indent("", "  ")
88	if err := encoder.Encode(coverage); err != nil {
89		return err
90	}
91
92	_, _ = fmt.Fprintln(out)
93	return nil
94}
95
96func getPackages(profiles []*Profile) ([]*packages.Package, error) {
97	var pkgNames []string
98	for _, profile := range profiles {
99		pkgNames = append(pkgNames, getPackageName(profile.FileName))
100	}
101	return packages.Load(&packages.Config{Mode: packages.NeedFiles | packages.NeedModule}, pkgNames...)
102}
103
104func appendIfUnique(sources []*Source, dir string) []*Source {
105	for _, source := range sources {
106		if source.Path == dir {
107			return sources
108		}
109	}
110	return append(sources, &Source{dir})
111}
112
113func getPackageName(filename string) string {
114	pkgName, _ := filepath.Split(filename)
115	// TODO(boumenot): Windows vs. Linux
116	return strings.TrimRight(strings.TrimRight(pkgName, "\\"), "/")
117}
118
119func findAbsFilePath(pkg *packages.Package, profileName string) string {
120	filename := filepath.Base(profileName)
121	for _, fullpath := range pkg.GoFiles {
122		if filepath.Base(fullpath) == filename {
123			return fullpath
124		}
125	}
126	return ""
127}
128
129func (cov *Coverage) parseProfiles(profiles []*Profile, pkgMap map[string]*packages.Package, ignore *Ignore) error {
130	cov.Packages = []*Package{}
131	for _, profile := range profiles {
132		pkgName := getPackageName(profile.FileName)
133		pkgPkg := pkgMap[pkgName]
134		if err := cov.parseProfile(profile, pkgPkg, ignore); err != nil {
135			return err
136		}
137	}
138	cov.LinesValid = cov.NumLines()
139	cov.LinesCovered = cov.NumLinesWithHits()
140	cov.LineRate = cov.HitRate()
141	return nil
142}
143
144func (cov *Coverage) parseProfile(profile *Profile, pkgPkg *packages.Package, ignore *Ignore) error {
145	if pkgPkg == nil || pkgPkg.Module == nil {
146		return fmt.Errorf("package required when using go modules")
147	}
148	fileName := profile.FileName[len(pkgPkg.Module.Path)+1:]
149	absFilePath := findAbsFilePath(pkgPkg, profile.FileName)
150	fset := token.NewFileSet()
151	parsed, err := parser.ParseFile(fset, absFilePath, nil, 0)
152	if err != nil {
153		return err
154	}
155	data, err := ioutil.ReadFile(absFilePath)
156	if err != nil {
157		return err
158	}
159
160	if ignore.Match(fileName, data) {
161		return nil
162	}
163
164	pkgPath, _ := filepath.Split(fileName)
165	pkgPath = strings.TrimRight(strings.TrimRight(pkgPath, "/"), "\\")
166	pkgPath = filepath.Join(pkgPkg.Module.Path, pkgPath)
167	// TODO(boumenot): package paths are not file paths, there is a consistent separator
168	pkgPath = strings.Replace(pkgPath, "\\", "/", -1)
169
170	var pkg *Package
171	for _, p := range cov.Packages {
172		if p.Name == pkgPath {
173			pkg = p
174		}
175	}
176	if pkg == nil {
177		pkg = &Package{Name: pkgPkg.ID, Classes: []*Class{}}
178		cov.Packages = append(cov.Packages, pkg)
179	}
180	visitor := &fileVisitor{
181		fset:     fset,
182		fileName: fileName,
183		fileData: data,
184		classes:  make(map[string]*Class),
185		pkg:      pkg,
186		profile:  profile,
187	}
188	ast.Walk(visitor, parsed)
189	pkg.LineRate = pkg.HitRate()
190	return nil
191}
192
193type fileVisitor struct {
194	fset     *token.FileSet
195	fileName string
196	fileData []byte
197	pkg      *Package
198	classes  map[string]*Class
199	profile  *Profile
200}
201
202func (v *fileVisitor) Visit(node ast.Node) ast.Visitor {
203	switch n := node.(type) {
204	case *ast.FuncDecl:
205		class := v.class(n)
206		method := v.method(n)
207		method.LineRate = method.Lines.HitRate()
208		class.Methods = append(class.Methods, method)
209		for _, line := range method.Lines {
210			class.Lines = append(class.Lines, line)
211		}
212		class.LineRate = class.Lines.HitRate()
213	}
214	return v
215}
216
217func (v *fileVisitor) method(n *ast.FuncDecl) *Method {
218	method := &Method{Name: n.Name.Name}
219	method.Lines = []*Line{}
220
221	start := v.fset.Position(n.Pos())
222	end := v.fset.Position(n.End())
223	startLine := start.Line
224	startCol := start.Column
225	endLine := end.Line
226	endCol := end.Column
227	// The blocks are sorted, so we can stop counting as soon as we reach the end of the relevant block.
228	for _, b := range v.profile.Blocks {
229		if b.StartLine > endLine || (b.StartLine == endLine && b.StartCol >= endCol) {
230			// Past the end of the function.
231			break
232		}
233		if b.EndLine < startLine || (b.EndLine == startLine && b.EndCol <= startCol) {
234			// Before the beginning of the function
235			continue
236		}
237		for i := b.StartLine; i <= b.EndLine; i++ {
238			method.Lines.AddOrUpdateLine(i, int64(b.Count))
239		}
240	}
241	return method
242}
243
244func (v *fileVisitor) class(n *ast.FuncDecl) *Class {
245	var className string
246	if byFiles {
247		//className = filepath.Base(v.fileName)
248		//
249		// NOTE(boumenot): ReportGenerator creates links that collide if names are not distinct.
250		// This could be an issue in how I am generating the report, but I have not been able
251		// to figure it out.  The work around is to generate a fully qualified name based on
252		// the file path.
253		//
254		// src/lib/util/foo.go -> src.lib.util.foo.go
255		className = strings.Replace(v.fileName, "/", ".", -1)
256		className = strings.Replace(className, "\\", ".", -1)
257	} else {
258		className = v.recvName(n)
259	}
260	class := v.classes[className]
261	if class == nil {
262		class = &Class{Name: className, Filename: v.fileName, Methods: []*Method{}, Lines: []*Line{}}
263		v.classes[className] = class
264		v.pkg.Classes = append(v.pkg.Classes, class)
265	}
266	return class
267}
268
269func (v *fileVisitor) recvName(n *ast.FuncDecl) string {
270	if n.Recv == nil {
271		return "-"
272	}
273	recv := n.Recv.List[0].Type
274	start := v.fset.Position(recv.Pos())
275	end := v.fset.Position(recv.End())
276	name := string(v.fileData[start.Offset:end.Offset])
277	return strings.TrimSpace(strings.TrimLeft(name, "*"))
278}
279