1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// This file implements export filtering of an AST.
6
7package doc
8
9import (
10	"go/ast"
11	"go/token"
12)
13
14// filterIdentList removes unexported names from list in place
15// and returns the resulting list.
16//
17func filterIdentList(list []*ast.Ident) []*ast.Ident {
18	j := 0
19	for _, x := range list {
20		if ast.IsExported(x.Name) {
21			list[j] = x
22			j++
23		}
24	}
25	return list[0:j]
26}
27
28var underscore = ast.NewIdent("_")
29
30func filterCompositeLit(lit *ast.CompositeLit, filter Filter, export bool) {
31	n := len(lit.Elts)
32	lit.Elts = filterExprList(lit.Elts, filter, export)
33	if len(lit.Elts) < n {
34		lit.Incomplete = true
35	}
36}
37
38func filterExprList(list []ast.Expr, filter Filter, export bool) []ast.Expr {
39	j := 0
40	for _, exp := range list {
41		switch x := exp.(type) {
42		case *ast.CompositeLit:
43			filterCompositeLit(x, filter, export)
44		case *ast.KeyValueExpr:
45			if x, ok := x.Key.(*ast.Ident); ok && !filter(x.Name) {
46				continue
47			}
48			if x, ok := x.Value.(*ast.CompositeLit); ok {
49				filterCompositeLit(x, filter, export)
50			}
51		}
52		list[j] = exp
53		j++
54	}
55	return list[0:j]
56}
57
58// updateIdentList replaces all unexported identifiers with underscore
59// and reports whether at least one exported name exists.
60func updateIdentList(list []*ast.Ident) (hasExported bool) {
61	for i, x := range list {
62		if ast.IsExported(x.Name) {
63			hasExported = true
64		} else {
65			list[i] = underscore
66		}
67	}
68	return hasExported
69}
70
71// hasExportedName reports whether list contains any exported names.
72//
73func hasExportedName(list []*ast.Ident) bool {
74	for _, x := range list {
75		if x.IsExported() {
76			return true
77		}
78	}
79	return false
80}
81
82// removeErrorField removes anonymous fields named "error" from an interface.
83// This is called when "error" has been determined to be a local name,
84// not the predeclared type.
85//
86func removeErrorField(ityp *ast.InterfaceType) {
87	list := ityp.Methods.List // we know that ityp.Methods != nil
88	j := 0
89	for _, field := range list {
90		keepField := true
91		if n := len(field.Names); n == 0 {
92			// anonymous field
93			if fname, _ := baseTypeName(field.Type); fname == "error" {
94				keepField = false
95			}
96		}
97		if keepField {
98			list[j] = field
99			j++
100		}
101	}
102	if j < len(list) {
103		ityp.Incomplete = true
104	}
105	ityp.Methods.List = list[0:j]
106}
107
108// filterFieldList removes unexported fields (field names) from the field list
109// in place and reports whether fields were removed. Anonymous fields are
110// recorded with the parent type. filterType is called with the types of
111// all remaining fields.
112//
113func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
114	if fields == nil {
115		return
116	}
117	list := fields.List
118	j := 0
119	for _, field := range list {
120		keepField := false
121		if n := len(field.Names); n == 0 {
122			// anonymous field
123			fname := r.recordAnonymousField(parent, field.Type)
124			if ast.IsExported(fname) {
125				keepField = true
126			} else if ityp != nil && fname == "error" {
127				// possibly the predeclared error interface; keep
128				// it for now but remember this interface so that
129				// it can be fixed if error is also defined locally
130				keepField = true
131				r.remember(ityp)
132			}
133		} else {
134			field.Names = filterIdentList(field.Names)
135			if len(field.Names) < n {
136				removedFields = true
137			}
138			if len(field.Names) > 0 {
139				keepField = true
140			}
141		}
142		if keepField {
143			r.filterType(nil, field.Type)
144			list[j] = field
145			j++
146		}
147	}
148	if j < len(list) {
149		removedFields = true
150	}
151	fields.List = list[0:j]
152	return
153}
154
155// filterParamList applies filterType to each parameter type in fields.
156//
157func (r *reader) filterParamList(fields *ast.FieldList) {
158	if fields != nil {
159		for _, f := range fields.List {
160			r.filterType(nil, f.Type)
161		}
162	}
163}
164
165// filterType strips any unexported struct fields or method types from typ
166// in place. If fields (or methods) have been removed, the corresponding
167// struct or interface type has the Incomplete field set to true.
168//
169func (r *reader) filterType(parent *namedType, typ ast.Expr) {
170	switch t := typ.(type) {
171	case *ast.Ident:
172		// nothing to do
173	case *ast.ParenExpr:
174		r.filterType(nil, t.X)
175	case *ast.ArrayType:
176		r.filterType(nil, t.Elt)
177	case *ast.StructType:
178		if r.filterFieldList(parent, t.Fields, nil) {
179			t.Incomplete = true
180		}
181	case *ast.FuncType:
182		r.filterParamList(t.Params)
183		r.filterParamList(t.Results)
184	case *ast.InterfaceType:
185		if r.filterFieldList(parent, t.Methods, t) {
186			t.Incomplete = true
187		}
188	case *ast.MapType:
189		r.filterType(nil, t.Key)
190		r.filterType(nil, t.Value)
191	case *ast.ChanType:
192		r.filterType(nil, t.Value)
193	}
194}
195
196func (r *reader) filterSpec(spec ast.Spec) bool {
197	switch s := spec.(type) {
198	case *ast.ImportSpec:
199		// always keep imports so we can collect them
200		return true
201	case *ast.ValueSpec:
202		s.Values = filterExprList(s.Values, ast.IsExported, true)
203		if len(s.Values) > 0 || s.Type == nil && len(s.Values) == 0 {
204			// If there are values declared on RHS, just replace the unexported
205			// identifiers on the LHS with underscore, so that it matches
206			// the sequence of expression on the RHS.
207			//
208			// Similarly, if there are no type and values, then this expression
209			// must be following an iota expression, where order matters.
210			if updateIdentList(s.Names) {
211				r.filterType(nil, s.Type)
212				return true
213			}
214		} else {
215			s.Names = filterIdentList(s.Names)
216			if len(s.Names) > 0 {
217				r.filterType(nil, s.Type)
218				return true
219			}
220		}
221	case *ast.TypeSpec:
222		if name := s.Name.Name; ast.IsExported(name) {
223			r.filterType(r.lookupType(s.Name.Name), s.Type)
224			return true
225		} else if name == "error" {
226			// special case: remember that error is declared locally
227			r.errorDecl = true
228		}
229	}
230	return false
231}
232
233// copyConstType returns a copy of typ with position pos.
234// typ must be a valid constant type.
235// In practice, only (possibly qualified) identifiers are possible.
236//
237func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
238	switch typ := typ.(type) {
239	case *ast.Ident:
240		return &ast.Ident{Name: typ.Name, NamePos: pos}
241	case *ast.SelectorExpr:
242		if id, ok := typ.X.(*ast.Ident); ok {
243			// presumably a qualified identifier
244			return &ast.SelectorExpr{
245				Sel: ast.NewIdent(typ.Sel.Name),
246				X:   &ast.Ident{Name: id.Name, NamePos: pos},
247			}
248		}
249	}
250	return nil // shouldn't happen, but be conservative and don't panic
251}
252
253func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
254	if tok == token.CONST {
255		// Propagate any type information that would get lost otherwise
256		// when unexported constants are filtered.
257		var prevType ast.Expr
258		for _, spec := range list {
259			spec := spec.(*ast.ValueSpec)
260			if spec.Type == nil && len(spec.Values) == 0 && prevType != nil {
261				// provide current spec with an explicit type
262				spec.Type = copyConstType(prevType, spec.Pos())
263			}
264			if hasExportedName(spec.Names) {
265				// exported names are preserved so there's no need to propagate the type
266				prevType = nil
267			} else {
268				prevType = spec.Type
269			}
270		}
271	}
272
273	j := 0
274	for _, s := range list {
275		if r.filterSpec(s) {
276			list[j] = s
277			j++
278		}
279	}
280	return list[0:j]
281}
282
283func (r *reader) filterDecl(decl ast.Decl) bool {
284	switch d := decl.(type) {
285	case *ast.GenDecl:
286		d.Specs = r.filterSpecList(d.Specs, d.Tok)
287		return len(d.Specs) > 0
288	case *ast.FuncDecl:
289		// ok to filter these methods early because any
290		// conflicting method will be filtered here, too -
291		// thus, removing these methods early will not lead
292		// to the false removal of possible conflicts
293		return ast.IsExported(d.Name.Name)
294	}
295	return false
296}
297
298// fileExports removes unexported declarations from src in place.
299//
300func (r *reader) fileExports(src *ast.File) {
301	j := 0
302	for _, d := range src.Decls {
303		if r.filterDecl(d) {
304			src.Decls[j] = d
305			j++
306		}
307	}
308	src.Decls = src.Decls[0:j]
309}
310