1# utils.py -- Test utilities for Dulwich. 2# Copyright (C) 2010 Google, Inc. 3# 4# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU 5# General Public License as public by the Free Software Foundation; version 2.0 6# or (at your option) any later version. You can redistribute it and/or 7# modify it under the terms of either of these two licenses. 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# You should have received a copy of the licenses; if not, see 16# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License 17# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache 18# License, Version 2.0. 19# 20 21"""Utility functions common to Dulwich tests.""" 22 23 24import datetime 25import os 26import shutil 27import tempfile 28import time 29import types 30 31import warnings 32 33from dulwich.index import ( 34 commit_tree, 35 ) 36from dulwich.objects import ( 37 FixedSha, 38 Commit, 39 Tag, 40 object_class, 41 ) 42from dulwich.pack import ( 43 OFS_DELTA, 44 REF_DELTA, 45 DELTA_TYPES, 46 obj_sha, 47 SHA1Writer, 48 write_pack_header, 49 write_pack_object, 50 create_delta, 51 ) 52from dulwich.repo import Repo 53from dulwich.tests import ( # noqa: F401 54 skipIf, 55 SkipTest, 56 ) 57 58 59# Plain files are very frequently used in tests, so let the mode be very short. 60F = 0o100644 # Shorthand mode for Files. 61 62 63def open_repo(name, temp_dir=None): 64 """Open a copy of a repo in a temporary directory. 65 66 Use this function for accessing repos in dulwich/tests/data/repos to avoid 67 accidentally or intentionally modifying those repos in place. Use 68 tear_down_repo to delete any temp files created. 69 70 Args: 71 name: The name of the repository, relative to 72 dulwich/tests/data/repos 73 temp_dir: temporary directory to initialize to. If not provided, a 74 temporary directory will be created. 75 Returns: An initialized Repo object that lives in a temporary directory. 76 """ 77 if temp_dir is None: 78 temp_dir = tempfile.mkdtemp() 79 repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name) 80 temp_repo_dir = os.path.join(temp_dir, name) 81 shutil.copytree(repo_dir, temp_repo_dir, symlinks=True) 82 return Repo(temp_repo_dir) 83 84 85def tear_down_repo(repo): 86 """Tear down a test repository.""" 87 repo.close() 88 temp_dir = os.path.dirname(repo.path.rstrip(os.sep)) 89 shutil.rmtree(temp_dir) 90 91 92def make_object(cls, **attrs): 93 """Make an object for testing and assign some members. 94 95 This method creates a new subclass to allow arbitrary attribute 96 reassignment, which is not otherwise possible with objects having 97 __slots__. 98 99 Args: 100 attrs: dict of attributes to set on the new object. 101 Returns: A newly initialized object of type cls. 102 """ 103 104 class TestObject(cls): 105 """Class that inherits from the given class, but without __slots__. 106 107 Note that classes with __slots__ can't have arbitrary attributes 108 monkey-patched in, so this is a class that is exactly the same only 109 with a __dict__ instead of __slots__. 110 """ 111 pass 112 TestObject.__name__ = 'TestObject_' + cls.__name__ 113 114 obj = TestObject() 115 for name, value in attrs.items(): 116 if name == 'id': 117 # id property is read-only, so we overwrite sha instead. 118 sha = FixedSha(value) 119 obj.sha = lambda: sha 120 else: 121 setattr(obj, name, value) 122 return obj 123 124 125def make_commit(**attrs): 126 """Make a Commit object with a default set of members. 127 128 Args: 129 attrs: dict of attributes to overwrite from the default values. 130 Returns: A newly initialized Commit object. 131 """ 132 default_time = 1262304000 # 2010-01-01 00:00:00 133 all_attrs = {'author': b'Test Author <test@nodomain.com>', 134 'author_time': default_time, 135 'author_timezone': 0, 136 'committer': b'Test Committer <test@nodomain.com>', 137 'commit_time': default_time, 138 'commit_timezone': 0, 139 'message': b'Test message.', 140 'parents': [], 141 'tree': b'0' * 40} 142 all_attrs.update(attrs) 143 return make_object(Commit, **all_attrs) 144 145 146def make_tag(target, **attrs): 147 """Make a Tag object with a default set of values. 148 149 Args: 150 target: object to be tagged (Commit, Blob, Tree, etc) 151 attrs: dict of attributes to overwrite from the default values. 152 Returns: A newly initialized Tag object. 153 """ 154 target_id = target.id 155 target_type = object_class(target.type_name) 156 default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple())) 157 all_attrs = {'tagger': b'Test Author <test@nodomain.com>', 158 'tag_time': default_time, 159 'tag_timezone': 0, 160 'message': b'Test message.', 161 'object': (target_type, target_id), 162 'name': b'Test Tag', 163 } 164 all_attrs.update(attrs) 165 return make_object(Tag, **all_attrs) 166 167 168def functest_builder(method, func): 169 """Generate a test method that tests the given function.""" 170 171 def do_test(self): 172 method(self, func) 173 174 return do_test 175 176 177def ext_functest_builder(method, func): 178 """Generate a test method that tests the given extension function. 179 180 This is intended to generate test methods that test both a pure-Python 181 version and an extension version using common test code. The extension test 182 will raise SkipTest if the extension is not found. 183 184 Sample usage: 185 186 class MyTest(TestCase); 187 def _do_some_test(self, func_impl): 188 self.assertEqual('foo', func_impl()) 189 190 test_foo = functest_builder(_do_some_test, foo_py) 191 test_foo_extension = ext_functest_builder(_do_some_test, _foo_c) 192 193 Args: 194 method: The method to run. It must must two parameters, self and the 195 function implementation to test. 196 func: The function implementation to pass to method. 197 """ 198 199 def do_test(self): 200 if not isinstance(func, types.BuiltinFunctionType): 201 raise SkipTest("%s extension not found" % func) 202 method(self, func) 203 204 return do_test 205 206 207def build_pack(f, objects_spec, store=None): 208 """Write test pack data from a concise spec. 209 210 Args: 211 f: A file-like object to write the pack to. 212 objects_spec: A list of (type_num, obj). For non-delta types, obj 213 is the string of that object's data. 214 For delta types, obj is a tuple of (base, data), where: 215 216 * base can be either an index in objects_spec of the base for that 217 * delta; or for a ref delta, a SHA, in which case the resulting pack 218 * will be thin and the base will be an external ref. 219 * data is a string of the full, non-deltified data for that object. 220 221 Note that offsets/refs and deltas are computed within this function. 222 store: An optional ObjectStore for looking up external refs. 223 Returns: A list of tuples in the order specified by objects_spec: 224 (offset, type num, data, sha, CRC32) 225 """ 226 sf = SHA1Writer(f) 227 num_objects = len(objects_spec) 228 write_pack_header(sf, num_objects) 229 230 full_objects = {} 231 offsets = {} 232 crc32s = {} 233 234 while len(full_objects) < num_objects: 235 for i, (type_num, data) in enumerate(objects_spec): 236 if type_num not in DELTA_TYPES: 237 full_objects[i] = (type_num, data, 238 obj_sha(type_num, [data])) 239 continue 240 base, data = data 241 if isinstance(base, int): 242 if base not in full_objects: 243 continue 244 base_type_num, _, _ = full_objects[base] 245 else: 246 base_type_num, _ = store.get_raw(base) 247 full_objects[i] = (base_type_num, data, 248 obj_sha(base_type_num, [data])) 249 250 for i, (type_num, obj) in enumerate(objects_spec): 251 offset = f.tell() 252 if type_num == OFS_DELTA: 253 base_index, data = obj 254 base = offset - offsets[base_index] 255 _, base_data, _ = full_objects[base_index] 256 obj = (base, create_delta(base_data, data)) 257 elif type_num == REF_DELTA: 258 base_ref, data = obj 259 if isinstance(base_ref, int): 260 _, base_data, base = full_objects[base_ref] 261 else: 262 base_type_num, base_data = store.get_raw(base_ref) 263 base = obj_sha(base_type_num, base_data) 264 obj = (base, create_delta(base_data, data)) 265 266 crc32 = write_pack_object(sf, type_num, obj) 267 offsets[i] = offset 268 crc32s[i] = crc32 269 270 expected = [] 271 for i in range(num_objects): 272 type_num, data, sha = full_objects[i] 273 assert len(sha) == 20 274 expected.append((offsets[i], type_num, data, sha, crc32s[i])) 275 276 sf.write_sha() 277 f.seek(0) 278 return expected 279 280 281def build_commit_graph(object_store, commit_spec, trees=None, attrs=None): 282 """Build a commit graph from a concise specification. 283 284 Sample usage: 285 >>> c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]]) 286 >>> store[store[c3].parents[0]] == c1 287 True 288 >>> store[store[c3].parents[1]] == c2 289 True 290 291 If not otherwise specified, commits will refer to the empty tree and have 292 commit times increasing in the same order as the commit spec. 293 294 Args: 295 object_store: An ObjectStore to commit objects to. 296 commit_spec: An iterable of iterables of ints defining the commit 297 graph. Each entry defines one commit, and entries must be in 298 topological order. The first element of each entry is a commit number, 299 and the remaining elements are its parents. The commit numbers are only 300 meaningful for the call to make_commits; since real commit objects are 301 created, they will get created with real, opaque SHAs. 302 trees: An optional dict of commit number -> tree spec for building 303 trees for commits. The tree spec is an iterable of (path, blob, mode) 304 or (path, blob) entries; if mode is omitted, it defaults to the normal 305 file mode (0100644). 306 attrs: A dict of commit number -> (dict of attribute -> value) for 307 assigning additional values to the commits. 308 Returns: The list of commit objects created. 309 Raises: 310 ValueError: If an undefined commit identifier is listed as a parent. 311 """ 312 if trees is None: 313 trees = {} 314 if attrs is None: 315 attrs = {} 316 commit_time = 0 317 nums = {} 318 commits = [] 319 320 for commit in commit_spec: 321 commit_num = commit[0] 322 try: 323 parent_ids = [nums[pn] for pn in commit[1:]] 324 except KeyError as e: 325 missing_parent, = e.args 326 raise ValueError('Unknown parent %i' % missing_parent) 327 328 blobs = [] 329 for entry in trees.get(commit_num, []): 330 if len(entry) == 2: 331 path, blob = entry 332 entry = (path, blob, F) 333 path, blob, mode = entry 334 blobs.append((path, blob.id, mode)) 335 object_store.add_object(blob) 336 tree_id = commit_tree(object_store, blobs) 337 338 commit_attrs = { 339 'message': ('Commit %i' % commit_num).encode('ascii'), 340 'parents': parent_ids, 341 'tree': tree_id, 342 'commit_time': commit_time, 343 } 344 commit_attrs.update(attrs.get(commit_num, {})) 345 commit_obj = make_commit(**commit_attrs) 346 347 # By default, increment the time by a lot. Out-of-order commits should 348 # be closer together than this because their main cause is clock skew. 349 commit_time = commit_attrs['commit_time'] + 100 350 nums[commit_num] = commit_obj.id 351 object_store.add_object(commit_obj) 352 commits.append(commit_obj) 353 354 return commits 355 356 357def setup_warning_catcher(): 358 """Wrap warnings.showwarning with code that records warnings.""" 359 360 caught_warnings = [] 361 original_showwarning = warnings.showwarning 362 363 def custom_showwarning(*args, **kwargs): 364 caught_warnings.append(args[0]) 365 366 warnings.showwarning = custom_showwarning 367 368 def restore_showwarning(): 369 warnings.showwarning = original_showwarning 370 371 return caught_warnings, restore_showwarning 372