1import os
2import shutil
3import tarfile
4import tempfile
5import zipfile
6
7import numpy as np
8from PyQt5.QtCore import QDir
9from PyQt5.QtWidgets import QFileDialog, QMessageBox
10
11from urh.signalprocessing.IQArray import IQArray
12
13archives = {}
14""":type: dict of [str, str]
15   :param: archives[extracted_filename] = filename"""
16
17RECENT_PATH = QDir.homePath()
18
19SIGNAL_FILE_EXTENSIONS_BY_TYPE = {
20    np.int8: ".complex16s",
21    np.uint8: ".complex16u",
22    np.int16: ".complex32s",
23    np.uint16: ".complex32u",
24    np.float32: ".complex",
25    np.complex64: ".complex"
26}
27
28SIGNAL_NAME_FILTERS_BY_TYPE = {
29    np.int8: "Complex16 signed (*.complex16s *.cs8)",
30    np.uint8: "Complex16 unsigned (*.complex16u *.cu8)",
31    np.uint16: "Complex32 unsigned (*.complex32u *.cu16)",
32    np.int16: "Complex32 signed (*.complex32s *.cs16)",
33    np.float32: "Complex (*.complex)",
34    np.complex64: "Complex (*.complex)"
35}
36
37EVERYTHING_FILE_FILTER = "All Files (*)"
38
39SIGNAL_NAME_FILTERS = list(sorted(set(SIGNAL_NAME_FILTERS_BY_TYPE.values())))
40
41COMPRESSED_COMPLEX_FILE_FILTER = "Compressed Complex File (*.coco)"
42WAV_FILE_FILTER = "Waveform Audio File Format (*.wav *.wave)"
43PROTOCOL_FILE_FILTER = "Protocol (*.proto.xml *.proto)"
44BINARY_PROTOCOL_FILE_FILTER = "Binary Protocol (*.bin)"
45PLAIN_BITS_FILE_FILTER = "Plain Bits (*.txt)"
46FUZZING_FILE_FILTER = "Fuzzing Profile (*.fuzz.xml *.fuzz)"
47SIMULATOR_FILE_FILTER = "Simulator Profile (*.sim.xml *.sim)"
48TAR_FILE_FILTER = "Tar Archive (*.tar *.tar.gz *.tar.bz2)"
49ZIP_FILE_FILTER = "Zip Archive (*.zip)"
50
51
52def __get__name_filter_for_signals() -> str:
53    return ";;".join([EVERYTHING_FILE_FILTER] + SIGNAL_NAME_FILTERS + [COMPRESSED_COMPLEX_FILE_FILTER, WAV_FILE_FILTER])
54
55
56def get_open_dialog(directory_mode=False, parent=None, name_filter="full") -> QFileDialog:
57    dialog = QFileDialog(parent=parent, directory=RECENT_PATH)
58
59    if directory_mode:
60        dialog.setFileMode(QFileDialog.Directory)
61        dialog.setWindowTitle("Open Folder")
62    else:
63        dialog.setFileMode(QFileDialog.ExistingFiles)
64        dialog.setWindowTitle("Open Files")
65        if name_filter == "full":
66            name_filter = __get__name_filter_for_signals() + ";;" \
67                          + ";;".join([PROTOCOL_FILE_FILTER, BINARY_PROTOCOL_FILE_FILTER, PLAIN_BITS_FILE_FILTER,
68                                       FUZZING_FILE_FILTER, SIMULATOR_FILE_FILTER, TAR_FILE_FILTER, ZIP_FILE_FILTER])
69        elif name_filter == "signals_only":
70            name_filter = __get__name_filter_for_signals()
71        elif name_filter == "proto":
72            name_filter = ";;".join([PROTOCOL_FILE_FILTER, BINARY_PROTOCOL_FILE_FILTER])
73        elif name_filter == "fuzz":
74            name_filter = FUZZING_FILE_FILTER
75        elif name_filter == "simulator":
76            name_filter = SIMULATOR_FILE_FILTER
77
78        dialog.setNameFilter(name_filter)
79
80    return dialog
81
82
83def ask_save_file_name(initial_name: str, caption="Save signal", selected_name_filter=None):
84    global RECENT_PATH
85    if caption == "Save signal":
86        name_filter = __get__name_filter_for_signals()
87    elif caption == "Save fuzzing profile":
88        name_filter = FUZZING_FILE_FILTER
89    elif caption == "Save encoding":
90        name_filter = ""
91    elif caption == "Save simulator profile":
92        name_filter = SIMULATOR_FILE_FILTER
93    elif caption == "Export spectrogram":
94        name_filter = "Frequency Time (*.ft);;Frequency Time Amplitude (*.fta)"
95    elif caption == "Save protocol":
96        name_filter = ";;".join([PROTOCOL_FILE_FILTER, BINARY_PROTOCOL_FILE_FILTER])
97    elif caption == "Export demodulated":
98        name_filter = WAV_FILE_FILTER
99    else:
100        name_filter = EVERYTHING_FILE_FILTER
101
102    filename = None
103    dialog = QFileDialog(directory=RECENT_PATH, caption=caption, filter=name_filter)
104    dialog.setFileMode(QFileDialog.AnyFile)
105    dialog.setLabelText(QFileDialog.Accept, "Save")
106    dialog.setAcceptMode(QFileDialog.AcceptSave)
107
108    if selected_name_filter is not None:
109        dialog.selectNameFilter(selected_name_filter)
110
111    dialog.selectFile(initial_name)
112
113    if dialog.exec():
114        filename = dialog.selectedFiles()[0]
115
116    if filename:
117        RECENT_PATH = os.path.split(filename)[0]
118
119    return filename
120
121
122def ask_signal_file_name_and_save(signal_name: str, data, sample_rate=1e6, wav_only=False, parent=None) -> str:
123    if wav_only:
124        if not signal_name.endswith(".wav") and not signal_name.endswith(".wave"):
125            signal_name += ".wav"
126        selected_name_filter = WAV_FILE_FILTER
127    else:
128        if not any(signal_name.endswith(e) for e in SIGNAL_NAME_FILTERS_BY_TYPE.values()):
129            try:
130                dtype = next(d for d in SIGNAL_FILE_EXTENSIONS_BY_TYPE.keys() if d == data.dtype)
131                signal_name += SIGNAL_FILE_EXTENSIONS_BY_TYPE[dtype]
132                selected_name_filter = SIGNAL_NAME_FILTERS_BY_TYPE[dtype]
133            except StopIteration:
134                selected_name_filter = None
135        else:
136            selected_name_filter = None
137
138    filename = ask_save_file_name(signal_name, selected_name_filter=selected_name_filter)
139
140    if filename:
141        try:
142            save_data(data, filename, sample_rate=sample_rate)
143        except Exception as e:
144            QMessageBox.critical(parent, "Error saving signal", e.args[0])
145            filename = None
146    else:
147        filename = None
148
149    return filename
150
151
152def save_data(data, filename: str, sample_rate=1e6, num_channels=2):
153    if not isinstance(data, IQArray):
154        data = IQArray(data)
155
156    if filename.endswith(".wav"):
157        data.export_to_wav(filename, num_channels, sample_rate)
158    elif filename.endswith(".coco"):
159        data.save_compressed(filename)
160    else:
161        data.tofile(filename)
162
163    if filename in archives.keys():
164        archive = archives[filename]
165        if archive.endswith("zip"):
166            rewrite_zip(archive)
167        elif archive.endswith("tar") or archive.endswith("bz2") or archive.endswith("gz"):
168            rewrite_tar(archive)
169
170
171def save_signal(signal):
172    save_data(signal.iq_array.data, signal.filename, signal.sample_rate)
173
174
175def rewrite_zip(zip_name):
176    tempdir = tempfile.mkdtemp()
177    try:
178        temp_name = os.path.join(tempdir, 'new.zip')
179        files_in_archive = [f for f in archives.keys() if archives[f] == zip_name]
180        with zipfile.ZipFile(temp_name, 'w') as zip_write:
181            for filename in files_in_archive:
182                zip_write.write(filename)
183        shutil.move(temp_name, zip_name)
184    finally:
185        shutil.rmtree(tempdir)
186
187
188def rewrite_tar(tar_name: str):
189    tempdir = tempfile.mkdtemp()
190    compression = ""
191    if tar_name.endswith("gz"):
192        compression = "gz"
193    elif tar_name.endswith("bz2"):
194        compression = "bz2"
195    try:
196        ext = "" if len(compression) == 0 else "." + compression
197        temp_name = os.path.join(tempdir, 'new.tar' + ext)
198        files_in_archive = [f for f in archives.keys() if archives[f] == tar_name]
199        with tarfile.open(temp_name, 'w:' + compression) as tar_write:
200            for file in files_in_archive:
201                tar_write.add(file)
202        shutil.move(temp_name, tar_name)
203    finally:
204        shutil.rmtree(tempdir)
205
206
207def uncompress_archives(file_names, temp_dir):
208    """
209    Extract each archive from the list of filenames.
210    Normal files stay untouched.
211    Add all files to the Recent Files.
212    :type file_names: list of str
213    :type temp_dir: str
214    :rtype: list of str
215    """
216    result = []
217    for filename in file_names:
218        if filename.endswith(".tar") or filename.endswith(".tar.gz") or filename.endswith(".tar.bz2"):
219            obj = tarfile.open(filename, "r")
220            extracted_file_names = []
221            for j, member in enumerate(obj.getmembers()):
222                obj.extract(member, temp_dir)
223                extracted_filename = os.path.join(temp_dir, obj.getnames()[j])
224                extracted_file_names.append(extracted_filename)
225                archives[extracted_filename] = filename
226            result.extend(extracted_file_names[:])
227        elif filename.endswith(".zip"):
228            obj = zipfile.ZipFile(filename)
229            extracted_file_names = []
230            for j, info in enumerate(obj.infolist()):
231                obj.extract(info, path=temp_dir)
232                extracted_filename = os.path.join(temp_dir, obj.namelist()[j])
233                extracted_file_names.append(extracted_filename)
234                archives[extracted_filename] = filename
235            result.extend(extracted_file_names[:])
236        else:
237            result.append(filename)
238
239    return result
240
241
242def get_directory():
243    directory = QFileDialog.getExistingDirectory(None, "Choose Directory", QDir.homePath(),
244                                                 QFileDialog.ShowDirsOnly | QFileDialog.DontResolveSymlinks)
245    return directory
246