1#!/usr/bin/env python3
2
3from __future__ import print_function
4
5import argparse
6import generate_amalgamation
7import hashlib
8import os
9import string
10import subprocess
11import sys
12
13
14class UnstagedFiles(Exception):
15    pass
16
17
18class UnknownHash(Exception):
19    pass
20
21
22class IncorrectType(Exception):
23    pass
24
25
26class bcolors:
27    HEADER = '\033[95m'
28    OKBLUE = '\033[94m'
29    OKGREEN = '\033[92m'
30    WARNING = '\033[93m'
31    FAIL = '\033[91m'
32    ENDC = '\033[0m'
33
34
35def _print_command(cmd):
36    """Print the command to be executed to the console.
37
38    Use a different color so that it can be easily seen amongst the output
39    commands.
40    """
41    if (isinstance(cmd, list)):
42        cmd = ' '.join(cmd)
43    print('{}{}{}'.format(bcolors.OKBLUE, cmd, bcolors.ENDC))
44
45
46class ManifestEntry(object):
47    """Represents a single entry in a SQLite manifest."""
48
49    def __init__(self, entry_type, items):
50        if not len(entry_type) == 1:
51            raise IncorrectType(entry_type)
52        self.entry_type = entry_type
53        self.items = items
54
55    def get_hash_type(self):
56        """Return the type of hash used for this entry."""
57        last_item = self.items[-1]
58        if not all(c in string.hexdigits for c in last_item):
59            print(
60                '"{}" doesn\'t appear to be a hash.'.format(last_item),
61                file=sys.stderr)
62            raise UnknownHash()
63        elif len(last_item) == 40:
64            return 'sha1'
65        elif len(last_item) == 64:
66            return 'sha3'
67        else:
68            raise UnknownHash('Incorrect length {} for {}'.format(
69                len(last_item), last_item))
70
71    @staticmethod
72    def calc_hash(data, method):
73        """Return the string sha1 or sha3 hash digest for the given data."""
74        if method == 'sha3':
75            h = hashlib.sha3_256()
76        elif method == 'sha1':
77            h = hashlib.sha1()
78        else:
79            assert False
80        h.update(data)
81        return h.hexdigest()
82
83    @staticmethod
84    def calc_file_hash(fname, method):
85        """Return the string sha1 or sha3 hash digest for the given file."""
86        with open(fname, 'rb') as input_file:
87            return ManifestEntry.calc_hash(input_file.read(), method)
88
89    def update_file_hash(self):
90        """Calculates a new file hash for this entry."""
91        self.items[1] = ManifestEntry.calc_file_hash(self.items[0],
92                                                     self.get_hash_type())
93
94    def __str__(self):
95        return '{} {}'.format(self.entry_type, ' '.join(self.items))
96
97
98class Manifest(object):
99    """A deserialized SQLite manifest."""
100
101    def __init__(self):
102        self.entries = []
103
104    def find_file_entry(self, fname):
105        """Given a file path return the entry. Returns None if none found."""
106        for entry in self.entries:
107            if entry.entry_type == 'F' and entry.items[0] == fname:
108                return entry
109        return None
110
111
112class ManifestSerializer(object):
113    """De/serialize SQLite manifests."""
114
115    @staticmethod
116    def read_stream(input_stream):
117        """Deserialize a manifest from an input stream and return a Manifest
118        object."""
119        _manifest = Manifest()
120        for line in input_stream.readlines():
121            items = line.split()
122            if not items:
123                continue
124            _manifest.entries.append(ManifestEntry(items[0], items[1:]))
125        return _manifest
126
127    @staticmethod
128    def read_file(fname):
129        """Deserialize a manifest file and return a Manifest object."""
130        with open(fname) as input_stream:
131            return ManifestSerializer.read_stream(input_stream)
132
133    @staticmethod
134    def write_stream(manifest, output_stream):
135        """Serialize the given manifest to the given stream."""
136        for entry in manifest.entries:
137            print(str(entry), file=output_stream)
138
139    @staticmethod
140    def write_file(manifest, fname):
141        """Serialize the given manifest to the specified file."""
142        with open(fname, 'w') as output_stream:
143            ManifestSerializer.write_stream(manifest, output_stream)
144
145
146class Git(object):
147    @staticmethod
148    def _get_status():
149        changes = []
150        for line in subprocess.check_output(['git', 'status',
151                                             '--porcelain']).splitlines():
152            changes.append(line.decode('utf-8'))
153        return changes
154
155    @staticmethod
156    def get_staged_changes():
157        changes = []
158        for line in Git._get_status():
159            entry = line[0:2]
160            if entry == 'M ':
161                changes.append(line.split()[1])
162        return changes
163
164    @staticmethod
165    def get_unstaged_changes():
166        changes = []
167        for line in Git._get_status():
168            entry = line[0:2]
169            if entry == ' M':
170                changes.append(line.split()[1])
171        return changes
172
173    @staticmethod
174    def get_unmerged_changes():
175        changes = []
176        for line in Git._get_status():
177            entry = line[0:2]
178            if entry == 'UU':
179                changes.append(line.split()[1])
180        return changes
181
182
183class CherryPicker(object):
184    """Class to cherry pick commits in a SQLite Git repository."""
185
186    # The binary file extenions for files committed to the SQLite repository.
187    # This is used as a simple way of detecting files that cannot (simply) be
188    # resolved in a merge conflict. This script will automatically ignore
189    # all conflicted files with any of these extensions. If, in the future, new
190    # binary types are added then a conflict will arise during cherry-pick and
191    # the user will need to resolve it.
192    binary_extensions = (
193        '.data',
194        '.db',
195        '.ico',
196        '.jpg',
197        '.png',
198    )
199
200    def __init__(self):
201        self._print_cmds = True
202        self._update_amangamation = True
203
204    def _take_head_version(self, file_path):
205        subprocess.call(
206            'git show HEAD:{} > {}'.format(file_path, file_path), shell=True)
207        subprocess.call('git add {}'.format(file_path), shell=True)
208
209    @staticmethod
210    def _is_binary_file(file_path):
211        _, file_extension = os.path.splitext(file_path)
212        return file_extension in CherryPicker.binary_extensions
213
214    @staticmethod
215    def _append_cherry_pick_comments(comments):
216        # TODO(cmumford): Figure out how to append comments on cherry picks
217        pass
218
219    def _cherry_pick_git_commit(self, commit_id):
220        """Cherry-pick a given Git commit into the current branch."""
221        cmd = ['git', 'cherry-pick', '-x', commit_id]
222        if self._print_cmds:
223            _print_command(' '.join(cmd))
224        returncode = subprocess.call(cmd)
225        # The manifest and manifest.uuid contain Fossil hashes. Restore to
226        # HEAD version and update only when all conflicts have been resolved.
227        comments = None
228        self._take_head_version('manifest')
229        self._take_head_version('manifest.uuid')
230        for unmerged_file in Git.get_unmerged_changes():
231            if CherryPicker._is_binary_file(unmerged_file):
232                print('{} is a binary file, keeping branch version.'.format(
233                    unmerged_file))
234                self._take_head_version(unmerged_file)
235                if not comments:
236                    comments = [
237                        'Cherry-pick notes', '=============================='
238                    ]
239                comments.append(
240                    '{} is binary file (with conflict). Keeping branch version'
241                    .format(unmerged_file))
242        if comments:
243            CherryPicker._append_cherry_pick_comments(comments)
244        self.continue_cherry_pick()
245
246    @staticmethod
247    def _is_git_commit_id(commit_id):
248        return len(commit_id) == 40
249
250    def _find_git_commit_id(self, fossil_commit_id):
251        cmd = [
252            'git', '--no-pager', 'log', '--color=never', '--all',
253            '--pretty=format:%H', '--grep={}'.format(fossil_commit_id),
254            'origin/master'
255        ]
256        if self._print_cmds:
257            _print_command(' '.join(cmd))
258        for line in subprocess.check_output(cmd).splitlines():
259            return line.decode('utf-8')
260        # Not found.
261        assert False
262
263    def cherry_pick(self, commit_id):
264        """Cherry-pick a given commit into the current branch.
265
266        Can cherry-pick a given Git or a Fossil commit.
267        """
268        if not CherryPicker._is_git_commit_id(commit_id):
269            commit_id = self._find_git_commit_id(commit_id)
270        self._cherry_pick_git_commit(commit_id)
271
272    def _generate_amalgamation(self):
273        for config_name in ['chromium', 'dev']:
274            generate_amalgamation.make_aggregate(config_name)
275            generate_amalgamation.extract_sqlite_api(config_name)
276
277    def _add_amalgamation(self):
278        os.chdir(generate_amalgamation._SQLITE_SRC_DIR)
279        for config_name in ['chromium', 'dev']:
280            cmd = [
281                'git', 'add',
282                generate_amalgamation.get_amalgamation_dir(config_name)
283            ]
284            if self._print_cmds:
285                _print_command(' '.join(cmd))
286            subprocess.check_call(cmd)
287
288    def _update_manifests(self):
289        """Update the SQLite's Fossil manifest files.
290
291        This isn't strictly necessary as the manifest isn't used during
292        any build, and manifest.uuid is the Fossil commit ID (which
293        has no meaning in a Git repo). However, keeping these updated
294        helps make it more obvious that a commit originated in
295        Git and not Fossil.
296        """
297        manifest = ManifestSerializer.read_file('manifest')
298        files_not_in_manifest = ('manifest', 'manifest.uuid')
299        for fname in Git.get_staged_changes():
300            if fname in files_not_in_manifest:
301                continue
302            entry = manifest.find_file_entry(fname)
303            if not entry:
304                print(
305                    'Cannot find manifest entry for "{}"'.format(fname),
306                    file=sys.stderr)
307                sys.exit(1)
308            manifest.find_file_entry(fname).update_file_hash()
309        ManifestSerializer.write_file(manifest, 'manifest')
310        cmd = ['git', 'add', 'manifest']
311        if self._print_cmds:
312            _print_command(' '.join(cmd))
313        subprocess.check_call(cmd)
314        # manifest.uuid contains the hash from the Fossil repository which
315        # doesn't make sense in a Git branch. Just write all zeros.
316        with open('manifest.uuid', 'w') as output_file:
317            print('0' * 64, file=output_file)
318        cmd = ['git', 'add', 'manifest.uuid']
319        if self._print_cmds:
320            _print_command(' '.join(cmd))
321        subprocess.check_call(cmd)
322
323    def continue_cherry_pick(self):
324        if Git.get_unstaged_changes() or Git.get_unmerged_changes():
325            raise UnstagedFiles()
326        self._update_manifests()
327        if self._update_amangamation:
328            self._generate_amalgamation()
329            self._add_amalgamation()
330        cmd = ['git', 'cherry-pick', '--continue']
331        if self._print_cmds:
332            _print_command(' '.join(cmd))
333        subprocess.check_call(cmd)
334
335
336if __name__ == '__main__':
337    desc = 'A script for cherry-picking commits from the SQLite repo.'
338    parser = argparse.ArgumentParser(description=desc)
339    parser.add_argument(
340        'commit', nargs='*', help='The commit ids to cherry pick (in order)')
341    parser.add_argument(
342        '--continue',
343        dest='cont',
344        action='store_true',
345        help='Continue the cherry-pick once conflicts have been resolved')
346    namespace = parser.parse_args()
347    cherry_picker = CherryPicker()
348    if namespace.cont:
349        try:
350            cherry_picker.continue_cherry_pick()
351            sys.exit(0)
352        except UnstagedFiles:
353            print(
354                'There are still unstaged files to resolve before continuing.')
355            sys.exit(1)
356    num_picked = 0
357    for commit_id in namespace.commit:
358        try:
359            cherry_picker.cherry_pick(commit_id)
360            num_picked += 1
361        except UnstagedFiles:
362            print(
363                '\nThis cherry-pick contains conflicts. Please resolve them ')
364            print('(e.g git mergetool) and rerun this script '
365                  '`sqlite_cherry_picker.py --continue`')
366            print('or `git cherry-pick --abort`.')
367            if commit_id != namespace.commit[-1]:
368                msg = (
369                    'NOTE: You have only successfully cherry-picked {} out of '
370                    '{} commits.')
371                print(msg.format(num_picked, len(namespace.commit)))
372            sys.exit(1)
373