1
2from __future__ import absolute_import
3from os.path import realpath
4from functools import partial
5
6from bup import client, git, vfs
7
8
9_next_repo_id = 0
10_repo_ids = {}
11
12def _repo_id(key):
13    global _next_repo_id, _repo_ids
14    repo_id = _repo_ids.get(key)
15    if repo_id:
16        return repo_id
17    next_id = _next_repo_id = _next_repo_id + 1
18    _repo_ids[key] = next_id
19    return next_id
20
21class LocalRepo:
22    def __init__(self, repo_dir=None):
23        self.repo_dir = realpath(repo_dir or git.repo())
24        self._cp = git.cp(self.repo_dir)
25        self.update_ref = partial(git.update_ref, repo_dir=self.repo_dir)
26        self.rev_list = partial(git.rev_list, repo_dir=self.repo_dir)
27        self._id = _repo_id(self.repo_dir)
28
29    def close(self):
30        pass
31
32    def __del__(self):
33        self.close()
34
35    def __enter__(self):
36        return self
37
38    def __exit__(self, type, value, traceback):
39        self.close()
40
41    def id(self):
42        """Return an identifier that differs from any other repository that
43        doesn't share the same repository-specific information
44        (e.g. refs, tags, etc.)."""
45        return self._id
46
47    def is_remote(self):
48        return False
49
50    def new_packwriter(self, compression_level=1,
51                       max_pack_size=None, max_pack_objects=None):
52        return git.PackWriter(repo_dir=self.repo_dir,
53                              compression_level=compression_level,
54                              max_pack_size=max_pack_size,
55                              max_pack_objects=max_pack_objects)
56
57    def cat(self, ref):
58        """If ref does not exist, yield (None, None, None).  Otherwise yield
59        (oidx, type, size), and then all of the data associated with
60        ref.
61
62        """
63        it = self._cp.get(ref)
64        oidx, typ, size = info = next(it)
65        yield info
66        if oidx:
67            for data in it:
68                yield data
69        assert not next(it, None)
70
71    def join(self, ref):
72        return self._cp.join(ref)
73
74    def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
75        for ref in git.list_refs(patterns=patterns,
76                                 limit_to_heads=limit_to_heads,
77                                 limit_to_tags=limit_to_tags,
78                                 repo_dir=self.repo_dir):
79            yield ref
80
81    ## Of course, the vfs better not call this...
82    def resolve(self, path, parent=None, want_meta=True, follow=True):
83        ## FIXME: mode_only=?
84        return vfs.resolve(self, path,
85                           parent=parent, want_meta=want_meta, follow=follow)
86
87
88class RemoteRepo:
89    def __init__(self, address):
90        self.address = address
91        self.client = client.Client(address)
92        self.new_packwriter = self.client.new_packwriter
93        self.update_ref = self.client.update_ref
94        self.rev_list = self.client.rev_list
95        self._id = _repo_id(self.address)
96
97    def close(self):
98        if self.client:
99            self.client.close()
100            self.client = None
101
102    def __del__(self):
103        self.close()
104
105    def __enter__(self):
106        return self
107
108    def __exit__(self, type, value, traceback):
109        self.close()
110
111    def id(self):
112        """Return an identifier that differs from any other repository that
113        doesn't share the same repository-specific information
114        (e.g. refs, tags, etc.)."""
115        return self._id
116
117    def is_remote(self):
118        return True
119
120    def cat(self, ref):
121        """If ref does not exist, yield (None, None, None).  Otherwise yield
122        (oidx, type, size), and then all of the data associated with
123        ref.
124
125        """
126        # Yield all the data here so that we don't finish the
127        # cat_batch iterator (triggering its cleanup) until all of the
128        # data has been read.  Otherwise we'd be out of sync with the
129        # server.
130        items = self.client.cat_batch((ref,))
131        oidx, typ, size, it = info = next(items)
132        yield info[:-1]
133        if oidx:
134            for data in it:
135                yield data
136        assert not next(items, None)
137
138    def join(self, ref):
139        return self.client.join(ref)
140
141    def refs(self, patterns=None, limit_to_heads=False, limit_to_tags=False):
142        for ref in self.client.refs(patterns=patterns,
143                                    limit_to_heads=limit_to_heads,
144                                    limit_to_tags=limit_to_tags):
145            yield ref
146
147    def resolve(self, path, parent=None, want_meta=True, follow=True):
148        ## FIXME: mode_only=?
149        return self.client.resolve(path, parent=parent, want_meta=want_meta,
150                                   follow=follow)
151