1def tokenize(s): 2 tokens = [] 3 token = "" 4 isString = False 5 isComment = False 6 for symbol in s: 7 isComment = (isComment and symbol != '\n') or (not isString and symbol == '#') 8 if isComment: 9 continue 10 11 if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \ 12 symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \ 13 symbol == ',': 14 15 if (symbol == '\"' or symbol == '\'') and isString: 16 tokens.append(token) 17 token = "" 18 else: 19 if isString: 20 token += symbol 21 elif token: 22 tokens.append(token) 23 token = "" 24 isString = (symbol == '\"' or symbol == '\'') ^ isString 25 26 elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']': 27 if token: 28 tokens.append(token) 29 token = "" 30 tokens.append(symbol) 31 else: 32 token += symbol 33 if token: 34 tokens.append(token) 35 return tokens 36 37 38def parseMessage(tokens, idx): 39 msg = {} 40 assert(tokens[idx] == '{') 41 42 isArray = False 43 while True: 44 if not isArray: 45 idx += 1 46 if idx < len(tokens): 47 fieldName = tokens[idx] 48 else: 49 return None 50 if fieldName == '}': 51 break 52 53 idx += 1 54 fieldValue = tokens[idx] 55 56 if fieldValue == '{': 57 embeddedMsg, idx = parseMessage(tokens, idx) 58 if fieldName in msg: 59 msg[fieldName].append(embeddedMsg) 60 else: 61 msg[fieldName] = [embeddedMsg] 62 elif fieldValue == '[': 63 isArray = True 64 elif fieldValue == ']': 65 isArray = False 66 else: 67 if fieldName in msg: 68 msg[fieldName].append(fieldValue) 69 else: 70 msg[fieldName] = [fieldValue] 71 return msg, idx 72 73 74def readTextMessage(filePath): 75 if not filePath: 76 return {} 77 with open(filePath, 'rt') as f: 78 content = f.read() 79 80 tokens = tokenize('{' + content + '}') 81 msg = parseMessage(tokens, 0) 82 return msg[0] if msg else {} 83 84 85def listToTensor(values): 86 if all([isinstance(v, float) for v in values]): 87 dtype = 'DT_FLOAT' 88 field = 'float_val' 89 elif all([isinstance(v, int) for v in values]): 90 dtype = 'DT_INT32' 91 field = 'int_val' 92 else: 93 raise Exception('Wrong values types') 94 95 msg = { 96 'tensor': { 97 'dtype': dtype, 98 'tensor_shape': { 99 'dim': { 100 'size': len(values) 101 } 102 } 103 } 104 } 105 msg['tensor'][field] = values 106 return msg 107 108 109def addConstNode(name, values, graph_def): 110 node = NodeDef() 111 node.name = name 112 node.op = 'Const' 113 node.addAttr('value', values) 114 graph_def.node.extend([node]) 115 116 117def addSlice(inp, out, begins, sizes, graph_def): 118 beginsNode = NodeDef() 119 beginsNode.name = out + '/begins' 120 beginsNode.op = 'Const' 121 beginsNode.addAttr('value', begins) 122 graph_def.node.extend([beginsNode]) 123 124 sizesNode = NodeDef() 125 sizesNode.name = out + '/sizes' 126 sizesNode.op = 'Const' 127 sizesNode.addAttr('value', sizes) 128 graph_def.node.extend([sizesNode]) 129 130 sliced = NodeDef() 131 sliced.name = out 132 sliced.op = 'Slice' 133 sliced.input.append(inp) 134 sliced.input.append(beginsNode.name) 135 sliced.input.append(sizesNode.name) 136 graph_def.node.extend([sliced]) 137 138 139def addReshape(inp, out, shape, graph_def): 140 shapeNode = NodeDef() 141 shapeNode.name = out + '/shape' 142 shapeNode.op = 'Const' 143 shapeNode.addAttr('value', shape) 144 graph_def.node.extend([shapeNode]) 145 146 reshape = NodeDef() 147 reshape.name = out 148 reshape.op = 'Reshape' 149 reshape.input.append(inp) 150 reshape.input.append(shapeNode.name) 151 graph_def.node.extend([reshape]) 152 153 154def addSoftMax(inp, out, graph_def): 155 softmax = NodeDef() 156 softmax.name = out 157 softmax.op = 'Softmax' 158 softmax.addAttr('axis', -1) 159 softmax.input.append(inp) 160 graph_def.node.extend([softmax]) 161 162 163def addFlatten(inp, out, graph_def): 164 flatten = NodeDef() 165 flatten.name = out 166 flatten.op = 'Flatten' 167 flatten.input.append(inp) 168 graph_def.node.extend([flatten]) 169 170 171class NodeDef: 172 def __init__(self): 173 self.input = [] 174 self.name = "" 175 self.op = "" 176 self.attr = {} 177 178 def addAttr(self, key, value): 179 assert(not key in self.attr) 180 if isinstance(value, bool): 181 self.attr[key] = {'b': value} 182 elif isinstance(value, int): 183 self.attr[key] = {'i': value} 184 elif isinstance(value, float): 185 self.attr[key] = {'f': value} 186 elif isinstance(value, str): 187 self.attr[key] = {'s': value} 188 elif isinstance(value, list): 189 self.attr[key] = listToTensor(value) 190 else: 191 raise Exception('Unknown type of attribute ' + key) 192 193 def Clear(self): 194 self.input = [] 195 self.name = "" 196 self.op = "" 197 self.attr = {} 198 199 200class GraphDef: 201 def __init__(self): 202 self.node = [] 203 204 def save(self, filePath): 205 with open(filePath, 'wt') as f: 206 207 def printAttr(d, indent): 208 indent = ' ' * indent 209 for key, value in sorted(d.items(), key=lambda x:x[0].lower()): 210 value = value if isinstance(value, list) else [value] 211 for v in value: 212 if isinstance(v, dict): 213 f.write(indent + key + ' {\n') 214 printAttr(v, len(indent) + 2) 215 f.write(indent + '}\n') 216 else: 217 isString = False 218 if isinstance(v, str) and not v.startswith('DT_'): 219 try: 220 float(v) 221 except: 222 isString = True 223 224 if isinstance(v, bool): 225 printed = 'true' if v else 'false' 226 elif v == 'true' or v == 'false': 227 printed = 'true' if v == 'true' else 'false' 228 elif isString: 229 printed = '\"%s\"' % v 230 else: 231 printed = str(v) 232 f.write(indent + key + ': ' + printed + '\n') 233 234 for node in self.node: 235 f.write('node {\n') 236 f.write(' name: \"%s\"\n' % node.name) 237 f.write(' op: \"%s\"\n' % node.op) 238 for inp in node.input: 239 f.write(' input: \"%s\"\n' % inp) 240 for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()): 241 f.write(' attr {\n') 242 f.write(' key: \"%s\"\n' % key) 243 f.write(' value {\n') 244 printAttr(value, 6) 245 f.write(' }\n') 246 f.write(' }\n') 247 f.write('}\n') 248 249 250def parseTextGraph(filePath): 251 msg = readTextMessage(filePath) 252 253 graph = GraphDef() 254 for node in msg['node']: 255 graphNode = NodeDef() 256 graphNode.name = node['name'][0] 257 graphNode.op = node['op'][0] 258 graphNode.input = node['input'] if 'input' in node else [] 259 260 if 'attr' in node: 261 for attr in node['attr']: 262 graphNode.attr[attr['key'][0]] = attr['value'][0] 263 264 graph.node.append(graphNode) 265 return graph 266 267 268# Removes Identity nodes 269def removeIdentity(graph_def): 270 identities = {} 271 for node in graph_def.node: 272 if node.op == 'Identity' or node.op == 'IdentityN': 273 inp = node.input[0] 274 if inp in identities: 275 identities[node.name] = identities[inp] 276 else: 277 identities[node.name] = inp 278 graph_def.node.remove(node) 279 280 for node in graph_def.node: 281 for i in range(len(node.input)): 282 if node.input[i] in identities: 283 node.input[i] = identities[node.input[i]] 284 285 286def removeUnusedNodesAndAttrs(to_remove, graph_def): 287 unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', 288 'Index', 'Tperm', 'is_training', 'Tpaddings'] 289 290 removedNodes = [] 291 292 for i in reversed(range(len(graph_def.node))): 293 op = graph_def.node[i].op 294 name = graph_def.node[i].name 295 296 if to_remove(name, op): 297 if op != 'Const': 298 removedNodes.append(name) 299 300 del graph_def.node[i] 301 else: 302 for attr in unusedAttrs: 303 if attr in graph_def.node[i].attr: 304 del graph_def.node[i].attr[attr] 305 306 # Remove references to removed nodes except Const nodes. 307 for node in graph_def.node: 308 for i in reversed(range(len(node.input))): 309 if node.input[i] in removedNodes: 310 del node.input[i] 311 312 313def writeTextGraph(modelPath, outputPath, outNodes): 314 try: 315 import cv2 as cv 316 317 cv.dnn.writeTextGraph(modelPath, outputPath) 318 except: 319 import tensorflow as tf 320 from tensorflow.tools.graph_transforms import TransformGraph 321 322 with tf.gfile.FastGFile(modelPath, 'rb') as f: 323 graph_def = tf.GraphDef() 324 graph_def.ParseFromString(f.read()) 325 326 graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order']) 327 328 for node in graph_def.node: 329 if node.op == 'Const': 330 if 'value' in node.attr and node.attr['value'].tensor.tensor_content: 331 node.attr['value'].tensor.tensor_content = b'' 332 333 tf.train.write_graph(graph_def, "", outputPath, as_text=True) 334