1package shared
2
3import (
4	"context"
5	"errors"
6	"fmt"
7	"net/http"
8	"net/url"
9	"regexp"
10	"sort"
11	"strconv"
12	"strings"
13
14	"github.com/cli/cli/v2/api"
15	remotes "github.com/cli/cli/v2/context"
16	"github.com/cli/cli/v2/git"
17	"github.com/cli/cli/v2/internal/ghinstance"
18	"github.com/cli/cli/v2/internal/ghrepo"
19	"github.com/cli/cli/v2/pkg/cmdutil"
20	"github.com/cli/cli/v2/pkg/set"
21	graphql "github.com/cli/shurcooL-graphql"
22	"github.com/shurcooL/githubv4"
23	"golang.org/x/sync/errgroup"
24)
25
26type PRFinder interface {
27	Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error)
28}
29
30type progressIndicator interface {
31	StartProgressIndicator()
32	StopProgressIndicator()
33}
34
35type finder struct {
36	baseRepoFn   func() (ghrepo.Interface, error)
37	branchFn     func() (string, error)
38	remotesFn    func() (remotes.Remotes, error)
39	httpClient   func() (*http.Client, error)
40	branchConfig func(string) git.BranchConfig
41	progress     progressIndicator
42
43	repo       ghrepo.Interface
44	prNumber   int
45	branchName string
46}
47
48func NewFinder(factory *cmdutil.Factory) PRFinder {
49	if runCommandFinder != nil {
50		f := runCommandFinder
51		runCommandFinder = &mockFinder{err: errors.New("you must use a RunCommandFinder to stub PR lookups")}
52		return f
53	}
54
55	return &finder{
56		baseRepoFn:   factory.BaseRepo,
57		branchFn:     factory.Branch,
58		remotesFn:    factory.Remotes,
59		httpClient:   factory.HttpClient,
60		progress:     factory.IOStreams,
61		branchConfig: git.ReadBranchConfig,
62	}
63}
64
65var runCommandFinder PRFinder
66
67// RunCommandFinder is the NewMockFinder substitute to be used ONLY in runCommand-style tests.
68func RunCommandFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
69	finder := NewMockFinder(selector, pr, repo)
70	runCommandFinder = finder
71	return finder
72}
73
74type FindOptions struct {
75	// Selector can be a number with optional `#` prefix, a branch name with optional `<owner>:` prefix, or
76	// a PR URL.
77	Selector string
78	// Fields lists the GraphQL fields to fetch for the PullRequest.
79	Fields []string
80	// BaseBranch is the name of the base branch to scope the PR-for-branch lookup to.
81	BaseBranch string
82	// States lists the possible PR states to scope the PR-for-branch lookup to.
83	States []string
84}
85
86func (f *finder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
87	if len(opts.Fields) == 0 {
88		return nil, nil, errors.New("Find error: no fields specified")
89	}
90
91	if repo, prNumber, err := f.parseURL(opts.Selector); err == nil {
92		f.prNumber = prNumber
93		f.repo = repo
94	}
95
96	if f.repo == nil {
97		repo, err := f.baseRepoFn()
98		if err != nil {
99			return nil, nil, fmt.Errorf("could not determine base repo: %w", err)
100		}
101		f.repo = repo
102	}
103
104	if opts.Selector == "" {
105		if branch, prNumber, err := f.parseCurrentBranch(); err != nil {
106			return nil, nil, err
107		} else if prNumber > 0 {
108			f.prNumber = prNumber
109		} else {
110			f.branchName = branch
111		}
112	} else if f.prNumber == 0 {
113		if prNumber, err := strconv.Atoi(strings.TrimPrefix(opts.Selector, "#")); err == nil {
114			f.prNumber = prNumber
115		} else {
116			f.branchName = opts.Selector
117		}
118	}
119
120	httpClient, err := f.httpClient()
121	if err != nil {
122		return nil, nil, err
123	}
124
125	// TODO(josebalius): Should we be guarding here?
126	if f.progress != nil {
127		f.progress.StartProgressIndicator()
128		defer f.progress.StopProgressIndicator()
129	}
130
131	fields := set.NewStringSet()
132	fields.AddValues(opts.Fields)
133	numberFieldOnly := fields.Len() == 1 && fields.Contains("number")
134	fields.Add("id") // for additional preload queries below
135
136	var pr *api.PullRequest
137	if f.prNumber > 0 {
138		if numberFieldOnly {
139			// avoid hitting the API if we already have all the information
140			return &api.PullRequest{Number: f.prNumber}, f.repo, nil
141		}
142		pr, err = findByNumber(httpClient, f.repo, f.prNumber, fields.ToSlice())
143	} else {
144		pr, err = findForBranch(httpClient, f.repo, opts.BaseBranch, f.branchName, opts.States, fields.ToSlice())
145	}
146	if err != nil {
147		return pr, f.repo, err
148	}
149
150	g, _ := errgroup.WithContext(context.Background())
151	if fields.Contains("reviews") {
152		g.Go(func() error {
153			return preloadPrReviews(httpClient, f.repo, pr)
154		})
155	}
156	if fields.Contains("comments") {
157		g.Go(func() error {
158			return preloadPrComments(httpClient, f.repo, pr)
159		})
160	}
161	if fields.Contains("statusCheckRollup") {
162		g.Go(func() error {
163			return preloadPrChecks(httpClient, f.repo, pr)
164		})
165	}
166
167	return pr, f.repo, g.Wait()
168}
169
170var pullURLRE = regexp.MustCompile(`^/([^/]+)/([^/]+)/pull/(\d+)`)
171
172func (f *finder) parseURL(prURL string) (ghrepo.Interface, int, error) {
173	if prURL == "" {
174		return nil, 0, fmt.Errorf("invalid URL: %q", prURL)
175	}
176
177	u, err := url.Parse(prURL)
178	if err != nil {
179		return nil, 0, err
180	}
181
182	if u.Scheme != "https" && u.Scheme != "http" {
183		return nil, 0, fmt.Errorf("invalid scheme: %s", u.Scheme)
184	}
185
186	m := pullURLRE.FindStringSubmatch(u.Path)
187	if m == nil {
188		return nil, 0, fmt.Errorf("not a pull request URL: %s", prURL)
189	}
190
191	repo := ghrepo.NewWithHost(m[1], m[2], u.Hostname())
192	prNumber, _ := strconv.Atoi(m[3])
193	return repo, prNumber, nil
194}
195
196var prHeadRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`)
197
198func (f *finder) parseCurrentBranch() (string, int, error) {
199	prHeadRef, err := f.branchFn()
200	if err != nil {
201		return "", 0, err
202	}
203
204	branchConfig := f.branchConfig(prHeadRef)
205
206	// the branch is configured to merge a special PR head ref
207	if m := prHeadRE.FindStringSubmatch(branchConfig.MergeRef); m != nil {
208		prNumber, _ := strconv.Atoi(m[1])
209		return "", prNumber, nil
210	}
211
212	var branchOwner string
213	if branchConfig.RemoteURL != nil {
214		// the branch merges from a remote specified by URL
215		if r, err := ghrepo.FromURL(branchConfig.RemoteURL); err == nil {
216			branchOwner = r.RepoOwner()
217		}
218	} else if branchConfig.RemoteName != "" {
219		// the branch merges from a remote specified by name
220		rem, _ := f.remotesFn()
221		if r, err := rem.FindByName(branchConfig.RemoteName); err == nil {
222			branchOwner = r.RepoOwner()
223		}
224	}
225
226	if branchOwner != "" {
227		if strings.HasPrefix(branchConfig.MergeRef, "refs/heads/") {
228			prHeadRef = strings.TrimPrefix(branchConfig.MergeRef, "refs/heads/")
229		}
230		// prepend `OWNER:` if this branch is pushed to a fork
231		if !strings.EqualFold(branchOwner, f.repo.RepoOwner()) {
232			prHeadRef = fmt.Sprintf("%s:%s", branchOwner, prHeadRef)
233		}
234	}
235
236	return prHeadRef, 0, nil
237}
238
239func findByNumber(httpClient *http.Client, repo ghrepo.Interface, number int, fields []string) (*api.PullRequest, error) {
240	type response struct {
241		Repository struct {
242			PullRequest api.PullRequest
243		}
244	}
245
246	query := fmt.Sprintf(`
247	query PullRequestByNumber($owner: String!, $repo: String!, $pr_number: Int!) {
248		repository(owner: $owner, name: $repo) {
249			pullRequest(number: $pr_number) {%s}
250		}
251	}`, api.PullRequestGraphQL(fields))
252
253	variables := map[string]interface{}{
254		"owner":     repo.RepoOwner(),
255		"repo":      repo.RepoName(),
256		"pr_number": number,
257	}
258
259	var resp response
260	client := api.NewClientFromHTTP(httpClient)
261	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
262	if err != nil {
263		return nil, err
264	}
265
266	return &resp.Repository.PullRequest, nil
267}
268
269func findForBranch(httpClient *http.Client, repo ghrepo.Interface, baseBranch, headBranch string, stateFilters, fields []string) (*api.PullRequest, error) {
270	type response struct {
271		Repository struct {
272			PullRequests struct {
273				Nodes []api.PullRequest
274			}
275			DefaultBranchRef struct {
276				Name string
277			}
278		}
279	}
280
281	fieldSet := set.NewStringSet()
282	fieldSet.AddValues(fields)
283	// these fields are required for filtering below
284	fieldSet.AddValues([]string{"state", "baseRefName", "headRefName", "isCrossRepository", "headRepositoryOwner"})
285
286	query := fmt.Sprintf(`
287	query PullRequestForBranch($owner: String!, $repo: String!, $headRefName: String!, $states: [PullRequestState!]) {
288		repository(owner: $owner, name: $repo) {
289			pullRequests(headRefName: $headRefName, states: $states, first: 30, orderBy: { field: CREATED_AT, direction: DESC }) {
290				nodes {%s}
291			}
292			defaultBranchRef { name }
293		}
294	}`, api.PullRequestGraphQL(fieldSet.ToSlice()))
295
296	branchWithoutOwner := headBranch
297	if idx := strings.Index(headBranch, ":"); idx >= 0 {
298		branchWithoutOwner = headBranch[idx+1:]
299	}
300
301	variables := map[string]interface{}{
302		"owner":       repo.RepoOwner(),
303		"repo":        repo.RepoName(),
304		"headRefName": branchWithoutOwner,
305		"states":      stateFilters,
306	}
307
308	var resp response
309	client := api.NewClientFromHTTP(httpClient)
310	err := client.GraphQL(repo.RepoHost(), query, variables, &resp)
311	if err != nil {
312		return nil, err
313	}
314
315	prs := resp.Repository.PullRequests.Nodes
316	sort.SliceStable(prs, func(a, b int) bool {
317		return prs[a].State == "OPEN" && prs[b].State != "OPEN"
318	})
319
320	for _, pr := range prs {
321		if pr.HeadLabel() == headBranch && (baseBranch == "" || pr.BaseRefName == baseBranch) && (pr.State == "OPEN" || resp.Repository.DefaultBranchRef.Name != headBranch) {
322			return &pr, nil
323		}
324	}
325
326	return nil, &NotFoundError{fmt.Errorf("no pull requests found for branch %q", headBranch)}
327}
328
329func preloadPrReviews(httpClient *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
330	if !pr.Reviews.PageInfo.HasNextPage {
331		return nil
332	}
333
334	type response struct {
335		Node struct {
336			PullRequest struct {
337				Reviews api.PullRequestReviews `graphql:"reviews(first: 100, after: $endCursor)"`
338			} `graphql:"...on PullRequest"`
339		} `graphql:"node(id: $id)"`
340	}
341
342	variables := map[string]interface{}{
343		"id":        githubv4.ID(pr.ID),
344		"endCursor": githubv4.String(pr.Reviews.PageInfo.EndCursor),
345	}
346
347	gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), httpClient)
348
349	for {
350		var query response
351		err := gql.QueryNamed(context.Background(), "ReviewsForPullRequest", &query, variables)
352		if err != nil {
353			return err
354		}
355
356		pr.Reviews.Nodes = append(pr.Reviews.Nodes, query.Node.PullRequest.Reviews.Nodes...)
357		pr.Reviews.TotalCount = len(pr.Reviews.Nodes)
358
359		if !query.Node.PullRequest.Reviews.PageInfo.HasNextPage {
360			break
361		}
362		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Reviews.PageInfo.EndCursor)
363	}
364
365	pr.Reviews.PageInfo.HasNextPage = false
366	return nil
367}
368
369func preloadPrComments(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
370	if !pr.Comments.PageInfo.HasNextPage {
371		return nil
372	}
373
374	type response struct {
375		Node struct {
376			PullRequest struct {
377				Comments api.Comments `graphql:"comments(first: 100, after: $endCursor)"`
378			} `graphql:"...on PullRequest"`
379		} `graphql:"node(id: $id)"`
380	}
381
382	variables := map[string]interface{}{
383		"id":        githubv4.ID(pr.ID),
384		"endCursor": githubv4.String(pr.Comments.PageInfo.EndCursor),
385	}
386
387	gql := graphql.NewClient(ghinstance.GraphQLEndpoint(repo.RepoHost()), client)
388
389	for {
390		var query response
391		err := gql.QueryNamed(context.Background(), "CommentsForPullRequest", &query, variables)
392		if err != nil {
393			return err
394		}
395
396		pr.Comments.Nodes = append(pr.Comments.Nodes, query.Node.PullRequest.Comments.Nodes...)
397		pr.Comments.TotalCount = len(pr.Comments.Nodes)
398
399		if !query.Node.PullRequest.Comments.PageInfo.HasNextPage {
400			break
401		}
402		variables["endCursor"] = githubv4.String(query.Node.PullRequest.Comments.PageInfo.EndCursor)
403	}
404
405	pr.Comments.PageInfo.HasNextPage = false
406	return nil
407}
408
409func preloadPrChecks(client *http.Client, repo ghrepo.Interface, pr *api.PullRequest) error {
410	if len(pr.StatusCheckRollup.Nodes) == 0 {
411		return nil
412	}
413	statusCheckRollup := &pr.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
414	if !statusCheckRollup.PageInfo.HasNextPage {
415		return nil
416	}
417
418	endCursor := statusCheckRollup.PageInfo.EndCursor
419
420	type response struct {
421		Node *api.PullRequest
422	}
423
424	query := fmt.Sprintf(`
425	query PullRequestStatusChecks($id: ID!, $endCursor: String!) {
426		node(id: $id) {
427			...on PullRequest {
428				%s
429			}
430		}
431	}`, api.StatusCheckRollupGraphQL("$endCursor"))
432
433	variables := map[string]interface{}{
434		"id": pr.ID,
435	}
436
437	apiClient := api.NewClientFromHTTP(client)
438	for {
439		variables["endCursor"] = endCursor
440		var resp response
441		err := apiClient.GraphQL(repo.RepoHost(), query, variables, &resp)
442		if err != nil {
443			return err
444		}
445
446		result := resp.Node.StatusCheckRollup.Nodes[0].Commit.StatusCheckRollup.Contexts
447		statusCheckRollup.Nodes = append(
448			statusCheckRollup.Nodes,
449			result.Nodes...,
450		)
451
452		if !result.PageInfo.HasNextPage {
453			break
454		}
455		endCursor = result.PageInfo.EndCursor
456	}
457
458	statusCheckRollup.PageInfo.HasNextPage = false
459	return nil
460}
461
462type NotFoundError struct {
463	error
464}
465
466func (err *NotFoundError) Unwrap() error {
467	return err.error
468}
469
470func NewMockFinder(selector string, pr *api.PullRequest, repo ghrepo.Interface) *mockFinder {
471	var err error
472	if pr == nil {
473		err = &NotFoundError{errors.New("no pull requests found")}
474	}
475	return &mockFinder{
476		expectSelector: selector,
477		pr:             pr,
478		repo:           repo,
479		err:            err,
480	}
481}
482
483type mockFinder struct {
484	called         bool
485	expectSelector string
486	expectFields   []string
487	pr             *api.PullRequest
488	repo           ghrepo.Interface
489	err            error
490}
491
492func (m *mockFinder) Find(opts FindOptions) (*api.PullRequest, ghrepo.Interface, error) {
493	if m.err != nil {
494		return nil, nil, m.err
495	}
496	if m.expectSelector != opts.Selector {
497		return nil, nil, fmt.Errorf("mockFinder: expected selector %q, got %q", m.expectSelector, opts.Selector)
498	}
499	if len(m.expectFields) > 0 && !isEqualSet(m.expectFields, opts.Fields) {
500		return nil, nil, fmt.Errorf("mockFinder: expected fields %v, got %v", m.expectFields, opts.Fields)
501	}
502	if m.called {
503		return nil, nil, errors.New("mockFinder used more than once")
504	}
505	m.called = true
506
507	if m.pr.HeadRepositoryOwner.Login == "" {
508		// pose as same-repo PR by default
509		m.pr.HeadRepositoryOwner.Login = m.repo.RepoOwner()
510	}
511
512	return m.pr, m.repo, nil
513}
514
515func (m *mockFinder) ExpectFields(fields []string) {
516	m.expectFields = fields
517}
518
519func isEqualSet(a, b []string) bool {
520	if len(a) != len(b) {
521		return false
522	}
523
524	aCopy := make([]string, len(a))
525	copy(aCopy, a)
526	bCopy := make([]string, len(b))
527	copy(bCopy, b)
528	sort.Strings(aCopy)
529	sort.Strings(bCopy)
530
531	for i := range aCopy {
532		if aCopy[i] != bCopy[i] {
533			return false
534		}
535	}
536	return true
537}
538