1import itertools
2import operator
3import os
4import pickle
5import sys
6import traceback
7import zlib
8from typing import (
9    Union,
10    Tuple,
11    Optional,
12    Dict,
13    cast,
14    Iterable,
15    Type,
16    BinaryIO,
17    FrozenSet,
18    Sequence,
19)
20
21from unrpa.errors import (
22    OutputDirectoryNotFoundError,
23    ErrorExtractingFile,
24    AmbiguousArchiveError,
25    UnknownArchiveError,
26)
27from unrpa.versions import official_rpa, alt, zix, unofficial_rpa
28from unrpa.versions.version import Version
29from unrpa.view import ArchiveView
30
31# Offset, Length
32SimpleIndexPart = Tuple[int, int]
33SimpleIndexEntry = Iterable[SimpleIndexPart]
34# Offset, Length, Prefix
35ComplexIndexPart = Tuple[int, int, bytes]
36ComplexIndexEntry = Iterable[ComplexIndexPart]
37IndexPart = Union[SimpleIndexPart, ComplexIndexPart]
38IndexEntry = Iterable[IndexPart]
39
40
41class TreeNode:
42    def __init__(self, name: str, children: Iterable[Sequence[str]]) -> None:
43        self.name = name
44        if children:
45            self.children = [
46                TreeNode(
47                    child,
48                    [
49                        subchild[1:]
50                        for subchild in children_of_child
51                        if len(subchild) > 1
52                    ],
53                )
54                for (child, children_of_child) in itertools.groupby(
55                    children, key=operator.itemgetter(0)
56                )
57            ]
58        else:
59            self.children = []
60
61
62class UnRPA:
63    """Extraction tool for RPA archives."""
64
65    name = "unrpa"
66
67    error = 0
68    info = 1
69    debug = 2
70
71    ordered_versions: Tuple[Type[Version], ...] = (
72        *official_rpa.versions,
73        *alt.versions,
74        *zix.versions,
75        *unofficial_rpa.versions,
76    )
77
78    provided_versions: FrozenSet[Type[Version]] = frozenset(ordered_versions)
79
80    def __init__(
81        self,
82        filename: str,
83        verbosity: int = -1,
84        path: Optional[str] = None,
85        mkdir: bool = False,
86        version: Optional[Type[Version]] = None,
87        continue_on_error: bool = False,
88        offset_and_key: Optional[Tuple[int, int]] = None,
89        extra_versions: FrozenSet[Type[Version]] = frozenset(),
90    ) -> None:
91        self.verbose = verbosity
92        if path:
93            self.path = os.path.abspath(path)
94        else:
95            self.path = os.getcwd()
96        self.mkdir = mkdir
97        self.version = version
98        self.archive = filename
99        self.continue_on_error = continue_on_error
100        self.offset_and_key = offset_and_key
101        self.tty = sys.stdout.isatty()
102        self.versions = UnRPA.provided_versions | extra_versions
103
104    def log(
105        self, verbosity: int, human_message: str, machine_message: str = None
106    ) -> None:
107        if self.tty and self.verbose > verbosity:
108            print(
109                human_message if self.tty else machine_message,
110                file=sys.stderr if verbosity == UnRPA.error else sys.stdout,
111            )
112
113    def extract_files(self) -> None:
114        self.log(UnRPA.error, f"Extracting files from {self.archive}.")
115        if self.mkdir:
116            self.make_directory_structure(self.path)
117        if not os.path.isdir(self.path):
118            raise OutputDirectoryNotFoundError(self.path)
119
120        version = self.version() if self.version else self.detect_version()
121
122        with open(self.archive, "rb") as archive:
123            index = self.get_index(archive, version)
124            total_files = len(index)
125            for file_number, (path, data) in enumerate(index.items()):
126                try:
127                    self.make_directory_structure(
128                        os.path.join(self.path, os.path.split(path)[0])
129                    )
130                    file_view = self.extract_file(
131                        path, data, file_number, total_files, archive
132                    )
133                    with open(os.path.join(self.path, path), "wb") as output_file:
134                        version.postprocess(file_view, output_file)
135                except BaseException as error:
136                    if self.continue_on_error:
137                        self.log(
138                            0,
139                            f"Error extracting from the archive, but directed to continue on error. Detail: "
140                            f"{traceback.format_exc()}.",
141                        )
142                    else:
143                        raise ErrorExtractingFile(traceback.format_exc()) from error
144
145    def list_files(self) -> None:
146        self.log(UnRPA.info, f"Listing files in {self.archive}:")
147        with open(self.archive, "rb") as archive:
148            paths = self.get_index(archive).keys()
149        for path in sorted(paths):
150            print(path)
151
152    def list_files_tree(self) -> None:
153        print(self.archive)
154        for line in self.tree_lines():
155            print(line)
156
157    def tree(self) -> TreeNode:
158        with open(self.archive, "rb") as archive:
159            paths = sorted(self.get_index(archive).keys())
160        return TreeNode(
161            self.archive,
162            [list(reversed(list(self.full_split(path)))) for path in paths],
163        )
164
165    @staticmethod
166    def full_split(path: str) -> Iterable[str]:
167        while path:
168            (path, tail) = os.path.split(path)
169            yield tail
170
171    def tree_lines(
172        self, current_node: TreeNode = None, prefix: str = ""
173    ) -> Iterable[str]:
174        if not current_node:
175            current_node = self.tree()
176        for child in current_node.children[:-1]:
177            yield f"{prefix}├--- {child.name}"
178            yield from self.tree_lines(child, f"{prefix}|    ")
179        if current_node.children:
180            child = current_node.children[-1]
181            yield f"{prefix}└--- {child.name}"
182            yield from self.tree_lines(child, f"{prefix}     ")
183
184    def extract_file(
185        self,
186        name: str,
187        data: ComplexIndexEntry,
188        file_number: int,
189        total_files: int,
190        archive: BinaryIO,
191    ) -> ArchiveView:
192        self.log(
193            UnRPA.info, f"[{file_number / float(total_files):04.2%}] {name:>3}", name
194        )
195        offset, length, start = next(iter(data))
196        return ArchiveView(archive, offset, length, start)
197
198    def make_directory_structure(self, name: str) -> None:
199        self.log(UnRPA.debug, f"Creating directory structure: {name}")
200        if not os.path.exists(name):
201            os.makedirs(name)
202
203    def get_index(
204        self, archive: BinaryIO, version: Optional[Version] = None
205    ) -> Dict[str, ComplexIndexEntry]:
206        if not version:
207            version = self.version() if self.version else self.detect_version()
208
209        offset = 0
210        key: Optional[int] = None
211        if self.offset_and_key:
212            offset, key = self.offset_and_key
213        else:
214            offset, key = version.find_offset_and_key(archive)
215        archive.seek(offset)
216        index: Dict[bytes, IndexEntry] = pickle.loads(
217            zlib.decompress(archive.read()), encoding="bytes"
218        )
219        if key is not None:
220            normal_index = UnRPA.deobfuscate_index(key, index)
221        else:
222            normal_index = UnRPA.normalise_index(index)
223
224        return {
225            UnRPA.ensure_str_path(path).replace("/", os.sep): data
226            for path, data in normal_index.items()
227        }
228
229    def detect_version(self) -> Version:
230        potential = (version() for version in self.versions)
231        ext = os.path.splitext(self.archive)[1].lower()
232        with open(self.archive, "rb") as f:
233            header = f.readline()
234            detected = {version for version in potential if version.detect(ext, header)}
235            if len(detected) > 1:
236                raise AmbiguousArchiveError(detected)
237            try:
238                return next(iter(detected))
239            except StopIteration:
240                raise UnknownArchiveError(header)
241
242    @staticmethod
243    def ensure_str_path(path: Union[str, bytes]) -> str:
244        if isinstance(path, str):
245            return path
246        else:
247            return path.decode("utf-8", "replace")
248
249    @staticmethod
250    def deobfuscate_index(
251        key: int, index: Dict[bytes, IndexEntry]
252    ) -> Dict[bytes, ComplexIndexEntry]:
253        return {
254            path: UnRPA.deobfuscate_entry(key, entry) for path, entry in index.items()
255        }
256
257    @staticmethod
258    def deobfuscate_entry(key: int, entry: IndexEntry) -> ComplexIndexEntry:
259        return [
260            (offset ^ key, length ^ key, start)
261            for offset, length, start in UnRPA.normalise_entry(entry)
262        ]
263
264    @staticmethod
265    def normalise_index(
266        index: Dict[bytes, IndexEntry]
267    ) -> Dict[bytes, ComplexIndexEntry]:
268        return {path: UnRPA.normalise_entry(entry) for path, entry in index.items()}
269
270    @staticmethod
271    def normalise_entry(entry: IndexEntry) -> ComplexIndexEntry:
272        return [
273            (*cast(SimpleIndexPart, part), b"")
274            if len(part) == 2
275            else cast(ComplexIndexPart, part)
276            for part in entry
277        ]
278