1import os 2from functools import lru_cache 3 4import numpy as np 5 6 7@lru_cache(1) 8def _read_data(): 9 data = [] 10 datasets = [] 11 classifiers = [] 12 basedir = os.path.split(__file__)[0] 13 with open(os.path.join(basedir, "accuracies.txt")) as f: 14 classifier = f.readline().strip() 15 while True: # loop over classifier 16 if not classifier: 17 break 18 classifiers.append(classifier) 19 data.append([]) 20 t_datasets = datasets and [] 21 while True: # loop over data sets 22 line = f.readline().strip() 23 dataset, *scores = line.split() if line else ("",) 24 if not scores: 25 # Check that order of data sets is same for all classifiers 26 assert datasets == t_datasets 27 classifier = dataset 28 break 29 data[-1].append([float(x) for x in scores]) 30 t_datasets.append(dataset) 31 return np.array(data), classifiers, datasets 32 33 34def get_data(classifier=..., dataset=..., aggregate=False): 35 def get_indices(names, pool): 36 if names is ...: 37 return np.arange(len(pool), dtype=int) 38 if isinstance(names, str): 39 return np.array([pool.index(names)]) 40 else: 41 return np.array([pool.index(name) for name in names]) 42 43 data, classifiers, datasets = _read_data() 44 data = data[np.ix_(get_indices(classifier, classifiers), 45 get_indices(dataset, datasets))] 46 if aggregate: 47 data = np.mean(data, axis=2) 48 data = data.squeeze() 49 return data 50 51 52def get_classifiers(): 53 return _read_data()[1] 54 55 56def get_datasets(): 57 return _read_data()[2] 58