1// Copyright (c) 2017, Daniel Martí <mvdan@mvdan.cc> 2// See LICENSE for licensing information 3 4package gogrep 5 6import ( 7 "bytes" 8 "fmt" 9 "go/ast" 10 "go/parser" 11 "go/scanner" 12 "go/token" 13 "strings" 14 "text/template" 15) 16 17func transformSource(expr string) (string, []posOffset, error) { 18 toks, err := tokenize([]byte(expr)) 19 if err != nil { 20 return "", nil, fmt.Errorf("cannot tokenize expr: %v", err) 21 } 22 var offs []posOffset 23 lbuf := lineColBuffer{line: 1, col: 1} 24 lastLit := false 25 for _, t := range toks { 26 if lbuf.offs >= t.pos.Offset && lastLit && t.lit != "" { 27 _, _ = lbuf.WriteString(" ") 28 } 29 for lbuf.offs < t.pos.Offset { 30 _, _ = lbuf.WriteString(" ") 31 } 32 if t.lit == "" { 33 _, _ = lbuf.WriteString(t.tok.String()) 34 lastLit = false 35 continue 36 } 37 _, _ = lbuf.WriteString(t.lit) 38 lastLit = strings.TrimSpace(t.lit) != "" 39 } 40 // trailing newlines can cause issues with commas 41 return strings.TrimSpace(lbuf.String()), offs, nil 42} 43 44func parseExpr(fset *token.FileSet, expr string) (ast.Node, error) { 45 exprStr, offs, err := transformSource(expr) 46 if err != nil { 47 return nil, err 48 } 49 node, _, err := parseDetectingNode(fset, exprStr) 50 if err != nil { 51 err = subPosOffsets(err, offs...) 52 return nil, fmt.Errorf("cannot parse expr: %v", err) 53 } 54 return node, nil 55} 56 57type lineColBuffer struct { 58 bytes.Buffer 59 line, col, offs int 60} 61 62func (l *lineColBuffer) WriteString(s string) (n int, err error) { 63 for _, r := range s { 64 if r == '\n' { 65 l.line++ 66 l.col = 1 67 } else { 68 l.col++ 69 } 70 l.offs++ 71 } 72 return l.Buffer.WriteString(s) 73} 74 75var tmplDecl = template.Must(template.New("").Parse(`` + 76 `package p; {{ . }}`)) 77 78var tmplBlock = template.Must(template.New("").Parse(`` + 79 `package p; func _() { if true {{ . }} else {} }`)) 80 81var tmplExprs = template.Must(template.New("").Parse(`` + 82 `package p; var _ = []interface{}{ {{ . }}, }`)) 83 84var tmplStmts = template.Must(template.New("").Parse(`` + 85 `package p; func _() { {{ . }} }`)) 86 87var tmplType = template.Must(template.New("").Parse(`` + 88 `package p; var _ {{ . }}`)) 89 90var tmplValSpec = template.Must(template.New("").Parse(`` + 91 `package p; var {{ . }}`)) 92 93func execTmpl(tmpl *template.Template, src string) string { 94 var buf bytes.Buffer 95 if err := tmpl.Execute(&buf, src); err != nil { 96 panic(err) 97 } 98 return buf.String() 99} 100 101func noBadNodes(node ast.Node) bool { 102 any := false 103 ast.Inspect(node, func(n ast.Node) bool { 104 if any { 105 return false 106 } 107 switch n.(type) { 108 case *ast.BadExpr, *ast.BadDecl: 109 any = true 110 } 111 return true 112 }) 113 return !any 114} 115 116func parseType(fset *token.FileSet, src string) (ast.Expr, *ast.File, error) { 117 asType := execTmpl(tmplType, src) 118 f, err := parser.ParseFile(fset, "", asType, 0) 119 if err != nil { 120 err = subPosOffsets(err, posOffset{1, 1, 17}) 121 return nil, nil, err 122 } 123 vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec) 124 return vs.Type, f, nil 125} 126 127// parseDetectingNode tries its best to parse the ast.Node contained in src, as 128// one of: *ast.File, ast.Decl, ast.Expr, ast.Stmt, *ast.ValueSpec. 129// It also returns the *ast.File used for the parsing, so that the returned node 130// can be easily type-checked. 131func parseDetectingNode(fset *token.FileSet, src string) (ast.Node, *ast.File, error) { 132 file := fset.AddFile("", fset.Base(), len(src)) 133 scan := scanner.Scanner{} 134 scan.Init(file, []byte(src), nil, 0) 135 if _, tok, _ := scan.Scan(); tok == token.EOF { 136 return nil, nil, fmt.Errorf("empty source code") 137 } 138 var mainErr error 139 140 // first try as a whole file 141 if f, err := parser.ParseFile(fset, "", src, 0); err == nil && noBadNodes(f) { 142 return f, f, nil 143 } 144 145 // then as a single declaration, or many 146 asDecl := execTmpl(tmplDecl, src) 147 if f, err := parser.ParseFile(fset, "", asDecl, 0); err == nil && noBadNodes(f) { 148 if len(f.Decls) == 1 { 149 return f.Decls[0], f, nil 150 } 151 return f, f, nil 152 } 153 154 // then as a block; otherwise blocks might be mistaken for composite 155 // literals further below 156 asBlock := execTmpl(tmplBlock, src) 157 if f, err := parser.ParseFile(fset, "", asBlock, 0); err == nil && noBadNodes(f) { 158 bl := f.Decls[0].(*ast.FuncDecl).Body 159 if len(bl.List) == 1 { 160 ifs := bl.List[0].(*ast.IfStmt) 161 return ifs.Body, f, nil 162 } 163 } 164 165 // then as value expressions 166 asExprs := execTmpl(tmplExprs, src) 167 if f, err := parser.ParseFile(fset, "", asExprs, 0); err == nil && noBadNodes(f) { 168 vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec) 169 cl := vs.Values[0].(*ast.CompositeLit) 170 if len(cl.Elts) == 1 { 171 return cl.Elts[0], f, nil 172 } 173 return exprSlice(cl.Elts), f, nil 174 } 175 176 // then try as statements 177 asStmts := execTmpl(tmplStmts, src) 178 f, err := parser.ParseFile(fset, "", asStmts, 0) 179 if err == nil && noBadNodes(f) { 180 bl := f.Decls[0].(*ast.FuncDecl).Body 181 if len(bl.List) == 1 { 182 return bl.List[0], f, nil 183 } 184 return stmtSlice(bl.List), f, nil 185 } 186 // Statements is what covers most cases, so it will give 187 // the best overall error message. Show positions 188 // relative to where the user's code is put in the 189 // template. 190 mainErr = subPosOffsets(err, posOffset{1, 1, 22}) 191 192 // type expressions not yet picked up, for e.g. chans and interfaces 193 if typ, f, err := parseType(fset, src); err == nil && noBadNodes(f) { 194 return typ, f, nil 195 } 196 197 // value specs 198 asValSpec := execTmpl(tmplValSpec, src) 199 if f, err := parser.ParseFile(fset, "", asValSpec, 0); err == nil && noBadNodes(f) { 200 vs := f.Decls[0].(*ast.GenDecl).Specs[0].(*ast.ValueSpec) 201 return vs, f, nil 202 } 203 return nil, nil, mainErr 204} 205 206type posOffset struct { 207 atLine, atCol int 208 offset int 209} 210 211func subPosOffsets(err error, offs ...posOffset) error { 212 list, ok := err.(scanner.ErrorList) 213 if !ok { 214 return err 215 } 216 for i, err := range list { 217 for _, off := range offs { 218 if err.Pos.Line != off.atLine { 219 continue 220 } 221 if err.Pos.Column < off.atCol { 222 continue 223 } 224 err.Pos.Column -= off.offset 225 } 226 list[i] = err 227 } 228 return list 229} 230 231type fullToken struct { 232 pos token.Position 233 tok token.Token 234 lit string 235} 236 237type caseStatus uint 238 239const ( 240 caseNone caseStatus = iota 241 caseNeedBlock 242 caseHere 243) 244 245func tokenize(src []byte) ([]fullToken, error) { 246 var s scanner.Scanner 247 fset := token.NewFileSet() 248 file := fset.AddFile("", fset.Base(), len(src)) 249 250 var err error 251 onError := func(pos token.Position, msg string) { 252 switch msg { // allow certain extra chars 253 case `illegal character U+0024 '$'`: 254 case `illegal character U+007E '~'`: 255 default: 256 err = fmt.Errorf("%v: %s", pos, msg) 257 } 258 } 259 260 // we will modify the input source under the scanner's nose to 261 // enable some features such as regexes. 262 s.Init(file, src, onError, scanner.ScanComments) 263 264 next := func() fullToken { 265 pos, tok, lit := s.Scan() 266 return fullToken{fset.Position(pos), tok, lit} 267 } 268 269 caseStat := caseNone 270 271 var toks []fullToken 272 for t := next(); t.tok != token.EOF; t = next() { 273 switch t.lit { 274 case "$": // continues below 275 case "switch", "select", "case": 276 if t.lit == "case" { 277 caseStat = caseNone 278 } else { 279 caseStat = caseNeedBlock 280 } 281 fallthrough 282 default: // regular Go code 283 if t.tok == token.LBRACE && caseStat == caseNeedBlock { 284 caseStat = caseHere 285 } 286 toks = append(toks, t) 287 continue 288 } 289 wt, err := tokenizeWildcard(t.pos, next) 290 if err != nil { 291 return nil, err 292 } 293 if caseStat == caseHere { 294 toks = append(toks, fullToken{wt.pos, token.IDENT, "case"}) 295 } 296 toks = append(toks, wt) 297 if caseStat == caseHere { 298 toks = append(toks, fullToken{wt.pos, token.COLON, ""}) 299 toks = append(toks, fullToken{wt.pos, token.IDENT, "gogrep_body"}) 300 } 301 } 302 return toks, err 303} 304 305type varInfo struct { 306 Name string 307 Seq bool 308} 309 310func tokenizeWildcard(pos token.Position, next func() fullToken) (fullToken, error) { 311 t := next() 312 any := false 313 if t.tok == token.MUL { 314 t = next() 315 any = true 316 } 317 wildName := encodeWildName(t.lit, any) 318 wt := fullToken{pos, token.IDENT, wildName} 319 if t.tok != token.IDENT { 320 return wt, fmt.Errorf("%v: $ must be followed by ident, got %v", 321 t.pos, t.tok) 322 } 323 return wt, nil 324} 325 326const wildSeparator = "ᐸᐳ" 327 328func isWildName(s string) bool { 329 return strings.HasPrefix(s, wildSeparator) 330} 331 332func encodeWildName(name string, any bool) string { 333 suffix := "v" 334 if any { 335 suffix = "a" 336 } 337 return wildSeparator + name + wildSeparator + suffix 338} 339 340func decodeWildName(s string) varInfo { 341 s = s[len(wildSeparator):] 342 nameEnd := strings.Index(s, wildSeparator) 343 name := s[:nameEnd] 344 s = s[nameEnd:] 345 s = s[len(wildSeparator):] 346 kind := s 347 return varInfo{Name: name, Seq: kind == "a"} 348} 349 350func decodeWildNode(n ast.Node) varInfo { 351 switch n := n.(type) { 352 case *ast.ExprStmt: 353 return decodeWildNode(n.X) 354 case *ast.Ident: 355 if isWildName(n.Name) { 356 return decodeWildName(n.Name) 357 } 358 } 359 return varInfo{} 360} 361