1import logging 2import warnings 3import xml.etree.ElementTree as ET 4from datetime import datetime 5from collections import defaultdict 6import numpy as np 7from typing import Tuple, Dict, List, Any, Union, IO 8from .cached_property import cached_property 9from .dict2xml import TEXT, ATTR 10from .epoch import Epoch 11import copy 12""" 13Copyright 2019 Brain Electrophysiology Laboratory Company LLC 14 15Licensed under the ApacheLicense, Version 2.0(the "License"); 16you may not use this module except in compliance with the License. 17You may obtain a copy of the License at: 18 19http: // www.apache.org / licenses / LICENSE - 2.0 20 21Unless required by applicable law or agreed to in writing, software 22distributed under the License is distributed on an 23"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF 24ANY KIND, either express or implied. 25""" 26 27"""Parsing for all xml files""" 28 29 30FilePointer = Union[str, IO[bytes]] 31 32 33class XMLType(type): 34 """`XMLType` registers all xml types 35 36 Spawn the _right_ XMLType sub-class with `from_file`. To be registered, 37 sub-classes need to implement class attributes `_xmlns` and 38 `_xmlroottag`.""" 39 40 _registry: Dict[str, Any] = {} 41 _tag_registry: Dict[str, Any] = {} 42 _logger = logging.getLogger(name='XMLType') 43 _extensions = ['.xml', '.XML'] 44 _supported_versions: Tuple[str, ...] = ('',) 45 _time_format = "%Y-%m-%dT%H:%M:%S.%f%z" 46 47 def __new__(typ, name, bases, dct): 48 new_xml_type = super().__new__(typ, name, bases, dct) 49 typ.register(new_xml_type) 50 return new_xml_type 51 52 @classmethod 53 def register(typ, xml_type): 54 try: 55 ns_tag = xml_type._xmlns + xml_type._xmlroottag 56 if ns_tag in typ._registry: 57 typ._logger.warn("overwritting %s in registry" % 58 typ._registry[ns_tag]) 59 typ._registry[ns_tag] = xml_type 60 # add another key for the same type 61 typ._tag_registry[xml_type._xmlroottag] = xml_type 62 return True 63 except (AttributeError, TypeError): 64 typ._logger.info("type %s cannot be registered" % xml_type) 65 return False 66 67 @classmethod 68 def from_file(typ, filepointer: FilePointer): 69 """return new `XMLType` instance of the appropriate sub-class 70 71 **Parameters** 72 *filepointer*: str or IO[bytes] 73 pointer to the xml file 74 """ 75 xml_root = ET.parse(filepointer).getroot() 76 return typ._registry[xml_root.tag](xml_root) 77 78 @classmethod 79 def todict(typ, xmltype, **kwargs) -> Dict[str, Any]: 80 """return dict of `kwargs` specific for `xmltype` 81 82 The output of this function is supposed to fit perfectly into 83 `mffpy.json2xml.dict2xml` returning a valid .xml file for the specific 84 xml type. For this, each of these types needs to implement a method 85 `content` that takes `**kwargs` as argument. 86 """ 87 assert xmltype in typ._tag_registry, f""" 88 {xmltype} is not one of the valid .xml types: 89 {typ._tag_registry.keys()}""" 90 T = typ._tag_registry[xmltype] 91 return { 92 'content': T.content(**kwargs), 93 'rootname': T._xmlroottag, 94 'filename': T._default_filename, 95 # remove the '}'/'{' characters 96 'namespace': T._xmlns[1:-1] 97 } 98 99 @classmethod 100 def xml_root_tags(cls): 101 """return list of root tags of supported xml files""" 102 return list(cls._tag_registry.keys()) 103 104 105class XML(metaclass=XMLType): 106 107 _default_filename: str = '' 108 109 def __init__(self, xml_root): 110 self.root = xml_root 111 112 @classmethod 113 def _parse_time_str(cls, txt): 114 # convert time string "2003-04-17T13:35:22.000000-08:00" 115 # to "2003-04-17T13:35:22.000000-0800" .. 116 if txt.count(':') == 3: 117 txt = txt[::-1].replace(':', '', 1)[::-1] 118 return datetime.strptime(txt, cls._time_format) 119 120 @classmethod 121 def _dump_datetime(cls, dt): 122 assert dt.tzinfo is not None, f""" 123 Timezone required for date/time {dt}""" 124 txt = dt.strftime(cls._time_format) 125 return txt[:-2] + ':' + txt[-2:] 126 127 def find(self, tag, root=None): 128 root = root or self.root 129 return root.find(self._xmlns+tag) 130 131 def findall(self, tag, root=None): 132 root = root or self.root 133 return root.findall(self._xmlns+tag) 134 135 def nsstrip(self, tag): 136 return tag[len(self._xmlns):] 137 138 @property 139 def xml_root_tag(self): 140 return self._xmlroottag 141 142 @classmethod 143 def content(typ, *args, **kwargs): 144 """checks and returns `**kwargs` as a formatted dict""" 145 raise NotImplementedError(f""" 146 json converter not implemented for type {typ}""") 147 148 149class FileInfo(XML): 150 151 _xmlns = '{http://www.egi.com/info_mff}' 152 _xmlroottag = 'fileInfo' 153 _default_filename = 'info.xml' 154 _supported_versions = ('3',) 155 156 @cached_property 157 def mffVersion(self): 158 el = self.find('mffVersion') 159 return None if el is None else el.text 160 161 @property 162 def version(self): 163 """return mffVersion""" 164 warnings.warn(".version is deprecated, use .mffVersion instead", 165 DeprecationWarning) 166 return self.mffVersion 167 168 @cached_property 169 def acquisitionVersion(self): 170 """return content of acquisitionVersion field""" 171 el = self.find('acquisitionVersion') 172 return None if el is None else el.text 173 174 @cached_property 175 def ampType(self): 176 """return content of ampType field""" 177 el = self.find('ampType') 178 return None if el is None else el.text 179 180 @cached_property 181 def recordTime(self): 182 el = self.find('recordTime') 183 return self._parse_time_str(el.text) if el is not None else None 184 185 @classmethod 186 def content(cls, recordTime: datetime, # type: ignore 187 mffVersion: str = '3', 188 acquisitionVersion: str = None, 189 ampType: str = None) -> dict: 190 """returns MFF file information 191 192 Only Version '3' is supported. 193 """ 194 assert mffVersion in cls._supported_versions, f""" 195 version {mffVersion} not supported""" 196 content = { 197 'mffVersion': { 198 TEXT: mffVersion 199 }, 200 'recordTime': { 201 TEXT: cls._dump_datetime(recordTime) 202 } 203 } 204 if acquisitionVersion: 205 content.update(acquisitionVersion={TEXT: acquisitionVersion}) 206 207 if ampType: 208 content.update(ampType={TEXT: ampType}) 209 210 return content 211 212 def get_content(self): 213 """return dictionary of MFF file information 214 215 The info are comprised of: 216 - MFF version number 217 - time of start of recording 218 - acquisition version (optional) 219 - amplifier type (optional) 220 """ 221 content = { 222 'mffVersion': self.mffVersion, 223 'recordTime': self.recordTime 224 } 225 if self.acquisitionVersion: 226 content.update(acquisitionVersion=self.acquisitionVersion) 227 228 if self.ampType: 229 content.update(ampType=self.ampType) 230 231 return content 232 233 def get_serializable_content(self): 234 """return serializable dictionary of MFF file information""" 235 content = copy.deepcopy(self.get_content()) 236 content['recordTime'] = XML._dump_datetime(content['recordTime']) 237 return content 238 239 240class DataInfo(XML): 241 242 _xmlns = r'{http://www.egi.com/info_n_mff}' 243 _xmlroottag = r'dataInfo' 244 _default_filename = 'info1.xml' 245 246 @cached_property 247 def generalInformation(self): 248 el = self.find('fileDataType', self.find('generalInformation')) 249 el = el[0] 250 info = {} 251 info['channel_type'] = self.nsstrip(el.tag) 252 for el_i in el: 253 info[self.nsstrip(el_i.tag)] = el_i.text 254 return info 255 256 @cached_property 257 def filters(self): 258 filters = [] 259 if self.find('filters') is not None: 260 filters = [self._parse_filter(f) for f in self.find('filters')] 261 return filters 262 263 def _parse_filter(self, f): 264 ans = {} 265 for prop in (('beginTime', float), 'method', 'type'): 266 prop, conv = prop if len(prop) == 2 else (prop, lambda x: x) 267 ans[prop] = conv(self.find(prop, f).text) 268 269 el = self.find('cutoffFrequency', f) 270 ans['cutoffFrequency'] = (float(el.text), el.get('units')) 271 return ans 272 273 @cached_property 274 def calibrations(self): 275 calibrations = self.find('calibrations') 276 ans = {} 277 if calibrations is not None: 278 for cali in calibrations: 279 typ = self.find('type', cali) 280 ans[typ.text] = self._parse_calibration(cali) 281 return ans 282 283 def _parse_calibration(self, cali): 284 ans = {} 285 ans['beginTime'] = float(self.find('beginTime', cali).text) 286 ans['channels'] = { 287 int(el.get('n')): np.float32(el.text) 288 for el in self.find('channels', cali) 289 } 290 return ans 291 292 @classmethod 293 def content(cls, fileDataType: str, # type: ignore 294 dataTypeProps: dict = None, 295 filters: List[dict] = None, 296 calibrations: List[dict] = None) -> dict: 297 """returns info on the associated (data) .bin file 298 299 **Parameters** 300 301 *fileDataType*: indicates the type of data 302 *dataTypeProps*: indicates the recording device 303 *filters*: lists all applied filters 304 *calibrations*: lists a number of calibrations 305 """ 306 dataTypeProps = dataTypeProps or {} 307 calibrations = calibrations or [] 308 filters = filters or [] 309 return { 310 'generalInformation': { 311 TEXT: { 312 'fileDataType': { 313 TEXT: { 314 fileDataType: { 315 TEXT: { 316 k: {TEXT: v} 317 for k, v in dataTypeProps.items() 318 } 319 } 320 } 321 } 322 } 323 }, 324 'filters': { 325 'filter': [ 326 { 327 'beginTime': {TEXT: f['beginTime']}, 328 'method': {TEXT: f['method']}, 329 'type': {TEXT: f['type']}, 330 'cutoffFrequency': {TEXT: f['cutoffFrequency'], 331 ATTR: {'units': 'Hz'}} 332 } for f in filters 333 ] 334 }, 335 'calibrations': { 336 'calibration': [ 337 { 338 'beginTime': {TEXT: cal['beginTime']}, 339 'type': {TEXT: cal['type']}, 340 'channels': { 341 'ch': [ 342 { 343 TEXT: str(v), 344 ATTR: {'n': str(n)} 345 } 346 ] for k, (v, n) in cal.items() 347 } 348 } for cal in calibrations 349 ] 350 } 351 } 352 353 def get_content(self): 354 """return info on the associated (data) .bin file""" 355 return { 356 'generalInformation': self.generalInformation, 357 'filters': self.filters, 358 'calibrations': self.calibrations 359 } 360 361 def get_serializable_content(self): 362 """return a serializable object containing 363 info on the associated (data) .bin file""" 364 content = copy.deepcopy(self.get_content()) 365 # Convert np.float32 values to float built-in type 366 for value in content['calibrations'].values(): 367 channels = value['channels'] 368 for channel in channels.keys(): 369 channels[channel] = float(channels[channel]) 370 return content 371 372 373class Patient(XML): 374 375 _xmlns = r'{http://www.egi.com/subject_mff}' 376 _xmlroottag = r'patient' 377 _default_filename = 'subject.xml' 378 379 _type_converter = { 380 'string': str, 381 None: lambda x: x 382 } 383 384 @cached_property 385 def fields(self): 386 ans = {} 387 for field in self.find('fields'): 388 assert self.nsstrip(field.tag) == 'field', f""" 389 Unknown field with tag {self.nsstrip(field.tag)}""" 390 name = self.find('name', field).text 391 data = self.find('data', field) 392 data = self._type_converter[data.get('dataType')](data.text) 393 ans[name] = data 394 return ans 395 396 @classmethod 397 def content(self, name, data, dataType='string'): 398 return { 399 'fields': { 400 'field': [ 401 { 402 TEXT: { 403 'name': {TEXT: name}, 404 'data': {TEXT: data, 405 ATTR: {'dataType': dataType}} 406 } 407 } 408 ] 409 } 410 } 411 412 def get_content(self): 413 """return patient related info""" 414 return { 415 'fields': self.fields 416 } 417 418 def get_serializable_content(self): 419 """return a serializable object 420 containing patient related info""" 421 return copy.deepcopy(self.get_content()) 422 423 424class SensorLayout(XML): 425 426 _xmlns = r'{http://www.egi.com/sensorLayout_mff}' 427 _xmlroottag = r'sensorLayout' 428 _default_filename = 'sensorLayout.xml' 429 430 _type_converter = { 431 'name': str, 432 'number': int, 433 'type': int, 434 'identifier': int, 435 'x': np.float32, 436 'y': np.float32, 437 'z': np.float32, 438 'originalNumber': int 439 } 440 441 @cached_property 442 def sensors(self): 443 return dict([ 444 self._parse_sensor(sensor) 445 for sensor in self.find('sensors') 446 ]) 447 448 def _parse_sensor(self, el): 449 assert self.nsstrip(el.tag) == 'sensor', f""" 450 Unknown sensor with tag '{self.nsstrip(el.tag)}'""" 451 ans = {} 452 for e in el: 453 tag = self.nsstrip(e.tag) 454 ans[tag] = self._type_converter[tag](e.text) 455 return ans['number'], ans 456 457 @cached_property 458 def name(self): 459 el = self.find('name') 460 return 'UNK' if el is None else el.text 461 462 @cached_property 463 def threads(self): 464 ans = [] 465 if self.find('threads') is not None: 466 for thread in self.find('threads'): 467 assert self.nsstrip(thread.tag) == 'thread', f""" 468 Unknown thread with tag {self.nsstrip(thread.tag)}""" 469 ans.append(tuple(map(int, thread.text.split(',')))) 470 return ans 471 472 @cached_property 473 def tilingSets(self): 474 ans = [] 475 if self.find('tilingSets') is not None: 476 for tilingSet in self.find('tilingSets'): 477 assert self.nsstrip(tilingSet.tag) == 'tilingSet', f""" 478 Unknown tilingSet with tag {self.nsstrip(tilingSet.tag)}""" 479 ans.append(list(map(int, tilingSet.text.split()))) 480 return ans 481 482 @cached_property 483 def neighbors(self): 484 ans = {} 485 if self.find('neighbors') is not None: 486 for ch in self.find('neighbors'): 487 assert self.nsstrip(ch.tag) == 'ch', f""" 488 Unknown ch with tag {self.nsstrip(ch.tag)}""" 489 key = int(ch.get('n')) 490 ans[key] = list(map(int, ch.text.split())) 491 return ans 492 493 @property 494 def mappings(self): 495 raise NotImplementedError("No method to parse mappings.") 496 497 def get_content(self): 498 """return info on the sensor 499 net used for the recording""" 500 return { 501 'name': self.name, 502 'sensors': self.sensors, 503 'threads': self.threads, 504 'tilingSets': self.tilingSets, 505 'neighbors': self.neighbors 506 } 507 508 def get_serializable_content(self): 509 """return a serializable object containing 510 info on the sensor net used for the recording""" 511 content = copy.deepcopy(self.get_content()) 512 for field in ['sensors', 'neighbors']: 513 # Stringify integer keys 514 content[field] = { 515 str(key): value 516 for key, value in content[field].items() 517 } 518 # Convert np.float32 values to float built-in type 519 for value in content['sensors'].values(): 520 for coord in ['x', 'y', 'z']: 521 value[coord] = float(value[coord]) 522 # Convert list of tuples into a list of list 523 content['threads'] = list(map(list, content['threads'])) 524 return content 525 526 527class Coordinates(XML): 528 529 _xmlns = r'{http://www.egi.com/coordinates_mff}' 530 _xmlroottag = r'coordinates' 531 _default_filename = 'coordinates.xml' 532 _type_converter = { 533 'name': str, 534 'number': int, 535 'type': int, 536 'identifier': int, 537 'x': np.float32, 538 'y': np.float32, 539 'z': np.float32, 540 } 541 542 @cached_property 543 def acqTime(self): 544 txt = self.find("acqTime").text 545 return self._parse_time_str(txt) 546 547 @cached_property 548 def acqMethod(self): 549 el = self.find("acqMethod") 550 return el.text 551 552 @cached_property 553 def name(self): 554 el = self.find('name', self.find('sensorLayout')) 555 return 'UNK' if el is None else el.text 556 557 @cached_property 558 def defaultSubject(self): 559 return bool(self.find('defaultSubject').text) 560 561 @cached_property 562 def sensors(self): 563 sensorLayout = self.find('sensorLayout') 564 return dict([ 565 self._parse_sensor(sensor) 566 for sensor in self.find('sensors', sensorLayout) 567 ]) 568 569 def _parse_sensor(self, el): 570 assert self.nsstrip(el.tag) == 'sensor', f""" 571 Unknown sensor with tag {self.nsstrip(el.tag)}""" 572 ans = {} 573 for e in el: 574 tag = self.nsstrip(e.tag) 575 ans[tag] = self._type_converter[tag](e.text) 576 return ans['number'], ans 577 578 def get_content(self): 579 """return info on the acquisition time and method, 580 sensor net name and default subject""" 581 return { 582 'acqTime': self.acqTime, 583 'acqMethod': self.acqMethod, 584 'name': self.name, 585 'defaultSubject': self.defaultSubject, 586 'sensors': self.sensors 587 } 588 589 def get_serializable_content(self): 590 """return a serializable object containing 591 info on the acquisition time and method, 592 sensor net name and default subject""" 593 content = copy.deepcopy(self.get_content()) 594 content['acqTime'] = XML._dump_datetime(content['acqTime']) 595 # Stringify integer keys 596 content['sensors'] = { 597 str(key): value 598 for key, value in content['sensors'].items() 599 } 600 # Convert np.float32 values to float built-in type 601 for value in content['sensors'].values(): 602 for coord in ['x', 'y', 'z']: 603 value[coord] = float(value[coord]) 604 return content 605 606 607class Epochs(XML): 608 609 _xmlns = r'{http://www.egi.com/epochs_mff}' 610 _xmlroottag = r'epochs' 611 _default_filename = 'epochs.xml' 612 _type_converter = { 613 'beginTime': int, 614 'endTime': int, 615 'firstBlock': int, 616 'lastBlock': int, 617 } 618 619 def __getitem__(self, n): 620 """If `n` is an int, interpret as index and return the 621 corresponding epoch in the list. If `n` is a str, return 622 a list of all epochs with name `n`, or the individual 623 epoch if only one epoch with name `n`.""" 624 if isinstance(n, int): 625 return self.epochs[n] 626 elif isinstance(n, str): 627 matched = list(filter(lambda epoch: epoch.name == n, self.epochs)) 628 return matched[0] if len(matched) == 1 else matched 629 else: 630 raise ValueError(f"Unsupported argument type '{n}': {type(n)}") 631 632 def __len__(self): 633 return len(self.epochs) 634 635 @cached_property 636 def epochs(self): 637 return [ 638 self._parse_epoch(epoch) 639 for epoch in self.root 640 ] 641 642 def _parse_epoch(self, el): 643 assert self.nsstrip(el.tag) == 'epoch', f""" 644 Unknown epoch with tag {self.nsstrip(el.tag)}""" 645 646 def elem2KeyVal(e): 647 key = self.nsstrip(e.tag) 648 val = self._type_converter[key](e.text) 649 return key, val 650 651 return Epoch(**{key: val 652 for key, val in map(elem2KeyVal, el)}) 653 654 @classmethod 655 def content(cls, epochs: List[Epoch]) -> dict: # type: ignore 656 return { 657 'epoch': [ 658 epoch.content 659 for epoch in epochs 660 ] 661 } 662 663 def get_content(self): 664 """return begin and end time of each epoch as 665 well as the number of first and last block""" 666 epochs = [] 667 for epch in self.epochs: 668 epochs.append({ 669 'beginTime': epch.beginTime, 670 'endTime': epch.endTime, 671 'firstBlock': epch.firstBlock, 672 'lastBlock': epch.lastBlock 673 }) 674 return epochs 675 676 def get_serializable_content(self): 677 """return a serializable object containing 678 begin and end time of each epoch as well 679 as the number of first and last block""" 680 return copy.deepcopy(self.get_content()) 681 682 def associate_categories(self, categories): 683 """ 684 populate epoch.name for each epoch with its corresponding category name 685 686 Retrieve category names from each epoch from a sorted list of 687 categories and set epoch.name for each corresponding epoch. If number 688 of categories does not match number of epochs, epoch names are 689 unchanged. 690 691 **Arguments** 692 693 * **`categories`**: `Categories` from which to extract category names 694 """ 695 # Sort categories 696 sorted_categories = categories.sort_categories_by_starttime() 697 # Add category names to epochs 698 if len(sorted_categories) == len(self): 699 for epoch, category in zip(self.epochs, sorted_categories): 700 epoch.name = category['category'] 701 else: 702 print(f'Number of categories ({len(sorted_categories)}) does not ' 703 f'match number of epochs ({len(self)}). `Epoch.name` will ' 704 'default to "epoch" for all epochs.') 705 706 707class EventTrack(XML): 708 709 _xmlns = r'{http://www.egi.com/event_mff}' 710 _xmlroottag = r'eventTrack' 711 _default_filename = 'Events.xml' 712 _event_type_reverter = { 713 'beginTime': XML._dump_datetime, 714 'duration': str, 715 'relativeBeginTime': str, 716 'segmentationEvent': lambda t: ('true' if t else 'false'), 717 'code': str, 718 'label': str, 719 'description': str, 720 'sourceDevice': str 721 } 722 723 def __init__(self, *args, **kwargs): 724 super().__init__(*args, **kwargs) 725 self._event_type_converter = { 726 'beginTime': lambda e: self._parse_time_str(e.text), 727 'duration': lambda e: int(e.text), 728 'relativeBeginTime': lambda e: int(e.text), 729 'segmentationEvent': lambda e: e.text == 'true', 730 'code': lambda e: str(e.text), 731 'label': lambda e: str(e.text), 732 'description': lambda e: str(e.text), 733 'sourceDevice': lambda e: str(e.text), 734 'keys': self._parse_keys 735 } 736 self._key_type_converter = { 737 'short': np.int16, 738 'long': np.int64, 739 'string': str, 740 'TEXT': str, 741 } 742 743 @cached_property 744 def name(self): 745 return self.find('name').text 746 747 @cached_property 748 def trackType(self): 749 return self.find('trackType').text 750 751 @cached_property 752 def events(self): 753 return [ 754 self._parse_event(event) 755 for event in self.findall('event') 756 ] 757 758 def _parse_event(self, events_el): 759 assert self.nsstrip(events_el.tag) == 'event', f""" 760 Unknown event with tag {self.nsstrip(events_el.tag)}""" 761 return { 762 tag: self._event_type_converter[tag](el) 763 for tag, el in map(lambda e: (self.nsstrip(e.tag), e), events_el) 764 } 765 766 def _parse_keys(self, keys_el): 767 return dict([self._parse_key(key_el) 768 for key_el in keys_el]) 769 770 def _parse_key(self, key): 771 """ 772 Attributes : 773 key (ElementTree.XMLElement) : parsed from a structure 774 ``` 775 <key> 776 <keyCode>cel#</keyCode> 777 <data dataType="short">1</data> 778 </key> 779 ``` 780 """ 781 code = self.find('keyCode', key).text 782 data = self.find('data', key) 783 val = self._key_type_converter[data.get('dataType')](data.text) 784 return code, val 785 786 @classmethod 787 def content(cls, name: str, trackType: str, # type: ignore 788 events: List[dict]) -> dict: 789 """return content in xml-convertible json format 790 791 Note 792 ---- 793 `events` is a list dicts with specials keys, none of which are 794 required, for example: 795 ``` 796 events = [ 797 { 798 'beginTime': <datetime object>, 799 'duration': <int in ms>, 800 'relativeBeginTime': <int in ms>, 801 'code': <str>, 802 'label': <str> 803 } 804 ] 805 ``` 806 """ 807 formatted_events = [] 808 for event in events: 809 formatted = {} 810 for k, v in event.items(): 811 assert k in cls._event_type_reverter, f"event property '{k}' " 812 "not serializable. Needs to be on of " 813 "{list(cls._event_type_reverter.keys())}" 814 formatted[k] = { 815 TEXT: cls._event_type_reverter[k](v) # type: ignore 816 } 817 formatted_events.append({TEXT: formatted}) 818 return { 819 'name': {TEXT: name}, 820 'trackType': {TEXT: trackType}, 821 'event': formatted_events 822 } 823 824 def get_content(self): 825 """return the name, type and info on 826 the events read from the .xml""" 827 return { 828 'name': self.name, 829 'trackType': self.trackType, 830 'event': self.events 831 } 832 833 def get_serializable_content(self): 834 """return a serializable object containing the name, 835 type and info on the events read from the .xml""" 836 content = copy.deepcopy(self.get_content()) 837 for evt in content['event']: 838 evt['beginTime'] = XML._dump_datetime(evt['beginTime']) 839 return content 840 841 842class Categories(XML): 843 """Parser for 'categories.xml' file 844 845 These files have the following structure: 846 ``` 847 <?xml version="1.0" encoding="UTF-8" standalone="yes" ?> 848 <categories xmlns="http://www.egi.com/categories_mff" 849 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 850 <cat> 851 <name>ULRN</name> 852 <segments> 853 <seg status="bad"> 854 <faults> 855 <fault>eyeb</fault> 856 <fault>eyem</fault> 857 <fault>badc</fault> 858 </faults> 859 <beginTime>0</beginTime> 860 <endTime>1200000</endTime> 861 <evtBegin>201981</evtBegin> 862 <evtEnd>201981</evtEnd> 863 <channelStatus> 864 <channels signalBin="1" exclusion="badChannels"> 865 1 12 15 50 251 253</channels> 866 </channelStatus> 867 <keys /> 868 </seg> 869 ... 870 ``` 871 """ 872 873 _xmlns = r'{http://www.egi.com/categories_mff}' 874 _xmlroottag = r'categories' 875 _default_filename = 'categories.xml' 876 _type_converter = { 877 'long': int, 878 } 879 880 def __init__(self, *args, **kwargs): 881 super().__init__(*args, **kwargs) 882 self._segment_converter = { 883 'beginTime': lambda e: int(e.text), 884 'endTime': lambda e: int(e.text), 885 'evtBegin': lambda e: int(e.text), 886 'evtEnd': lambda e: int(e.text), 887 'channelStatus': self._parse_channel_status, 888 } 889 self._optional_segment_converter = { 890 'name': lambda e: str(e.text), 891 'keys': self._parse_keys, 892 'faults': self._parse_faults, 893 } 894 self._channel_prop_converter = { 895 'signalBin': int, 896 'exclusion': str, 897 } 898 899 @cached_property 900 def categories(self): 901 return dict(self._parse_cat(cat) for cat in self.findall('cat')) 902 903 def __getitem__(self, k): 904 return self.categories[k] 905 906 def __contains__(self, k): 907 return k in self.categories 908 909 def __len__(self): 910 return len(self.categories) 911 912 def _parse_cat(self, cat_el) -> Tuple[str, List[Dict[str, Any]]]: 913 """parse and return element <cat> 914 915 Contains <name /> and a <segments /> 916 """ 917 assert self.nsstrip(cat_el.tag) == 'cat', f""" 918 Unknown cat with tag {self.nsstrip(cat_el.tag)}""" 919 name = self.find('name', cat_el).text 920 segment_els = self.findall('seg', self.find('segments', cat_el)) 921 segments = [self._parse_segment(seg_el) for seg_el in segment_els] 922 return name, segments 923 924 def _parse_channel_status(self, status_el): 925 """parse element <channelStatus> 926 927 Contains <channels /> 928 """ 929 def parse_channel_element(element): 930 """return parsed channel element""" 931 text = element.text or '' 932 indices = list(map(int, text.split())) 933 channel = {'channels': indices} 934 for prop, converter in self._channel_prop_converter.items(): 935 channel[prop] = converter(element.get(prop)) 936 937 return channel 938 939 channels = self.findall('channels', status_el) 940 ret = list(map(parse_channel_element, channels)) 941 return ret or None 942 943 def _parse_faults(self, faults_el): 944 """parse element <faults> 945 946 Contains a bunch of <fault />""" 947 faults = [el.text for el in self.findall('fault', faults_el)] 948 return faults or None 949 950 def _parse_keys(self, keys_el): 951 keys = {} 952 for key_el in self.findall('key', keys_el): 953 keyCode = self.find('keyCode', key_el).text 954 data_el = self.find('data', key_el) 955 dtype = data_el.get('dataType') 956 data = self._type_converter.get(dtype, lambda s: s)(data_el.text) 957 keys[keyCode] = { 958 'data': data, 959 'type': dtype 960 } 961 962 return keys or None 963 964 def _parse_segment(self, seg_el): 965 """parse element <seg> 966 967 A <seg> element is expected to contain all elements in 968 `self._segment_converter.keys()`, and can additionally contain elements 969 in `self._optional_segment_converter.keys()`. 970 """ 971 ret = {'status': seg_el.get('status', None)} 972 for tag, converter in self._segment_converter.items(): 973 val = converter(self.find(tag, seg_el)) 974 ret[tag] = converter(self.find(tag, seg_el)) 975 976 for tag, converter in self._optional_segment_converter.items(): 977 el = self.find(tag, seg_el) 978 val = converter(el) if el is not None else None 979 if val: 980 ret[tag] = val 981 982 return ret 983 984 def get_content(self): 985 """return categories related info""" 986 return { 987 'categories': self.categories 988 } 989 990 def get_serializable_content(self): 991 """return a serializable object 992 containing categories related info""" 993 return copy.deepcopy(self.get_content()) 994 995 def sort_categories_by_starttime(self) -> List[dict]: 996 """return a list of dict `{category: name, t0: starttime}` 997 for each data block""" 998 sorted_categories = [] 999 for name, cat in self.categories.items(): 1000 for block in cat: 1001 sorted_categories.append( 1002 {'category': name, 't0': block['beginTime']}) 1003 sorted_categories.sort(key=lambda b: b['t0']) 1004 return sorted_categories 1005 1006 @classmethod 1007 def content(cls, categories): 1008 """return content of `categories` ready for dict2xml 1009 1010 **Arguments** 1011 1012 * **`categories`**: dict containing all infos for "categories.xml" 1013 1014 **Returns** 1015 1016 dict that can be passed into `dict2xml.dict2xml` to convert the 1017 information to an .xml file that follows the specification in 1018 "schemata/categories.xsd". 1019 1020 **Example** 1021 1022 Here's an example dict for `categories`: 1023 1024 ```python 1025 expected_categories = { 1026 'first category': [ 1027 { 1028 'status': 'bad', 1029 'name': 'Average', 1030 'faults': ['eyeb'], 1031 'beginTime': 0, 1032 'endTime': 1200000, 1033 'evtBegin': 205135, 1034 'evtEnd': 310153, 1035 'channelStatus': [ 1036 { 1037 'signalBin': 1, 1038 'exclusion': 'badChannels', 1039 'channels': [1, 12, 25, 55] 1040 } 1041 ], 1042 'keys': { 1043 '#seg': { 1044 'type': 'long', 1045 'data': 3 1046 }, 1047 'subj': { 1048 'type': 'person', 1049 'data': 'RM271_noise_test' 1050 } 1051 } 1052 } 1053 ], 1054 } 1055 ``` 1056 """ 1057 return {'cat': [ 1058 cls.serialize_category(name, segments) 1059 for name, segments in categories.items() 1060 ]} 1061 1062 @classmethod 1063 def serialize_category(cls, name, segments): 1064 """return serialized category `name` with `segments`""" 1065 name = {TEXT: str(name)} 1066 seg = list(map(cls.serialize_segment, segments)) 1067 segments = {TEXT: {'seg': seg}} 1068 return { 1069 TEXT: { 1070 'name': name, 1071 'segments': segments 1072 } 1073 } 1074 1075 @staticmethod 1076 def serialize_segment(segment): 1077 """return serialized segment""" 1078 text = {} 1079 output = {TEXT: text} 1080 # In the following we'll modify `text` 1081 1082 required_integer_props = [ 1083 'beginTime', 1084 'endTime', 1085 'evtBegin', 1086 'evtEnd' 1087 ] 1088 for prop in required_integer_props: 1089 text[prop] = {TEXT: str(int(segment[prop]))} 1090 1091 # Add optionals: 1092 # 1093 # - status 1094 # - name 1095 # - faults 1096 # - channelStatus 1097 # - keys 1098 if 'status' in segment: 1099 output[ATTR] = {'status': segment['status']} 1100 1101 if 'name' in segment: 1102 text['name'] = {TEXT: str(segment['name'])} 1103 1104 if 'faults' in segment: 1105 fault_list = [ 1106 {TEXT: fault} for fault in segment['faults'] 1107 ] 1108 text['faults'] = { 1109 TEXT: {'fault': fault_list} 1110 } 1111 1112 if 'channelStatus' in segment: 1113 channels_list = [] 1114 for status in segment['channelStatus']: 1115 attributes = { 1116 'signalBin': str(int(status['signalBin'])), 1117 'exclusion': status['exclusion'] 1118 } 1119 channels = ' '.join(map(str, status['channels'])) 1120 channels = {ATTR: attributes, TEXT: channels} 1121 channels_list.append(channels) 1122 text['channelStatus'] = {TEXT: {'channels': channels_list}} 1123 1124 if 'keys' in segment: 1125 # convert xml element 'data' 1126 keys_by_code = { 1127 keyCode: { 1128 ATTR: {'dataType': key['type']}, 1129 TEXT: str(key['data']) 1130 } for keyCode, key in segment['keys'].items() 1131 } 1132 # convert xml element 'keyCode' 1133 key_list = [{ 1134 'keyCode': {TEXT: keyCode}, 1135 'data': data 1136 } for keyCode, data in keys_by_code.items()] 1137 # convert xml element list 1138 key_list = [{TEXT: item} for item in key_list] 1139 # add to output 1140 text['keys'] = {TEXT: {'key': key_list}} 1141 1142 return output 1143 1144 1145class DipoleSet(XML): 1146 """Parser for 'dipoleSet.xml' file 1147 1148 These files have the following structure: 1149 ``` 1150 <?xml version="1.0" encoding="UTF-8" standalone="yes" ?> 1151 <dipoleSet xmlns="http://www.egi.com/dipoleSet_mff" 1152 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 1153 <name>SWS_003_IHM</name> 1154 <type>Dense</type> 1155 <dipoles> 1156 <dipole> 1157 <computationCoordinate>64,1.2e+02,1.5e+02</computationCoordinate> 1158 <visualizationCoordinate>61,1.4e+02,1.5e+02</visualizationCoordinate> 1159 <orientationVector>0.25,0.35,0.9</orientationVector> 1160 </dipole> 1161 <dipole> 1162 ... 1163 ``` 1164 """ 1165 1166 _xmlns = r'{http://www.egi.com/dipoleSet_mff}' 1167 _xmlroottag = r'dipoleSet' 1168 _default_filename = 'dipoleSet.xml' 1169 1170 @property 1171 def computationCoordinate(self) -> np.ndarray: 1172 """return computation coordinates""" 1173 return self.dipoles['computationCoordinate'] 1174 1175 @property 1176 def visualizationCoordinate(self) -> np.ndarray: 1177 """return visualization coordinates""" 1178 return self.dipoles['visualizationCoordinate'] 1179 1180 @property 1181 def orientationVector(self) -> np.ndarray: 1182 """return orientation vectors of dipoles""" 1183 return self.dipoles['orientationVector'] 1184 1185 def __len__(self) -> int: 1186 """return number of dipoles""" 1187 return self.dipoles['computationCoordinate'].shape[0] 1188 1189 @cached_property 1190 def name(self) -> str: 1191 """return value of the name tag""" 1192 return self.find('name').text 1193 1194 @cached_property 1195 def type(self) -> str: 1196 """return value of the type tag""" 1197 return self.find('type').text 1198 1199 @cached_property 1200 def dipoles(self) -> Dict[str, np.ndarray]: 1201 """return dipoles read from the .xml 1202 1203 Dipole elements are expected to have a homogenuous number of elements 1204 such as 'computationCoordinate', 'visualizationCoordinate', and 1205 'orientationVector'. The text of each element is expected to be three 1206 comma-separated floats in scientific notation.""" 1207 dipoles_tag = self.find('dipoles') 1208 dipole_tags = self.findall('dipole', root=dipoles_tag) 1209 dipoles: Dict[str, list] = defaultdict(list) 1210 for tag in dipole_tags: 1211 for attr in tag.findall('*'): 1212 tag = self.nsstrip(attr.tag) 1213 v3 = list(map(float, attr.text.split(','))) 1214 dipoles[tag].append(v3) 1215 d_arrays = { 1216 tag: np.array(lists, dtype=np.float32) 1217 for tag, lists in dipoles.items() 1218 } 1219 1220 # check that all dipole attributes have same lengths and 3 components 1221 shp = (len(dipole_tags), 3) 1222 assert all(v.shape == shp for v in d_arrays.values()), f""" 1223 Parsing dipoles result in broken shape. Found {[(k, v.shape) for k, v 1224 in d_arrays.items()]}""" 1225 return d_arrays 1226 1227 def get_content(self): 1228 """return name, type and coordinates 1229 of the dipole set read from the .xml""" 1230 return { 1231 'name': self.name, 1232 'type': self.type, 1233 'dipoles': self.dipoles 1234 } 1235 1236 def get_serializable_content(self): 1237 """return a serializable object containing 1238 the name, type and coordinates of the dipole 1239 set read from the .xml""" 1240 content = copy.deepcopy(self.get_content()) 1241 content['dipoles'] = { 1242 key: value.tolist() 1243 for key, value in content['dipoles'].items() 1244 } 1245 return content 1246 1247 1248class History(XML): 1249 """Parser for 'history.xml' files 1250 1251 These files have the following structure: 1252 ``` 1253 <?xml version="1.0" encoding="UTF-8" standalone="yes" ?> 1254 <historyEntries xmlns="http://www.egi.com/history_mff" 1255 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> 1256 <entry> 1257 <tool> 1258 <name>example</name> 1259 <method>Segmentation</method> 1260 <version>5.4.3-R</version> 1261 <beginTime>2020-08-27T13:32:26.008693-07:00</beginTime> 1262 <endTime>2020-08-27T13:32:26.113988-07:00</endTime> 1263 <sourceFiles> 1264 <filePath type="" creator="">/Users/egi/Desktop/ 1265 RM271_noise_test_20190501_105754.mff</filePath> 1266 </sourceFiles> 1267 <settings> 1268 <setting> 1: Rules for category 1269 "Category A"</setting> 1270 ... 1271 ``` 1272 """ 1273 1274 _xmlns = '{http://www.egi.com/history_mff}' 1275 _xmlroottag = 'historyEntries' 1276 _default_filename = 'history.xml' 1277 _entry_type_reverter = { 1278 'name': str, 1279 'kind': str, 1280 'method': str, 1281 'version': str, 1282 'beginTime': XML._dump_datetime, 1283 'endTime': XML._dump_datetime, 1284 'sourceFiles': lambda e: {'filePath': [{TEXT: filepath} 1285 for filepath in e]}, 1286 'settings': lambda e: {'setting': [{TEXT: setting} for setting in e]}, 1287 'results': lambda e: {'result': [{TEXT: result} for result in e]} 1288 } 1289 1290 def __init__(self, *args, **kwargs): 1291 super().__init__(*args, **kwargs) 1292 self._entry_type_converter = { 1293 'name': lambda e: str(e.text), 1294 'kind': lambda e: str(e.text), 1295 'method': lambda e: str(e.text), 1296 'version': lambda e: str(e.text), 1297 'beginTime': lambda e: self._parse_time_str(e.text), 1298 'endTime': lambda e: self._parse_time_str(e.text), 1299 'sourceFiles': lambda e: [filepath.text for filepath in 1300 self.findall('filePath', e)], 1301 'settings': lambda e: [setting.text for setting in 1302 self.findall('setting', e)], 1303 'results': lambda e: [result.text for result in 1304 self.findall('result', e)] 1305 } 1306 1307 def __getitem__(self, idx): 1308 return self.entries[idx] 1309 1310 def __len__(self): 1311 return len(self.entries) 1312 1313 @cached_property 1314 def entries(self): 1315 return [self._parse_entry(entry) for entry in self.findall('entry')] 1316 1317 def _parse_entry(self, entry_el): 1318 assert self.nsstrip(entry_el.tag) == 'entry', f""" 1319 Unknown tool with tag {self.nsstrip(entry_el.tag)}""" 1320 tool_el = self.find('tool', entry_el) 1321 return { 1322 tag: self._entry_type_converter[tag](el) 1323 for tag, el in map(lambda e: (self.nsstrip(e.tag), e), tool_el) 1324 } 1325 1326 @classmethod 1327 def content(cls, entries: List[dict]) -> dict: # type: ignore 1328 """return content in xml-convertible json format 1329 1330 Note 1331 ---- 1332 `entries` is a list of dicts with several keys, none of which are 1333 required. `entries` should have the following structure: 1334 ``` 1335 entries = [ 1336 { 1337 'name': <str>, 1338 'kind': <str>, 1339 'method': <str>, 1340 'version': <str>, 1341 'beginTime': <datetime object>, 1342 'endTime': <datetime object>, 1343 'sourceFiles': <List[str]>, 1344 'settings': <List[str]>, 1345 'results': <List[str]> 1346 } 1347 ] 1348 ``` 1349 """ 1350 formatted_entries = [] 1351 for entry in entries: 1352 formatted = {} 1353 for tag, text in entry.items(): 1354 assert tag in cls._entry_type_reverter, "entry property " 1355 f"'{text}' not serializable. Needs to be one of " 1356 f"{list(cls._entry_type_reverter.keys())}." 1357 formatted[tag] = { 1358 TEXT: cls._entry_type_reverter[tag](text) # type: ignore 1359 } 1360 formatted_entries.append({TEXT: formatted}) 1361 return { 1362 'entry': [ 1363 { 1364 TEXT: { 1365 'tool': e 1366 } 1367 } 1368 for e in formatted_entries 1369 ] 1370 } 1371 1372 def get_content(self): 1373 """return history entries""" 1374 formatted_entries = [] 1375 for entry in self.entries: 1376 entry['beginTime'] = self._dump_datetime(entry['beginTime']) 1377 entry['endTime'] = self._dump_datetime(entry['endTime']) 1378 formatted_entries.append(entry) 1379 return formatted_entries 1380 1381 def get_serializable_content(self): 1382 """return a serializable object containing history entries""" 1383 return copy.deepcopy(self.get_content()) 1384 1385 def mff_flavor(self) -> str: 1386 """return either 'continuous', 'segmented', 1387 or 'averaged' representing mff flavor""" 1388 methods = [entry['method'].lower() for entry in self.entries] 1389 if 'averaging' in methods: 1390 return 'averaged' 1391 elif 'segmentation' in methods: 1392 return 'segmented' 1393 else: 1394 return 'continuous' 1395