1import logging
2import warnings
3import xml.etree.ElementTree as ET
4from datetime import datetime
5from collections import defaultdict
6import numpy as np
7from typing import Tuple, Dict, List, Any, Union, IO
8from .cached_property import cached_property
9from .dict2xml import TEXT, ATTR
10from .epoch import Epoch
11import copy
12"""
13Copyright 2019 Brain Electrophysiology Laboratory Company LLC
14
15Licensed under the ApacheLicense, Version 2.0(the "License");
16you may not use this module except in compliance with the License.
17You may obtain a copy of the License at:
18
19http: // www.apache.org / licenses / LICENSE - 2.0
20
21Unless required by applicable law or agreed to in writing, software
22distributed under the License is distributed on an
23"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
24ANY KIND, either express or implied.
25"""
26
27"""Parsing for all xml files"""
28
29
30FilePointer = Union[str, IO[bytes]]
31
32
33class XMLType(type):
34    """`XMLType` registers all xml types
35
36    Spawn the _right_ XMLType sub-class with `from_file`.  To be registered,
37    sub-classes need to implement class attributes `_xmlns` and
38    `_xmlroottag`."""
39
40    _registry: Dict[str, Any] = {}
41    _tag_registry: Dict[str, Any] = {}
42    _logger = logging.getLogger(name='XMLType')
43    _extensions = ['.xml', '.XML']
44    _supported_versions: Tuple[str, ...] = ('',)
45    _time_format = "%Y-%m-%dT%H:%M:%S.%f%z"
46
47    def __new__(typ, name, bases, dct):
48        new_xml_type = super().__new__(typ, name, bases, dct)
49        typ.register(new_xml_type)
50        return new_xml_type
51
52    @classmethod
53    def register(typ, xml_type):
54        try:
55            ns_tag = xml_type._xmlns + xml_type._xmlroottag
56            if ns_tag in typ._registry:
57                typ._logger.warn("overwritting %s in registry" %
58                                 typ._registry[ns_tag])
59            typ._registry[ns_tag] = xml_type
60            # add another key for the same type
61            typ._tag_registry[xml_type._xmlroottag] = xml_type
62            return True
63        except (AttributeError, TypeError):
64            typ._logger.info("type %s cannot be registered" % xml_type)
65            return False
66
67    @classmethod
68    def from_file(typ, filepointer: FilePointer):
69        """return new `XMLType` instance of the appropriate sub-class
70
71        **Parameters**
72        *filepointer*: str or IO[bytes]
73            pointer to the xml file
74        """
75        xml_root = ET.parse(filepointer).getroot()
76        return typ._registry[xml_root.tag](xml_root)
77
78    @classmethod
79    def todict(typ, xmltype, **kwargs) -> Dict[str, Any]:
80        """return dict of `kwargs` specific for `xmltype`
81
82        The output of this function is supposed to fit perfectly into
83        `mffpy.json2xml.dict2xml` returning a valid .xml file for the specific
84        xml type.  For this, each of these types needs to implement a method
85        `content` that takes `**kwargs` as argument.
86        """
87        assert xmltype in typ._tag_registry, f"""
88        {xmltype} is not one of the valid .xml types:
89        {typ._tag_registry.keys()}"""
90        T = typ._tag_registry[xmltype]
91        return {
92            'content': T.content(**kwargs),
93            'rootname': T._xmlroottag,
94            'filename': T._default_filename,
95            # remove the '}'/'{' characters
96            'namespace': T._xmlns[1:-1]
97        }
98
99    @classmethod
100    def xml_root_tags(cls):
101        """return list of root tags of supported xml files"""
102        return list(cls._tag_registry.keys())
103
104
105class XML(metaclass=XMLType):
106
107    _default_filename: str = ''
108
109    def __init__(self, xml_root):
110        self.root = xml_root
111
112    @classmethod
113    def _parse_time_str(cls, txt):
114        # convert time string "2003-04-17T13:35:22.000000-08:00"
115        # to "2003-04-17T13:35:22.000000-0800" ..
116        if txt.count(':') == 3:
117            txt = txt[::-1].replace(':', '', 1)[::-1]
118        return datetime.strptime(txt, cls._time_format)
119
120    @classmethod
121    def _dump_datetime(cls, dt):
122        assert dt.tzinfo is not None, f"""
123        Timezone required for date/time {dt}"""
124        txt = dt.strftime(cls._time_format)
125        return txt[:-2] + ':' + txt[-2:]
126
127    def find(self, tag, root=None):
128        root = root or self.root
129        return root.find(self._xmlns+tag)
130
131    def findall(self, tag, root=None):
132        root = root or self.root
133        return root.findall(self._xmlns+tag)
134
135    def nsstrip(self, tag):
136        return tag[len(self._xmlns):]
137
138    @property
139    def xml_root_tag(self):
140        return self._xmlroottag
141
142    @classmethod
143    def content(typ, *args, **kwargs):
144        """checks and returns `**kwargs` as a formatted dict"""
145        raise NotImplementedError(f"""
146        json converter not implemented for type {typ}""")
147
148
149class FileInfo(XML):
150
151    _xmlns = '{http://www.egi.com/info_mff}'
152    _xmlroottag = 'fileInfo'
153    _default_filename = 'info.xml'
154    _supported_versions = ('3',)
155
156    @cached_property
157    def mffVersion(self):
158        el = self.find('mffVersion')
159        return None if el is None else el.text
160
161    @property
162    def version(self):
163        """return mffVersion"""
164        warnings.warn(".version is deprecated, use .mffVersion instead",
165                      DeprecationWarning)
166        return self.mffVersion
167
168    @cached_property
169    def acquisitionVersion(self):
170        """return content of acquisitionVersion field"""
171        el = self.find('acquisitionVersion')
172        return None if el is None else el.text
173
174    @cached_property
175    def ampType(self):
176        """return content of ampType field"""
177        el = self.find('ampType')
178        return None if el is None else el.text
179
180    @cached_property
181    def recordTime(self):
182        el = self.find('recordTime')
183        return self._parse_time_str(el.text) if el is not None else None
184
185    @classmethod
186    def content(cls, recordTime: datetime,  # type: ignore
187                mffVersion: str = '3',
188                acquisitionVersion: str = None,
189                ampType: str = None) -> dict:
190        """returns MFF file information
191
192        Only Version '3' is supported.
193        """
194        assert mffVersion in cls._supported_versions, f"""
195        version {mffVersion} not supported"""
196        content = {
197            'mffVersion': {
198                TEXT: mffVersion
199            },
200            'recordTime': {
201                TEXT: cls._dump_datetime(recordTime)
202            }
203        }
204        if acquisitionVersion:
205            content.update(acquisitionVersion={TEXT: acquisitionVersion})
206
207        if ampType:
208            content.update(ampType={TEXT: ampType})
209
210        return content
211
212    def get_content(self):
213        """return dictionary of MFF file information
214
215        The info are comprised of:
216        - MFF version number
217        - time of start of recording
218        - acquisition version (optional)
219        - amplifier type (optional)
220        """
221        content = {
222            'mffVersion': self.mffVersion,
223            'recordTime': self.recordTime
224        }
225        if self.acquisitionVersion:
226            content.update(acquisitionVersion=self.acquisitionVersion)
227
228        if self.ampType:
229            content.update(ampType=self.ampType)
230
231        return content
232
233    def get_serializable_content(self):
234        """return serializable dictionary of MFF file information"""
235        content = copy.deepcopy(self.get_content())
236        content['recordTime'] = XML._dump_datetime(content['recordTime'])
237        return content
238
239
240class DataInfo(XML):
241
242    _xmlns = r'{http://www.egi.com/info_n_mff}'
243    _xmlroottag = r'dataInfo'
244    _default_filename = 'info1.xml'
245
246    @cached_property
247    def generalInformation(self):
248        el = self.find('fileDataType', self.find('generalInformation'))
249        el = el[0]
250        info = {}
251        info['channel_type'] = self.nsstrip(el.tag)
252        for el_i in el:
253            info[self.nsstrip(el_i.tag)] = el_i.text
254        return info
255
256    @cached_property
257    def filters(self):
258        filters = []
259        if self.find('filters') is not None:
260            filters = [self._parse_filter(f) for f in self.find('filters')]
261        return filters
262
263    def _parse_filter(self, f):
264        ans = {}
265        for prop in (('beginTime', float), 'method', 'type'):
266            prop, conv = prop if len(prop) == 2 else (prop, lambda x: x)
267            ans[prop] = conv(self.find(prop, f).text)
268
269        el = self.find('cutoffFrequency', f)
270        ans['cutoffFrequency'] = (float(el.text), el.get('units'))
271        return ans
272
273    @cached_property
274    def calibrations(self):
275        calibrations = self.find('calibrations')
276        ans = {}
277        if calibrations is not None:
278            for cali in calibrations:
279                typ = self.find('type', cali)
280                ans[typ.text] = self._parse_calibration(cali)
281        return ans
282
283    def _parse_calibration(self, cali):
284        ans = {}
285        ans['beginTime'] = float(self.find('beginTime', cali).text)
286        ans['channels'] = {
287            int(el.get('n')): np.float32(el.text)
288            for el in self.find('channels', cali)
289        }
290        return ans
291
292    @classmethod
293    def content(cls, fileDataType: str,  # type: ignore
294                dataTypeProps: dict = None,
295                filters: List[dict] = None,
296                calibrations: List[dict] = None) -> dict:
297        """returns info on the associated (data) .bin file
298
299        **Parameters**
300
301        *fileDataType*: indicates the type of data
302        *dataTypeProps*: indicates the recording device
303        *filters*: lists all applied filters
304        *calibrations*: lists a number of calibrations
305        """
306        dataTypeProps = dataTypeProps or {}
307        calibrations = calibrations or []
308        filters = filters or []
309        return {
310            'generalInformation': {
311                TEXT: {
312                    'fileDataType': {
313                        TEXT: {
314                            fileDataType: {
315                                TEXT: {
316                                    k: {TEXT: v}
317                                    for k, v in dataTypeProps.items()
318                                }
319                            }
320                        }
321                    }
322                }
323            },
324            'filters': {
325                'filter': [
326                    {
327                        'beginTime': {TEXT: f['beginTime']},
328                        'method': {TEXT: f['method']},
329                        'type': {TEXT: f['type']},
330                        'cutoffFrequency': {TEXT: f['cutoffFrequency'],
331                                            ATTR: {'units': 'Hz'}}
332                    } for f in filters
333                ]
334            },
335            'calibrations': {
336                'calibration': [
337                    {
338                        'beginTime': {TEXT: cal['beginTime']},
339                        'type': {TEXT: cal['type']},
340                        'channels': {
341                            'ch': [
342                                {
343                                    TEXT: str(v),
344                                    ATTR: {'n': str(n)}
345                                }
346                            ] for k, (v, n) in cal.items()
347                        }
348                    } for cal in calibrations
349                ]
350            }
351        }
352
353    def get_content(self):
354        """return info on the associated (data) .bin file"""
355        return {
356            'generalInformation': self.generalInformation,
357            'filters': self.filters,
358            'calibrations': self.calibrations
359        }
360
361    def get_serializable_content(self):
362        """return a serializable object containing
363        info on the associated (data) .bin file"""
364        content = copy.deepcopy(self.get_content())
365        # Convert np.float32 values to float built-in type
366        for value in content['calibrations'].values():
367            channels = value['channels']
368            for channel in channels.keys():
369                channels[channel] = float(channels[channel])
370        return content
371
372
373class Patient(XML):
374
375    _xmlns = r'{http://www.egi.com/subject_mff}'
376    _xmlroottag = r'patient'
377    _default_filename = 'subject.xml'
378
379    _type_converter = {
380        'string': str,
381        None: lambda x: x
382    }
383
384    @cached_property
385    def fields(self):
386        ans = {}
387        for field in self.find('fields'):
388            assert self.nsstrip(field.tag) == 'field', f"""
389            Unknown field with tag {self.nsstrip(field.tag)}"""
390            name = self.find('name', field).text
391            data = self.find('data', field)
392            data = self._type_converter[data.get('dataType')](data.text)
393            ans[name] = data
394        return ans
395
396    @classmethod
397    def content(self, name, data, dataType='string'):
398        return {
399            'fields': {
400                'field': [
401                    {
402                        TEXT: {
403                            'name': {TEXT: name},
404                            'data': {TEXT: data,
405                                     ATTR: {'dataType': dataType}}
406                        }
407                    }
408                ]
409            }
410        }
411
412    def get_content(self):
413        """return patient related info"""
414        return {
415            'fields': self.fields
416        }
417
418    def get_serializable_content(self):
419        """return a serializable object
420        containing patient related info"""
421        return copy.deepcopy(self.get_content())
422
423
424class SensorLayout(XML):
425
426    _xmlns = r'{http://www.egi.com/sensorLayout_mff}'
427    _xmlroottag = r'sensorLayout'
428    _default_filename = 'sensorLayout.xml'
429
430    _type_converter = {
431        'name': str,
432        'number': int,
433        'type': int,
434        'identifier': int,
435        'x': np.float32,
436        'y': np.float32,
437        'z': np.float32,
438        'originalNumber': int
439    }
440
441    @cached_property
442    def sensors(self):
443        return dict([
444            self._parse_sensor(sensor)
445            for sensor in self.find('sensors')
446        ])
447
448    def _parse_sensor(self, el):
449        assert self.nsstrip(el.tag) == 'sensor', f"""
450        Unknown sensor with tag '{self.nsstrip(el.tag)}'"""
451        ans = {}
452        for e in el:
453            tag = self.nsstrip(e.tag)
454            ans[tag] = self._type_converter[tag](e.text)
455        return ans['number'], ans
456
457    @cached_property
458    def name(self):
459        el = self.find('name')
460        return 'UNK' if el is None else el.text
461
462    @cached_property
463    def threads(self):
464        ans = []
465        if self.find('threads') is not None:
466            for thread in self.find('threads'):
467                assert self.nsstrip(thread.tag) == 'thread', f"""
468                Unknown thread with tag {self.nsstrip(thread.tag)}"""
469                ans.append(tuple(map(int, thread.text.split(','))))
470        return ans
471
472    @cached_property
473    def tilingSets(self):
474        ans = []
475        if self.find('tilingSets') is not None:
476            for tilingSet in self.find('tilingSets'):
477                assert self.nsstrip(tilingSet.tag) == 'tilingSet', f"""
478                Unknown tilingSet with tag {self.nsstrip(tilingSet.tag)}"""
479                ans.append(list(map(int, tilingSet.text.split())))
480        return ans
481
482    @cached_property
483    def neighbors(self):
484        ans = {}
485        if self.find('neighbors') is not None:
486            for ch in self.find('neighbors'):
487                assert self.nsstrip(ch.tag) == 'ch', f"""
488                Unknown ch with tag {self.nsstrip(ch.tag)}"""
489                key = int(ch.get('n'))
490                ans[key] = list(map(int, ch.text.split()))
491        return ans
492
493    @property
494    def mappings(self):
495        raise NotImplementedError("No method to parse mappings.")
496
497    def get_content(self):
498        """return info on the sensor
499        net used for the recording"""
500        return {
501            'name': self.name,
502            'sensors': self.sensors,
503            'threads': self.threads,
504            'tilingSets': self.tilingSets,
505            'neighbors': self.neighbors
506        }
507
508    def get_serializable_content(self):
509        """return a serializable object containing
510        info on the sensor net used for the recording"""
511        content = copy.deepcopy(self.get_content())
512        for field in ['sensors', 'neighbors']:
513            # Stringify integer keys
514            content[field] = {
515                str(key): value
516                for key, value in content[field].items()
517            }
518        # Convert np.float32 values to float built-in type
519        for value in content['sensors'].values():
520            for coord in ['x', 'y', 'z']:
521                value[coord] = float(value[coord])
522        # Convert list of tuples into a list of list
523        content['threads'] = list(map(list, content['threads']))
524        return content
525
526
527class Coordinates(XML):
528
529    _xmlns = r'{http://www.egi.com/coordinates_mff}'
530    _xmlroottag = r'coordinates'
531    _default_filename = 'coordinates.xml'
532    _type_converter = {
533        'name': str,
534        'number': int,
535        'type': int,
536        'identifier': int,
537        'x': np.float32,
538        'y': np.float32,
539        'z': np.float32,
540    }
541
542    @cached_property
543    def acqTime(self):
544        txt = self.find("acqTime").text
545        return self._parse_time_str(txt)
546
547    @cached_property
548    def acqMethod(self):
549        el = self.find("acqMethod")
550        return el.text
551
552    @cached_property
553    def name(self):
554        el = self.find('name', self.find('sensorLayout'))
555        return 'UNK' if el is None else el.text
556
557    @cached_property
558    def defaultSubject(self):
559        return bool(self.find('defaultSubject').text)
560
561    @cached_property
562    def sensors(self):
563        sensorLayout = self.find('sensorLayout')
564        return dict([
565            self._parse_sensor(sensor)
566            for sensor in self.find('sensors', sensorLayout)
567        ])
568
569    def _parse_sensor(self, el):
570        assert self.nsstrip(el.tag) == 'sensor', f"""
571        Unknown sensor with tag {self.nsstrip(el.tag)}"""
572        ans = {}
573        for e in el:
574            tag = self.nsstrip(e.tag)
575            ans[tag] = self._type_converter[tag](e.text)
576        return ans['number'], ans
577
578    def get_content(self):
579        """return info on the acquisition time and method,
580        sensor net name and default subject"""
581        return {
582            'acqTime': self.acqTime,
583            'acqMethod': self.acqMethod,
584            'name': self.name,
585            'defaultSubject': self.defaultSubject,
586            'sensors': self.sensors
587        }
588
589    def get_serializable_content(self):
590        """return a serializable object containing
591        info on the acquisition time and method,
592        sensor net name and default subject"""
593        content = copy.deepcopy(self.get_content())
594        content['acqTime'] = XML._dump_datetime(content['acqTime'])
595        # Stringify integer keys
596        content['sensors'] = {
597            str(key): value
598            for key, value in content['sensors'].items()
599        }
600        # Convert np.float32 values to float built-in type
601        for value in content['sensors'].values():
602            for coord in ['x', 'y', 'z']:
603                value[coord] = float(value[coord])
604        return content
605
606
607class Epochs(XML):
608
609    _xmlns = r'{http://www.egi.com/epochs_mff}'
610    _xmlroottag = r'epochs'
611    _default_filename = 'epochs.xml'
612    _type_converter = {
613        'beginTime': int,
614        'endTime': int,
615        'firstBlock': int,
616        'lastBlock': int,
617    }
618
619    def __getitem__(self, n):
620        """If `n` is an int, interpret as index and return the
621        corresponding epoch in the list. If `n` is a str, return
622        a list of all epochs with name `n`, or the individual
623        epoch if only one epoch with name `n`."""
624        if isinstance(n, int):
625            return self.epochs[n]
626        elif isinstance(n, str):
627            matched = list(filter(lambda epoch: epoch.name == n, self.epochs))
628            return matched[0] if len(matched) == 1 else matched
629        else:
630            raise ValueError(f"Unsupported argument type '{n}': {type(n)}")
631
632    def __len__(self):
633        return len(self.epochs)
634
635    @cached_property
636    def epochs(self):
637        return [
638            self._parse_epoch(epoch)
639            for epoch in self.root
640        ]
641
642    def _parse_epoch(self, el):
643        assert self.nsstrip(el.tag) == 'epoch', f"""
644        Unknown epoch with tag {self.nsstrip(el.tag)}"""
645
646        def elem2KeyVal(e):
647            key = self.nsstrip(e.tag)
648            val = self._type_converter[key](e.text)
649            return key, val
650
651        return Epoch(**{key: val
652                        for key, val in map(elem2KeyVal, el)})
653
654    @classmethod
655    def content(cls, epochs: List[Epoch]) -> dict:  # type: ignore
656        return {
657            'epoch': [
658                epoch.content
659                for epoch in epochs
660            ]
661        }
662
663    def get_content(self):
664        """return begin and end time of each epoch as
665        well as the number of first and last block"""
666        epochs = []
667        for epch in self.epochs:
668            epochs.append({
669                'beginTime': epch.beginTime,
670                'endTime': epch.endTime,
671                'firstBlock': epch.firstBlock,
672                'lastBlock': epch.lastBlock
673            })
674        return epochs
675
676    def get_serializable_content(self):
677        """return a serializable object containing
678        begin and end time of each epoch as well
679        as the number of first and last block"""
680        return copy.deepcopy(self.get_content())
681
682    def associate_categories(self, categories):
683        """
684        populate epoch.name for each epoch with its corresponding category name
685
686        Retrieve category names from each epoch from a sorted list of
687        categories and set epoch.name for each corresponding epoch. If number
688        of categories does not match number of epochs, epoch names are
689        unchanged.
690
691        **Arguments**
692
693        * **`categories`**: `Categories` from which to extract category names
694        """
695        # Sort categories
696        sorted_categories = categories.sort_categories_by_starttime()
697        # Add category names to epochs
698        if len(sorted_categories) == len(self):
699            for epoch, category in zip(self.epochs, sorted_categories):
700                epoch.name = category['category']
701        else:
702            print(f'Number of categories ({len(sorted_categories)}) does not '
703                  f'match number of epochs ({len(self)}). `Epoch.name` will '
704                  'default to "epoch" for all epochs.')
705
706
707class EventTrack(XML):
708
709    _xmlns = r'{http://www.egi.com/event_mff}'
710    _xmlroottag = r'eventTrack'
711    _default_filename = 'Events.xml'
712    _event_type_reverter = {
713        'beginTime': XML._dump_datetime,
714        'duration': str,
715        'relativeBeginTime': str,
716        'segmentationEvent': lambda t: ('true' if t else 'false'),
717        'code': str,
718        'label': str,
719        'description': str,
720        'sourceDevice': str
721    }
722
723    def __init__(self, *args, **kwargs):
724        super().__init__(*args, **kwargs)
725        self._event_type_converter = {
726            'beginTime': lambda e: self._parse_time_str(e.text),
727            'duration': lambda e: int(e.text),
728            'relativeBeginTime': lambda e: int(e.text),
729            'segmentationEvent': lambda e: e.text == 'true',
730            'code': lambda e: str(e.text),
731            'label': lambda e: str(e.text),
732            'description': lambda e: str(e.text),
733            'sourceDevice': lambda e: str(e.text),
734            'keys': self._parse_keys
735        }
736        self._key_type_converter = {
737            'short': np.int16,
738            'long': np.int64,
739            'string': str,
740            'TEXT': str,
741        }
742
743    @cached_property
744    def name(self):
745        return self.find('name').text
746
747    @cached_property
748    def trackType(self):
749        return self.find('trackType').text
750
751    @cached_property
752    def events(self):
753        return [
754            self._parse_event(event)
755            for event in self.findall('event')
756        ]
757
758    def _parse_event(self, events_el):
759        assert self.nsstrip(events_el.tag) == 'event', f"""
760        Unknown event with tag {self.nsstrip(events_el.tag)}"""
761        return {
762            tag: self._event_type_converter[tag](el)
763            for tag, el in map(lambda e: (self.nsstrip(e.tag), e), events_el)
764        }
765
766    def _parse_keys(self, keys_el):
767        return dict([self._parse_key(key_el)
768                     for key_el in keys_el])
769
770    def _parse_key(self, key):
771        """
772        Attributes :
773            key (ElementTree.XMLElement) : parsed from a structure
774                ```
775                <key>
776                    <keyCode>cel#</keyCode>
777                    <data dataType="short">1</data>
778                </key>
779                ```
780        """
781        code = self.find('keyCode', key).text
782        data = self.find('data', key)
783        val = self._key_type_converter[data.get('dataType')](data.text)
784        return code, val
785
786    @classmethod
787    def content(cls, name: str, trackType: str,  # type: ignore
788                events: List[dict]) -> dict:
789        """return content in xml-convertible json format
790
791        Note
792        ----
793        `events` is a list dicts with specials keys, none of which are
794        required, for example:
795        ```
796        events = [
797            {
798                'beginTime': <datetime object>,
799                'duration': <int in ms>,
800                'relativeBeginTime': <int in ms>,
801                'code': <str>,
802                'label': <str>
803            }
804        ]
805        ```
806        """
807        formatted_events = []
808        for event in events:
809            formatted = {}
810            for k, v in event.items():
811                assert k in cls._event_type_reverter, f"event property '{k}' "
812                "not serializable.  Needs to be on of "
813                "{list(cls._event_type_reverter.keys())}"
814                formatted[k] = {
815                    TEXT: cls._event_type_reverter[k](v)  # type: ignore
816                }
817            formatted_events.append({TEXT: formatted})
818        return {
819            'name': {TEXT: name},
820            'trackType': {TEXT: trackType},
821            'event': formatted_events
822        }
823
824    def get_content(self):
825        """return the name, type and info on
826        the events read from the .xml"""
827        return {
828            'name': self.name,
829            'trackType': self.trackType,
830            'event': self.events
831        }
832
833    def get_serializable_content(self):
834        """return a serializable object containing the name,
835        type and info on the events read from the .xml"""
836        content = copy.deepcopy(self.get_content())
837        for evt in content['event']:
838            evt['beginTime'] = XML._dump_datetime(evt['beginTime'])
839        return content
840
841
842class Categories(XML):
843    """Parser for 'categories.xml' file
844
845    These files have the following structure:
846    ```
847    <?xml version="1.0" encoding="UTF-8" standalone="yes" ?>
848    <categories xmlns="http://www.egi.com/categories_mff"
849        xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
850        <cat>
851            <name>ULRN</name>
852            <segments>
853                <seg status="bad">
854                    <faults>
855                        <fault>eyeb</fault>
856                        <fault>eyem</fault>
857                        <fault>badc</fault>
858                    </faults>
859                    <beginTime>0</beginTime>
860                    <endTime>1200000</endTime>
861                    <evtBegin>201981</evtBegin>
862                    <evtEnd>201981</evtEnd>
863                    <channelStatus>
864                        <channels signalBin="1" exclusion="badChannels">
865                        1 12 15 50 251 253</channels>
866                    </channelStatus>
867                    <keys />
868                </seg>
869                ...
870    ```
871    """
872
873    _xmlns = r'{http://www.egi.com/categories_mff}'
874    _xmlroottag = r'categories'
875    _default_filename = 'categories.xml'
876    _type_converter = {
877        'long': int,
878    }
879
880    def __init__(self, *args, **kwargs):
881        super().__init__(*args, **kwargs)
882        self._segment_converter = {
883            'beginTime': lambda e: int(e.text),
884            'endTime': lambda e: int(e.text),
885            'evtBegin': lambda e: int(e.text),
886            'evtEnd': lambda e: int(e.text),
887            'channelStatus': self._parse_channel_status,
888        }
889        self._optional_segment_converter = {
890            'name': lambda e: str(e.text),
891            'keys': self._parse_keys,
892            'faults': self._parse_faults,
893        }
894        self._channel_prop_converter = {
895            'signalBin': int,
896            'exclusion': str,
897        }
898
899    @cached_property
900    def categories(self):
901        return dict(self._parse_cat(cat) for cat in self.findall('cat'))
902
903    def __getitem__(self, k):
904        return self.categories[k]
905
906    def __contains__(self, k):
907        return k in self.categories
908
909    def __len__(self):
910        return len(self.categories)
911
912    def _parse_cat(self, cat_el) -> Tuple[str, List[Dict[str, Any]]]:
913        """parse and return element <cat>
914
915        Contains <name /> and a <segments />
916        """
917        assert self.nsstrip(cat_el.tag) == 'cat', f"""
918        Unknown cat with tag {self.nsstrip(cat_el.tag)}"""
919        name = self.find('name', cat_el).text
920        segment_els = self.findall('seg', self.find('segments', cat_el))
921        segments = [self._parse_segment(seg_el) for seg_el in segment_els]
922        return name, segments
923
924    def _parse_channel_status(self, status_el):
925        """parse element <channelStatus>
926
927        Contains <channels />
928        """
929        def parse_channel_element(element):
930            """return parsed channel element"""
931            text = element.text or ''
932            indices = list(map(int, text.split()))
933            channel = {'channels': indices}
934            for prop, converter in self._channel_prop_converter.items():
935                channel[prop] = converter(element.get(prop))
936
937            return channel
938
939        channels = self.findall('channels', status_el)
940        ret = list(map(parse_channel_element, channels))
941        return ret or None
942
943    def _parse_faults(self, faults_el):
944        """parse element <faults>
945
946        Contains a bunch of <fault />"""
947        faults = [el.text for el in self.findall('fault', faults_el)]
948        return faults or None
949
950    def _parse_keys(self, keys_el):
951        keys = {}
952        for key_el in self.findall('key', keys_el):
953            keyCode = self.find('keyCode', key_el).text
954            data_el = self.find('data', key_el)
955            dtype = data_el.get('dataType')
956            data = self._type_converter.get(dtype, lambda s: s)(data_el.text)
957            keys[keyCode] = {
958                'data': data,
959                'type': dtype
960            }
961
962        return keys or None
963
964    def _parse_segment(self, seg_el):
965        """parse element <seg>
966
967        A <seg> element is expected to contain all elements in
968        `self._segment_converter.keys()`, and can additionally contain elements
969        in `self._optional_segment_converter.keys()`.
970        """
971        ret = {'status': seg_el.get('status', None)}
972        for tag, converter in self._segment_converter.items():
973            val = converter(self.find(tag, seg_el))
974            ret[tag] = converter(self.find(tag, seg_el))
975
976        for tag, converter in self._optional_segment_converter.items():
977            el = self.find(tag, seg_el)
978            val = converter(el) if el is not None else None
979            if val:
980                ret[tag] = val
981
982        return ret
983
984    def get_content(self):
985        """return categories related info"""
986        return {
987            'categories': self.categories
988        }
989
990    def get_serializable_content(self):
991        """return a serializable object
992        containing categories related info"""
993        return copy.deepcopy(self.get_content())
994
995    def sort_categories_by_starttime(self) -> List[dict]:
996        """return a list of dict `{category: name, t0: starttime}`
997        for each data block"""
998        sorted_categories = []
999        for name, cat in self.categories.items():
1000            for block in cat:
1001                sorted_categories.append(
1002                    {'category': name, 't0': block['beginTime']})
1003        sorted_categories.sort(key=lambda b: b['t0'])
1004        return sorted_categories
1005
1006    @classmethod
1007    def content(cls, categories):
1008        """return content of `categories` ready for dict2xml
1009
1010        **Arguments**
1011
1012        * **`categories`**: dict containing all infos for "categories.xml"
1013
1014        **Returns**
1015
1016        dict that can be passed into `dict2xml.dict2xml` to convert the
1017        information to an .xml file that follows the specification in
1018        "schemata/categories.xsd".
1019
1020        **Example**
1021
1022        Here's an example dict for `categories`:
1023
1024        ```python
1025        expected_categories = {
1026            'first category': [
1027                {
1028                    'status': 'bad',
1029                    'name': 'Average',
1030                    'faults': ['eyeb'],
1031                    'beginTime': 0,
1032                    'endTime': 1200000,
1033                    'evtBegin': 205135,
1034                    'evtEnd': 310153,
1035                    'channelStatus': [
1036                        {
1037                            'signalBin': 1,
1038                            'exclusion': 'badChannels',
1039                            'channels': [1, 12, 25, 55]
1040                        }
1041                    ],
1042                    'keys': {
1043                        '#seg': {
1044                            'type': 'long',
1045                            'data': 3
1046                        },
1047                        'subj': {
1048                            'type': 'person',
1049                            'data': 'RM271_noise_test'
1050                        }
1051                    }
1052                }
1053            ],
1054        }
1055        ```
1056        """
1057        return {'cat': [
1058            cls.serialize_category(name, segments)
1059            for name, segments in categories.items()
1060        ]}
1061
1062    @classmethod
1063    def serialize_category(cls, name, segments):
1064        """return serialized category `name` with `segments`"""
1065        name = {TEXT: str(name)}
1066        seg = list(map(cls.serialize_segment, segments))
1067        segments = {TEXT: {'seg': seg}}
1068        return {
1069            TEXT: {
1070                'name': name,
1071                'segments': segments
1072            }
1073        }
1074
1075    @staticmethod
1076    def serialize_segment(segment):
1077        """return serialized segment"""
1078        text = {}
1079        output = {TEXT: text}
1080        # In the following we'll modify `text`
1081
1082        required_integer_props = [
1083            'beginTime',
1084            'endTime',
1085            'evtBegin',
1086            'evtEnd'
1087        ]
1088        for prop in required_integer_props:
1089            text[prop] = {TEXT: str(int(segment[prop]))}
1090
1091        # Add optionals:
1092        #
1093        # - status
1094        # - name
1095        # - faults
1096        # - channelStatus
1097        # - keys
1098        if 'status' in segment:
1099            output[ATTR] = {'status': segment['status']}
1100
1101        if 'name' in segment:
1102            text['name'] = {TEXT: str(segment['name'])}
1103
1104        if 'faults' in segment:
1105            fault_list = [
1106                {TEXT: fault} for fault in segment['faults']
1107            ]
1108            text['faults'] = {
1109                TEXT: {'fault': fault_list}
1110            }
1111
1112        if 'channelStatus' in segment:
1113            channels_list = []
1114            for status in segment['channelStatus']:
1115                attributes = {
1116                    'signalBin': str(int(status['signalBin'])),
1117                    'exclusion': status['exclusion']
1118                }
1119                channels = ' '.join(map(str, status['channels']))
1120                channels = {ATTR: attributes, TEXT: channels}
1121                channels_list.append(channels)
1122            text['channelStatus'] = {TEXT: {'channels': channels_list}}
1123
1124        if 'keys' in segment:
1125            # convert xml element 'data'
1126            keys_by_code = {
1127                keyCode: {
1128                    ATTR: {'dataType': key['type']},
1129                    TEXT: str(key['data'])
1130                } for keyCode, key in segment['keys'].items()
1131            }
1132            # convert xml element 'keyCode'
1133            key_list = [{
1134                'keyCode': {TEXT: keyCode},
1135                'data': data
1136            } for keyCode, data in keys_by_code.items()]
1137            # convert xml element list
1138            key_list = [{TEXT: item} for item in key_list]
1139            # add to output
1140            text['keys'] = {TEXT: {'key': key_list}}
1141
1142        return output
1143
1144
1145class DipoleSet(XML):
1146    """Parser for 'dipoleSet.xml' file
1147
1148    These files have the following structure:
1149    ```
1150    <?xml version="1.0" encoding="UTF-8" standalone="yes" ?>
1151    <dipoleSet xmlns="http://www.egi.com/dipoleSet_mff"
1152    xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
1153        <name>SWS_003_IHM</name>
1154        <type>Dense</type>
1155        <dipoles>
1156            <dipole>
1157                <computationCoordinate>64,1.2e+02,1.5e+02</computationCoordinate>
1158                <visualizationCoordinate>61,1.4e+02,1.5e+02</visualizationCoordinate>
1159                <orientationVector>0.25,0.35,0.9</orientationVector>
1160            </dipole>
1161            <dipole>
1162            ...
1163    ```
1164    """
1165
1166    _xmlns = r'{http://www.egi.com/dipoleSet_mff}'
1167    _xmlroottag = r'dipoleSet'
1168    _default_filename = 'dipoleSet.xml'
1169
1170    @property
1171    def computationCoordinate(self) -> np.ndarray:
1172        """return computation coordinates"""
1173        return self.dipoles['computationCoordinate']
1174
1175    @property
1176    def visualizationCoordinate(self) -> np.ndarray:
1177        """return visualization coordinates"""
1178        return self.dipoles['visualizationCoordinate']
1179
1180    @property
1181    def orientationVector(self) -> np.ndarray:
1182        """return orientation vectors of dipoles"""
1183        return self.dipoles['orientationVector']
1184
1185    def __len__(self) -> int:
1186        """return number of dipoles"""
1187        return self.dipoles['computationCoordinate'].shape[0]
1188
1189    @cached_property
1190    def name(self) -> str:
1191        """return value of the name tag"""
1192        return self.find('name').text
1193
1194    @cached_property
1195    def type(self) -> str:
1196        """return value of the type tag"""
1197        return self.find('type').text
1198
1199    @cached_property
1200    def dipoles(self) -> Dict[str, np.ndarray]:
1201        """return dipoles read from the .xml
1202
1203        Dipole elements are expected to have a homogenuous number of elements
1204        such as 'computationCoordinate', 'visualizationCoordinate', and
1205        'orientationVector'.  The text of each element is expected to be three
1206        comma-separated floats in scientific notation."""
1207        dipoles_tag = self.find('dipoles')
1208        dipole_tags = self.findall('dipole', root=dipoles_tag)
1209        dipoles: Dict[str, list] = defaultdict(list)
1210        for tag in dipole_tags:
1211            for attr in tag.findall('*'):
1212                tag = self.nsstrip(attr.tag)
1213                v3 = list(map(float, attr.text.split(',')))
1214                dipoles[tag].append(v3)
1215        d_arrays = {
1216            tag: np.array(lists, dtype=np.float32)
1217            for tag, lists in dipoles.items()
1218        }
1219
1220        # check that all dipole attributes have same lengths and 3 components
1221        shp = (len(dipole_tags), 3)
1222        assert all(v.shape == shp for v in d_arrays.values()), f"""
1223        Parsing dipoles result in broken shape.  Found {[(k, v.shape) for k, v
1224        in d_arrays.items()]}"""
1225        return d_arrays
1226
1227    def get_content(self):
1228        """return name, type and coordinates
1229        of the dipole set read from the .xml"""
1230        return {
1231            'name': self.name,
1232            'type': self.type,
1233            'dipoles': self.dipoles
1234        }
1235
1236    def get_serializable_content(self):
1237        """return a serializable object containing
1238        the name, type and coordinates of the dipole
1239        set read from the .xml"""
1240        content = copy.deepcopy(self.get_content())
1241        content['dipoles'] = {
1242            key: value.tolist()
1243            for key, value in content['dipoles'].items()
1244        }
1245        return content
1246
1247
1248class History(XML):
1249    """Parser for 'history.xml' files
1250
1251    These files have the following structure:
1252    ```
1253    <?xml version="1.0" encoding="UTF-8" standalone="yes" ?>
1254    <historyEntries xmlns="http://www.egi.com/history_mff"
1255        xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
1256        <entry>
1257            <tool>
1258                <name>example</name>
1259                <method>Segmentation</method>
1260                <version>5.4.3-R</version>
1261                <beginTime>2020-08-27T13:32:26.008693-07:00</beginTime>
1262                <endTime>2020-08-27T13:32:26.113988-07:00</endTime>
1263                <sourceFiles>
1264                    <filePath type="" creator="">/Users/egi/Desktop/
1265                        RM271_noise_test_20190501_105754.mff</filePath>
1266                </sourceFiles>
1267                <settings>
1268                    <setting>  1: Rules for category
1269                        &quot;Category A&quot;</setting>
1270                    ...
1271    ```
1272    """
1273
1274    _xmlns = '{http://www.egi.com/history_mff}'
1275    _xmlroottag = 'historyEntries'
1276    _default_filename = 'history.xml'
1277    _entry_type_reverter = {
1278        'name': str,
1279        'kind': str,
1280        'method': str,
1281        'version': str,
1282        'beginTime': XML._dump_datetime,
1283        'endTime': XML._dump_datetime,
1284        'sourceFiles': lambda e: {'filePath': [{TEXT: filepath}
1285                                               for filepath in e]},
1286        'settings': lambda e: {'setting': [{TEXT: setting} for setting in e]},
1287        'results': lambda e: {'result': [{TEXT: result} for result in e]}
1288    }
1289
1290    def __init__(self, *args, **kwargs):
1291        super().__init__(*args, **kwargs)
1292        self._entry_type_converter = {
1293            'name': lambda e: str(e.text),
1294            'kind': lambda e: str(e.text),
1295            'method': lambda e: str(e.text),
1296            'version': lambda e: str(e.text),
1297            'beginTime': lambda e: self._parse_time_str(e.text),
1298            'endTime': lambda e: self._parse_time_str(e.text),
1299            'sourceFiles': lambda e: [filepath.text for filepath in
1300                                      self.findall('filePath', e)],
1301            'settings': lambda e: [setting.text for setting in
1302                                   self.findall('setting', e)],
1303            'results': lambda e: [result.text for result in
1304                                  self.findall('result', e)]
1305        }
1306
1307    def __getitem__(self, idx):
1308        return self.entries[idx]
1309
1310    def __len__(self):
1311        return len(self.entries)
1312
1313    @cached_property
1314    def entries(self):
1315        return [self._parse_entry(entry) for entry in self.findall('entry')]
1316
1317    def _parse_entry(self, entry_el):
1318        assert self.nsstrip(entry_el.tag) == 'entry', f"""
1319        Unknown tool with tag {self.nsstrip(entry_el.tag)}"""
1320        tool_el = self.find('tool', entry_el)
1321        return {
1322            tag: self._entry_type_converter[tag](el)
1323            for tag, el in map(lambda e: (self.nsstrip(e.tag), e), tool_el)
1324        }
1325
1326    @classmethod
1327    def content(cls, entries: List[dict]) -> dict:  # type: ignore
1328        """return content in xml-convertible json format
1329
1330        Note
1331        ----
1332        `entries` is a list of dicts with several keys, none of which are
1333        required. `entries` should have the following structure:
1334        ```
1335        entries = [
1336            {
1337                'name': <str>,
1338                'kind': <str>,
1339                'method': <str>,
1340                'version': <str>,
1341                'beginTime': <datetime object>,
1342                'endTime': <datetime object>,
1343                'sourceFiles': <List[str]>,
1344                'settings': <List[str]>,
1345                'results': <List[str]>
1346            }
1347        ]
1348        ```
1349        """
1350        formatted_entries = []
1351        for entry in entries:
1352            formatted = {}
1353            for tag, text in entry.items():
1354                assert tag in cls._entry_type_reverter, "entry property "
1355                f"'{text}' not serializable. Needs to be one of "
1356                f"{list(cls._entry_type_reverter.keys())}."
1357                formatted[tag] = {
1358                    TEXT: cls._entry_type_reverter[tag](text)  # type: ignore
1359                }
1360            formatted_entries.append({TEXT: formatted})
1361        return {
1362            'entry': [
1363                {
1364                    TEXT: {
1365                        'tool': e
1366                    }
1367                }
1368                for e in formatted_entries
1369            ]
1370        }
1371
1372    def get_content(self):
1373        """return history entries"""
1374        formatted_entries = []
1375        for entry in self.entries:
1376            entry['beginTime'] = self._dump_datetime(entry['beginTime'])
1377            entry['endTime'] = self._dump_datetime(entry['endTime'])
1378            formatted_entries.append(entry)
1379        return formatted_entries
1380
1381    def get_serializable_content(self):
1382        """return a serializable object containing history entries"""
1383        return copy.deepcopy(self.get_content())
1384
1385    def mff_flavor(self) -> str:
1386        """return either 'continuous', 'segmented',
1387        or 'averaged' representing mff flavor"""
1388        methods = [entry['method'].lower() for entry in self.entries]
1389        if 'averaging' in methods:
1390            return 'averaged'
1391        elif 'segmentation' in methods:
1392            return 'segmented'
1393        else:
1394            return 'continuous'
1395