1 2# This file is part of the CLBlast project. The project is licensed under Apache Version 2.0. This file follows the 3# PEP8 Python style guide and uses a max-width of 120 characters per line. 4# 5# Author(s): 6# Cedric Nugteren <www.cedricnugteren.nl> 7 8import re 9import json 10 11try: 12 from urllib.request import urlopen # Python 3 13except ImportError: 14 from urllib2 import urlopen # Python 2 15 16 17def download_database(filename, database_url): 18 """Downloads a database and saves it to disk""" 19 print("[database] Downloading database from '" + database_url + "'...") 20 database = urlopen(database_url) 21 with open(filename, "wb") as f: 22 f.write(database.read()) 23 24 25def load_database(filename): 26 """Loads a database from disk""" 27 print("[database] Loading database from '" + filename + "'") 28 with open(filename) as f: 29 database = json.load(f) 30 return decompress_database(database) 31 32 33def save_database(database, filename): 34 """Saves a database to disk""" 35 compressed_db = compress_database(database) 36 print("[database] Saving database to '" + filename + "'") 37 with open(filename, "w") as f: 38 json.dump(compressed_db, f, sort_keys=True, indent=2, separators=(',', ': ')) 39 40 41def compress_database(database): 42 """Moves certain common fields up in the hierarchy, transforms dicts into lists""" 43 new_sections = [] 44 for section in database["sections"]: 45 new_section = {} 46 for field in section: 47 if field == "results": 48 parameter_names = [result["parameters"].keys() for result in section["results"]] 49 assert len(list(set([" ".join(p) for p in parameter_names]))) == 1 50 new_section["parameter_names"] = parameter_names[0] # they are all be the same 51 new_results = [[",".join([str(v) for v in result["parameters"].values()]), 52 result["time"]] 53 for result in section["results"]] 54 new_section[field] = new_results 55 else: 56 new_section[field] = section[field] 57 new_sections.append(new_section) 58 return {"sections": new_sections} 59 60 61def decompress_database(database): 62 """Undo the above compression""" 63 for section in database["sections"]: 64 new_results = [] 65 for result in section["results"]: 66 parameters = {} 67 for name, value in zip(section["parameter_names"], result[0].split(",")): 68 parameters[name] = value 69 new_result = { 70 "parameters": parameters, 71 "time": result[1] 72 } 73 new_results.append(new_result) 74 section["results"] = new_results 75 return database 76 77 78def load_tuning_results(filename): 79 """Loads JSON data from file and pre-processes it""" 80 with open(filename) as f: 81 json_data = json.load(f) 82 83 # Removes the numbering following the kernel family name 84 json_data["kernel_family"] = re.sub(r'_\d+', '', json_data["kernel_family"]) 85 86 # Adds the kernel name to the section instead of to the individual results 87 assert len(json_data["results"]) > 0 88 json_data["kernel"] = json_data["results"][0]["kernel"] 89 for result in json_data["results"]: 90 assert json_data["kernel"] == result["kernel"] 91 result.pop("kernel", None) 92 93 # Removes the 'PRECISION' parameter from the individual results: it is redundant 94 for result in json_data["results"]: 95 assert json_data["precision"] == str(result["parameters"]["PRECISION"]) 96 result["parameters"].pop("PRECISION", None) 97 98 # Fixes the scalar argument values 99 for value, replacement in zip(["2.00", "2.00+0.50i"], ["2.000000", "2+0.5i"]): 100 for field in ["arg_alpha", "arg_beta"]: 101 if field in json_data.keys() and json_data[field] == value: 102 json_data[field] = replacement 103 104 # All done 105 return json_data 106