1import base64 2from google.protobuf import json_format 3from importlib import import_module 4import json 5import numpy as np 6import os 7import sys 8 9from mmdnn.conversion.caffe.errors import ConversionError 10from mmdnn.conversion.caffe.common_graph import fetch_attr_value 11from mmdnn.conversion.caffe.utils import get_lower_case, get_upper_case, get_real_name 12 13 14class JsonFormatter(object): 15 '''Dumpt a DL graph into a Json file.''' 16 17 def __init__(self, graph): 18 self.graph_def = graph.as_graph_def() 19 20 def dump(self, json_path): 21 json_txt = json_format.MessageToJson(self.graph_def) 22 parsed = json.loads(json_txt) 23 formatted = json.dumps(parsed, indent=4, sort_keys=True) 24 with open(json_path, 'w') as f: 25 f.write(formatted) 26 27 28class PyWriter(object): 29 '''Dumpt a DL graph into a Python script.''' 30 31 def __init__(self, graph, data, target): 32 self.graph = graph 33 self.data = data 34 self.tab = ' ' * 4 35 self.prefix = '' 36 target = target.lower() 37 if target == 'tensorflow': 38 self.target = target 39 self.net = 'TensorFlowNetwork' 40 elif target == 'keras': 41 self.target = target 42 self.net = 'KerasNetwork' 43 elif target == 'caffe': 44 self.target = target 45 self.net = 'CaffeNetwork' 46 else: 47 raise ConversionError('Target %s is not supported yet.' % target) 48 49 def indent(self): 50 self.prefix += self.tab 51 52 def outdent(self): 53 self.prefix = self.prefix[:-len(self.tab)] 54 55 def statement(self, s): 56 return self.prefix + s + '\n' 57 58 def emit_imports(self): 59 return self.statement('from dlconv.%s import %s\n' % (self.target, self.net)) 60 61 def emit_class_def(self, name): 62 return self.statement('class %s(%s):' % (name, self.net)) 63 64 def emit_setup_def(self): 65 return self.statement('def setup(self):') 66 67 def emit_node(self, node): 68 '''Emits the Python source for this node.''' 69 70 def pair(key, value): 71 return '%s=%s' % (key, value) 72 args = [] 73 for input in node.input: 74 input = input.strip().split(':') 75 name = ''.join(input[:-1]) 76 idx = int(input[-1]) 77 assert name in self.graph.node_dict 78 parent = self.graph.get_node(name) 79 args.append(parent.output[idx]) 80 #FIXME: 81 output = [node.output[0]] 82 # output = node.output 83 for k, v in node.attr: 84 if k == 'cell_type': 85 args.append(pair(k, "'" + fetch_attr_value(v) + "'")) 86 else: 87 args.append(pair(k, fetch_attr_value(v))) 88 args.append(pair('name', "'" + node.name + "'")) # Set the node name 89 args = ', '.join(args) 90 return self.statement('%s = self.%s(%s)' % (', '.join(output), node.op, args)) 91 92 def dump(self, code_output_dir): 93 if not os.path.exists(code_output_dir): 94 os.makedirs(code_output_dir) 95 file_name = get_lower_case(self.graph.name) 96 code_output_path = os.path.join(code_output_dir, file_name + '.py') 97 data_output_path = os.path.join(code_output_dir, file_name + '.npy') 98 with open(code_output_path, 'w') as f: 99 f.write(self.emit()) 100 with open(data_output_path, 'wb') as f: 101 np.save(f, self.data) 102 return code_output_path, data_output_path 103 104 def emit(self): 105 # Decompose DAG into chains 106 chains = [] 107 for node in self.graph.topologically_sorted(): 108 attach_to_chain = None 109 if len(node.input) == 1: 110 parent = get_real_name(node.input[0]) 111 for chain in chains: 112 if chain[-1].name == parent: # Node is part of an existing chain. 113 attach_to_chain = chain 114 break 115 if attach_to_chain is None: # Start a new chain for this node. 116 attach_to_chain = [] 117 chains.append(attach_to_chain) 118 attach_to_chain.append(node) 119 120 # Generate Python code line by line 121 source = self.emit_imports() 122 source += self.emit_class_def(self.graph.name) 123 self.indent() 124 source += self.emit_setup_def() 125 self.indent() 126 blocks = [] 127 for chain in chains: 128 b = '' 129 for node in chain: 130 b += self.emit_node(node) 131 blocks.append(b[:-1]) 132 source += '\n\n'.join(blocks) 133 return source 134 135 136class ModelSaver(object): 137 138 def __init__(self, code_output_path, data_output_path): 139 self.code_output_path = code_output_path 140 self.data_output_path = data_output_path 141 142 def dump(self, model_output_dir): 143 '''Return the file path containing graph in generated model files.''' 144 if not os.path.exists(model_output_dir): 145 os.makedirs(model_output_dir) 146 sys.path.append(os.path.dirname(self.code_output_path)) 147 file_name = os.path.splitext(os.path.basename(self.code_output_path))[0] 148 module = import_module(file_name) 149 class_name = get_upper_case(file_name) 150 net = getattr(module, class_name) 151 return net.dump(self.data_output_path, model_output_dir) 152 153 154class GraphDrawer(object): 155 156 def __init__(self, toolkit, meta_path): 157 self.toolkit = toolkit.lower() 158 self.meta_path = meta_path 159 160 def dump(self, graph_path): 161 if self.toolkit == 'tensorflow': 162 from dlconv.tensorflow.visualizer import TensorFlowVisualizer 163 if self._is_web_page(graph_path): 164 TensorFlowVisualizer(self.meta_path).dump_html(graph_path) 165 else: 166 raise NotImplementedError('Image format or %s is unsupported!' % graph_path) 167 elif self.toolkit == 'keras': 168 from dlconv.keras.visualizer import KerasVisualizer 169 png_path, html_path = (None, None) 170 if graph_path.endswith('.png'): 171 png_path = graph_path 172 elif self._is_web_page(graph_path): 173 png_path = graph_path + ".png" 174 html_path = graph_path 175 else: 176 raise NotImplementedError('Image format or %s is unsupported!' % graph_path) 177 KerasVisualizer(self.meta_path).dump_png(png_path) 178 if html_path: 179 self._png_to_html(png_path, html_path) 180 os.remove(png_path) 181 else: 182 raise NotImplementedError('Visualization of %s is unsupported!' % self.toolkit) 183 184 def _is_web_page(self, path): 185 return path.split('.')[-1] in ('html', 'htm') 186 187 def _png_to_html(self, png_path, html_path): 188 with open(png_path, "rb") as f: 189 encoded = base64.b64encode(f.read()).decode('utf-8') 190 source = """<!DOCTYPE> 191<html> 192 <head> 193 <meta charset="utf-8"> 194 <title>Keras</title> 195 </head> 196 <body> 197 <img alt="Model Graph" src="data:image/png;base64,{base64_str}" /> 198 </body> 199</html>""".format(base64_str=encoded) 200 with open(html_path, 'w', encoding='utf-8') as f: 201 f.write(source)