1package checkout 2 3import ( 4 "fmt" 5 "net/http" 6 "os" 7 "os/exec" 8 "strings" 9 10 "github.com/cli/cli/v2/api" 11 "github.com/cli/cli/v2/context" 12 "github.com/cli/cli/v2/git" 13 "github.com/cli/cli/v2/internal/config" 14 "github.com/cli/cli/v2/internal/ghrepo" 15 "github.com/cli/cli/v2/internal/run" 16 "github.com/cli/cli/v2/pkg/cmd/pr/shared" 17 "github.com/cli/cli/v2/pkg/cmdutil" 18 "github.com/cli/cli/v2/pkg/iostreams" 19 "github.com/cli/safeexec" 20 "github.com/spf13/cobra" 21) 22 23type CheckoutOptions struct { 24 HttpClient func() (*http.Client, error) 25 Config func() (config.Config, error) 26 IO *iostreams.IOStreams 27 Remotes func() (context.Remotes, error) 28 Branch func() (string, error) 29 30 Finder shared.PRFinder 31 32 SelectorArg string 33 RecurseSubmodules bool 34 Force bool 35 Detach bool 36 BranchName string 37} 38 39func NewCmdCheckout(f *cmdutil.Factory, runF func(*CheckoutOptions) error) *cobra.Command { 40 opts := &CheckoutOptions{ 41 IO: f.IOStreams, 42 HttpClient: f.HttpClient, 43 Config: f.Config, 44 Remotes: f.Remotes, 45 Branch: f.Branch, 46 } 47 48 cmd := &cobra.Command{ 49 Use: "checkout {<number> | <url> | <branch>}", 50 Short: "Check out a pull request in git", 51 Args: cmdutil.ExactArgs(1, "argument required"), 52 RunE: func(cmd *cobra.Command, args []string) error { 53 opts.Finder = shared.NewFinder(f) 54 55 if len(args) > 0 { 56 opts.SelectorArg = args[0] 57 } 58 59 if runF != nil { 60 return runF(opts) 61 } 62 return checkoutRun(opts) 63 }, 64 } 65 66 cmd.Flags().BoolVarP(&opts.RecurseSubmodules, "recurse-submodules", "", false, "Update all submodules after checkout") 67 cmd.Flags().BoolVarP(&opts.Force, "force", "f", false, "Reset the existing local branch to the latest state of the pull request") 68 cmd.Flags().BoolVarP(&opts.Detach, "detach", "", false, "Checkout PR with a detached HEAD") 69 cmd.Flags().StringVarP(&opts.BranchName, "branch", "b", "", "Local branch name to use (default: the name of the head branch)") 70 71 return cmd 72} 73 74func checkoutRun(opts *CheckoutOptions) error { 75 findOptions := shared.FindOptions{ 76 Selector: opts.SelectorArg, 77 Fields: []string{"number", "headRefName", "headRepository", "headRepositoryOwner", "isCrossRepository", "maintainerCanModify"}, 78 } 79 pr, baseRepo, err := opts.Finder.Find(findOptions) 80 if err != nil { 81 return err 82 } 83 84 cfg, err := opts.Config() 85 if err != nil { 86 return err 87 } 88 protocol, _ := cfg.Get(baseRepo.RepoHost(), "git_protocol") 89 90 remotes, err := opts.Remotes() 91 if err != nil { 92 return err 93 } 94 baseRemote, _ := remotes.FindByRepo(baseRepo.RepoOwner(), baseRepo.RepoName()) 95 baseURLOrName := ghrepo.FormatRemoteURL(baseRepo, protocol) 96 if baseRemote != nil { 97 baseURLOrName = baseRemote.Name 98 } 99 100 headRemote := baseRemote 101 if pr.HeadRepository == nil { 102 headRemote = nil 103 } else if pr.IsCrossRepository { 104 headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name) 105 } 106 107 if strings.HasPrefix(pr.HeadRefName, "-") { 108 return fmt.Errorf("invalid branch name: %q", pr.HeadRefName) 109 } 110 111 var cmdQueue [][]string 112 113 if headRemote != nil { 114 cmdQueue = append(cmdQueue, cmdsForExistingRemote(headRemote, pr, opts)...) 115 } else { 116 httpClient, err := opts.HttpClient() 117 if err != nil { 118 return err 119 } 120 apiClient := api.NewClientFromHTTP(httpClient) 121 122 defaultBranch, err := api.RepoDefaultBranch(apiClient, baseRepo) 123 if err != nil { 124 return err 125 } 126 cmdQueue = append(cmdQueue, cmdsForMissingRemote(pr, baseURLOrName, baseRepo.RepoHost(), defaultBranch, protocol, opts)...) 127 } 128 129 if opts.RecurseSubmodules { 130 cmdQueue = append(cmdQueue, []string{"git", "submodule", "sync", "--recursive"}) 131 cmdQueue = append(cmdQueue, []string{"git", "submodule", "update", "--init", "--recursive"}) 132 } 133 134 err = executeCmds(cmdQueue) 135 if err != nil { 136 return err 137 } 138 139 return nil 140} 141 142func cmdsForExistingRemote(remote *context.Remote, pr *api.PullRequest, opts *CheckoutOptions) [][]string { 143 var cmds [][]string 144 remoteBranch := fmt.Sprintf("%s/%s", remote.Name, pr.HeadRefName) 145 146 refSpec := fmt.Sprintf("+refs/heads/%s", pr.HeadRefName) 147 if !opts.Detach { 148 refSpec += fmt.Sprintf(":refs/remotes/%s", remoteBranch) 149 } 150 151 cmds = append(cmds, []string{"git", "fetch", remote.Name, refSpec}) 152 153 localBranch := pr.HeadRefName 154 if opts.BranchName != "" { 155 localBranch = opts.BranchName 156 } 157 158 switch { 159 case opts.Detach: 160 cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) 161 case localBranchExists(localBranch): 162 cmds = append(cmds, []string{"git", "checkout", localBranch}) 163 if opts.Force { 164 cmds = append(cmds, []string{"git", "reset", "--hard", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 165 } else { 166 // TODO: check if non-fast-forward and suggest to use `--force` 167 cmds = append(cmds, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)}) 168 } 169 default: 170 cmds = append(cmds, []string{"git", "checkout", "-b", localBranch, "--track", remoteBranch}) 171 } 172 173 return cmds 174} 175 176func cmdsForMissingRemote(pr *api.PullRequest, baseURLOrName, repoHost, defaultBranch, protocol string, opts *CheckoutOptions) [][]string { 177 var cmds [][]string 178 ref := fmt.Sprintf("refs/pull/%d/head", pr.Number) 179 180 if opts.Detach { 181 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref}) 182 cmds = append(cmds, []string{"git", "checkout", "--detach", "FETCH_HEAD"}) 183 return cmds 184 } 185 186 localBranch := pr.HeadRefName 187 if opts.BranchName != "" { 188 localBranch = opts.BranchName 189 } else if pr.HeadRefName == defaultBranch { 190 // avoid naming the new branch the same as the default branch 191 localBranch = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, localBranch) 192 } 193 194 currentBranch, _ := opts.Branch() 195 if localBranch == currentBranch { 196 // PR head matches currently checked out branch 197 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, ref}) 198 if opts.Force { 199 cmds = append(cmds, []string{"git", "reset", "--hard", "FETCH_HEAD"}) 200 } else { 201 // TODO: check if non-fast-forward and suggest to use `--force` 202 cmds = append(cmds, []string{"git", "merge", "--ff-only", "FETCH_HEAD"}) 203 } 204 } else { 205 if opts.Force { 206 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch), "--force"}) 207 } else { 208 // TODO: check if non-fast-forward and suggest to use `--force` 209 cmds = append(cmds, []string{"git", "fetch", baseURLOrName, fmt.Sprintf("%s:%s", ref, localBranch)}) 210 } 211 212 cmds = append(cmds, []string{"git", "checkout", localBranch}) 213 } 214 215 remote := baseURLOrName 216 mergeRef := ref 217 if pr.MaintainerCanModify && pr.HeadRepository != nil { 218 headRepo := ghrepo.NewWithHost(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name, repoHost) 219 remote = ghrepo.FormatRemoteURL(headRepo, protocol) 220 mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName) 221 } 222 if missingMergeConfigForBranch(localBranch) { 223 // .remote is needed for `git pull` to work 224 // .pushRemote is needed for `git push` to work, if user has set `remote.pushDefault`. 225 // see https://git-scm.com/docs/git-config#Documentation/git-config.txt-branchltnamegtremote 226 cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.remote", localBranch), remote}) 227 cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.pushRemote", localBranch), remote}) 228 cmds = append(cmds, []string{"git", "config", fmt.Sprintf("branch.%s.merge", localBranch), mergeRef}) 229 } 230 231 return cmds 232} 233 234func missingMergeConfigForBranch(b string) bool { 235 mc, err := git.Config(fmt.Sprintf("branch.%s.merge", b)) 236 return err != nil || mc == "" 237} 238 239func localBranchExists(b string) bool { 240 _, err := git.ShowRefs("refs/heads/" + b) 241 return err == nil 242} 243 244func executeCmds(cmdQueue [][]string) error { 245 for _, args := range cmdQueue { 246 // TODO: reuse the result of this lookup across loop iteration 247 exe, err := safeexec.LookPath(args[0]) 248 if err != nil { 249 return err 250 } 251 cmd := exec.Command(exe, args[1:]...) 252 cmd.Stdout = os.Stdout 253 cmd.Stderr = os.Stderr 254 if err := run.PrepareCmd(cmd).Run(); err != nil { 255 return err 256 } 257 } 258 return nil 259} 260