1# pylint: disable=line-too-long 2"""Encode/Decode helper function for compressed quantized models""" 3import zlib 4import base64 5 6def encode_json(json_file, is_print=False): 7 r""" Encode json string to compressed base64 string. 8 9 Parameters 10 ---------- 11 json_file : str 12 String value represents the path to json file. 13 is_print : bool 14 Boolean value controls whether to print the encoded base64 string. 15 """ 16 with open(json_file, encoding='utf-8') as fh: 17 data = fh.read() 18 zipped_str = zlib.compress(data.encode('utf-8')) 19 b64_str = base64.b64encode(zipped_str) 20 if is_print: 21 print(b64_str) 22 return b64_str 23 24def decode_b64(b64_str, is_print=False): 25 r""" Decode b64 string to json format 26 27 Parameters 28 --------- 29 b64_str: str 30 String value represents the compressed base64 string. 31 is_print : bool 32 Boolean value controls whether to print the decoded json string. 33 """ 34 json_str = zlib.decompress(base64.b64decode(b64_str)).decode('utf-8') 35 if is_print: 36 print(json_str) 37 return json_str 38 39def get_compressed_model(model_name, compressed_json): 40 r""" Get compressed (INT8) models from existing `compressed_json` dict 41 42 Parameters 43 ---------- 44 model_name: str 45 String value represents the name of compressed (INT8) model. 46 compressed_json : dict 47 Dictionary's key represents the name of (INT8) model, and dictionary's value 48 represents the compressed json string of (INT8) model. 49 """ 50 b64_str = compressed_json.get(model_name, None) 51 if b64_str: 52 return decode_b64(b64_str) 53 raise ValueError('Model: {} is not found. Available compressed models are:\n{}'.format(model_name, '\n'.join(list(compressed_json.keys())))) 54