1package typematch 2 3import ( 4 "fmt" 5 "go/ast" 6 "go/parser" 7 "go/token" 8 "go/types" 9 "strconv" 10 "strings" 11 12 "github.com/quasilyte/go-ruleguard/internal/xtypes" 13) 14 15//go:generate stringer -type=patternOp 16type patternOp int 17 18const ( 19 opBuiltinType patternOp = iota 20 opPointer 21 opVar 22 opVarSeq 23 opSlice 24 opArray 25 opMap 26 opChan 27 opFunc 28 opStructNoSeq 29 opStruct 30 opNamed 31) 32 33type Pattern struct { 34 typeMatches map[string]types.Type 35 int64Matches map[string]int64 36 37 root *pattern 38} 39 40type pattern struct { 41 value interface{} 42 op patternOp 43 subs []*pattern 44} 45 46func (pat pattern) String() string { 47 if len(pat.subs) == 0 { 48 return fmt.Sprintf("<%s %#v>", pat.op, pat.value) 49 } 50 parts := make([]string, len(pat.subs)) 51 for i, sub := range pat.subs { 52 parts[i] = sub.String() 53 } 54 return fmt.Sprintf("<%s %#v (%s)>", pat.op, pat.value, strings.Join(parts, ", ")) 55} 56 57type ImportsTab struct { 58 imports []map[string]string 59} 60 61func NewImportsTab(initial map[string]string) *ImportsTab { 62 return &ImportsTab{imports: []map[string]string{initial}} 63} 64 65func (itab *ImportsTab) Lookup(pkgName string) (string, bool) { 66 for i := len(itab.imports) - 1; i >= 0; i-- { 67 pkgPath, ok := itab.imports[i][pkgName] 68 if ok { 69 return pkgPath, true 70 } 71 } 72 return "", false 73} 74 75func (itab *ImportsTab) Load(pkgName, pkgPath string) { 76 itab.imports[len(itab.imports)-1][pkgName] = pkgPath 77} 78 79func (itab *ImportsTab) EnterScope() { 80 itab.imports = append(itab.imports, map[string]string{}) 81} 82 83func (itab *ImportsTab) LeaveScope() { 84 itab.imports = itab.imports[:len(itab.imports)-1] 85} 86 87type Context struct { 88 Itab *ImportsTab 89} 90 91const ( 92 varPrefix = `ᐸvarᐳ` 93 varSeqPrefix = `ᐸvar_seqᐳ` 94) 95 96func Parse(ctx *Context, s string) (*Pattern, error) { 97 noDollars := strings.ReplaceAll(s, "$*", varSeqPrefix) 98 noDollars = strings.ReplaceAll(noDollars, "$", varPrefix) 99 n, err := parser.ParseExpr(noDollars) 100 if err != nil { 101 return nil, err 102 } 103 root := parseExpr(ctx, n) 104 if root == nil { 105 return nil, fmt.Errorf("can't convert %s type expression", s) 106 } 107 p := &Pattern{ 108 typeMatches: map[string]types.Type{}, 109 int64Matches: map[string]int64{}, 110 root: root, 111 } 112 return p, nil 113} 114 115var ( 116 builtinTypeByName = map[string]types.Type{ 117 "bool": types.Typ[types.Bool], 118 "int": types.Typ[types.Int], 119 "int8": types.Typ[types.Int8], 120 "int16": types.Typ[types.Int16], 121 "int32": types.Typ[types.Int32], 122 "int64": types.Typ[types.Int64], 123 "uint": types.Typ[types.Uint], 124 "uint8": types.Typ[types.Uint8], 125 "uint16": types.Typ[types.Uint16], 126 "uint32": types.Typ[types.Uint32], 127 "uint64": types.Typ[types.Uint64], 128 "uintptr": types.Typ[types.Uintptr], 129 "float32": types.Typ[types.Float32], 130 "float64": types.Typ[types.Float64], 131 "complex64": types.Typ[types.Complex64], 132 "complex128": types.Typ[types.Complex128], 133 "string": types.Typ[types.String], 134 135 "error": types.Universe.Lookup("error").Type(), 136 137 // Aliases. 138 "byte": types.Typ[types.Uint8], 139 "rune": types.Typ[types.Int32], 140 } 141 142 efaceType = types.NewInterfaceType(nil, nil) 143) 144 145func parseExpr(ctx *Context, e ast.Expr) *pattern { 146 switch e := e.(type) { 147 case *ast.Ident: 148 basic, ok := builtinTypeByName[e.Name] 149 if ok { 150 return &pattern{op: opBuiltinType, value: basic} 151 } 152 if strings.HasPrefix(e.Name, varPrefix) { 153 name := strings.TrimPrefix(e.Name, varPrefix) 154 return &pattern{op: opVar, value: name} 155 } 156 if strings.HasPrefix(e.Name, varSeqPrefix) { 157 name := strings.TrimPrefix(e.Name, varSeqPrefix) 158 // Only unnamed seq are supported right now. 159 if name == "_" { 160 return &pattern{op: opVarSeq, value: name} 161 } 162 } 163 164 case *ast.SelectorExpr: 165 pkg, ok := e.X.(*ast.Ident) 166 if !ok { 167 return nil 168 } 169 pkgPath, ok := ctx.Itab.Lookup(pkg.Name) 170 if !ok { 171 return nil 172 } 173 return &pattern{op: opNamed, value: [2]string{pkgPath, e.Sel.Name}} 174 175 case *ast.StarExpr: 176 elem := parseExpr(ctx, e.X) 177 if elem == nil { 178 return nil 179 } 180 return &pattern{op: opPointer, subs: []*pattern{elem}} 181 182 case *ast.ArrayType: 183 elem := parseExpr(ctx, e.Elt) 184 if elem == nil { 185 return nil 186 } 187 if e.Len == nil { 188 return &pattern{ 189 op: opSlice, 190 subs: []*pattern{elem}, 191 } 192 } 193 if id, ok := e.Len.(*ast.Ident); ok && strings.HasPrefix(id.Name, varPrefix) { 194 name := strings.TrimPrefix(id.Name, varPrefix) 195 return &pattern{ 196 op: opArray, 197 value: name, 198 subs: []*pattern{elem}, 199 } 200 } 201 lit, ok := e.Len.(*ast.BasicLit) 202 if !ok || lit.Kind != token.INT { 203 return nil 204 } 205 length, err := strconv.ParseInt(lit.Value, 10, 64) 206 if err != nil { 207 return nil 208 } 209 return &pattern{ 210 op: opArray, 211 value: length, 212 subs: []*pattern{elem}, 213 } 214 215 case *ast.MapType: 216 keyType := parseExpr(ctx, e.Key) 217 if keyType == nil { 218 return nil 219 } 220 valType := parseExpr(ctx, e.Value) 221 if valType == nil { 222 return nil 223 } 224 return &pattern{ 225 op: opMap, 226 subs: []*pattern{keyType, valType}, 227 } 228 229 case *ast.ChanType: 230 valType := parseExpr(ctx, e.Value) 231 if valType == nil { 232 return nil 233 } 234 var dir types.ChanDir 235 switch { 236 case e.Dir&ast.SEND != 0 && e.Dir&ast.RECV != 0: 237 dir = types.SendRecv 238 case e.Dir&ast.SEND != 0: 239 dir = types.SendOnly 240 case e.Dir&ast.RECV != 0: 241 dir = types.RecvOnly 242 default: 243 return nil 244 } 245 return &pattern{ 246 op: opChan, 247 value: dir, 248 subs: []*pattern{valType}, 249 } 250 251 case *ast.ParenExpr: 252 return parseExpr(ctx, e.X) 253 254 case *ast.FuncType: 255 var params []*pattern 256 var results []*pattern 257 if e.Params != nil { 258 for _, field := range e.Params.List { 259 p := parseExpr(ctx, field.Type) 260 if p == nil { 261 return nil 262 } 263 if len(field.Names) != 0 { 264 return nil 265 } 266 params = append(params, p) 267 } 268 } 269 if e.Results != nil { 270 for _, field := range e.Results.List { 271 p := parseExpr(ctx, field.Type) 272 if p == nil { 273 return nil 274 } 275 if len(field.Names) != 0 { 276 return nil 277 } 278 results = append(results, p) 279 } 280 } 281 return &pattern{ 282 op: opFunc, 283 value: len(params), 284 subs: append(params, results...), 285 } 286 287 case *ast.StructType: 288 hasSeq := false 289 members := make([]*pattern, 0, len(e.Fields.List)) 290 for _, field := range e.Fields.List { 291 p := parseExpr(ctx, field.Type) 292 if p == nil { 293 return nil 294 } 295 if len(field.Names) != 0 { 296 return nil 297 } 298 if p.op == opVarSeq { 299 hasSeq = true 300 } 301 members = append(members, p) 302 } 303 op := opStructNoSeq 304 if hasSeq { 305 op = opStruct 306 } 307 return &pattern{ 308 op: op, 309 subs: members, 310 } 311 312 case *ast.InterfaceType: 313 if len(e.Methods.List) == 0 { 314 return &pattern{op: opBuiltinType, value: efaceType} 315 } 316 } 317 318 return nil 319} 320 321// MatchIdentical returns true if the go typ matches pattern p. 322func (p *Pattern) MatchIdentical(typ types.Type) bool { 323 p.reset() 324 return p.matchIdentical(p.root, typ) 325} 326 327func (p *Pattern) reset() { 328 if len(p.int64Matches) != 0 { 329 p.int64Matches = map[string]int64{} 330 } 331 if len(p.typeMatches) != 0 { 332 p.typeMatches = map[string]types.Type{} 333 } 334} 335 336func (p *Pattern) matchIdenticalFielder(subs []*pattern, f fielder) bool { 337 // TODO: do backtracking. 338 339 numFields := f.NumFields() 340 fieldsMatched := 0 341 342 if len(subs) == 0 && numFields != 0 { 343 return false 344 } 345 346 matchAny := false 347 348 i := 0 349 for i < len(subs) { 350 pat := subs[i] 351 352 if pat.op == opVarSeq { 353 matchAny = true 354 } 355 356 fieldsLeft := numFields - fieldsMatched 357 if matchAny { 358 switch { 359 // "Nothing left to match" stop condition. 360 case fieldsLeft == 0: 361 matchAny = false 362 i++ 363 // Lookahead for non-greedy matching. 364 case i+1 < len(subs) && p.matchIdentical(subs[i+1], f.Field(fieldsMatched).Type()): 365 matchAny = false 366 i += 2 367 fieldsMatched++ 368 default: 369 fieldsMatched++ 370 } 371 continue 372 } 373 374 if fieldsLeft == 0 || !p.matchIdentical(pat, f.Field(fieldsMatched).Type()) { 375 return false 376 } 377 i++ 378 fieldsMatched++ 379 } 380 381 return numFields == fieldsMatched 382} 383 384func (p *Pattern) matchIdentical(sub *pattern, typ types.Type) bool { 385 switch sub.op { 386 case opVar: 387 name := sub.value.(string) 388 if name == "_" { 389 return true 390 } 391 y, ok := p.typeMatches[name] 392 if !ok { 393 p.typeMatches[name] = typ 394 return true 395 } 396 if y == nil { 397 return typ == nil 398 } 399 return xtypes.Identical(typ, y) 400 401 case opBuiltinType: 402 return xtypes.Identical(typ, sub.value.(types.Type)) 403 404 case opPointer: 405 typ, ok := typ.(*types.Pointer) 406 if !ok { 407 return false 408 } 409 return p.matchIdentical(sub.subs[0], typ.Elem()) 410 411 case opSlice: 412 typ, ok := typ.(*types.Slice) 413 if !ok { 414 return false 415 } 416 return p.matchIdentical(sub.subs[0], typ.Elem()) 417 418 case opArray: 419 typ, ok := typ.(*types.Array) 420 if !ok { 421 return false 422 } 423 var wantLen int64 424 switch v := sub.value.(type) { 425 case string: 426 if v == "_" { 427 wantLen = typ.Len() 428 break 429 } 430 length, ok := p.int64Matches[v] 431 if ok { 432 wantLen = length 433 } else { 434 p.int64Matches[v] = typ.Len() 435 wantLen = typ.Len() 436 } 437 case int64: 438 wantLen = v 439 } 440 return wantLen == typ.Len() && p.matchIdentical(sub.subs[0], typ.Elem()) 441 442 case opMap: 443 typ, ok := typ.(*types.Map) 444 if !ok { 445 return false 446 } 447 return p.matchIdentical(sub.subs[0], typ.Key()) && 448 p.matchIdentical(sub.subs[1], typ.Elem()) 449 450 case opChan: 451 typ, ok := typ.(*types.Chan) 452 if !ok { 453 return false 454 } 455 dir := sub.value.(types.ChanDir) 456 return dir == typ.Dir() && p.matchIdentical(sub.subs[0], typ.Elem()) 457 458 case opNamed: 459 typ, ok := typ.(*types.Named) 460 if !ok { 461 return false 462 } 463 obj := typ.Obj() 464 pkg := obj.Pkg() 465 // pkg can be nil for builtin named types. 466 // There is no point in checking anything else as we never 467 // generate the opNamed for such types. 468 if pkg == nil { 469 return false 470 } 471 pkgPath := sub.value.([2]string)[0] 472 typeName := sub.value.([2]string)[1] 473 // obj.Pkg().Path() may be in a vendor directory. 474 path := strings.SplitAfter(obj.Pkg().Path(), "/vendor/") 475 return path[len(path)-1] == pkgPath && typeName == obj.Name() 476 477 case opFunc: 478 typ, ok := typ.(*types.Signature) 479 if !ok { 480 return false 481 } 482 numParams := sub.value.(int) 483 params := sub.subs[:numParams] 484 results := sub.subs[numParams:] 485 if typ.Params().Len() != len(params) { 486 return false 487 } 488 if typ.Results().Len() != len(results) { 489 return false 490 } 491 for i := 0; i < typ.Params().Len(); i++ { 492 if !p.matchIdentical(params[i], typ.Params().At(i).Type()) { 493 return false 494 } 495 } 496 for i := 0; i < typ.Results().Len(); i++ { 497 if !p.matchIdentical(results[i], typ.Results().At(i).Type()) { 498 return false 499 } 500 } 501 return true 502 503 case opStructNoSeq: 504 typ, ok := typ.(*types.Struct) 505 if !ok { 506 return false 507 } 508 if typ.NumFields() != len(sub.subs) { 509 return false 510 } 511 for i, member := range sub.subs { 512 if !p.matchIdentical(member, typ.Field(i).Type()) { 513 return false 514 } 515 } 516 return true 517 518 case opStruct: 519 typ, ok := typ.(*types.Struct) 520 if !ok { 521 return false 522 } 523 if !p.matchIdenticalFielder(sub.subs, typ) { 524 return false 525 } 526 return true 527 528 default: 529 return false 530 } 531} 532 533type fielder interface { 534 Field(i int) *types.Var 535 NumFields() int 536} 537