1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5import os
6import re
7import subprocess
8import sys
9import tempfile
10from six.moves import input
11from six.moves.urllib import parse as urlparse
12
13from wptrunner.update.tree import get_unique_name
14from wptrunner.update.base import Step, StepRunner, exit_clean, exit_unclean
15
16from .tree import Commit, GitTree, Patch
17from .github import GitHub
18
19
20def rewrite_patch(patch, strip_dir):
21    """Take a Patch and convert to a different repository by stripping a prefix from the
22    file paths. Also rewrite the message to remove the bug number and reviewer, but add
23    a bugzilla link in the summary.
24
25    :param patch: the Patch to convert
26    :param strip_dir: the path prefix to remove
27    """
28
29    if not strip_dir.startswith("/"):
30        strip_dir = "/%s" % strip_dir
31
32    new_diff = []
33    line_starts = [
34        ("diff ", True),
35        ("+++ ", True),
36        ("--- ", True),
37        ("rename from ", False),
38        ("rename to ", False),
39    ]
40    for line in patch.diff.split("\n"):
41        for start, leading_slash in line_starts:
42            strip = strip_dir if leading_slash else strip_dir[1:]
43            if line.startswith(start):
44                new_diff.append(line.replace(strip, "").encode("utf8"))
45                break
46        else:
47            new_diff.append(line)
48
49    new_diff = "\n".join(new_diff)
50
51    assert new_diff != patch
52
53    return Patch(patch.author, patch.email, rewrite_message(patch), new_diff)
54
55
56def rewrite_message(patch):
57    if patch.message.bug is not None:
58        return "\n".join(
59            [
60                patch.message.summary,
61                patch.message.body,
62                "",
63                "Upstreamed from https://bugzilla.mozilla.org/show_bug.cgi?id=%s [ci skip]"
64                % patch.message.bug,  # noqa E501
65            ]
66        )
67
68    return "\n".join(
69        [patch.message.full_summary, "%s\n[ci skip]\n" % patch.message.body]
70    )
71
72
73class SyncToUpstream(Step):
74    """Sync local changes to upstream"""
75
76    def create(self, state):
77        if not state.kwargs["upstream"]:
78            return
79
80        if not isinstance(state.local_tree, GitTree):
81            self.logger.error("Cannot sync with upstream from a non-Git checkout.")
82            return exit_clean
83
84        try:
85            import requests  # noqa F401
86        except ImportError:
87            self.logger.error(
88                "Upstream sync requires the requests module to be installed"
89            )
90            return exit_clean
91
92        if not state.sync_tree:
93            os.makedirs(state.sync["path"])
94            state.sync_tree = GitTree(root=state.sync["path"])
95
96        kwargs = state.kwargs
97        with state.push(
98            ["local_tree", "sync_tree", "tests_path", "metadata_path", "sync"]
99        ):
100            state.token = kwargs["token"]
101            runner = SyncToUpstreamRunner(self.logger, state)
102            runner.run()
103
104
105class GetLastSyncData(Step):
106    """Find the gecko commit at which we last performed a sync with upstream and the upstream
107    commit that was synced."""
108
109    provides = ["sync_data_path", "last_sync_commit", "old_upstream_rev"]
110
111    def create(self, state):
112        self.logger.info("Looking for last sync commit")
113        state.sync_data_path = os.path.join(state.metadata_path, "mozilla-sync")
114        items = {}
115        with open(state.sync_data_path) as f:
116            for line in f.readlines():
117                key, value = [item.strip() for item in line.split(":", 1)]
118                items[key] = value
119
120        state.last_sync_commit = Commit(
121            state.local_tree, state.local_tree.rev_from_hg(items["local"])
122        )
123        state.old_upstream_rev = items["upstream"]
124
125        if not state.local_tree.contains_commit(state.last_sync_commit):
126            self.logger.error(
127                "Could not find last sync commit %s" % state.last_sync_commit.sha1
128            )
129            return exit_clean
130
131        self.logger.info(
132            "Last sync to web-platform-tests happened in %s"
133            % state.last_sync_commit.sha1
134        )
135
136
137class CheckoutBranch(Step):
138    """Create a branch in the sync tree pointing at the last upstream sync commit
139    and check it out"""
140
141    provides = ["branch"]
142
143    def create(self, state):
144        self.logger.info("Updating sync tree from %s" % state.sync["remote_url"])
145        state.branch = state.sync_tree.unique_branch_name(
146            "outbound_update_%s" % state.old_upstream_rev
147        )
148        state.sync_tree.update(
149            state.sync["remote_url"], state.sync["branch"], state.branch
150        )
151        state.sync_tree.checkout(state.old_upstream_rev, state.branch, force=True)
152
153
154class GetBaseCommit(Step):
155    """Find the latest upstream commit on the branch that we are syncing with"""
156
157    provides = ["base_commit"]
158
159    def create(self, state):
160        state.base_commit = state.sync_tree.get_remote_sha1(
161            state.sync["remote_url"], state.sync["branch"]
162        )
163        self.logger.debug("New base commit is %s" % state.base_commit.sha1)
164
165
166class LoadCommits(Step):
167    """Get a list of commits in the gecko tree that need to be upstreamed"""
168
169    provides = ["source_commits", "has_backouts"]
170
171    def create(self, state):
172        state.source_commits = state.local_tree.log(
173            state.last_sync_commit, state.tests_path
174        )
175
176        update_regexp = re.compile(
177            "Bug \d+ - Update web-platform-tests to revision [0-9a-f]{40}"
178        )
179
180        state.has_backouts = False
181
182        for i, commit in enumerate(state.source_commits[:]):
183            if update_regexp.match(commit.message.text):
184                # This is a previous update commit so ignore it
185                state.source_commits.remove(commit)
186                continue
187
188            elif commit.message.backouts:
189                # TODO: Add support for collapsing backouts
190                state.has_backouts = True
191
192            elif not commit.message.bug:
193                self.logger.error(
194                    "Commit %i (%s) doesn't have an associated bug number."
195                    % (i + 1, commit.sha1)
196                )
197                return exit_unclean
198
199        self.logger.debug("Source commits: %s" % state.source_commits)
200
201
202class SelectCommits(Step):
203    """Provide a UI to select which commits to upstream"""
204
205    def create(self, state):
206        while True:
207            commits = state.source_commits[:]
208            for i, commit in enumerate(commits):
209                print("{}:\t{}".format(i, commit.message.summary))
210
211            remove = input(
212                "Provide a space-separated list of any commits numbers "
213                "to remove from the list to upstream:\n"
214            ).strip()
215            remove_idx = set()
216            for item in remove.split(" "):
217                try:
218                    item = int(item)
219                except ValueError:
220                    continue
221                if item < 0 or item >= len(commits):
222                    continue
223                remove_idx.add(item)
224
225            keep_commits = [
226                (i, cmt) for i, cmt in enumerate(commits) if i not in remove_idx
227            ]
228            # TODO: consider printed removed commits
229            print("Selected the following commits to keep:")
230            for i, commit in keep_commits:
231                print("{}:\t{}".format(i, commit.message.summary))
232            confirm = input("Keep the above commits? y/n\n").strip().lower()
233
234            if confirm == "y":
235                state.source_commits = [item[1] for item in keep_commits]
236                break
237
238
239class MovePatches(Step):
240    """Convert gecko commits into patches against upstream and commit these to the sync tree."""
241
242    provides = ["commits_loaded"]
243
244    def create(self, state):
245        if not hasattr(state, "commits_loaded"):
246            state.commits_loaded = 0
247
248        strip_path = os.path.relpath(state.tests_path, state.local_tree.root)
249        self.logger.debug("Stripping patch %s" % strip_path)
250
251        if not hasattr(state, "patch"):
252            state.patch = None
253
254        for commit in state.source_commits[state.commits_loaded :]:
255            i = state.commits_loaded + 1
256            self.logger.info("Moving commit %i: %s" % (i, commit.message.full_summary))
257            stripped_patch = None
258            if state.patch:
259                filename, stripped_patch = state.patch
260                if not os.path.exists(filename):
261                    stripped_patch = None
262                else:
263                    with open(filename) as f:
264                        stripped_patch.diff = f.read()
265            state.patch = None
266            if not stripped_patch:
267                patch = commit.export_patch(state.tests_path)
268                stripped_patch = rewrite_patch(patch, strip_path)
269            if not stripped_patch.diff:
270                self.logger.info("Skipping empty patch")
271                state.commits_loaded = i
272                continue
273            try:
274                state.sync_tree.import_patch(stripped_patch)
275            except Exception:
276                with tempfile.NamedTemporaryFile(delete=False, suffix=".diff") as f:
277                    f.write(stripped_patch.diff)
278                    print(
279                        """Patch failed to apply. Diff saved in {}
280Fix this file so it applies and run with --continue""".format(
281                            f.name
282                        )
283                    )
284                    state.patch = (f.name, stripped_patch)
285                    print(state.patch)
286                sys.exit(1)
287            state.commits_loaded = i
288        input("Check for differences with upstream")
289
290
291class RebaseCommits(Step):
292    """Rebase commits from the current branch on top of the upstream destination branch.
293
294    This step is particularly likely to fail if the rebase generates merge conflicts.
295    In that case the conflicts can be fixed up locally and the sync process restarted
296    with --continue.
297    """
298
299    def create(self, state):
300        self.logger.info("Rebasing local commits")
301        continue_rebase = False
302        # Check if there's a rebase in progress
303        if os.path.exists(
304            os.path.join(state.sync_tree.root, ".git", "rebase-merge")
305        ) or os.path.exists(os.path.join(state.sync_tree.root, ".git", "rebase-apply")):
306            continue_rebase = True
307
308        try:
309            state.sync_tree.rebase(state.base_commit, continue_rebase=continue_rebase)
310        except subprocess.CalledProcessError:
311            self.logger.info(
312                "Rebase failed, fix merge and run %s again with --continue"
313                % sys.argv[0]
314            )
315            raise
316        self.logger.info("Rebase successful")
317
318
319class CheckRebase(Step):
320    """Check if there are any commits remaining after rebase"""
321
322    provides = ["rebased_commits"]
323
324    def create(self, state):
325        state.rebased_commits = state.sync_tree.log(state.base_commit)
326        if not state.rebased_commits:
327            self.logger.info("Nothing to upstream, exiting")
328            return exit_clean
329
330
331class MergeUpstream(Step):
332    """Run steps to push local commits as seperate PRs and merge upstream."""
333
334    provides = ["merge_index", "gh_repo"]
335
336    def create(self, state):
337        gh = GitHub(state.token)
338        if "merge_index" not in state:
339            state.merge_index = 0
340
341        org, name = urlparse.urlsplit(state.sync["remote_url"]).path[1:].split("/")
342        if name.endswith(".git"):
343            name = name[:-4]
344        state.gh_repo = gh.repo(org, name)
345        for commit in state.rebased_commits[state.merge_index :]:
346            with state.push(["gh_repo", "sync_tree"]):
347                state.commit = commit
348                pr_merger = PRMergeRunner(self.logger, state)
349                rv = pr_merger.run()
350                if rv is not None:
351                    return rv
352            state.merge_index += 1
353
354
355class UpdateLastSyncData(Step):
356    """Update the gecko commit at which we last performed a sync with upstream."""
357
358    provides = []
359
360    def create(self, state):
361        self.logger.info("Updating last sync commit")
362        data = {
363            "local": state.local_tree.rev_to_hg(state.local_tree.rev),
364            "upstream": state.sync_tree.rev,
365        }
366        with open(state.sync_data_path, "w") as f:
367            for key, value in data.iteritems():
368                f.write("%s: %s\n" % (key, value))
369        # This gets added to the patch later on
370
371
372class MergeLocalBranch(Step):
373    """Create a local branch pointing at the commit to upstream"""
374
375    provides = ["local_branch"]
376
377    def create(self, state):
378        branch_prefix = "sync_%s" % state.commit.sha1
379        local_branch = state.sync_tree.unique_branch_name(branch_prefix)
380
381        state.sync_tree.create_branch(local_branch, state.commit)
382        state.local_branch = local_branch
383
384
385class MergeRemoteBranch(Step):
386    """Get an unused remote branch name to use for the PR"""
387
388    provides = ["remote_branch"]
389
390    def create(self, state):
391        remote_branch = "sync_%s" % state.commit.sha1
392        branches = [
393            ref[len("refs/heads/") :]
394            for sha1, ref in state.sync_tree.list_remote(state.gh_repo.url)
395            if ref.startswith("refs/heads")
396        ]
397        state.remote_branch = get_unique_name(branches, remote_branch)
398
399
400class PushUpstream(Step):
401    """Push local branch to remote"""
402
403    def create(self, state):
404        self.logger.info("Pushing commit upstream")
405        state.sync_tree.push(state.gh_repo.url, state.local_branch, state.remote_branch)
406
407
408class CreatePR(Step):
409    """Create a PR for the remote branch"""
410
411    provides = ["pr"]
412
413    def create(self, state):
414        self.logger.info("Creating a PR")
415        commit = state.commit
416        state.pr = state.gh_repo.create_pr(
417            commit.message.full_summary,
418            state.remote_branch,
419            "master",
420            commit.message.body if commit.message.body else "",
421        )
422
423
424class PRAddComment(Step):
425    """Add an issue comment indicating that the code has been reviewed already"""
426
427    def create(self, state):
428        state.pr.issue.add_comment("Code reviewed upstream.")
429
430
431class MergePR(Step):
432    """Merge the PR"""
433
434    def create(self, state):
435        self.logger.info("Merging PR")
436        state.pr.merge()
437
438
439class PRDeleteBranch(Step):
440    """Delete the remote branch"""
441
442    def create(self, state):
443        self.logger.info("Deleting remote branch")
444        state.sync_tree.push(state.gh_repo.url, "", state.remote_branch)
445
446
447class SyncToUpstreamRunner(StepRunner):
448    """Runner for syncing local changes to upstream"""
449
450    steps = [
451        GetLastSyncData,
452        CheckoutBranch,
453        GetBaseCommit,
454        LoadCommits,
455        SelectCommits,
456        MovePatches,
457        RebaseCommits,
458        CheckRebase,
459        MergeUpstream,
460        UpdateLastSyncData,
461    ]
462
463
464class PRMergeRunner(StepRunner):
465    """(Sub)Runner for creating and merging a PR"""
466
467    steps = [
468        MergeLocalBranch,
469        MergeRemoteBranch,
470        PushUpstream,
471        CreatePR,
472        PRAddComment,
473        MergePR,
474        PRDeleteBranch,
475    ]
476