1"""
2Copyright 2019 Brain Electrophysiology Laboratory Company LLC
3
4Licensed under the ApacheLicense, Version 2.0(the "License");
5you may not use this module except in compliance with the License.
6You may obtain a copy of the License at:
7
8http: // www.apache.org / licenses / LICENSE - 2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an
12"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
13ANY KIND, either express or implied.
14"""
15from datetime import datetime
16from typing import Tuple, Dict, List
17
18from deprecated import deprecated
19import numpy as np
20
21from .cached_property import cached_property
22
23from . import xml_files
24from .xml_files import XML, Categories, Epochs
25from . import bin_files
26from .mffdir import get_directory
27from base64 import b64encode
28
29
30def object_to_bytes(object, encoding='utf-8'):
31    """
32    Translate an object into its string form
33    and then convert that string into its raw bytes form.
34    :param object: An object to convert into a bytes literal.
35    :param encoding: A string value indicating the encoding to use.
36                     Defaults to 'utf-8'.
37    :return: the converted bytes object.
38    """
39    return bytes(str(object), encoding=encoding)
40
41
42class Reader:
43    """
44    Create an .mff reader
45
46    class `Reader` is the main entry point to `mffpy`'s functionality.
47
48    :throws: ValueError if the passed filename
49             does not point to a valid MFF file.
50
51    Example use:
52    ```python
53    import mffpy
54    fo = mffpy.Reader('./examples/example_1.mff')
55    fo.set_unit('EEG', 'uV')
56    X = fo.read_physical_samples_from_epoch(
57            fo.epochs[0], channels=['EEG'])
58    ```
59    """
60
61    def __init__(self, filename: str):
62        self.directory = get_directory(filename)
63
64    @cached_property
65    @deprecated(version='0.6.3', reason='Use ".mff_flavor" instead')
66    def flavor(self) -> str:
67        """deprecated.  Use `.mff_flavor` instead"""
68        if 'history.xml' in self.directory.listdir():
69            with self.directory.filepointer('history') as fp:
70                history = XML.from_file(fp)
71                return history.mff_flavor()
72        else:
73            return 'continuous'
74
75    @cached_property
76    def mff_flavor(self) -> str:
77        """returns flavor of the MFF
78
79        The flavor is either, 'continuous', or 'segmented'.  A file has flavor
80        'segmented' if 'categories.xml' exist.
81        """
82        if 'categories.xml' in self.directory.listdir():
83            return 'segmented'
84
85        return 'continuous'
86
87    @cached_property
88    def categories(self) -> Categories:
89        """
90        ```python
91        Reader.categories
92        ```
93        categories present in a loaded MFF file
94
95        Return dictionary of categories names and the segments of
96        data associated with each category. If this is a continuous
97        MFF file, this method results in a ValueError.
98        """
99        with self.directory.filepointer('categories') as fp:
100            categories = XML.from_file(fp)
101        assert isinstance(categories, xml_files.Categories), f"""
102            .xml file 'categories.xml' of wrong type {type(categories)}"""
103        return categories
104
105    @cached_property
106    def epochs(self) -> Epochs:
107        """
108        ```python
109        Reader.epochs
110        ```
111        return all epochs in MFF file
112
113        Return a list of `epoch.Epoch` objects containing information
114        about each epoch in the MFF file. If categories information
115        is present, this is used to fill in `Epoch.name` for each epoch.
116        """
117        with self.directory.filepointer('epochs') as fp:
118            epochs = XML.from_file(fp)
119        assert isinstance(epochs, xml_files.Epochs), f"""
120            .xml file 'epochs.xml' of wrong type {type(epochs)}"""
121        # Attempt to add category names to the `Epoch` objects in `epochs`
122        try:
123            categories = self.categories
124        except (ValueError, AssertionError):
125            print('categories.xml not found or of wrong type. '
126                  '`Epoch.name` will default to "epoch" for all epochs.')
127            return epochs
128        # Sort category info by start time of each block
129        epochs.associate_categories(categories)
130        return epochs
131
132    @cached_property
133    def sampling_rates(self) -> Dict[str, float]:
134        """
135        ```python
136        Reader.sampling_rates
137        ```
138        sampling rates by channel type
139
140        Return dictionary of sampling rate by channel type.  Each
141        sampling rate is returned in Hz as a float.
142        """
143        return {
144            fn: bin_file.sampling_rate
145            for fn, bin_file in self._blobs.items()
146        }
147
148    @cached_property
149    def durations(self) -> Dict[str, float]:
150        """
151        ```python
152        Reader.durations
153        ```
154        recorded durations by channel type
155
156        Return dictionary of duration by channel type.  Each
157        duration is returned in seconds as a float.
158        """
159        return {
160            fn: bin_file.duration
161            for fn, bin_file in self._blobs.items()
162        }
163
164    @cached_property
165    def startdatetime(self) -> datetime:
166        """
167        ```python
168        Reader.startdatetime
169        ```
170        UTC start date and time of the recording
171
172        Return UTC start date and time of the recording.  The
173        returned object is of type `datetime.datetime`.
174        """
175        with self.directory.filepointer('info') as fp:
176            info = XML.from_file(fp)
177        assert isinstance(info, xml_files.FileInfo), f"""
178        .xml file 'info.xml' of wrong type {type(info)}"""
179        return info.recordTime
180
181    @property
182    def units(self) -> Dict[str, str]:
183        """
184        ```python
185        Reader.units
186        ```
187
188        Return dictionary of units by channel type.  Each unit is returned as a
189        `str` of SI units (micro: `'u'`).
190        """
191        return {
192            fn: bin_file.unit
193            for fn, bin_file in self._blobs.items()
194        }
195
196    @cached_property
197    def num_channels(self) -> Dict[str, int]:
198        """
199        ```python
200        Reader.num_channels
201        ```
202
203        Return dictionary of number of channels by channel type.  Each
204        number is returned as an `int`.
205        """
206        return {
207            fn: bin_file.num_channels
208            for fn, bin_file in self._blobs.items()
209        }
210
211    @property
212    def _blobs(self) -> Dict[str, bin_files.BinFile]:
213        """return dictionary of `BinFile` data readers by signal type"""
214        __blobs = {}
215        for si in self.directory.signals_with_info():
216            with self.directory.filepointer(si.info) as fp:
217                info = XML.from_file(fp)
218            bf = bin_files.BinFile(si.signal, info,
219                                   info.generalInformation['channel_type'])
220            __blobs[bf.signal_type] = bf
221        return __blobs
222
223    def set_unit(self, channel_type: str, unit: str):
224        """set output units for a type of channels
225
226        Set physical unit of a channel type.  The allowed conversion
227        values for `unit` depend on the original unit.  We allow all
228        combinations of conversions of 'V', 'mV', 'uV'.
229
230        **Arguments**
231
232        * **`channel_type`**: `str` with the channel type.
233        * **`unit`**: `str` with the unit you would like to convert to.
234
235        **Example use**
236
237        ```python
238        import mffpy
239        fo = mffpy.Reader('./examples/example_1.mff')
240        fo.set_unit('EEG', 'uV')
241        ```
242        """
243        self._blobs[channel_type].unit = unit
244
245    def set_calibration(self, channel_type: str, cal: str):
246        """set calibration of a channel type"""
247        self._blobs[channel_type].calibration = cal
248
249    def get_physical_samples(self, t0: float = 0.0, dt: float = None,
250                             channels: List[str] = None,
251                             block_slice: slice = None
252                             ) -> Dict[str, Tuple[np.ndarray, float]]:
253        """return signal data in the range `(t0, t0+dt)` in seconds from `channels`
254
255        Use `get_physical_samples_from_epoch` instead."""
256        if channels is None:
257            channels = list(self._blobs.keys())
258
259        return {
260            typ: blob.get_physical_samples(t0, dt, block_slice=block_slice)
261            for typ, blob in self._blobs.items()
262            if typ in channels
263        }
264
265    def get_physical_samples_from_epoch(self, epoch: xml_files.Epoch,
266                                        t0: float = 0.0, dt: float = None,
267                                        channels: List[str] = None
268                                        ) -> Dict[str,
269                                                  Tuple[np.ndarray, float]]:
270        """
271        return samples and start time by channels of an epoch
272
273        Returns a `dict` of tuples of [0] signal samples by channel names given
274        and [1] the start time in seconds, with keys from the list `channels`.
275        The samples will be in the range `(t0, t0+dt)` taken relative to
276        `epoch.t0`.
277
278        **Arguments**
279
280        * **`epoch`**: `xml_files.Epoch` from which you would like to get data.
281
282        * **`t0`**: `float` with relative offset in seconds into the data from
283        epoch start.
284
285        * **`dt`**: `float` with requested signal duration in seconds.  Value
286        `None` defaults to maximum starting at `t0`.
287
288        * **`channels`**: `list` of channel-type `str` each of which will be a
289        key in the returned `dict`.  `None` defaults to all available channels.
290
291        **Note**
292
293        * The start time of the returned data is `epoch.t0` seconds from
294        recording start.
295
296        * Only the epoch data can be requested.  If you want to pad these, do
297        it yourself.
298
299        * No interpolation will be performed to correct for the fact that `t0`
300        falls in between samples.
301
302        **Example use**
303
304        ```python
305        import mffpy
306        fo = mffpy.Reader('./examples/example_1.mff')
307        X = fo.read_physical_samples_from_epoch(fo.epochs[0], t0, dt)
308        eeg, t0_eeg = X['EEG']
309        ```
310        """
311        assert isinstance(epoch, xml_files.Epoch), f"""
312        argument epoch of type {type(epoch)} [requires {xml_files.Epoch}]"""
313        assert t0 >= 0.0, "Only non-negative `t0` allowed [%s]" % t0
314        dt = dt if dt is None or 0.0 <= dt < epoch.dt-t0 else None
315        return self.get_physical_samples(
316            t0, dt, channels, block_slice=epoch.block_slice)
317
318    def get_mff_content(self):
319        """return the content of an mff file.
320
321        The output of this function is supposed to return a dictionary
322        containing one serializable object per valid .xml file. Valid
323        .xml files are those whose types belongs to one of the available
324        XMLType sub-classes.
325
326        **Returns**
327        mff_content: dict
328            dictionary containing the content of an mff file.
329        """
330
331        # Root tags corresponding to available XMLType sub-classes
332        xml_root_tags = xml_files.XMLType.xml_root_tags()
333        # Create the dictionary that will be returned by this function
334        mff_content = {tag: {} for tag in xml_root_tags}
335
336        # Iterate over existing .xml files
337        for xmlfile in self.directory.files_by_type['.xml']:
338            with self.directory.filepointer(xmlfile) as fp:
339                try:
340                    obj = XML.from_file(fp)
341                    content = obj.get_serializable_content()
342
343                    if obj.xml_root_tag == 'categories':
344                        # Add EEG data to each segment of each category
345                        for category in content['categories'].values():
346                            # Iterate over each segment
347                            for segment in category:
348                                # Multiply time values by 1e-6
349                                # because "get_physical_samples" function
350                                # expects time values to be in seconds.
351                                t0 = segment['beginTime'] * 1e-6
352                                dt = (segment['endTime'] -
353                                      segment['beginTime']) * 1e-6
354                                # Get samples from current segment
355                                samples = self.get_physical_samples(
356                                    t0=t0, dt=dt, channels=['EEG'])
357                                eeg, start_time = samples['EEG']
358                                # Insert an EEG data field into each segment.
359                                # Compress EEG data using a
360                                # base64 encoding scheme.
361                                segment['eegData'] = str(
362                                    b64encode(object_to_bytes(eeg.tolist())),
363                                    encoding='utf-8')
364
365                    mff_content[obj.xml_root_tag] = content
366                except KeyError as e:
367                    print(f'{e} is not one of the valid .xml types')
368
369        # Add extra info to the returned dictionary
370        mff_content['samplingRate'] = self.sampling_rates['EEG']
371        mff_content['durations'] = self.durations['EEG']
372        mff_content['units'] = self.units['EEG']
373        mff_content['numChannels'] = self.num_channels['EEG']
374
375        return mff_content
376