1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build go1.16
6
7// This file implements export filtering of an AST.
8
9package doc
10
11import (
12	"go/ast"
13	"go/token"
14)
15
16// filterIdentList removes unexported names from list in place
17// and returns the resulting list.
18//
19func filterIdentList(list []*ast.Ident) []*ast.Ident {
20	j := 0
21	for _, x := range list {
22		if token.IsExported(x.Name) {
23			list[j] = x
24			j++
25		}
26	}
27	return list[0:j]
28}
29
30var underscore = ast.NewIdent("_")
31
32func filterCompositeLit(lit *ast.CompositeLit, filter Filter, export bool) {
33	n := len(lit.Elts)
34	lit.Elts = filterExprList(lit.Elts, filter, export)
35	if len(lit.Elts) < n {
36		lit.Incomplete = true
37	}
38}
39
40func filterExprList(list []ast.Expr, filter Filter, export bool) []ast.Expr {
41	j := 0
42	for _, exp := range list {
43		switch x := exp.(type) {
44		case *ast.CompositeLit:
45			filterCompositeLit(x, filter, export)
46		case *ast.KeyValueExpr:
47			if x, ok := x.Key.(*ast.Ident); ok && !filter(x.Name) {
48				continue
49			}
50			if x, ok := x.Value.(*ast.CompositeLit); ok {
51				filterCompositeLit(x, filter, export)
52			}
53		}
54		list[j] = exp
55		j++
56	}
57	return list[0:j]
58}
59
60// updateIdentList replaces all unexported identifiers with underscore
61// and reports whether at least one exported name exists.
62func updateIdentList(list []*ast.Ident) (hasExported bool) {
63	for i, x := range list {
64		if token.IsExported(x.Name) {
65			hasExported = true
66		} else {
67			list[i] = underscore
68		}
69	}
70	return hasExported
71}
72
73// hasExportedName reports whether list contains any exported names.
74//
75func hasExportedName(list []*ast.Ident) bool {
76	for _, x := range list {
77		if x.IsExported() {
78			return true
79		}
80	}
81	return false
82}
83
84// removeErrorField removes anonymous fields named "error" from an interface.
85// This is called when "error" has been determined to be a local name,
86// not the predeclared type.
87//
88func removeErrorField(ityp *ast.InterfaceType) {
89	list := ityp.Methods.List // we know that ityp.Methods != nil
90	j := 0
91	for _, field := range list {
92		keepField := true
93		if n := len(field.Names); n == 0 {
94			// anonymous field
95			if fname, _ := baseTypeName(field.Type); fname == "error" {
96				keepField = false
97			}
98		}
99		if keepField {
100			list[j] = field
101			j++
102		}
103	}
104	if j < len(list) {
105		ityp.Incomplete = true
106	}
107	ityp.Methods.List = list[0:j]
108}
109
110// filterFieldList removes unexported fields (field names) from the field list
111// in place and reports whether fields were removed. Anonymous fields are
112// recorded with the parent type. filterType is called with the types of
113// all remaining fields.
114//
115func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
116	if fields == nil {
117		return
118	}
119	list := fields.List
120	j := 0
121	for _, field := range list {
122		keepField := false
123		if n := len(field.Names); n == 0 {
124			// anonymous field
125			fname := r.recordAnonymousField(parent, field.Type)
126			if token.IsExported(fname) {
127				keepField = true
128			} else if ityp != nil && fname == "error" {
129				// possibly the predeclared error interface; keep
130				// it for now but remember this interface so that
131				// it can be fixed if error is also defined locally
132				keepField = true
133				r.remember(ityp)
134			}
135		} else {
136			field.Names = filterIdentList(field.Names)
137			if len(field.Names) < n {
138				removedFields = true
139			}
140			if len(field.Names) > 0 {
141				keepField = true
142			}
143		}
144		if keepField {
145			r.filterType(nil, field.Type)
146			list[j] = field
147			j++
148		}
149	}
150	if j < len(list) {
151		removedFields = true
152	}
153	fields.List = list[0:j]
154	return
155}
156
157// filterParamList applies filterType to each parameter type in fields.
158//
159func (r *reader) filterParamList(fields *ast.FieldList) {
160	if fields != nil {
161		for _, f := range fields.List {
162			r.filterType(nil, f.Type)
163		}
164	}
165}
166
167// filterType strips any unexported struct fields or method types from typ
168// in place. If fields (or methods) have been removed, the corresponding
169// struct or interface type has the Incomplete field set to true.
170//
171func (r *reader) filterType(parent *namedType, typ ast.Expr) {
172	switch t := typ.(type) {
173	case *ast.Ident:
174		// nothing to do
175	case *ast.ParenExpr:
176		r.filterType(nil, t.X)
177	case *ast.ArrayType:
178		r.filterType(nil, t.Elt)
179	case *ast.StructType:
180		if r.filterFieldList(parent, t.Fields, nil) {
181			t.Incomplete = true
182		}
183	case *ast.FuncType:
184		r.filterParamList(t.Params)
185		r.filterParamList(t.Results)
186	case *ast.InterfaceType:
187		if r.filterFieldList(parent, t.Methods, t) {
188			t.Incomplete = true
189		}
190	case *ast.MapType:
191		r.filterType(nil, t.Key)
192		r.filterType(nil, t.Value)
193	case *ast.ChanType:
194		r.filterType(nil, t.Value)
195	}
196}
197
198func (r *reader) filterSpec(spec ast.Spec) bool {
199	switch s := spec.(type) {
200	case *ast.ImportSpec:
201		// always keep imports so we can collect them
202		return true
203	case *ast.ValueSpec:
204		s.Values = filterExprList(s.Values, token.IsExported, true)
205		if len(s.Values) > 0 || s.Type == nil && len(s.Values) == 0 {
206			// If there are values declared on RHS, just replace the unexported
207			// identifiers on the LHS with underscore, so that it matches
208			// the sequence of expression on the RHS.
209			//
210			// Similarly, if there are no type and values, then this expression
211			// must be following an iota expression, where order matters.
212			if updateIdentList(s.Names) {
213				r.filterType(nil, s.Type)
214				return true
215			}
216		} else {
217			s.Names = filterIdentList(s.Names)
218			if len(s.Names) > 0 {
219				r.filterType(nil, s.Type)
220				return true
221			}
222		}
223	case *ast.TypeSpec:
224		if name := s.Name.Name; token.IsExported(name) {
225			r.filterType(r.lookupType(s.Name.Name), s.Type)
226			return true
227		} else if name == "error" {
228			// special case: remember that error is declared locally
229			r.errorDecl = true
230		}
231	}
232	return false
233}
234
235// copyConstType returns a copy of typ with position pos.
236// typ must be a valid constant type.
237// In practice, only (possibly qualified) identifiers are possible.
238//
239func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
240	switch typ := typ.(type) {
241	case *ast.Ident:
242		return &ast.Ident{Name: typ.Name, NamePos: pos}
243	case *ast.SelectorExpr:
244		if id, ok := typ.X.(*ast.Ident); ok {
245			// presumably a qualified identifier
246			return &ast.SelectorExpr{
247				Sel: ast.NewIdent(typ.Sel.Name),
248				X:   &ast.Ident{Name: id.Name, NamePos: pos},
249			}
250		}
251	}
252	return nil // shouldn't happen, but be conservative and don't panic
253}
254
255func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
256	if tok == token.CONST {
257		// Propagate any type information that would get lost otherwise
258		// when unexported constants are filtered.
259		var prevType ast.Expr
260		for _, spec := range list {
261			spec := spec.(*ast.ValueSpec)
262			if spec.Type == nil && len(spec.Values) == 0 && prevType != nil {
263				// provide current spec with an explicit type
264				spec.Type = copyConstType(prevType, spec.Pos())
265			}
266			if hasExportedName(spec.Names) {
267				// exported names are preserved so there's no need to propagate the type
268				prevType = nil
269			} else {
270				prevType = spec.Type
271			}
272		}
273	}
274
275	j := 0
276	for _, s := range list {
277		if r.filterSpec(s) {
278			list[j] = s
279			j++
280		}
281	}
282	return list[0:j]
283}
284
285func (r *reader) filterDecl(decl ast.Decl) bool {
286	switch d := decl.(type) {
287	case *ast.GenDecl:
288		d.Specs = r.filterSpecList(d.Specs, d.Tok)
289		return len(d.Specs) > 0
290	case *ast.FuncDecl:
291		// ok to filter these methods early because any
292		// conflicting method will be filtered here, too -
293		// thus, removing these methods early will not lead
294		// to the false removal of possible conflicts
295		return token.IsExported(d.Name.Name)
296	}
297	return false
298}
299
300// fileExports removes unexported declarations from src in place.
301//
302func (r *reader) fileExports(src *ast.File) {
303	j := 0
304	for _, d := range src.Decls {
305		if r.filterDecl(d) {
306			src.Decls[j] = d
307			j++
308		}
309	}
310	src.Decls = src.Decls[0:j]
311}
312