1import h5py
2import os
3import shutil
4from unittest.mock import patch
5
6from ont_fast5_api.compression_settings import VBZ, GZIP
7from ont_fast5_api.conversion_tools.check_file_compression import check_read_compression, check_compression
8from ont_fast5_api.conversion_tools.compress_fast5 import compress_file, compress_single_read, compress_batch
9from ont_fast5_api.fast5_file import Fast5File, EmptyFast5
10from ont_fast5_api.fast5_info import ReadInfo
11from ont_fast5_api.fast5_interface import get_fast5_file
12from ont_fast5_api.multi_fast5 import MultiFast5File
13from ont_fast5_api.static_data import OPTIONAL_READ_GROUPS
14from test.helpers import TestFast5ApiHelper, test_data
15
16
17class TestVbzReadWrite(TestFast5ApiHelper):
18    run_id = "123abc"
19
20    def test_write_vbz_directly(self):
21        input_data = range(10)
22        with h5py.File(os.path.join(self.save_path, 'h5py.fast5'), 'w') as fast5:
23            fast5.create_dataset('Signal', data=input_data, **vars(VBZ))
24            raw = fast5['Signal']
25
26            self.assertTrue(str(VBZ.compression) in raw._filters)
27            self.assertEqual(VBZ.compression_opts, raw._filters[str(VBZ.compression)])
28            self.assertEqual(list(input_data), list(raw))
29
30    def test_read_vbz_using_api(self):
31        with MultiFast5File(os.path.join(test_data, 'vbz_reads', 'vbz_reads.fast5'), 'r') as fast5:
32            read_count = 0
33            for read in fast5.get_reads():
34                # This input file was created to have 4 reads with 20 samples per read
35                read_count += 1
36                raw_data = read.get_raw_data()
37                self.assertEqual(20, len(raw_data))
38            self.assertEqual(4, read_count)
39
40    def test_write_vbz_using_api(self):
41        input_data = list(range(5))
42        read_id = "0a1b2c3d"
43        with MultiFast5File(self.generate_temp_filename(), 'w') as fast5:
44            fast5.create_empty_read(read_id, self.run_id)
45            read = fast5.get_read(read_id)
46            read.add_raw_data(input_data, attrs={}, compression=VBZ)
47            raw = read.get_raw_data()
48            # First check the data comes back in an appropriate form
49            self.assertEqual(input_data, list(raw))
50            # Then check the types are as they should be under the hood
51            filters = read.raw_compression_filters
52            self.assertTrue(str(VBZ.compression) in filters)
53            self.assertEqual(VBZ.compression_opts, filters[str(VBZ.compression)])
54
55    def test_write_vbz_using_api_single_read(self):
56        input_data = list(range(5))
57        read_id = "0a1b2c3d"
58        read_number = 0
59        with Fast5File(self.generate_temp_filename(), 'w') as fast5:
60            fast5.status.read_number_map[read_number] = read_number
61            fast5.status.read_info = [ReadInfo(read_number=read_number, read_id=read_id,
62                                               start_time=0, duration=len(input_data))]
63            fast5.add_raw_data(data=input_data, attrs={}, compression=VBZ)
64            raw = fast5.get_raw_data()
65            # First check the data comes back in an appropriate form
66            self.assertEqual(input_data, list(raw))
67
68            # Then check the types are as they should be under the hood
69            filters = fast5.raw_compression_filters
70            self.assertTrue(str(VBZ.compression) in filters)
71            self.assertEqual(VBZ.compression_opts, filters[str(VBZ.compression)])
72
73
74class TestVbzConvert(TestFast5ApiHelper):
75    run_id = "123abc"
76
77    def assertCompressed(self, data_path, expected_compression, read_count, file_count):
78        files = set()
79        read_ids = set()
80        for compression, read_id, filepath in check_compression(data_path, False, False, check_all_reads=True):
81            self.assertEqual(expected_compression, compression)
82            read_ids.add(read_id)
83            files.add(filepath)
84        self.assertEqual(read_count, len(read_ids))
85        self.assertEqual(file_count, len(files))
86
87    def test_add_read_from_multi(self):
88        target_compression = VBZ
89        with get_fast5_file(os.path.join(test_data, "multi_read", "batch_0.fast5"), "r") as input_f5, \
90                MultiFast5File(self.generate_temp_filename(), 'w') as output_f5:
91            read_id = input_f5.get_read_ids()[0]
92            input_read = input_f5.get_read(read_id)
93
94            # Input read should be uncompressed on the way in:
95            self.assertEqual(check_read_compression(input_read), GZIP)
96
97            output_f5.add_existing_read(input_read, target_compression)
98
99            output_read = output_f5.get_read(read_id)
100            self.assertEqual(check_read_compression(output_read), VBZ)
101
102    def test_compress_read_from_single(self):
103        with get_fast5_file(os.path.join(test_data, "single_reads", "read0.fast5"), "r") as input_f5, \
104                EmptyFast5(self.generate_temp_filename(), 'w') as output_f5:
105            read_id = input_f5.get_read_ids()[0]
106            input_read = input_f5.get_read(read_id)
107
108            # Input read should be uncompressed on the way in:
109            self.assertEqual(check_read_compression(input_read), GZIP)
110
111            compress_single_read(output_f5, input_read, target_compression=VBZ)
112
113            output_read = output_f5.get_read(read_id)
114            self.assertEqual(check_read_compression(output_read), VBZ)
115
116    @patch('ont_fast5_api.conversion_tools.compress_fast5.get_progress_bar')
117    def test_conversion_script_multi(self, mock_pbar):
118        input_folder = os.path.join(test_data, 'multi_read')
119        compress_batch(input_folder=input_folder, output_folder=self.save_path, target_compression=VBZ)
120        self.assertCompressed(self.save_path, VBZ, read_count=4, file_count=1)
121
122    @patch('ont_fast5_api.conversion_tools.compress_fast5.get_progress_bar')
123    def test_conversion_script_single(self, mock_pbar):
124        input_folder = os.path.join(test_data, 'single_reads')
125        compress_batch(input_folder=input_folder, output_folder=self.save_path, target_compression=VBZ)
126        self.assertCompressed(self.save_path, VBZ, read_count=4, file_count=4)
127
128    @patch('ont_fast5_api.conversion_tools.compress_fast5.get_progress_bar')
129    def test_compress_in_place(self, mock_pbar):
130        for input_file in os.listdir(os.path.join(test_data, 'single_reads')):
131            # We copy file by file as copytree won't work to an existing directory
132            shutil.copy(os.path.join(test_data, 'single_reads', input_file), self.save_path)
133
134        self.assertCompressed(self.save_path, GZIP, read_count=4, file_count=4)
135        in_files = set(os.listdir(self.save_path))
136        compress_batch(self.save_path, output_folder=None, target_compression=VBZ, in_place=True)
137        self.assertCompressed(self.save_path, VBZ, read_count=4, file_count=4)
138        self.assertEqual(in_files, set(os.listdir(self.save_path)))
139
140
141class TestSanitise(TestFast5ApiHelper):
142
143    @staticmethod
144    def list_groups(fname, single_multi='multi'):
145        split_index = {
146            'multi': 1, 'single': 0}
147        all_groups = list()
148        filtered_groups = list()
149        def _add_group(name):
150            all_groups.append(name)
151            try:
152                subgroup = name.split('/')[split_index[single_multi]]
153            except IndexError:
154                # top level
155                filtered_groups.append(name)
156            else:
157                if not subgroup in OPTIONAL_READ_GROUPS:
158                    filtered_groups.append(name)
159        with h5py.File(fname, 'r') as fh:
160            fh.visit(_add_group)
161        return all_groups, filtered_groups
162
163    def _test(self, input_file, output_file, single_or_multi):
164        orig_all_groups, orig_filtered_groups = self.list_groups(input_file, single_or_multi)
165        new_all_groups, new_filtered_groups = self.list_groups(output_file, single_or_multi)
166
167        self.assertNotEqual(orig_all_groups, orig_filtered_groups)
168        self.assertEqual(orig_filtered_groups, new_filtered_groups)
169        self.assertEqual(new_all_groups, new_filtered_groups)
170
171    def test_multi_to_multi(self):
172        input_file = os.path.join(test_data, "multi_read_analyses", "batch_0.fast5")
173        output_file = self.generate_temp_filename()
174        compress_file(input_file, output_file, VBZ, sanitize=True)
175        self._test(input_file, output_file, 'multi')
176
177    def test_single_to_multi(self):
178        input_file = os.path.join(test_data, "single_read_analyses", "read.fast5")
179        output_file = self.generate_temp_filename()
180        with Fast5File(input_file, 'r') as input_f5, \
181                EmptyFast5(output_file, 'a') as output_f5:
182            compress_single_read(output_f5, input_f5, VBZ, sanitize=True)
183        self._test(input_file, output_file, 'single')
184
185