1package checkout
2
3import (
4	"fmt"
5	"net/http"
6	"os"
7	"os/exec"
8	"strings"
9
10	"github.com/cli/cli/v2/api"
11	"github.com/cli/cli/v2/context"
12	"github.com/cli/cli/v2/git"
13	"github.com/cli/cli/v2/internal/config"
14	"github.com/cli/cli/v2/internal/ghrepo"
15	"github.com/cli/cli/v2/internal/run"
16	"github.com/cli/cli/v2/pkg/cmd/pr/shared"
17	"github.com/cli/cli/v2/pkg/cmdutil"
18	"github.com/cli/cli/v2/pkg/iostreams"
19	"github.com/cli/safeexec"
20	"github.com/spf13/cobra"
21)
22
23type CheckoutOptions struct {
24	HttpClient func() (*http.Client, error)
25	Config     func() (config.Config, error)
26	IO         *iostreams.IOStreams
27	Remotes    func() (context.Remotes, error)
28	Branch     func() (string, error)
29
30	Finder shared.PRFinder
31
32	SelectorArg       string
33	RecurseSubmodules bool
34	Force             bool
35	Detach            bool
36	BranchName        string
37}
38
39func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command {
40	opts := &CheckoutOptions{
41		IO:         f.IOStreams,
42		HttpClient: f.HttpClient,
43		Config:     f.Config,
44		Remotes:    f.Remotes,
45		Branch:     f.Branch,
46	}
47
48	cmd := &cobra.Command{
49		Use:   "checkout {<number> | <url> | <branch>}",
50		Short: "Check out a pull request in git",
51		Args:  cmdutil.ExactArgs(1, "argument required"),
52		RunE: func(cmd *cobra.Command, args []string) error {
53			opts.Finder = shared.NewFinder(f)
54
55			if len(args) > 0 {
56				opts.SelectorArg = args[0]
57			}
58
59			if runF != nil {
60				return runF(opts)
61			}
62			return checkoutRun(opts)
63		},
64	}
65
66	cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout")
67	cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request")
68	cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD")
69	cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)")
70
71	return cmd
72}
73
74func checkoutRun(opts *CheckoutOptions) error {
75	findOptions := shared.FindOptions{
76		Selector: opts.SelectorArg,
77		Fields:   []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"},
78	}
79	pr, baseRepo, err := opts.Finder.Find(findOptions)
80	if err != nil {
81		return err
82	}
83
84	cfg, err := opts.Config()
85	if err != nil {
86		return err
87	}
88	protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol")
89
90	remotes, err := opts.Remotes()
91	if err != nil {
92		return err
93	}
94	baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName())
95	baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol)
96	if baseRemote != nil {
97		baseURLOrName = baseRemote.Name
98	}
99
100	headRemote := baseRemote
101	if pr.HeadRepository == nil {
102		headRemote = nil
103	} else if pr.IsCrossRepository {
104		headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
105	}
106
107	if strings.HasPrefix(pr.HeadRefName, "-") {
108		return fmt.Errorf("invalid branch name: %q", pr.HeadRefName)
109	}
110
111	var cmdQueue [][]string
112
113	if headRemote != nil {
114		cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...)
115	} else {
116		httpClient, err := opts.HttpClient()
117		if err != nil {
118			return err
119		}
120		apiClient := api.NewClientFromHTTP(httpClient)
121
122		defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo)
123		if err != nil {
124			return err
125		}
126		cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...)
127	}
128
129	if opts.RecurseSubmodules {
130		cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"})
131		cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"})
132	}
133
134	err = executeCmds(cmdQueue)
135	if err != nil {
136		return err
137	}
138
139	return nil
140}
141
142func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string {
143	var cmds [][]string
144	remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName)
145
146	refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName)
147	if !opts.Detach {
148		refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch)
149	}
150
151	cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec})
152
153	localBranch := pr.HeadRefName
154	if opts.BranchName != "" {
155		localBranch = opts.BranchName
156	}
157
158	switch {
159	case opts.Detach:
160		cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
161	case localBranchExists(localBranch):
162		cmds = append(cmds, []string{"git", "checkout", localBranch})
163		if opts.Force {
164			cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
165		} else {
166			// TODO: check if non-fast-forward and suggest to use `--force`
167			cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
168		}
169	default:
170		cmds = append(cmds, []string{"git", "checkout", "-b", localBranch, "--track", remoteBranch})
171	}
172
173	return cmds
174}
175
176func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string {
177	var cmds [][]string
178	ref := fmt.Sprintf("refs/pull/%d/head", pr.Number)
179
180	if opts.Detach {
181		cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
182		cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"})
183		return cmds
184	}
185
186	localBranch := pr.HeadRefName
187	if opts.BranchName != "" {
188		localBranch = opts.BranchName
189	} else if pr.HeadRefName == defaultBranch {
190		// avoid naming the new branch the same as the default branch
191		localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch)
192	}
193
194	currentBranch, _ := opts.Branch()
195	if localBranch == currentBranch {
196		// PR head matches currently checked out branch
197		cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref})
198		if opts.Force {
199			cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"})
200		} else {
201			// TODO: check if non-fast-forward and suggest to use `--force`
202			cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
203		}
204	} else {
205		if opts.Force {
206			cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"})
207		} else {
208			// TODO: check if non-fast-forward and suggest to use `--force`
209			cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)})
210		}
211
212		cmds = append(cmds, []string{"git", "checkout", localBranch})
213	}
214
215	remote := baseURLOrName
216	mergeRef := ref
217	if pr.MaintainerCanModify && pr.HeadRepository != nil {
218		headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost)
219		remote = ghrepo.FormatRemoteURL(headRepo, protocol)
220		mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
221	}
222	if missingMergeConfigForBranch(localBranch) {
223		// .remote is needed for `git pull` to work
224		// .pushRemote is needed for `git push` to work, if user has set `remote.pushDefault`.
225		// see https://git-scm.com/docs/git-config#Documentation/git-config.txt-branchltnamegtremote
226		cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote})
227		cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.pushRemote", localBranch), remote})
228		cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef})
229	}
230
231	return cmds
232}
233
234func missingMergeConfigForBranch(b string) bool {
235	mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b))
236	return err != nil || mc == ""
237}
238
239func localBranchExists(b string) bool {
240	_, err := git.ShowRefs("refs/heads/" + b)
241	return err == nil
242}
243
244func executeCmds(cmdQueue [][]string) error {
245	for _, args := range cmdQueue {
246		// TODO: reuse the result of this lookup across loop iteration
247		exe, err := safeexec.LookPath(args[0])
248		if err != nil {
249			return err
250		}
251		cmd := exec.Command(exe, args[1:]...)
252		cmd.Stdout = os.Stdout
253		cmd.Stderr = os.Stderr
254		if err := run.PrepareCmd(cmd).Run(); err != nil {
255			return err
256		}
257	}
258	return nil
259}
260