1package shared 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "net/http" 8 "net/url" 9 "regexp" 10 "sort" 11 "strconv" 12 "strings" 13 14 "github.com/cli/cli/v2/api" 15 remotes "github.com/cli/cli/v2/context" 16 "github.com/cli/cli/v2/git" 17 "github.com/cli/cli/v2/internal/ghinstance" 18 "github.com/cli/cli/v2/internal/ghrepo" 19 "github.com/cli/cli/v2/pkg/cmdutil" 20 "github.com/cli/cli/v2/pkg/set" 21 graphql "github.com/cli/shurcooL-graphql" 22 "github.com/shurcooL/githubv4" 23 "golang.org/x/sync/errgroup" 24) 25 26type PRFinder interface { 27 Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) 28} 29 30type progressIndicator interface { 31 StartProgressIndicator() 32 StopProgressIndicator() 33} 34 35type finder struct { 36 baseRepoFn func() (ghrepo.Interface, error) 37 branchFn func() (string, error) 38 remotesFn func() (remotes.Remotes, error) 39 httpClient func() (*http.Client, error) 40 branchConfig func(string) git.BranchConfig 41 progress progressIndicator 42 43 repo ghrepo.Interface 44 prNumber int 45 branchName string 46} 47 48func NewFinder(factory *cmdutil.Factory) PRFinder { 49 if runCommandFinder != nil { 50 f := runCommandFinder 51 runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")} 52 return f 53 } 54 55 return &finder{ 56 baseRepoFn: factory.BaseRepo, 57 branchFn: factory.Branch, 58 remotesFn: factory.Remotes, 59 httpClient: factory.HttpClient, 60 progress: factory.IOStreams, 61 branchConfig: git.ReadBranchConfig, 62 } 63} 64 65var runCommandFinder PRFinder 66 67// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests. 68func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 69 finder := NewMockFinder(selector, pr, repo) 70 runCommandFinder = finder 71 return finder 72} 73 74type FindOptions struct { 75 // Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or 76 // a PR URL. 77 Selector string 78 // Fields lists the GraphQL fields to fetch for the PullRequest. 79 Fields []string 80 // BaseBranch is the name of the base branch to scope the PR-for-branch lookup to. 81 BaseBranch string 82 // States lists the possible PR states to scope the PR-for-branch lookup to. 83 States []string 84} 85 86func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 87 if len(opts.Fields) == 0 { 88 return nil, nil, errors.New("Find error: no fields specified") 89 } 90 91 if repo, prNumber, err := f.parseURL(opts.Selector); err == nil { 92 f.prNumber = prNumber 93 f.repo = repo 94 } 95 96 if f.repo == nil { 97 repo, err := f.baseRepoFn() 98 if err != nil { 99 return nil, nil, fmt.Errorf("could not determine base repo: %w", err) 100 } 101 f.repo = repo 102 } 103 104 if opts.Selector == "" { 105 if branch, prNumber, err := f.parseCurrentBranch(); err != nil { 106 return nil, nil, err 107 } else if prNumber > 0 { 108 f.prNumber = prNumber 109 } else { 110 f.branchName = branch 111 } 112 } else if f.prNumber == 0 { 113 if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil { 114 f.prNumber = prNumber 115 } else { 116 f.branchName = opts.Selector 117 } 118 } 119 120 httpClient, err := f.httpClient() 121 if err != nil { 122 return nil, nil, err 123 } 124 125 // TODO(josebalius): Should we be guarding here? 126 if f.progress != nil { 127 f.progress.StartProgressIndicator() 128 defer f.progress.StopProgressIndicator() 129 } 130 131 fields := set.NewStringSet() 132 fields.AddValues(opts.Fields) 133 numberFieldOnly := fields.Len() == 1 && fields.Contains("number") 134 fields.Add("id") // for additional preload queries below 135 136 var pr *api.PullRequest 137 if f.prNumber > 0 { 138 if numberFieldOnly { 139 // avoid hitting the API if we already have all the information 140 return &api.PullRequest{Number: f.prNumber}, f.repo, nil 141 } 142 pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice()) 143 } else { 144 pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice()) 145 } 146 if err != nil { 147 return pr, f.repo, err 148 } 149 150 g, _ := errgroup.WithContext(context.Background()) 151 if fields.Contains("reviews") { 152 g.Go(func() error { 153 return preloadPrReviews(httpClient, f.repo, pr) 154 }) 155 } 156 if fields.Contains("comments") { 157 g.Go(func() error { 158 return preloadPrComments(httpClient, f.repo, pr) 159 }) 160 } 161 if fields.Contains("statusCheckRollup") { 162 g.Go(func() error { 163 return preloadPrChecks(httpClient, f.repo, pr) 164 }) 165 } 166 167 return pr, f.repo, g.Wait() 168} 169 170var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`) 171 172func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) { 173 if prURL == "" { 174 return nil, 0, fmt.Errorf("invalid URL: %q", prURL) 175 } 176 177 u, err := url.Parse(prURL) 178 if err != nil { 179 return nil, 0, err 180 } 181 182 if u.Scheme != "https" && u.Scheme != "http" { 183 return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme) 184 } 185 186 m := pullURLRE.FindStringSubmatch(u.Path) 187 if m == nil { 188 return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL) 189 } 190 191 repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname()) 192 prNumber, _ := strconv.Atoi(m[3]) 193 return repo, prNumber, nil 194} 195 196var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) 197 198func (f *finder) parseCurrentBranch() (string, int, error) { 199 prHeadRef, err := f.branchFn() 200 if err != nil { 201 return "", 0, err 202 } 203 204 branchConfig := f.branchConfig(prHeadRef) 205 206 // the branch is configured to merge a special PR head ref 207 if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil { 208 prNumber, _ := strconv.Atoi(m[1]) 209 return "", prNumber, nil 210 } 211 212 var branchOwner string 213 if branchConfig.RemoteURL != nil { 214 // the branch merges from a remote specified by URL 215 if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil { 216 branchOwner = r.RepoOwner() 217 } 218 } else if branchConfig.RemoteName != "" { 219 // the branch merges from a remote specified by name 220 rem, _ := f.remotesFn() 221 if r, err := rem.FindByName(branchConfig.RemoteName); err == nil { 222 branchOwner = r.RepoOwner() 223 } 224 } 225 226 if branchOwner != "" { 227 if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") { 228 prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/") 229 } 230 // prepend `OWNER:` if this branch is pushed to a fork 231 if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) { 232 prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef) 233 } 234 } 235 236 return prHeadRef, 0, nil 237} 238 239func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) { 240 type response struct { 241 Repository struct { 242 PullRequest api.PullRequest 243 } 244 } 245 246 query := fmt.Sprintf(` 247 query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) { 248 repository(owner: $owner, name: $repo) { 249 pullRequest(number: $pr_number) {%s} 250 } 251 }`, api.PullRequestGraphQL(fields)) 252 253 variables := map[string]interface{}{ 254 "owner": repo.RepoOwner(), 255 "repo": repo.RepoName(), 256 "pr_number": number, 257 } 258 259 var resp response 260 client := api.NewClientFromHTTP(httpClient) 261 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 262 if err != nil { 263 return nil, err 264 } 265 266 return &resp.Repository.PullRequest, nil 267} 268 269func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) { 270 type response struct { 271 Repository struct { 272 PullRequests struct { 273 Nodes []api.PullRequest 274 } 275 DefaultBranchRef struct { 276 Name string 277 } 278 } 279 } 280 281 fieldSet := set.NewStringSet() 282 fieldSet.AddValues(fields) 283 // these fields are required for filtering below 284 fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"}) 285 286 query := fmt.Sprintf(` 287 query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) { 288 repository(owner: $owner, name: $repo) { 289 pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) { 290 nodes {%s} 291 } 292 defaultBranchRef { name } 293 } 294 }`, api.PullRequestGraphQL(fieldSet.ToSlice())) 295 296 branchWithoutOwner := headBranch 297 if idx := strings.Index(headBranch, ":"); idx >= 0 { 298 branchWithoutOwner = headBranch[idx+1:] 299 } 300 301 variables := map[string]interface{}{ 302 "owner": repo.RepoOwner(), 303 "repo": repo.RepoName(), 304 "headRefName": branchWithoutOwner, 305 "states": stateFilters, 306 } 307 308 var resp response 309 client := api.NewClientFromHTTP(httpClient) 310 err := client.GraphQL(repo.RepoHost(), query, variables, &resp) 311 if err != nil { 312 return nil, err 313 } 314 315 prs := resp.Repository.PullRequests.Nodes 316 sort.SliceStable(prs, func(a, b int) bool { 317 return prs[a].State == "OPEN" && prs[b].State != "OPEN" 318 }) 319 320 for _, pr := range prs { 321 if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) && (pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch) { 322 return &pr, nil 323 } 324 } 325 326 return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)} 327} 328 329func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 330 if !pr.Reviews.PageInfo.HasNextPage { 331 return nil 332 } 333 334 type response struct { 335 Node struct { 336 PullRequest struct { 337 Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"` 338 } `graphql:"...on PullRequest"` 339 } `graphql:"node(id: $id)"` 340 } 341 342 variables := map[string]interface{}{ 343 "id": githubv4.ID(pr.ID), 344 "endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor), 345 } 346 347 gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient) 348 349 for { 350 var query response 351 err := gql.QueryNamed(context.Background(), "ReviewsForPullRequest", &query, variables) 352 if err != nil { 353 return err 354 } 355 356 pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...) 357 pr.Reviews.TotalCount = len(pr.Reviews.Nodes) 358 359 if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage { 360 break 361 } 362 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor) 363 } 364 365 pr.Reviews.PageInfo.HasNextPage = false 366 return nil 367} 368 369func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 370 if !pr.Comments.PageInfo.HasNextPage { 371 return nil 372 } 373 374 type response struct { 375 Node struct { 376 PullRequest struct { 377 Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"` 378 } `graphql:"...on PullRequest"` 379 } `graphql:"node(id: $id)"` 380 } 381 382 variables := map[string]interface{}{ 383 "id": githubv4.ID(pr.ID), 384 "endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor), 385 } 386 387 gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), client) 388 389 for { 390 var query response 391 err := gql.QueryNamed(context.Background(), "CommentsForPullRequest", &query, variables) 392 if err != nil { 393 return err 394 } 395 396 pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...) 397 pr.Comments.TotalCount = len(pr.Comments.Nodes) 398 399 if !query.Node.PullRequest.Comments.PageInfo.HasNextPage { 400 break 401 } 402 variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor) 403 } 404 405 pr.Comments.PageInfo.HasNextPage = false 406 return nil 407} 408 409func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error { 410 if len(pr.StatusCheckRollup.Nodes) == 0 { 411 return nil 412 } 413 statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 414 if !statusCheckRollup.PageInfo.HasNextPage { 415 return nil 416 } 417 418 endCursor := statusCheckRollup.PageInfo.EndCursor 419 420 type response struct { 421 Node *api.PullRequest 422 } 423 424 query := fmt.Sprintf(` 425 query PullRequestStatusChecks($id: ID!, $endCursor: String!) { 426 node(id: $id) { 427 ...on PullRequest { 428 %s 429 } 430 } 431 }`, api.StatusCheckRollupGraphQL("$endCursor")) 432 433 variables := map[string]interface{}{ 434 "id": pr.ID, 435 } 436 437 apiClient := api.NewClientFromHTTP(client) 438 for { 439 variables["endCursor"] = endCursor 440 var resp response 441 err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp) 442 if err != nil { 443 return err 444 } 445 446 result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts 447 statusCheckRollup.Nodes = append( 448 statusCheckRollup.Nodes, 449 result.Nodes..., 450 ) 451 452 if !result.PageInfo.HasNextPage { 453 break 454 } 455 endCursor = result.PageInfo.EndCursor 456 } 457 458 statusCheckRollup.PageInfo.HasNextPage = false 459 return nil 460} 461 462type NotFoundError struct { 463 error 464} 465 466func (err *NotFoundError) Unwrap() error { 467 return err.error 468} 469 470func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder { 471 var err error 472 if pr == nil { 473 err = &NotFoundError{errors.New("no pull requests found")} 474 } 475 return &mockFinder{ 476 expectSelector: selector, 477 pr: pr, 478 repo: repo, 479 err: err, 480 } 481} 482 483type mockFinder struct { 484 called bool 485 expectSelector string 486 expectFields []string 487 pr *api.PullRequest 488 repo ghrepo.Interface 489 err error 490} 491 492func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) { 493 if m.err != nil { 494 return nil, nil, m.err 495 } 496 if m.expectSelector != opts.Selector { 497 return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector) 498 } 499 if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) { 500 return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields) 501 } 502 if m.called { 503 return nil, nil, errors.New("mockFinder used more than once") 504 } 505 m.called = true 506 507 if m.pr.HeadRepositoryOwner.Login == "" { 508 // pose as same-repo PR by default 509 m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner() 510 } 511 512 return m.pr, m.repo, nil 513} 514 515func (m *mockFinder) ExpectFields(fields []string) { 516 m.expectFields = fields 517} 518 519func isEqualSet(a, b []string) bool { 520 if len(a) != len(b) { 521 return false 522 } 523 524 aCopy := make([]string, len(a)) 525 copy(aCopy, a) 526 bCopy := make([]string, len(b)) 527 copy(bCopy, b) 528 sort.Strings(aCopy) 529 sort.Strings(bCopy) 530 531 for i := range aCopy { 532 if aCopy[i] != bCopy[i] { 533 return false 534 } 535 } 536 return true 537} 538