1def tokenize(s):
2    tokens = []
3    token = ""
4    isString = False
5    isComment = False
6    for symbol in s:
7        isComment = (isComment and symbol != '\n') or (not isString and symbol == '#')
8        if isComment:
9            continue
10
11        if symbol == ' ' or symbol == '\t' or symbol == '\r' or symbol == '\'' or \
12           symbol == '\n' or symbol == ':' or symbol == '\"' or symbol == ';' or \
13           symbol == ',':
14
15            if (symbol == '\"' or symbol == '\'') and isString:
16                tokens.append(token)
17                token = ""
18            else:
19                if isString:
20                    token += symbol
21                elif token:
22                    tokens.append(token)
23                    token = ""
24            isString = (symbol == '\"' or symbol == '\'') ^ isString
25
26        elif symbol == '{' or symbol == '}' or symbol == '[' or symbol == ']':
27            if token:
28                tokens.append(token)
29                token = ""
30            tokens.append(symbol)
31        else:
32            token += symbol
33    if token:
34        tokens.append(token)
35    return tokens
36
37
38def parseMessage(tokens, idx):
39    msg = {}
40    assert(tokens[idx] == '{')
41
42    isArray = False
43    while True:
44        if not isArray:
45            idx += 1
46            if idx < len(tokens):
47                fieldName = tokens[idx]
48            else:
49                return None
50            if fieldName == '}':
51                break
52
53        idx += 1
54        fieldValue = tokens[idx]
55
56        if fieldValue == '{':
57            embeddedMsg, idx = parseMessage(tokens, idx)
58            if fieldName in msg:
59                msg[fieldName].append(embeddedMsg)
60            else:
61                msg[fieldName] = [embeddedMsg]
62        elif fieldValue == '[':
63            isArray = True
64        elif fieldValue == ']':
65            isArray = False
66        else:
67            if fieldName in msg:
68                msg[fieldName].append(fieldValue)
69            else:
70                msg[fieldName] = [fieldValue]
71    return msg, idx
72
73
74def readTextMessage(filePath):
75    if not filePath:
76        return {}
77    with open(filePath, 'rt') as f:
78        content = f.read()
79
80    tokens = tokenize('{' + content + '}')
81    msg = parseMessage(tokens, 0)
82    return msg[0] if msg else {}
83
84
85def listToTensor(values):
86    if all([isinstance(v, float) for v in values]):
87        dtype = 'DT_FLOAT'
88        field = 'float_val'
89    elif all([isinstance(v, int) for v in values]):
90        dtype = 'DT_INT32'
91        field = 'int_val'
92    else:
93        raise Exception('Wrong values types')
94
95    msg = {
96        'tensor': {
97            'dtype': dtype,
98            'tensor_shape': {
99                'dim': {
100                    'size': len(values)
101                }
102            }
103        }
104    }
105    msg['tensor'][field] = values
106    return msg
107
108
109def addConstNode(name, values, graph_def):
110    node = NodeDef()
111    node.name = name
112    node.op = 'Const'
113    node.addAttr('value', values)
114    graph_def.node.extend([node])
115
116
117def addSlice(inp, out, begins, sizes, graph_def):
118    beginsNode = NodeDef()
119    beginsNode.name = out + '/begins'
120    beginsNode.op = 'Const'
121    beginsNode.addAttr('value', begins)
122    graph_def.node.extend([beginsNode])
123
124    sizesNode = NodeDef()
125    sizesNode.name = out + '/sizes'
126    sizesNode.op = 'Const'
127    sizesNode.addAttr('value', sizes)
128    graph_def.node.extend([sizesNode])
129
130    sliced = NodeDef()
131    sliced.name = out
132    sliced.op = 'Slice'
133    sliced.input.append(inp)
134    sliced.input.append(beginsNode.name)
135    sliced.input.append(sizesNode.name)
136    graph_def.node.extend([sliced])
137
138
139def addReshape(inp, out, shape, graph_def):
140    shapeNode = NodeDef()
141    shapeNode.name = out + '/shape'
142    shapeNode.op = 'Const'
143    shapeNode.addAttr('value', shape)
144    graph_def.node.extend([shapeNode])
145
146    reshape = NodeDef()
147    reshape.name = out
148    reshape.op = 'Reshape'
149    reshape.input.append(inp)
150    reshape.input.append(shapeNode.name)
151    graph_def.node.extend([reshape])
152
153
154def addSoftMax(inp, out, graph_def):
155    softmax = NodeDef()
156    softmax.name = out
157    softmax.op = 'Softmax'
158    softmax.addAttr('axis', -1)
159    softmax.input.append(inp)
160    graph_def.node.extend([softmax])
161
162
163def addFlatten(inp, out, graph_def):
164    flatten = NodeDef()
165    flatten.name = out
166    flatten.op = 'Flatten'
167    flatten.input.append(inp)
168    graph_def.node.extend([flatten])
169
170
171class NodeDef:
172    def __init__(self):
173        self.input = []
174        self.name = ""
175        self.op = ""
176        self.attr = {}
177
178    def addAttr(self, key, value):
179        assert(not key in self.attr)
180        if isinstance(value, bool):
181            self.attr[key] = {'b': value}
182        elif isinstance(value, int):
183            self.attr[key] = {'i': value}
184        elif isinstance(value, float):
185            self.attr[key] = {'f': value}
186        elif isinstance(value, str):
187            self.attr[key] = {'s': value}
188        elif isinstance(value, list):
189            self.attr[key] = listToTensor(value)
190        else:
191            raise Exception('Unknown type of attribute ' + key)
192
193    def Clear(self):
194        self.input = []
195        self.name = ""
196        self.op = ""
197        self.attr = {}
198
199
200class GraphDef:
201    def __init__(self):
202        self.node = []
203
204    def save(self, filePath):
205        with open(filePath, 'wt') as f:
206
207            def printAttr(d, indent):
208                indent = ' ' * indent
209                for key, value in sorted(d.items(), key=lambda x:x[0].lower()):
210                    value = value if isinstance(value, list) else [value]
211                    for v in value:
212                        if isinstance(v, dict):
213                            f.write(indent + key + ' {\n')
214                            printAttr(v, len(indent) + 2)
215                            f.write(indent + '}\n')
216                        else:
217                            isString = False
218                            if isinstance(v, str) and not v.startswith('DT_'):
219                                try:
220                                    float(v)
221                                except:
222                                    isString = True
223
224                            if isinstance(v, bool):
225                                printed = 'true' if v else 'false'
226                            elif v == 'true' or v == 'false':
227                                printed = 'true' if v == 'true' else 'false'
228                            elif isString:
229                                printed = '\"%s\"' % v
230                            else:
231                                printed = str(v)
232                            f.write(indent + key + ': ' + printed + '\n')
233
234            for node in self.node:
235                f.write('node {\n')
236                f.write('  name: \"%s\"\n' % node.name)
237                f.write('  op: \"%s\"\n' % node.op)
238                for inp in node.input:
239                    f.write('  input: \"%s\"\n' % inp)
240                for key, value in sorted(node.attr.items(), key=lambda x:x[0].lower()):
241                    f.write('  attr {\n')
242                    f.write('    key: \"%s\"\n' % key)
243                    f.write('    value {\n')
244                    printAttr(value, 6)
245                    f.write('    }\n')
246                    f.write('  }\n')
247                f.write('}\n')
248
249
250def parseTextGraph(filePath):
251    msg = readTextMessage(filePath)
252
253    graph = GraphDef()
254    for node in msg['node']:
255        graphNode = NodeDef()
256        graphNode.name = node['name'][0]
257        graphNode.op = node['op'][0]
258        graphNode.input = node['input'] if 'input' in node else []
259
260        if 'attr' in node:
261            for attr in node['attr']:
262                graphNode.attr[attr['key'][0]] = attr['value'][0]
263
264        graph.node.append(graphNode)
265    return graph
266
267
268# Removes Identity nodes
269def removeIdentity(graph_def):
270    identities = {}
271    for node in graph_def.node:
272        if node.op == 'Identity' or node.op == 'IdentityN':
273            inp = node.input[0]
274            if inp in identities:
275                identities[node.name] = identities[inp]
276            else:
277                identities[node.name] = inp
278            graph_def.node.remove(node)
279
280    for node in graph_def.node:
281        for i in range(len(node.input)):
282            if node.input[i] in identities:
283                node.input[i] = identities[node.input[i]]
284
285
286def removeUnusedNodesAndAttrs(to_remove, graph_def):
287    unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu',
288                   'Index', 'Tperm', 'is_training', 'Tpaddings']
289
290    removedNodes = []
291
292    for i in reversed(range(len(graph_def.node))):
293        op = graph_def.node[i].op
294        name = graph_def.node[i].name
295
296        if to_remove(name, op):
297            if op != 'Const':
298                removedNodes.append(name)
299
300            del graph_def.node[i]
301        else:
302            for attr in unusedAttrs:
303                if attr in graph_def.node[i].attr:
304                    del graph_def.node[i].attr[attr]
305
306    # Remove references to removed nodes except Const nodes.
307    for node in graph_def.node:
308        for i in reversed(range(len(node.input))):
309            if node.input[i] in removedNodes:
310                del node.input[i]
311
312
313def writeTextGraph(modelPath, outputPath, outNodes):
314    try:
315        import cv2 as cv
316
317        cv.dnn.writeTextGraph(modelPath, outputPath)
318    except:
319        import tensorflow as tf
320        from tensorflow.tools.graph_transforms import TransformGraph
321
322        with tf.gfile.FastGFile(modelPath, 'rb') as f:
323            graph_def = tf.GraphDef()
324            graph_def.ParseFromString(f.read())
325
326            graph_def = TransformGraph(graph_def, ['image_tensor'], outNodes, ['sort_by_execution_order'])
327
328            for node in graph_def.node:
329                if node.op == 'Const':
330                    if 'value' in node.attr and node.attr['value'].tensor.tensor_content:
331                        node.attr['value'].tensor.tensor_content = b''
332
333        tf.train.write_graph(graph_def, "", outputPath, as_text=True)
334