1# utils.py -- Test utilities for Dulwich.
2# Copyright (C) 2010 Google, Inc.
3#
4# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
5# General Public License as public by the Free Software Foundation; version 2.0
6# or (at your option) any later version. You can redistribute it and/or
7# modify it under the terms of either of these two licenses.
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# You should have received a copy of the licenses; if not, see
16# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
17# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
18# License, Version 2.0.
19#
20
21"""Utility functions common to Dulwich tests."""
22
23
24import datetime
25import os
26import shutil
27import tempfile
28import time
29import types
30
31import warnings
32
33from dulwich.index import (
34    commit_tree,
35    )
36from dulwich.objects import (
37    FixedSha,
38    Commit,
39    Tag,
40    object_class,
41    )
42from dulwich.pack import (
43    OFS_DELTA,
44    REF_DELTA,
45    DELTA_TYPES,
46    obj_sha,
47    SHA1Writer,
48    write_pack_header,
49    write_pack_object,
50    create_delta,
51    )
52from dulwich.repo import Repo
53from dulwich.tests import (  # noqa: F401
54    skipIf,
55    SkipTest,
56    )
57
58
59# Plain files are very frequently used in tests, so let the mode be very short.
60F = 0o100644  # Shorthand mode for Files.
61
62
63def open_repo(name, temp_dir=None):
64    """Open a copy of a repo in a temporary directory.
65
66    Use this function for accessing repos in dulwich/tests/data/repos to avoid
67    accidentally or intentionally modifying those repos in place. Use
68    tear_down_repo to delete any temp files created.
69
70    Args:
71      name: The name of the repository, relative to
72        dulwich/tests/data/repos
73      temp_dir: temporary directory to initialize to. If not provided, a
74        temporary directory will be created.
75    Returns: An initialized Repo object that lives in a temporary directory.
76    """
77    if temp_dir is None:
78        temp_dir = tempfile.mkdtemp()
79    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
80    temp_repo_dir = os.path.join(temp_dir, name)
81    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
82    return Repo(temp_repo_dir)
83
84
85def tear_down_repo(repo):
86    """Tear down a test repository."""
87    repo.close()
88    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
89    shutil.rmtree(temp_dir)
90
91
92def make_object(cls, **attrs):
93    """Make an object for testing and assign some members.
94
95    This method creates a new subclass to allow arbitrary attribute
96    reassignment, which is not otherwise possible with objects having
97    __slots__.
98
99    Args:
100      attrs: dict of attributes to set on the new object.
101    Returns: A newly initialized object of type cls.
102    """
103
104    class TestObject(cls):
105        """Class that inherits from the given class, but without __slots__.
106
107        Note that classes with __slots__ can't have arbitrary attributes
108        monkey-patched in, so this is a class that is exactly the same only
109        with a __dict__ instead of __slots__.
110        """
111        pass
112    TestObject.__name__ = 'TestObject_' + cls.__name__
113
114    obj = TestObject()
115    for name, value in attrs.items():
116        if name == 'id':
117            # id property is read-only, so we overwrite sha instead.
118            sha = FixedSha(value)
119            obj.sha = lambda: sha
120        else:
121            setattr(obj, name, value)
122    return obj
123
124
125def make_commit(**attrs):
126    """Make a Commit object with a default set of members.
127
128    Args:
129      attrs: dict of attributes to overwrite from the default values.
130    Returns: A newly initialized Commit object.
131    """
132    default_time = 1262304000  # 2010-01-01 00:00:00
133    all_attrs = {'author': b'Test Author <test@nodomain.com>',
134                 'author_time': default_time,
135                 'author_timezone': 0,
136                 'committer': b'Test Committer <test@nodomain.com>',
137                 'commit_time': default_time,
138                 'commit_timezone': 0,
139                 'message': b'Test message.',
140                 'parents': [],
141                 'tree': b'0' * 40}
142    all_attrs.update(attrs)
143    return make_object(Commit, **all_attrs)
144
145
146def make_tag(target, **attrs):
147    """Make a Tag object with a default set of values.
148
149    Args:
150      target: object to be tagged (Commit, Blob, Tree, etc)
151      attrs: dict of attributes to overwrite from the default values.
152    Returns: A newly initialized Tag object.
153    """
154    target_id = target.id
155    target_type = object_class(target.type_name)
156    default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
157    all_attrs = {'tagger': b'Test Author <test@nodomain.com>',
158                 'tag_time': default_time,
159                 'tag_timezone': 0,
160                 'message': b'Test message.',
161                 'object': (target_type, target_id),
162                 'name': b'Test Tag',
163                 }
164    all_attrs.update(attrs)
165    return make_object(Tag, **all_attrs)
166
167
168def functest_builder(method, func):
169    """Generate a test method that tests the given function."""
170
171    def do_test(self):
172        method(self, func)
173
174    return do_test
175
176
177def ext_functest_builder(method, func):
178    """Generate a test method that tests the given extension function.
179
180    This is intended to generate test methods that test both a pure-Python
181    version and an extension version using common test code. The extension test
182    will raise SkipTest if the extension is not found.
183
184    Sample usage:
185
186    class MyTest(TestCase);
187        def _do_some_test(self, func_impl):
188            self.assertEqual('foo', func_impl())
189
190        test_foo = functest_builder(_do_some_test, foo_py)
191        test_foo_extension = ext_functest_builder(_do_some_test, _foo_c)
192
193    Args:
194      method: The method to run. It must must two parameters, self and the
195        function implementation to test.
196      func: The function implementation to pass to method.
197    """
198
199    def do_test(self):
200        if not isinstance(func, types.BuiltinFunctionType):
201            raise SkipTest("%s extension not found" % func)
202        method(self, func)
203
204    return do_test
205
206
207def build_pack(f, objects_spec, store=None):
208    """Write test pack data from a concise spec.
209
210    Args:
211      f: A file-like object to write the pack to.
212      objects_spec: A list of (type_num, obj). For non-delta types, obj
213        is the string of that object's data.
214        For delta types, obj is a tuple of (base, data), where:
215
216        * base can be either an index in objects_spec of the base for that
217        * delta; or for a ref delta, a SHA, in which case the resulting pack
218        * will be thin and the base will be an external ref.
219        * data is a string of the full, non-deltified data for that object.
220
221        Note that offsets/refs and deltas are computed within this function.
222      store: An optional ObjectStore for looking up external refs.
223    Returns: A list of tuples in the order specified by objects_spec:
224        (offset, type num, data, sha, CRC32)
225    """
226    sf = SHA1Writer(f)
227    num_objects = len(objects_spec)
228    write_pack_header(sf, num_objects)
229
230    full_objects = {}
231    offsets = {}
232    crc32s = {}
233
234    while len(full_objects) < num_objects:
235        for i, (type_num, data) in enumerate(objects_spec):
236            if type_num not in DELTA_TYPES:
237                full_objects[i] = (type_num, data,
238                                   obj_sha(type_num, [data]))
239                continue
240            base, data = data
241            if isinstance(base, int):
242                if base not in full_objects:
243                    continue
244                base_type_num, _, _ = full_objects[base]
245            else:
246                base_type_num, _ = store.get_raw(base)
247            full_objects[i] = (base_type_num, data,
248                               obj_sha(base_type_num, [data]))
249
250    for i, (type_num, obj) in enumerate(objects_spec):
251        offset = f.tell()
252        if type_num == OFS_DELTA:
253            base_index, data = obj
254            base = offset - offsets[base_index]
255            _, base_data, _ = full_objects[base_index]
256            obj = (base, create_delta(base_data, data))
257        elif type_num == REF_DELTA:
258            base_ref, data = obj
259            if isinstance(base_ref, int):
260                _, base_data, base = full_objects[base_ref]
261            else:
262                base_type_num, base_data = store.get_raw(base_ref)
263                base = obj_sha(base_type_num, base_data)
264            obj = (base, create_delta(base_data, data))
265
266        crc32 = write_pack_object(sf, type_num, obj)
267        offsets[i] = offset
268        crc32s[i] = crc32
269
270    expected = []
271    for i in range(num_objects):
272        type_num, data, sha = full_objects[i]
273        assert len(sha) == 20
274        expected.append((offsets[i], type_num, data, sha, crc32s[i]))
275
276    sf.write_sha()
277    f.seek(0)
278    return expected
279
280
281def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
282    """Build a commit graph from a concise specification.
283
284    Sample usage:
285    >>> c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
286    >>> store[store[c3].parents[0]] == c1
287    True
288    >>> store[store[c3].parents[1]] == c2
289    True
290
291    If not otherwise specified, commits will refer to the empty tree and have
292    commit times increasing in the same order as the commit spec.
293
294    Args:
295      object_store: An ObjectStore to commit objects to.
296      commit_spec: An iterable of iterables of ints defining the commit
297        graph. Each entry defines one commit, and entries must be in
298        topological order. The first element of each entry is a commit number,
299        and the remaining elements are its parents. The commit numbers are only
300        meaningful for the call to make_commits; since real commit objects are
301        created, they will get created with real, opaque SHAs.
302      trees: An optional dict of commit number -> tree spec for building
303        trees for commits. The tree spec is an iterable of (path, blob, mode)
304        or (path, blob) entries; if mode is omitted, it defaults to the normal
305        file mode (0100644).
306      attrs: A dict of commit number -> (dict of attribute -> value) for
307        assigning additional values to the commits.
308    Returns: The list of commit objects created.
309    Raises:
310      ValueError: If an undefined commit identifier is listed as a parent.
311    """
312    if trees is None:
313        trees = {}
314    if attrs is None:
315        attrs = {}
316    commit_time = 0
317    nums = {}
318    commits = []
319
320    for commit in commit_spec:
321        commit_num = commit[0]
322        try:
323            parent_ids = [nums[pn] for pn in commit[1:]]
324        except KeyError as e:
325            missing_parent, = e.args
326            raise ValueError('Unknown parent %i' % missing_parent)
327
328        blobs = []
329        for entry in trees.get(commit_num, []):
330            if len(entry) == 2:
331                path, blob = entry
332                entry = (path, blob, F)
333            path, blob, mode = entry
334            blobs.append((path, blob.id, mode))
335            object_store.add_object(blob)
336        tree_id = commit_tree(object_store, blobs)
337
338        commit_attrs = {
339            'message': ('Commit %i' % commit_num).encode('ascii'),
340            'parents': parent_ids,
341            'tree': tree_id,
342            'commit_time': commit_time,
343            }
344        commit_attrs.update(attrs.get(commit_num, {}))
345        commit_obj = make_commit(**commit_attrs)
346
347        # By default, increment the time by a lot. Out-of-order commits should
348        # be closer together than this because their main cause is clock skew.
349        commit_time = commit_attrs['commit_time'] + 100
350        nums[commit_num] = commit_obj.id
351        object_store.add_object(commit_obj)
352        commits.append(commit_obj)
353
354    return commits
355
356
357def setup_warning_catcher():
358    """Wrap warnings.showwarning with code that records warnings."""
359
360    caught_warnings = []
361    original_showwarning = warnings.showwarning
362
363    def custom_showwarning(*args,  **kwargs):
364        caught_warnings.append(args[0])
365
366    warnings.showwarning = custom_showwarning
367
368    def restore_showwarning():
369        warnings.showwarning = original_showwarning
370
371    return caught_warnings, restore_showwarning
372