1# Copyright (C) 2007-2011 Canonical Ltd
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation; either version 2 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program; if not, write to the Free Software
15# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16
17import base64
18import contextlib
19from io import BytesIO
20import re
21
22from . import lazy_import
23lazy_import.lazy_import(globals(), """
24from breezy import (
25    branch as _mod_branch,
26    diff,
27    email_message,
28    gpg,
29    hooks,
30    registry,
31    revision as _mod_revision,
32    rio,
33    timestamp,
34    trace,
35    )
36from breezy.bzr import (
37    testament,
38    )
39from breezy.bzr.bundle import (
40    serializer as bundle_serializer,
41    )
42""")
43from . import (
44    errors,
45    )
46
47
48class IllegalMergeDirectivePayload(errors.BzrError):
49    """A merge directive contained something other than a patch or bundle"""
50
51    _fmt = "Bad merge directive payload %(start)r"
52
53    def __init__(self, start):
54        errors.BzrError(self)
55        self.start = start
56
57
58class MergeRequestBodyParams(object):
59    """Parameter object for the merge_request_body hook."""
60
61    def __init__(self, body, orig_body, directive, to, basename, subject,
62                 branch, tree=None):
63        self.body = body
64        self.orig_body = orig_body
65        self.directive = directive
66        self.branch = branch
67        self.tree = tree
68        self.to = to
69        self.basename = basename
70        self.subject = subject
71
72
73class MergeDirectiveHooks(hooks.Hooks):
74    """Hooks for MergeDirective classes."""
75
76    def __init__(self):
77        hooks.Hooks.__init__(self, "breezy.merge_directive",
78                             "BaseMergeDirective.hooks")
79        self.add_hook(
80            'merge_request_body',
81            "Called with a MergeRequestBodyParams when a body is needed for"
82            " a merge request.  Callbacks must return a body.  If more"
83            " than one callback is registered, the output of one callback is"
84            " provided to the next.", (1, 15, 0))
85
86
87class BaseMergeDirective(object):
88    """A request to perform a merge into a branch.
89
90    This is the base class that all merge directive implementations
91    should derive from.
92
93    :cvar multiple_output_files: Whether or not this merge directive
94        stores a set of revisions in more than one file
95    """
96
97    hooks = MergeDirectiveHooks()
98
99    multiple_output_files = False
100
101    def __init__(self, revision_id, testament_sha1, time, timezone,
102                 target_branch, patch=None, source_branch=None,
103                 message=None, bundle=None):
104        """Constructor.
105
106        :param revision_id: The revision to merge
107        :param testament_sha1: The sha1 of the testament of the revision to
108            merge.
109        :param time: The current POSIX timestamp time
110        :param timezone: The timezone offset
111        :param target_branch: Location of branch to apply the merge to
112        :param patch: The text of a diff or bundle
113        :param source_branch: A public location to merge the revision from
114        :param message: The message to use when committing this merge
115        """
116        self.revision_id = revision_id
117        self.testament_sha1 = testament_sha1
118        self.time = time
119        self.timezone = timezone
120        self.target_branch = target_branch
121        self.patch = patch
122        self.source_branch = source_branch
123        self.message = message
124
125    def to_lines(self):
126        """Serialize as a list of lines
127
128        :return: a list of lines
129        """
130        raise NotImplementedError(self.to_lines)
131
132    def to_files(self):
133        """Serialize as a set of files.
134
135        :return: List of tuples with filename and contents as lines
136        """
137        raise NotImplementedError(self.to_files)
138
139    def get_raw_bundle(self):
140        """Return the bundle for this merge directive.
141
142        :return: bundle text or None if there is no bundle
143        """
144        return None
145
146    def _to_lines(self, base_revision=False):
147        """Serialize as a list of lines
148
149        :return: a list of lines
150        """
151        time_str = timestamp.format_patch_date(self.time, self.timezone)
152        stanza = rio.Stanza(revision_id=self.revision_id, timestamp=time_str,
153                            target_branch=self.target_branch,
154                            testament_sha1=self.testament_sha1)
155        for key in ('source_branch', 'message'):
156            if self.__dict__[key] is not None:
157                stanza.add(key, self.__dict__[key])
158        if base_revision:
159            stanza.add('base_revision_id', self.base_revision_id)
160        lines = [b'# ' + self._format_string + b'\n']
161        lines.extend(rio.to_patch_lines(stanza))
162        lines.append(b'# \n')
163        return lines
164
165    def write_to_directory(self, path):
166        """Write this merge directive to a series of files in a directory.
167
168        :param path: Filesystem path to write to
169        """
170        raise NotImplementedError(self.write_to_directory)
171
172    @classmethod
173    def from_objects(klass, repository, revision_id, time, timezone,
174                     target_branch, patch_type='bundle',
175                     local_target_branch=None, public_branch=None, message=None):
176        """Generate a merge directive from various objects
177
178        :param repository: The repository containing the revision
179        :param revision_id: The revision to merge
180        :param time: The POSIX timestamp of the date the request was issued.
181        :param timezone: The timezone of the request
182        :param target_branch: The url of the branch to merge into
183        :param patch_type: 'bundle', 'diff' or None, depending on the type of
184            patch desired.
185        :param local_target_branch: the submit branch, either itself or a local copy
186        :param public_branch: location of a public branch containing
187            the target revision.
188        :param message: Message to use when committing the merge
189        :return: The merge directive
190
191        The public branch is always used if supplied.  If the patch_type is
192        not 'bundle', the public branch must be supplied, and will be verified.
193
194        If the message is not supplied, the message from revision_id will be
195        used for the commit.
196        """
197        t_revision_id = revision_id
198        if revision_id == _mod_revision.NULL_REVISION:
199            t_revision_id = None
200        t = testament.StrictTestament3.from_revision(repository, t_revision_id)
201        if local_target_branch is None:
202            submit_branch = _mod_branch.Branch.open(target_branch)
203        else:
204            submit_branch = local_target_branch
205        if submit_branch.get_public_branch() is not None:
206            target_branch = submit_branch.get_public_branch()
207        if patch_type is None:
208            patch = None
209        else:
210            submit_revision_id = submit_branch.last_revision()
211            submit_revision_id = _mod_revision.ensure_null(submit_revision_id)
212            repository.fetch(submit_branch.repository, submit_revision_id)
213            graph = repository.get_graph()
214            ancestor_id = graph.find_unique_lca(revision_id,
215                                                submit_revision_id)
216            type_handler = {'bundle': klass._generate_bundle,
217                            'diff': klass._generate_diff,
218                            None: lambda x, y, z: None}
219            patch = type_handler[patch_type](repository, revision_id,
220                                             ancestor_id)
221
222        if public_branch is not None and patch_type != 'bundle':
223            public_branch_obj = _mod_branch.Branch.open(public_branch)
224            if not public_branch_obj.repository.has_revision(revision_id):
225                raise errors.PublicBranchOutOfDate(public_branch,
226                                                   revision_id)
227
228        return klass(revision_id, t.as_sha1(), time, timezone, target_branch,
229                     patch, patch_type, public_branch, message)
230
231    def get_disk_name(self, branch):
232        """Generate a suitable basename for storing this directive on disk
233
234        :param branch: The Branch this merge directive was generated fro
235        :return: A string
236        """
237        revno, revision_id = branch.last_revision_info()
238        if self.revision_id == revision_id:
239            revno = [revno]
240        else:
241            try:
242                revno = branch.revision_id_to_dotted_revno(self.revision_id)
243            except errors.NoSuchRevision:
244                revno = ['merge']
245        nick = re.sub('(\\W+)', '-', branch.nick).strip('-')
246        return '%s-%s' % (nick, '.'.join(str(n) for n in revno))
247
248    @staticmethod
249    def _generate_diff(repository, revision_id, ancestor_id):
250        tree_1 = repository.revision_tree(ancestor_id)
251        tree_2 = repository.revision_tree(revision_id)
252        s = BytesIO()
253        diff.show_diff_trees(tree_1, tree_2, s, old_label='', new_label='')
254        return s.getvalue()
255
256    @staticmethod
257    def _generate_bundle(repository, revision_id, ancestor_id):
258        s = BytesIO()
259        bundle_serializer.write_bundle(repository, revision_id,
260                                       ancestor_id, s)
261        return s.getvalue()
262
263    def to_signed(self, branch):
264        """Serialize as a signed string.
265
266        :param branch: The source branch, to get the signing strategy
267        :return: a string
268        """
269        my_gpg = gpg.GPGStrategy(branch.get_config_stack())
270        return my_gpg.sign(b''.join(self.to_lines()), gpg.MODE_CLEAR)
271
272    def to_email(self, mail_to, branch, sign=False):
273        """Serialize as an email message.
274
275        :param mail_to: The address to mail the message to
276        :param branch: The source branch, to get the signing strategy and
277            source email address
278        :param sign: If True, gpg-sign the email
279        :return: an email message
280        """
281        mail_from = branch.get_config_stack().get('email')
282        if self.message is not None:
283            subject = self.message
284        else:
285            revision = branch.repository.get_revision(self.revision_id)
286            subject = revision.message
287        if sign:
288            body = self.to_signed(branch)
289        else:
290            body = b''.join(self.to_lines())
291        message = email_message.EmailMessage(mail_from, mail_to, subject,
292                                             body)
293        return message
294
295    def install_revisions(self, target_repo):
296        """Install revisions and return the target revision"""
297        if not target_repo.has_revision(self.revision_id):
298            if self.patch_type == 'bundle':
299                info = bundle_serializer.read_bundle(
300                    BytesIO(self.get_raw_bundle()))
301                # We don't use the bundle's target revision, because
302                # MergeDirective.revision_id is authoritative.
303                try:
304                    info.install_revisions(target_repo, stream_input=False)
305                except errors.RevisionNotPresent:
306                    # At least one dependency isn't present.  Try installing
307                    # missing revisions from the submit branch
308                    try:
309                        submit_branch = \
310                            _mod_branch.Branch.open(self.target_branch)
311                    except errors.NotBranchError:
312                        raise errors.TargetNotBranch(self.target_branch)
313                    missing_revisions = []
314                    bundle_revisions = set(r.revision_id for r in
315                                           info.real_revisions)
316                    for revision in info.real_revisions:
317                        for parent_id in revision.parent_ids:
318                            if (parent_id not in bundle_revisions
319                                    and not target_repo.has_revision(parent_id)):
320                                missing_revisions.append(parent_id)
321                    # reverse missing revisions to try to get heads first
322                    unique_missing = []
323                    unique_missing_set = set()
324                    for revision in reversed(missing_revisions):
325                        if revision in unique_missing_set:
326                            continue
327                        unique_missing.append(revision)
328                        unique_missing_set.add(revision)
329                    for missing_revision in unique_missing:
330                        target_repo.fetch(submit_branch.repository,
331                                          missing_revision)
332                    info.install_revisions(target_repo, stream_input=False)
333            else:
334                source_branch = _mod_branch.Branch.open(self.source_branch)
335                target_repo.fetch(source_branch.repository, self.revision_id)
336        return self.revision_id
337
338    def compose_merge_request(self, mail_client, to, body, branch, tree=None):
339        """Compose a request to merge this directive.
340
341        :param mail_client: The mail client to use for composing this request.
342        :param to: The address to compose the request to.
343        :param branch: The Branch that was used to produce this directive.
344        :param tree: The Tree (if any) for the Branch used to produce this
345            directive.
346        """
347        basename = self.get_disk_name(branch)
348        subject = '[MERGE] '
349        if self.message is not None:
350            subject += self.message
351        else:
352            revision = branch.repository.get_revision(self.revision_id)
353            subject += revision.get_summary()
354        if getattr(mail_client, 'supports_body', False):
355            orig_body = body
356            for hook in self.hooks['merge_request_body']:
357                params = MergeRequestBodyParams(body, orig_body, self,
358                                                to, basename, subject, branch,
359                                                tree)
360                body = hook(params)
361        elif len(self.hooks['merge_request_body']) > 0:
362            trace.warning('Cannot run merge_request_body hooks because mail'
363                          ' client %s does not support message bodies.',
364                          mail_client.__class__.__name__)
365        mail_client.compose_merge_request(to, subject,
366                                          b''.join(self.to_lines()),
367                                          basename, body)
368
369
370class MergeDirective(BaseMergeDirective):
371
372    """A request to perform a merge into a branch.
373
374    Designed to be serialized and mailed.  It provides all the information
375    needed to perform a merge automatically, by providing at minimum a revision
376    bundle or the location of a branch.
377
378    The serialization format is robust against certain common forms of
379    deterioration caused by mailing.
380
381    The format is also designed to be patch-compatible.  If the directive
382    includes a diff or revision bundle, it should be possible to apply it
383    directly using the standard patch program.
384    """
385
386    _format_string = b'Bazaar merge directive format 1'
387
388    def __init__(self, revision_id, testament_sha1, time, timezone,
389                 target_branch, patch=None, patch_type=None,
390                 source_branch=None, message=None, bundle=None):
391        """Constructor.
392
393        :param revision_id: The revision to merge
394        :param testament_sha1: The sha1 of the testament of the revision to
395            merge.
396        :param time: The current POSIX timestamp time
397        :param timezone: The timezone offset
398        :param target_branch: Location of the branch to apply the merge to
399        :param patch: The text of a diff or bundle
400        :param patch_type: None, "diff" or "bundle", depending on the contents
401            of patch
402        :param source_branch: A public location to merge the revision from
403        :param message: The message to use when committing this merge
404        """
405        BaseMergeDirective.__init__(self, revision_id, testament_sha1, time,
406                                    timezone, target_branch, patch, source_branch, message)
407        if patch_type not in (None, 'diff', 'bundle'):
408            raise ValueError(patch_type)
409        if patch_type != 'bundle' and source_branch is None:
410            raise errors.NoMergeSource()
411        if patch_type is not None and patch is None:
412            raise errors.PatchMissing(patch_type)
413        self.patch_type = patch_type
414
415    def clear_payload(self):
416        self.patch = None
417        self.patch_type = None
418
419    def get_raw_bundle(self):
420        return self.bundle
421
422    def _bundle(self):
423        if self.patch_type == 'bundle':
424            return self.patch
425        else:
426            return None
427
428    bundle = property(_bundle)
429
430    @classmethod
431    def from_lines(klass, lines):
432        """Deserialize a MergeRequest from an iterable of lines
433
434        :param lines: An iterable of lines
435        :return: a MergeRequest
436        """
437        line_iter = iter(lines)
438        firstline = b""
439        for line in line_iter:
440            if line.startswith(b'# Bazaar merge directive format '):
441                return _format_registry.get(line[2:].rstrip())._from_lines(
442                    line_iter)
443            firstline = firstline or line.strip()
444        raise errors.NotAMergeDirective(firstline)
445
446    @classmethod
447    def _from_lines(klass, line_iter):
448        stanza = rio.read_patch_stanza(line_iter)
449        patch_lines = list(line_iter)
450        if len(patch_lines) == 0:
451            patch = None
452            patch_type = None
453        else:
454            patch = b''.join(patch_lines)
455            try:
456                bundle_serializer.read_bundle(BytesIO(patch))
457            except (errors.NotABundle, errors.BundleNotSupported,
458                    errors.BadBundle):
459                patch_type = 'diff'
460            else:
461                patch_type = 'bundle'
462        time, timezone = timestamp.parse_patch_date(stanza.get('timestamp'))
463        kwargs = {}
464        for key in ('revision_id', 'testament_sha1', 'target_branch',
465                    'source_branch', 'message'):
466            try:
467                kwargs[key] = stanza.get(key)
468            except KeyError:
469                pass
470        kwargs['revision_id'] = kwargs['revision_id'].encode('utf-8')
471        if 'testament_sha1' in kwargs:
472            kwargs['testament_sha1'] = kwargs['testament_sha1'].encode('ascii')
473        return MergeDirective(time=time, timezone=timezone,
474                              patch_type=patch_type, patch=patch, **kwargs)
475
476    def to_lines(self):
477        lines = self._to_lines()
478        if self.patch is not None:
479            lines.extend(self.patch.splitlines(True))
480        return lines
481
482    @staticmethod
483    def _generate_bundle(repository, revision_id, ancestor_id):
484        s = BytesIO()
485        bundle_serializer.write_bundle(repository, revision_id,
486                                       ancestor_id, s, '0.9')
487        return s.getvalue()
488
489    def get_merge_request(self, repository):
490        """Provide data for performing a merge
491
492        Returns suggested base, suggested target, and patch verification status
493        """
494        return None, self.revision_id, 'inapplicable'
495
496
497class MergeDirective2(BaseMergeDirective):
498
499    _format_string = b'Bazaar merge directive format 2 (Bazaar 0.90)'
500
501    def __init__(self, revision_id, testament_sha1, time, timezone,
502                 target_branch, patch=None, source_branch=None, message=None,
503                 bundle=None, base_revision_id=None):
504        if source_branch is None and bundle is None:
505            raise errors.NoMergeSource()
506        BaseMergeDirective.__init__(self, revision_id, testament_sha1, time,
507                                    timezone, target_branch, patch, source_branch, message)
508        self.bundle = bundle
509        self.base_revision_id = base_revision_id
510
511    def _patch_type(self):
512        if self.bundle is not None:
513            return 'bundle'
514        elif self.patch is not None:
515            return 'diff'
516        else:
517            return None
518
519    patch_type = property(_patch_type)
520
521    def clear_payload(self):
522        self.patch = None
523        self.bundle = None
524
525    def get_raw_bundle(self):
526        if self.bundle is None:
527            return None
528        else:
529            return base64.b64decode(self.bundle)
530
531    @classmethod
532    def _from_lines(klass, line_iter):
533        stanza = rio.read_patch_stanza(line_iter)
534        patch = None
535        bundle = None
536        try:
537            start = next(line_iter)
538        except StopIteration:
539            pass
540        else:
541            if start.startswith(b'# Begin patch'):
542                patch_lines = []
543                for line in line_iter:
544                    if line.startswith(b'# Begin bundle'):
545                        start = line
546                        break
547                    patch_lines.append(line)
548                else:
549                    start = None
550                patch = b''.join(patch_lines)
551            if start is not None:
552                if start.startswith(b'# Begin bundle'):
553                    bundle = b''.join(line_iter)
554                else:
555                    raise IllegalMergeDirectivePayload(start)
556        time, timezone = timestamp.parse_patch_date(stanza.get('timestamp'))
557        kwargs = {}
558        for key in ('revision_id', 'testament_sha1', 'target_branch',
559                    'source_branch', 'message', 'base_revision_id'):
560            try:
561                kwargs[key] = stanza.get(key)
562            except KeyError:
563                pass
564        kwargs['revision_id'] = kwargs['revision_id'].encode('utf-8')
565        kwargs['base_revision_id'] =\
566            kwargs['base_revision_id'].encode('utf-8')
567        if 'testament_sha1' in kwargs:
568            kwargs['testament_sha1'] = kwargs['testament_sha1'].encode('ascii')
569        return klass(time=time, timezone=timezone, patch=patch, bundle=bundle,
570                     **kwargs)
571
572    def to_lines(self):
573        lines = self._to_lines(base_revision=True)
574        if self.patch is not None:
575            lines.append(b'# Begin patch\n')
576            lines.extend(self.patch.splitlines(True))
577        if self.bundle is not None:
578            lines.append(b'# Begin bundle\n')
579            lines.extend(self.bundle.splitlines(True))
580        return lines
581
582    @classmethod
583    def from_objects(klass, repository, revision_id, time, timezone,
584                     target_branch, include_patch=True, include_bundle=True,
585                     local_target_branch=None, public_branch=None, message=None,
586                     base_revision_id=None):
587        """Generate a merge directive from various objects
588
589        :param repository: The repository containing the revision
590        :param revision_id: The revision to merge
591        :param time: The POSIX timestamp of the date the request was issued.
592        :param timezone: The timezone of the request
593        :param target_branch: The url of the branch to merge into
594        :param include_patch: If true, include a preview patch
595        :param include_bundle: If true, include a bundle
596        :param local_target_branch: the target branch, either itself or a local copy
597        :param public_branch: location of a public branch containing
598            the target revision.
599        :param message: Message to use when committing the merge
600        :return: The merge directive
601
602        The public branch is always used if supplied.  If no bundle is
603        included, the public branch must be supplied, and will be verified.
604
605        If the message is not supplied, the message from revision_id will be
606        used for the commit.
607        """
608        with contextlib.ExitStack() as exit_stack:
609            exit_stack.enter_context(repository.lock_write())
610            t_revision_id = revision_id
611            if revision_id == b'null:':
612                t_revision_id = None
613            t = testament.StrictTestament3.from_revision(repository,
614                                                         t_revision_id)
615            if local_target_branch is None:
616                submit_branch = _mod_branch.Branch.open(target_branch)
617            else:
618                submit_branch = local_target_branch
619            exit_stack.enter_context(submit_branch.lock_read())
620            if submit_branch.get_public_branch() is not None:
621                target_branch = submit_branch.get_public_branch()
622            submit_revision_id = submit_branch.last_revision()
623            submit_revision_id = _mod_revision.ensure_null(submit_revision_id)
624            graph = repository.get_graph(submit_branch.repository)
625            ancestor_id = graph.find_unique_lca(revision_id,
626                                                submit_revision_id)
627            if base_revision_id is None:
628                base_revision_id = ancestor_id
629            if (include_patch, include_bundle) != (False, False):
630                repository.fetch(submit_branch.repository, submit_revision_id)
631            if include_patch:
632                patch = klass._generate_diff(repository, revision_id,
633                                             base_revision_id)
634            else:
635                patch = None
636
637            if include_bundle:
638                bundle = base64.b64encode(klass._generate_bundle(repository, revision_id,
639                                                                 ancestor_id))
640            else:
641                bundle = None
642
643            if public_branch is not None and not include_bundle:
644                public_branch_obj = _mod_branch.Branch.open(public_branch)
645                exit_stack.enter_context(public_branch_obj.lock_read())
646                if not public_branch_obj.repository.has_revision(
647                        revision_id):
648                    raise errors.PublicBranchOutOfDate(public_branch,
649                                                       revision_id)
650            testament_sha1 = t.as_sha1()
651        return klass(revision_id, testament_sha1, time, timezone,
652                     target_branch, patch, public_branch, message, bundle,
653                     base_revision_id)
654
655    def _verify_patch(self, repository):
656        calculated_patch = self._generate_diff(repository, self.revision_id,
657                                               self.base_revision_id)
658        # Convert line-endings to UNIX
659        stored_patch = re.sub(b'\r\n?', b'\n', self.patch)
660        calculated_patch = re.sub(b'\r\n?', b'\n', calculated_patch)
661        # Strip trailing whitespace
662        calculated_patch = re.sub(b' *\n', b'\n', calculated_patch)
663        stored_patch = re.sub(b' *\n', b'\n', stored_patch)
664        return (calculated_patch == stored_patch)
665
666    def get_merge_request(self, repository):
667        """Provide data for performing a merge
668
669        Returns suggested base, suggested target, and patch verification status
670        """
671        verified = self._maybe_verify(repository)
672        return self.base_revision_id, self.revision_id, verified
673
674    def _maybe_verify(self, repository):
675        if self.patch is not None:
676            if self._verify_patch(repository):
677                return 'verified'
678            else:
679                return 'failed'
680        else:
681            return 'inapplicable'
682
683
684class MergeDirectiveFormatRegistry(registry.Registry):
685
686    def register(self, directive, format_string=None):
687        if format_string is None:
688            format_string = directive._format_string
689        registry.Registry.register(self, format_string, directive)
690
691
692_format_registry = MergeDirectiveFormatRegistry()
693_format_registry.register(MergeDirective)
694_format_registry.register(MergeDirective2)
695# 0.19 never existed.  It got renamed to 0.90.  But by that point, there were
696# already merge directives in the wild that used 0.19. Registering with the old
697# format string to retain compatibility with those merge directives.
698_format_registry.register(MergeDirective2,
699                          b'Bazaar merge directive format 2 (Bazaar 0.19)')
700