1// (c) Copyright 2016 Hewlett Packard Enterprise Development LP
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rules
16
17import (
18	"go/ast"
19	"regexp"
20	"strings"
21
22	"github.com/securego/gosec/v2"
23)
24
25type sqlStatement struct {
26	gosec.MetaData
27	gosec.CallList
28
29	// Contains a list of patterns which must all match for the rule to match.
30	patterns []*regexp.Regexp
31}
32
33func (s *sqlStatement) ID() string {
34	return s.MetaData.ID
35}
36
37// See if the string matches the patterns for the statement.
38func (s *sqlStatement) MatchPatterns(str string) bool {
39	for _, pattern := range s.patterns {
40		if !pattern.MatchString(str) {
41			return false
42		}
43	}
44	return true
45}
46
47type sqlStrConcat struct {
48	sqlStatement
49}
50
51func (s *sqlStrConcat) ID() string {
52	return s.MetaData.ID
53}
54
55// see if we can figure out what it is
56func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {
57	if n.Obj != nil {
58		return n.Obj.Kind != ast.Var && n.Obj.Kind != ast.Fun
59	}
60
61	// Try to resolve unresolved identifiers using other files in same package
62	for _, file := range c.PkgFiles {
63		if node, ok := file.Scope.Objects[n.String()]; ok {
64			return node.Kind != ast.Var && node.Kind != ast.Fun
65		}
66	}
67	return false
68}
69
70// checkQuery verifies if the query parameters is a string concatenation
71func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
72	_, fnName, err := gosec.GetCallInfo(call, ctx)
73	if err != nil {
74		return nil, err
75	}
76	var query ast.Node
77	if strings.HasSuffix(fnName, "Context") {
78		query = call.Args[1]
79	} else {
80		query = call.Args[0]
81	}
82
83	if be, ok := query.(*ast.BinaryExpr); ok {
84		operands := gosec.GetBinaryExprOperands(be)
85		if start, ok := operands[0].(*ast.BasicLit); ok {
86			if str, e := gosec.GetString(start); e == nil {
87				if !s.MatchPatterns(str) {
88					return nil, nil
89				}
90			}
91			for _, op := range operands[1:] {
92				if _, ok := op.(*ast.BasicLit); ok {
93					continue
94				}
95				if op, ok := op.(*ast.Ident); ok && s.checkObject(op, ctx) {
96					continue
97				}
98				return gosec.NewIssue(ctx, be, s.ID(), s.What, s.Severity, s.Confidence), nil
99			}
100		}
101	}
102
103	return nil, nil
104}
105
106// Checks SQL query concatenation issues such as "SELECT * FROM table WHERE " + " ' OR 1=1"
107func (s *sqlStrConcat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
108	switch stmt := n.(type) {
109	case *ast.AssignStmt:
110		for _, expr := range stmt.Rhs {
111			if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
112				return s.checkQuery(sqlQueryCall, ctx)
113			}
114		}
115	case *ast.ExprStmt:
116		if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
117			return s.checkQuery(sqlQueryCall, ctx)
118		}
119	}
120	return nil, nil
121}
122
123// NewSQLStrConcat looks for cases where we are building SQL strings via concatenation
124func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
125	rule := &sqlStrConcat{
126		sqlStatement: sqlStatement{
127			patterns: []*regexp.Regexp{
128				regexp.MustCompile(`(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) `),
129			},
130			MetaData: gosec.MetaData{
131				ID:         id,
132				Severity:   gosec.Medium,
133				Confidence: gosec.High,
134				What:       "SQL string concatenation",
135			},
136			CallList: gosec.NewCallList(),
137		},
138	}
139
140	rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
141	rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
142	return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
143}
144
145type sqlStrFormat struct {
146	gosec.CallList
147	sqlStatement
148	fmtCalls      gosec.CallList
149	noIssue       gosec.CallList
150	noIssueQuoted gosec.CallList
151}
152
153// see if we can figure out what it is
154func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
155	n, ok := e.(*ast.Ident)
156	if !ok {
157		return false
158	}
159
160	if n.Obj != nil {
161		return n.Obj.Kind == ast.Con
162	}
163
164	// Try to resolve unresolved identifiers using other files in same package
165	for _, file := range c.PkgFiles {
166		if node, ok := file.Scope.Objects[n.String()]; ok {
167			return node.Kind == ast.Con
168		}
169	}
170	return false
171}
172
173func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
174	_, fnName, err := gosec.GetCallInfo(call, ctx)
175	if err != nil {
176		return nil, err
177	}
178	var query ast.Node
179	if strings.HasSuffix(fnName, "Context") {
180		query = call.Args[1]
181	} else {
182		query = call.Args[0]
183	}
184
185	if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
186		decl := ident.Obj.Decl
187		if assign, ok := decl.(*ast.AssignStmt); ok {
188			for _, expr := range assign.Rhs {
189				issue, err := s.checkFormatting(expr, ctx)
190				if issue != nil {
191					return issue, err
192				}
193			}
194		}
195	}
196
197	return nil, nil
198}
199
200func (s *sqlStrFormat) checkFormatting(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
201	// argIndex changes the function argument which gets matched to the regex
202	argIndex := 0
203	if node := s.fmtCalls.ContainsPkgCallExpr(n, ctx, false); node != nil {
204		// if the function is fmt.Fprintf, search for SQL statement in Args[1] instead
205		if sel, ok := node.Fun.(*ast.SelectorExpr); ok {
206			if sel.Sel.Name == "Fprintf" {
207				// if os.Stderr or os.Stdout is in Arg[0], mark as no issue
208				if arg, ok := node.Args[0].(*ast.SelectorExpr); ok {
209					if ident, ok := arg.X.(*ast.Ident); ok {
210						if s.noIssue.Contains(ident.Name, arg.Sel.Name) {
211							return nil, nil
212						}
213					}
214				}
215				// the function is Fprintf so set argIndex = 1
216				argIndex = 1
217			}
218		}
219
220		// no formatter
221		if len(node.Args) == 0 {
222			return nil, nil
223		}
224
225		var formatter string
226
227		// concats callexpr arg strings together if needed before regex evaluation
228		if argExpr, ok := node.Args[argIndex].(*ast.BinaryExpr); ok {
229			if fullStr, ok := gosec.ConcatString(argExpr); ok {
230				formatter = fullStr
231			}
232		} else if arg, e := gosec.GetString(node.Args[argIndex]); e == nil {
233			formatter = arg
234		}
235		if len(formatter) <= 0 {
236			return nil, nil
237		}
238
239		// If all formatter args are quoted or constant, then the SQL construction is safe
240		if argIndex+1 < len(node.Args) {
241			allSafe := true
242			for _, arg := range node.Args[argIndex+1:] {
243				if n := s.noIssueQuoted.ContainsPkgCallExpr(arg, ctx, true); n == nil && !s.constObject(arg, ctx) {
244					allSafe = false
245					break
246				}
247			}
248			if allSafe {
249				return nil, nil
250			}
251		}
252		if s.MatchPatterns(formatter) {
253			return gosec.NewIssue(ctx, n, s.ID(), s.What, s.Severity, s.Confidence), nil
254		}
255	}
256	return nil, nil
257}
258
259// Check SQL query formatting issues such as "fmt.Sprintf("SELECT * FROM foo where '%s', userInput)"
260func (s *sqlStrFormat) Match(n ast.Node, ctx *gosec.Context) (*gosec.Issue, error) {
261	switch stmt := n.(type) {
262	case *ast.AssignStmt:
263		for _, expr := range stmt.Rhs {
264			if sqlQueryCall, ok := expr.(*ast.CallExpr); ok && s.ContainsCallExpr(expr, ctx) != nil {
265				return s.checkQuery(sqlQueryCall, ctx)
266			}
267		}
268	case *ast.ExprStmt:
269		if sqlQueryCall, ok := stmt.X.(*ast.CallExpr); ok && s.ContainsCallExpr(stmt.X, ctx) != nil {
270			return s.checkQuery(sqlQueryCall, ctx)
271		}
272	}
273	return nil, nil
274}
275
276// NewSQLStrFormat looks for cases where we're building SQL query strings using format strings
277func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
278	rule := &sqlStrFormat{
279		CallList:      gosec.NewCallList(),
280		fmtCalls:      gosec.NewCallList(),
281		noIssue:       gosec.NewCallList(),
282		noIssueQuoted: gosec.NewCallList(),
283		sqlStatement: sqlStatement{
284			patterns: []*regexp.Regexp{
285				regexp.MustCompile("(?i)(SELECT|DELETE|INSERT|UPDATE|INTO|FROM|WHERE) "),
286				regexp.MustCompile("%[^bdoxXfFp]"),
287			},
288			MetaData: gosec.MetaData{
289				ID:         id,
290				Severity:   gosec.Medium,
291				Confidence: gosec.High,
292				What:       "SQL string formatting",
293			},
294		},
295	}
296	rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext")
297	rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext")
298	rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
299	rule.noIssue.AddAll("os", "Stdout", "Stderr")
300	rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
301
302	return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
303}
304