1"""Utilities for extracting common archive formats"""
2
3import zipfile
4import tarfile
5import os
6import shutil
7import posixpath
8import contextlib
9from distutils.errors import DistutilsError
10
11from pkg_resources import ensure_directory, ContextualZipFile
12
13__all__ = [
14    "unpack_archive", "unpack_zipfile", "unpack_tarfile", "default_filter",
15    "UnrecognizedFormat", "extraction_drivers", "unpack_directory",
16]
17
18
19class UnrecognizedFormat(DistutilsError):
20    """Couldn't recognize the archive type"""
21
22
23def default_filter(src, dst):
24    """The default progress/filter callback; returns True for all files"""
25    return dst
26
27
28def unpack_archive(filename, extract_dir, progress_filter=default_filter,
29        drivers=None):
30    """Unpack `filename` to `extract_dir`, or raise ``UnrecognizedFormat``
31
32    `progress_filter` is a function taking two arguments: a source path
33    internal to the archive ('/'-separated), and a filesystem path where it
34    will be extracted.  The callback must return the desired extract path
35    (which may be the same as the one passed in), or else ``None`` to skip
36    that file or directory.  The callback can thus be used to report on the
37    progress of the extraction, as well as to filter the items extracted or
38    alter their extraction paths.
39
40    `drivers`, if supplied, must be a non-empty sequence of functions with the
41    same signature as this function (minus the `drivers` argument), that raise
42    ``UnrecognizedFormat`` if they do not support extracting the designated
43    archive type.  The `drivers` are tried in sequence until one is found that
44    does not raise an error, or until all are exhausted (in which case
45    ``UnrecognizedFormat`` is raised).  If you do not supply a sequence of
46    drivers, the module's ``extraction_drivers`` constant will be used, which
47    means that ``unpack_zipfile`` and ``unpack_tarfile`` will be tried, in that
48    order.
49    """
50    for driver in drivers or extraction_drivers:
51        try:
52            driver(filename, extract_dir, progress_filter)
53        except UnrecognizedFormat:
54            continue
55        else:
56            return
57    else:
58        raise UnrecognizedFormat(
59            "Not a recognized archive type: %s" % filename
60        )
61
62
63def unpack_directory(filename, extract_dir, progress_filter=default_filter):
64    """"Unpack" a directory, using the same interface as for archives
65
66    Raises ``UnrecognizedFormat`` if `filename` is not a directory
67    """
68    if not os.path.isdir(filename):
69        raise UnrecognizedFormat("%s is not a directory" % filename)
70
71    paths = {
72        filename: ('', extract_dir),
73    }
74    for base, dirs, files in os.walk(filename):
75        src, dst = paths[base]
76        for d in dirs:
77            paths[os.path.join(base, d)] = src + d + '/', os.path.join(dst, d)
78        for f in files:
79            target = os.path.join(dst, f)
80            target = progress_filter(src + f, target)
81            if not target:
82                # skip non-files
83                continue
84            ensure_directory(target)
85            f = os.path.join(base, f)
86            shutil.copyfile(f, target)
87            shutil.copystat(f, target)
88
89
90def unpack_zipfile(filename, extract_dir, progress_filter=default_filter):
91    """Unpack zip `filename` to `extract_dir`
92
93    Raises ``UnrecognizedFormat`` if `filename` is not a zipfile (as determined
94    by ``zipfile.is_zipfile()``).  See ``unpack_archive()`` for an explanation
95    of the `progress_filter` argument.
96    """
97
98    if not zipfile.is_zipfile(filename):
99        raise UnrecognizedFormat("%s is not a zip file" % (filename,))
100
101    with ContextualZipFile(filename) as z:
102        for info in z.infolist():
103            name = info.filename
104
105            # don't extract absolute paths or ones with .. in them
106            if name.startswith('/') or '..' in name.split('/'):
107                continue
108
109            target = os.path.join(extract_dir, *name.split('/'))
110            target = progress_filter(name, target)
111            if not target:
112                continue
113            if name.endswith('/'):
114                # directory
115                ensure_directory(target)
116            else:
117                # file
118                ensure_directory(target)
119                data = z.read(info.filename)
120                with open(target, 'wb') as f:
121                    f.write(data)
122            unix_attributes = info.external_attr >> 16
123            if unix_attributes:
124                os.chmod(target, unix_attributes)
125
126
127def unpack_tarfile(filename, extract_dir, progress_filter=default_filter):
128    """Unpack tar/tar.gz/tar.bz2 `filename` to `extract_dir`
129
130    Raises ``UnrecognizedFormat`` if `filename` is not a tarfile (as determined
131    by ``tarfile.open()``).  See ``unpack_archive()`` for an explanation
132    of the `progress_filter` argument.
133    """
134    try:
135        tarobj = tarfile.open(filename)
136    except tarfile.TarError:
137        raise UnrecognizedFormat(
138            "%s is not a compressed or uncompressed tar file" % (filename,)
139        )
140    with contextlib.closing(tarobj):
141        # don't do any chowning!
142        tarobj.chown = lambda *args: None
143        for member in tarobj:
144            name = member.name
145            # don't extract absolute paths or ones with .. in them
146            if not name.startswith('/') and '..' not in name.split('/'):
147                prelim_dst = os.path.join(extract_dir, *name.split('/'))
148
149                # resolve any links and to extract the link targets as normal
150                # files
151                while member is not None and (member.islnk() or member.issym()):
152                    linkpath = member.linkname
153                    if member.issym():
154                        base = posixpath.dirname(member.name)
155                        linkpath = posixpath.join(base, linkpath)
156                        linkpath = posixpath.normpath(linkpath)
157                    member = tarobj._getmember(linkpath)
158
159                if member is not None and (member.isfile() or member.isdir()):
160                    final_dst = progress_filter(name, prelim_dst)
161                    if final_dst:
162                        if final_dst.endswith(os.sep):
163                            final_dst = final_dst[:-1]
164                        try:
165                            # XXX Ugh
166                            tarobj._extract_member(member, final_dst)
167                        except tarfile.ExtractError:
168                            # chown/chmod/mkfifo/mknode/makedev failed
169                            pass
170        return True
171
172
173extraction_drivers = unpack_directory, unpack_zipfile, unpack_tarfile
174