1"""Utilities for extracting common archive formats"""
2
3import zipfile
4import tarfile
5import os
6import shutil
7import posixpath
8import contextlib
9from distutils.errors import DistutilsError
10
11from pkg_resources import ensure_directory
12
13__all__ = [
14    "unpack_archive", "unpack_zipfile", "unpack_tarfile", "default_filter",
15    "UnrecognizedFormat", "extraction_drivers", "unpack_directory",
16]
17
18
19class UnrecognizedFormat(DistutilsError):
20    """Couldn't recognize the archive type"""
21
22
23def default_filter(src, dst):
24    """The default progress/filter callback; returns True for all files"""
25    return dst
26
27
28def unpack_archive(
29        filename, extract_dir, progress_filter=default_filter,
30        drivers=None):
31    """Unpack `filename` to `extract_dir`, or raise ``UnrecognizedFormat``
32
33    `progress_filter` is a function taking two arguments: a source path
34    internal to the archive ('/'-separated), and a filesystem path where it
35    will be extracted.  The callback must return the desired extract path
36    (which may be the same as the one passed in), or else ``None`` to skip
37    that file or directory.  The callback can thus be used to report on the
38    progress of the extraction, as well as to filter the items extracted or
39    alter their extraction paths.
40
41    `drivers`, if supplied, must be a non-empty sequence of functions with the
42    same signature as this function (minus the `drivers` argument), that raise
43    ``UnrecognizedFormat`` if they do not support extracting the designated
44    archive type.  The `drivers` are tried in sequence until one is found that
45    does not raise an error, or until all are exhausted (in which case
46    ``UnrecognizedFormat`` is raised).  If you do not supply a sequence of
47    drivers, the module's ``extraction_drivers`` constant will be used, which
48    means that ``unpack_zipfile`` and ``unpack_tarfile`` will be tried, in that
49    order.
50    """
51    for driver in drivers or extraction_drivers:
52        try:
53            driver(filename, extract_dir, progress_filter)
54        except UnrecognizedFormat:
55            continue
56        else:
57            return
58    else:
59        raise UnrecognizedFormat(
60            "Not a recognized archive type: %s" % filename
61        )
62
63
64def unpack_directory(filename, extract_dir, progress_filter=default_filter):
65    """"Unpack" a directory, using the same interface as for archives
66
67    Raises ``UnrecognizedFormat`` if `filename` is not a directory
68    """
69    if not os.path.isdir(filename):
70        raise UnrecognizedFormat("%s is not a directory" % filename)
71
72    paths = {
73        filename: ('', extract_dir),
74    }
75    for base, dirs, files in os.walk(filename):
76        src, dst = paths[base]
77        for d in dirs:
78            paths[os.path.join(base, d)] = src + d + '/', os.path.join(dst, d)
79        for f in files:
80            target = os.path.join(dst, f)
81            target = progress_filter(src + f, target)
82            if not target:
83                # skip non-files
84                continue
85            ensure_directory(target)
86            f = os.path.join(base, f)
87            shutil.copyfile(f, target)
88            shutil.copystat(f, target)
89
90
91def unpack_zipfile(filename, extract_dir, progress_filter=default_filter):
92    """Unpack zip `filename` to `extract_dir`
93
94    Raises ``UnrecognizedFormat`` if `filename` is not a zipfile (as determined
95    by ``zipfile.is_zipfile()``).  See ``unpack_archive()`` for an explanation
96    of the `progress_filter` argument.
97    """
98
99    if not zipfile.is_zipfile(filename):
100        raise UnrecognizedFormat("%s is not a zip file" % (filename,))
101
102    with zipfile.ZipFile(filename) as z:
103        for info in z.infolist():
104            name = info.filename
105
106            # don't extract absolute paths or ones with .. in them
107            if name.startswith('/') or '..' in name.split('/'):
108                continue
109
110            target = os.path.join(extract_dir, *name.split('/'))
111            target = progress_filter(name, target)
112            if not target:
113                continue
114            if name.endswith('/'):
115                # directory
116                ensure_directory(target)
117            else:
118                # file
119                ensure_directory(target)
120                data = z.read(info.filename)
121                with open(target, 'wb') as f:
122                    f.write(data)
123            unix_attributes = info.external_attr >> 16
124            if unix_attributes:
125                os.chmod(target, unix_attributes)
126
127
128def unpack_tarfile(filename, extract_dir, progress_filter=default_filter):
129    """Unpack tar/tar.gz/tar.bz2 `filename` to `extract_dir`
130
131    Raises ``UnrecognizedFormat`` if `filename` is not a tarfile (as determined
132    by ``tarfile.open()``).  See ``unpack_archive()`` for an explanation
133    of the `progress_filter` argument.
134    """
135    try:
136        tarobj = tarfile.open(filename)
137    except tarfile.TarError as e:
138        raise UnrecognizedFormat(
139            "%s is not a compressed or uncompressed tar file" % (filename,)
140        ) from e
141    with contextlib.closing(tarobj):
142        # don't do any chowning!
143        tarobj.chown = lambda *args: None
144        for member in tarobj:
145            name = member.name
146            # don't extract absolute paths or ones with .. in them
147            if not name.startswith('/') and '..' not in name.split('/'):
148                prelim_dst = os.path.join(extract_dir, *name.split('/'))
149
150                # resolve any links and to extract the link targets as normal
151                # files
152                while member is not None and (
153                        member.islnk() or member.issym()):
154                    linkpath = member.linkname
155                    if member.issym():
156                        base = posixpath.dirname(member.name)
157                        linkpath = posixpath.join(base, linkpath)
158                        linkpath = posixpath.normpath(linkpath)
159                    member = tarobj._getmember(linkpath)
160
161                if member is not None and (member.isfile() or member.isdir()):
162                    final_dst = progress_filter(name, prelim_dst)
163                    if final_dst:
164                        if final_dst.endswith(os.sep):
165                            final_dst = final_dst[:-1]
166                        try:
167                            # XXX Ugh
168                            tarobj._extract_member(member, final_dst)
169                        except tarfile.ExtractError:
170                            # chown/chmod/mkfifo/mknode/makedev failed
171                            pass
172        return True
173
174
175extraction_drivers = unpack_directory, unpack_zipfile, unpack_tarfile
176