1# pylint: disable=line-too-long
2"""Encode/Decode helper function for compressed quantized models"""
3import zlib
4import base64
5
6def encode_json(json_file, is_print=False):
7    r""" Encode json string to compressed base64 string.
8
9    Parameters
10    ----------
11    json_file : str
12        String value represents the path to json file.
13    is_print : bool
14        Boolean value controls whether to print the encoded base64 string.
15    """
16    with open(json_file, encoding='utf-8') as fh:
17        data = fh.read()
18    zipped_str = zlib.compress(data.encode('utf-8'))
19    b64_str = base64.b64encode(zipped_str)
20    if is_print:
21        print(b64_str)
22    return b64_str
23
24def decode_b64(b64_str, is_print=False):
25    r""" Decode b64 string to json format
26
27    Parameters
28    ---------
29    b64_str: str
30        String value represents the compressed base64 string.
31    is_print : bool
32        Boolean value controls whether to print the decoded json string.
33    """
34    json_str = zlib.decompress(base64.b64decode(b64_str)).decode('utf-8')
35    if is_print:
36        print(json_str)
37    return json_str
38
39def get_compressed_model(model_name, compressed_json):
40    r""" Get compressed (INT8) models from existing `compressed_json` dict
41
42    Parameters
43    ----------
44    model_name: str
45        String value represents the name of compressed (INT8) model.
46    compressed_json : dict
47        Dictionary's key represents the name of (INT8) model, and dictionary's value
48        represents the compressed json string of (INT8) model.
49    """
50    b64_str = compressed_json.get(model_name, None)
51    if b64_str:
52        return decode_b64(b64_str)
53    raise ValueError('Model: {} is not found. Available compressed models are:\n{}'.format(model_name, '\n'.join(list(compressed_json.keys()))))
54