1import base64
2from google.protobuf import json_format
3from importlib import import_module
4import json
5import numpy as np
6import os
7import sys
8
9from mmdnn.conversion.caffe.errors import ConversionError
10from mmdnn.conversion.caffe.common_graph import fetch_attr_value
11from mmdnn.conversion.caffe.utils import get_lower_case, get_upper_case, get_real_name
12
13
14class JsonFormatter(object):
15    '''Dumpt a DL graph into a Json file.'''
16
17    def __init__(self, graph):
18        self.graph_def = graph.as_graph_def()
19
20    def dump(self, json_path):
21        json_txt = json_format.MessageToJson(self.graph_def)
22        parsed = json.loads(json_txt)
23        formatted = json.dumps(parsed, indent=4, sort_keys=True)
24        with open(json_path, 'w') as f:
25            f.write(formatted)
26
27
28class PyWriter(object):
29    '''Dumpt a DL graph into a Python script.'''
30
31    def __init__(self, graph, data, target):
32        self.graph = graph
33        self.data = data
34        self.tab = ' ' * 4
35        self.prefix = ''
36        target = target.lower()
37        if target == 'tensorflow':
38            self.target = target
39            self.net = 'TensorFlowNetwork'
40        elif target == 'keras':
41            self.target = target
42            self.net = 'KerasNetwork'
43        elif target == 'caffe':
44            self.target = target
45            self.net = 'CaffeNetwork'
46        else:
47            raise ConversionError('Target %s is not supported yet.' % target)
48
49    def indent(self):
50        self.prefix += self.tab
51
52    def outdent(self):
53        self.prefix = self.prefix[:-len(self.tab)]
54
55    def statement(self, s):
56        return self.prefix + s + '\n'
57
58    def emit_imports(self):
59        return self.statement('from dlconv.%s import %s\n' % (self.target, self.net))
60
61    def emit_class_def(self, name):
62        return self.statement('class %s(%s):' % (name, self.net))
63
64    def emit_setup_def(self):
65        return self.statement('def setup(self):')
66
67    def emit_node(self, node):
68        '''Emits the Python source for this node.'''
69
70        def pair(key, value):
71            return '%s=%s' % (key, value)
72        args = []
73        for input in node.input:
74            input = input.strip().split(':')
75            name = ''.join(input[:-1])
76            idx = int(input[-1])
77            assert name in self.graph.node_dict
78            parent = self.graph.get_node(name)
79            args.append(parent.output[idx])
80        #FIXME:
81        output = [node.output[0]]
82        # output = node.output
83        for k, v in node.attr:
84            if k == 'cell_type':
85                args.append(pair(k, "'" + fetch_attr_value(v) + "'"))
86            else:
87                args.append(pair(k, fetch_attr_value(v)))
88        args.append(pair('name', "'" + node.name + "'")) # Set the node name
89        args = ', '.join(args)
90        return self.statement('%s = self.%s(%s)' % (', '.join(output), node.op, args))
91
92    def dump(self, code_output_dir):
93        if not os.path.exists(code_output_dir):
94            os.makedirs(code_output_dir)
95        file_name = get_lower_case(self.graph.name)
96        code_output_path = os.path.join(code_output_dir, file_name + '.py')
97        data_output_path = os.path.join(code_output_dir, file_name + '.npy')
98        with open(code_output_path, 'w') as f:
99            f.write(self.emit())
100        with open(data_output_path, 'wb') as f:
101            np.save(f, self.data)
102        return code_output_path, data_output_path
103
104    def emit(self):
105        # Decompose DAG into chains
106        chains = []
107        for node in self.graph.topologically_sorted():
108            attach_to_chain = None
109            if len(node.input) == 1:
110                parent = get_real_name(node.input[0])
111                for chain in chains:
112                    if chain[-1].name == parent: # Node is part of an existing chain.
113                        attach_to_chain = chain
114                        break
115            if attach_to_chain is None: # Start a new chain for this node.
116                attach_to_chain = []
117                chains.append(attach_to_chain)
118            attach_to_chain.append(node)
119
120        # Generate Python code line by line
121        source = self.emit_imports()
122        source += self.emit_class_def(self.graph.name)
123        self.indent()
124        source += self.emit_setup_def()
125        self.indent()
126        blocks = []
127        for chain in chains:
128            b = ''
129            for node in chain:
130                b += self.emit_node(node)
131            blocks.append(b[:-1])
132        source += '\n\n'.join(blocks)
133        return source
134
135
136class ModelSaver(object):
137
138    def __init__(self, code_output_path, data_output_path):
139        self.code_output_path = code_output_path
140        self.data_output_path = data_output_path
141
142    def dump(self, model_output_dir):
143        '''Return the file path containing graph in generated model files.'''
144        if not os.path.exists(model_output_dir):
145            os.makedirs(model_output_dir)
146        sys.path.append(os.path.dirname(self.code_output_path))
147        file_name = os.path.splitext(os.path.basename(self.code_output_path))[0]
148        module = import_module(file_name)
149        class_name = get_upper_case(file_name)
150        net = getattr(module, class_name)
151        return net.dump(self.data_output_path, model_output_dir)
152
153
154class GraphDrawer(object):
155
156    def __init__(self, toolkit, meta_path):
157        self.toolkit = toolkit.lower()
158        self.meta_path = meta_path
159
160    def dump(self, graph_path):
161        if self.toolkit == 'tensorflow':
162            from dlconv.tensorflow.visualizer import TensorFlowVisualizer
163            if self._is_web_page(graph_path):
164                TensorFlowVisualizer(self.meta_path).dump_html(graph_path)
165            else:
166                raise NotImplementedError('Image format or %s is unsupported!' % graph_path)
167        elif self.toolkit == 'keras':
168            from dlconv.keras.visualizer import KerasVisualizer
169            png_path, html_path = (None, None)
170            if graph_path.endswith('.png'):
171                png_path = graph_path
172            elif self._is_web_page(graph_path):
173                png_path = graph_path + ".png"
174                html_path = graph_path
175            else:
176                raise NotImplementedError('Image format or %s is unsupported!' % graph_path)
177            KerasVisualizer(self.meta_path).dump_png(png_path)
178            if html_path:
179                self._png_to_html(png_path, html_path)
180                os.remove(png_path)
181        else:
182            raise NotImplementedError('Visualization of %s is unsupported!' % self.toolkit)
183
184    def _is_web_page(self, path):
185        return path.split('.')[-1] in ('html', 'htm')
186
187    def _png_to_html(self, png_path, html_path):
188        with open(png_path, "rb") as f:
189            encoded = base64.b64encode(f.read()).decode('utf-8')
190        source = """<!DOCTYPE>
191<html>
192    <head>
193        <meta charset="utf-8">
194        <title>Keras</title>
195    </head>
196    <body>
197        <img alt="Model Graph" src="data:image/png;base64,{base64_str}" />
198    </body>
199</html>""".format(base64_str=encoded)
200        with open(html_path, 'w', encoding='utf-8') as f:
201            f.write(source)