1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5//go:build go1.16
6// +build go1.16
7
8// This file implements export filtering of an AST.
9
10package doc
11
12import (
13	"go/ast"
14	"go/token"
15)
16
17// filterIdentList removes unexported names from list in place
18// and returns the resulting list.
19//
20func filterIdentList(list []*ast.Ident) []*ast.Ident {
21	j := 0
22	for _, x := range list {
23		if token.IsExported(x.Name) {
24			list[j] = x
25			j++
26		}
27	}
28	return list[0:j]
29}
30
31var underscore = ast.NewIdent("_")
32
33func filterCompositeLit(lit *ast.CompositeLit, filter Filter, export bool) {
34	n := len(lit.Elts)
35	lit.Elts = filterExprList(lit.Elts, filter, export)
36	if len(lit.Elts) < n {
37		lit.Incomplete = true
38	}
39}
40
41func filterExprList(list []ast.Expr, filter Filter, export bool) []ast.Expr {
42	j := 0
43	for _, exp := range list {
44		switch x := exp.(type) {
45		case *ast.CompositeLit:
46			filterCompositeLit(x, filter, export)
47		case *ast.KeyValueExpr:
48			if x, ok := x.Key.(*ast.Ident); ok && !filter(x.Name) {
49				continue
50			}
51			if x, ok := x.Value.(*ast.CompositeLit); ok {
52				filterCompositeLit(x, filter, export)
53			}
54		}
55		list[j] = exp
56		j++
57	}
58	return list[0:j]
59}
60
61// updateIdentList replaces all unexported identifiers with underscore
62// and reports whether at least one exported name exists.
63func updateIdentList(list []*ast.Ident) (hasExported bool) {
64	for i, x := range list {
65		if token.IsExported(x.Name) {
66			hasExported = true
67		} else {
68			list[i] = underscore
69		}
70	}
71	return hasExported
72}
73
74// hasExportedName reports whether list contains any exported names.
75//
76func hasExportedName(list []*ast.Ident) bool {
77	for _, x := range list {
78		if x.IsExported() {
79			return true
80		}
81	}
82	return false
83}
84
85// removeErrorField removes anonymous fields named "error" from an interface.
86// This is called when "error" has been determined to be a local name,
87// not the predeclared type.
88//
89func removeErrorField(ityp *ast.InterfaceType) {
90	list := ityp.Methods.List // we know that ityp.Methods != nil
91	j := 0
92	for _, field := range list {
93		keepField := true
94		if n := len(field.Names); n == 0 {
95			// anonymous field
96			if fname, _ := baseTypeName(field.Type); fname == "error" {
97				keepField = false
98			}
99		}
100		if keepField {
101			list[j] = field
102			j++
103		}
104	}
105	if j < len(list) {
106		ityp.Incomplete = true
107	}
108	ityp.Methods.List = list[0:j]
109}
110
111// filterFieldList removes unexported fields (field names) from the field list
112// in place and reports whether fields were removed. Anonymous fields are
113// recorded with the parent type. filterType is called with the types of
114// all remaining fields.
115//
116func (r *reader) filterFieldList(parent *namedType, fields *ast.FieldList, ityp *ast.InterfaceType) (removedFields bool) {
117	if fields == nil {
118		return
119	}
120	list := fields.List
121	j := 0
122	for _, field := range list {
123		keepField := false
124		if n := len(field.Names); n == 0 {
125			// anonymous field
126			fname := r.recordAnonymousField(parent, field.Type)
127			if token.IsExported(fname) {
128				keepField = true
129			} else if ityp != nil && fname == "error" {
130				// possibly the predeclared error interface; keep
131				// it for now but remember this interface so that
132				// it can be fixed if error is also defined locally
133				keepField = true
134				r.remember(ityp)
135			}
136		} else {
137			field.Names = filterIdentList(field.Names)
138			if len(field.Names) < n {
139				removedFields = true
140			}
141			if len(field.Names) > 0 {
142				keepField = true
143			}
144		}
145		if keepField {
146			r.filterType(nil, field.Type)
147			list[j] = field
148			j++
149		}
150	}
151	if j < len(list) {
152		removedFields = true
153	}
154	fields.List = list[0:j]
155	return
156}
157
158// filterParamList applies filterType to each parameter type in fields.
159//
160func (r *reader) filterParamList(fields *ast.FieldList) {
161	if fields != nil {
162		for _, f := range fields.List {
163			r.filterType(nil, f.Type)
164		}
165	}
166}
167
168// filterType strips any unexported struct fields or method types from typ
169// in place. If fields (or methods) have been removed, the corresponding
170// struct or interface type has the Incomplete field set to true.
171//
172func (r *reader) filterType(parent *namedType, typ ast.Expr) {
173	switch t := typ.(type) {
174	case *ast.Ident:
175		// nothing to do
176	case *ast.ParenExpr:
177		r.filterType(nil, t.X)
178	case *ast.ArrayType:
179		r.filterType(nil, t.Elt)
180	case *ast.StructType:
181		if r.filterFieldList(parent, t.Fields, nil) {
182			t.Incomplete = true
183		}
184	case *ast.FuncType:
185		r.filterParamList(t.Params)
186		r.filterParamList(t.Results)
187	case *ast.InterfaceType:
188		if r.filterFieldList(parent, t.Methods, t) {
189			t.Incomplete = true
190		}
191	case *ast.MapType:
192		r.filterType(nil, t.Key)
193		r.filterType(nil, t.Value)
194	case *ast.ChanType:
195		r.filterType(nil, t.Value)
196	}
197}
198
199func (r *reader) filterSpec(spec ast.Spec) bool {
200	switch s := spec.(type) {
201	case *ast.ImportSpec:
202		// always keep imports so we can collect them
203		return true
204	case *ast.ValueSpec:
205		s.Values = filterExprList(s.Values, token.IsExported, true)
206		if len(s.Values) > 0 || s.Type == nil && len(s.Values) == 0 {
207			// If there are values declared on RHS, just replace the unexported
208			// identifiers on the LHS with underscore, so that it matches
209			// the sequence of expression on the RHS.
210			//
211			// Similarly, if there are no type and values, then this expression
212			// must be following an iota expression, where order matters.
213			if updateIdentList(s.Names) {
214				r.filterType(nil, s.Type)
215				return true
216			}
217		} else {
218			s.Names = filterIdentList(s.Names)
219			if len(s.Names) > 0 {
220				r.filterType(nil, s.Type)
221				return true
222			}
223		}
224	case *ast.TypeSpec:
225		if name := s.Name.Name; token.IsExported(name) {
226			r.filterType(r.lookupType(s.Name.Name), s.Type)
227			return true
228		} else if name == "error" {
229			// special case: remember that error is declared locally
230			r.errorDecl = true
231		}
232	}
233	return false
234}
235
236// copyConstType returns a copy of typ with position pos.
237// typ must be a valid constant type.
238// In practice, only (possibly qualified) identifiers are possible.
239//
240func copyConstType(typ ast.Expr, pos token.Pos) ast.Expr {
241	switch typ := typ.(type) {
242	case *ast.Ident:
243		return &ast.Ident{Name: typ.Name, NamePos: pos}
244	case *ast.SelectorExpr:
245		if id, ok := typ.X.(*ast.Ident); ok {
246			// presumably a qualified identifier
247			return &ast.SelectorExpr{
248				Sel: ast.NewIdent(typ.Sel.Name),
249				X:   &ast.Ident{Name: id.Name, NamePos: pos},
250			}
251		}
252	}
253	return nil // shouldn't happen, but be conservative and don't panic
254}
255
256func (r *reader) filterSpecList(list []ast.Spec, tok token.Token) []ast.Spec {
257	if tok == token.CONST {
258		// Propagate any type information that would get lost otherwise
259		// when unexported constants are filtered.
260		var prevType ast.Expr
261		for _, spec := range list {
262			spec := spec.(*ast.ValueSpec)
263			if spec.Type == nil && len(spec.Values) == 0 && prevType != nil {
264				// provide current spec with an explicit type
265				spec.Type = copyConstType(prevType, spec.Pos())
266			}
267			if hasExportedName(spec.Names) {
268				// exported names are preserved so there's no need to propagate the type
269				prevType = nil
270			} else {
271				prevType = spec.Type
272			}
273		}
274	}
275
276	j := 0
277	for _, s := range list {
278		if r.filterSpec(s) {
279			list[j] = s
280			j++
281		}
282	}
283	return list[0:j]
284}
285
286func (r *reader) filterDecl(decl ast.Decl) bool {
287	switch d := decl.(type) {
288	case *ast.GenDecl:
289		d.Specs = r.filterSpecList(d.Specs, d.Tok)
290		return len(d.Specs) > 0
291	case *ast.FuncDecl:
292		// ok to filter these methods early because any
293		// conflicting method will be filtered here, too -
294		// thus, removing these methods early will not lead
295		// to the false removal of possible conflicts
296		return token.IsExported(d.Name.Name)
297	}
298	return false
299}
300
301// fileExports removes unexported declarations from src in place.
302//
303func (r *reader) fileExports(src *ast.File) {
304	j := 0
305	for _, d := range src.Decls {
306		if r.filterDecl(d) {
307			src.Decls[j] = d
308			j++
309		}
310	}
311	src.Decls = src.Decls[0:j]
312}
313