1
2# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the
3# PEP8 Python style guide and uses a max-width of 120 characters per line.
4#
5# Author(s):
6#   Cedric Nugteren <www.cedricnugteren.nl>
7
8import re
9import json
10
11try:
12    from urllib.request import urlopen  # Python 3
13except ImportError:
14    from urllib2 import urlopen  # Python 2
15
16
17def download_database(filename, database_url):
18    """Downloads a database and saves it to disk"""
19    print("[database] Downloading database from '" + database_url + "'...")
20    database = urlopen(database_url)
21    with open(filename, "wb") as f:
22        f.write(database.read())
23
24
25def load_database(filename):
26    """Loads a database from disk"""
27    print("[database] Loading database from '" + filename + "'")
28    with open(filename) as f:
29        database = json.load(f)
30    return decompress_database(database)
31
32
33def save_database(database, filename):
34    """Saves a database to disk"""
35    compressed_db = compress_database(database)
36    print("[database] Saving database to '" + filename + "'")
37    with open(filename, "w") as f:
38        json.dump(compressed_db, f, sort_keys=True, indent=2, separators=(',', ': '))
39
40
41def compress_database(database):
42    """Moves certain common fields up in the hierarchy, transforms dicts into lists"""
43    new_sections = []
44    for section in database["sections"]:
45        new_section = {}
46        for field in section:
47            if field == "results":
48                parameter_names = [result["parameters"].keys() for result in section["results"]]
49                assert len(list(set([" ".join(p) for p in parameter_names]))) == 1
50                new_section["parameter_names"] = parameter_names[0]  # they are all be the same
51                new_results = [[",".join([str(v) for v in result["parameters"].values()]),
52                                result["time"]]
53                               for result in section["results"]]
54                new_section[field] = new_results
55            else:
56                new_section[field] = section[field]
57        new_sections.append(new_section)
58    return {"sections": new_sections}
59
60
61def decompress_database(database):
62    """Undo the above compression"""
63    for section in database["sections"]:
64        new_results = []
65        for result in section["results"]:
66            parameters = {}
67            for name, value in zip(section["parameter_names"], result[0].split(",")):
68                parameters[name] = value
69            new_result = {
70                "parameters": parameters,
71                "time": result[1]
72            }
73            new_results.append(new_result)
74        section["results"] = new_results
75    return database
76
77
78def load_tuning_results(filename):
79    """Loads JSON data from file and pre-processes it"""
80    with open(filename) as f:
81        json_data = json.load(f)
82
83    # Removes the numbering following the kernel family name
84    json_data["kernel_family"] = re.sub(r'_\d+', '', json_data["kernel_family"])
85
86    # Adds the kernel name to the section instead of to the individual results
87    assert len(json_data["results"]) > 0
88    json_data["kernel"] = json_data["results"][0]["kernel"]
89    for result in json_data["results"]:
90        assert json_data["kernel"] == result["kernel"]
91        result.pop("kernel", None)
92
93    # Removes the 'PRECISION' parameter from the individual results: it is redundant
94    for result in json_data["results"]:
95        assert json_data["precision"] == str(result["parameters"]["PRECISION"])
96        result["parameters"].pop("PRECISION", None)
97
98    # Fixes the scalar argument values
99    for value, replacement in zip(["2.00", "2.00+0.50i"], ["2.000000", "2+0.5i"]):
100        for field in ["arg_alpha", "arg_beta"]:
101            if field in json_data.keys() and json_data[field] == value:
102                json_data[field] = replacement
103
104    # All done
105    return json_data
106