1import os
2from functools import lru_cache
3
4import numpy as np
5
6
7@lru_cache(1)
8def _read_data():
9    data = []
10    datasets = []
11    classifiers = []
12    basedir = os.path.split(__file__)[0]
13    with open(os.path.join(basedir, "accuracies.txt")) as f:
14        classifier = f.readline().strip()
15        while True:  # loop over classifier
16            if not classifier:
17                break
18            classifiers.append(classifier)
19            data.append([])
20            t_datasets = datasets and []
21            while True:  # loop over data sets
22                line = f.readline().strip()
23                dataset, *scores = line.split() if line else ("",)
24                if not scores:
25                    # Check that order of data sets is same for all classifiers
26                    assert datasets == t_datasets
27                    classifier = dataset
28                    break
29                data[-1].append([float(x) for x in scores])
30                t_datasets.append(dataset)
31    return np.array(data), classifiers, datasets
32
33
34def get_data(classifier=..., dataset=..., aggregate=False):
35    def get_indices(names, pool):
36        if names is ...:
37            return np.arange(len(pool), dtype=int)
38        if isinstance(names, str):
39            return np.array([pool.index(names)])
40        else:
41            return np.array([pool.index(name) for name in names])
42
43    data, classifiers, datasets = _read_data()
44    data = data[np.ix_(get_indices(classifier, classifiers),
45                       get_indices(dataset, datasets))]
46    if aggregate:
47        data = np.mean(data, axis=2)
48    data = data.squeeze()
49    return data
50
51
52def get_classifiers():
53    return _read_data()[1]
54
55
56def get_datasets():
57    return _read_data()[2]
58