1import argparse
2import numpy as np
3from tf_text_graph_common import *
4
5parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
6                                             'Mask-RCNN model from TensorFlow Object Detection API. '
7                                             'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
8parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
9parser.add_argument('--output', required=True, help='Path to output text graph.')
10parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
11args = parser.parse_args()
12
13scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
14                'FirstStageBoxPredictor/BoxEncodingPredictor',
15                'FirstStageBoxPredictor/ClassPredictor',
16                'CropAndResize',
17                'MaxPool2D',
18                'SecondStageFeatureExtractor',
19                'SecondStageBoxPredictor',
20                'Preprocessor/sub',
21                'Preprocessor/mul',
22                'image_tensor')
23
24scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
25                  'FirstStageFeatureExtractor/Shape',
26                  'FirstStageFeatureExtractor/strided_slice',
27                  'FirstStageFeatureExtractor/GreaterEqual',
28                  'FirstStageFeatureExtractor/LogicalAnd',
29                  'Conv/required_space_to_batch_paddings')
30
31# Load a config file.
32config = readTextMessage(args.config)
33config = config['model'][0]['faster_rcnn'][0]
34num_classes = int(config['num_classes'][0])
35
36grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
37scales = [float(s) for s in grid_anchor_generator['scales']]
38aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
39width_stride = float(grid_anchor_generator['width_stride'][0])
40height_stride = float(grid_anchor_generator['height_stride'][0])
41features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0])
42first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
43first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
44
45print('Number of classes: %d' % num_classes)
46print('Scales:            %s' % str(scales))
47print('Aspect ratios:     %s' % str(aspect_ratios))
48print('Width stride:      %f' % width_stride)
49print('Height stride:     %f' % height_stride)
50print('Features stride:   %f' % features_stride)
51
52# Read the graph.
53writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks'])
54graph_def = parseTextGraph(args.output)
55
56removeIdentity(graph_def)
57
58nodesToKeep = []
59def to_remove(name, op):
60    if name in nodesToKeep:
61        return False
62    return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
63           (name.startswith('CropAndResize') and op != 'CropAndResize')
64
65# Fuse atrous convolutions (with dilations).
66nodesMap = {node.name: node for node in graph_def.node}
67for node in reversed(graph_def.node):
68    if node.op == 'BatchToSpaceND':
69        del node.input[2]
70        conv = nodesMap[node.input[0]]
71        spaceToBatchND = nodesMap[conv.input[0]]
72
73        paddingsNode = NodeDef()
74        paddingsNode.name = conv.name + '/paddings'
75        paddingsNode.op = 'Const'
76        paddingsNode.addAttr('value', [2, 2, 2, 2])
77        graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
78        nodesToKeep.append(paddingsNode.name)
79
80        spaceToBatchND.input[2] = paddingsNode.name
81
82removeUnusedNodesAndAttrs(to_remove, graph_def)
83
84
85# Connect input node to the first layer
86assert(graph_def.node[0].op == 'Placeholder')
87graph_def.node[1].input.insert(0, graph_def.node[0].name)
88
89# Temporarily remove top nodes.
90topNodes = []
91numCropAndResize = 0
92while True:
93    node = graph_def.node.pop()
94    topNodes.append(node)
95    if node.op == 'CropAndResize':
96        numCropAndResize += 1
97        if numCropAndResize == 2:
98            break
99
100addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
101           'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
102
103addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
104           'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def)  # Compare with Reshape_4
105
106addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
107           'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
108
109# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
110addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
111           'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
112
113proposals = NodeDef()
114proposals.name = 'proposals'  # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
115proposals.op = 'PriorBox'
116proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
117proposals.input.append(graph_def.node[0].name)  # image_tensor
118
119proposals.addAttr('flip', False)
120proposals.addAttr('clip', True)
121proposals.addAttr('step', features_stride)
122proposals.addAttr('offset', 0.0)
123proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
124
125widths = []
126heights = []
127for a in aspect_ratios:
128    for s in scales:
129        ar = np.sqrt(a)
130        heights.append((height_stride**2) * s / ar)
131        widths.append((width_stride**2) * s * ar)
132
133proposals.addAttr('width', widths)
134proposals.addAttr('height', heights)
135
136graph_def.node.extend([proposals])
137
138# Compare with Reshape_5
139detectionOut = NodeDef()
140detectionOut.name = 'detection_out'
141detectionOut.op = 'DetectionOutput'
142
143detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
144detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
145detectionOut.input.append('proposals')
146
147detectionOut.addAttr('num_classes', 2)
148detectionOut.addAttr('share_location', True)
149detectionOut.addAttr('background_label_id', 0)
150detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
151detectionOut.addAttr('top_k', 6000)
152detectionOut.addAttr('code_type', "CENTER_SIZE")
153detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
154detectionOut.addAttr('clip', True)
155
156graph_def.node.extend([detectionOut])
157
158# Save as text.
159cropAndResizeNodesNames = []
160for node in reversed(topNodes):
161    if node.op != 'CropAndResize':
162        graph_def.node.extend([node])
163        topNodes.pop()
164    else:
165        cropAndResizeNodesNames.append(node.name)
166        if numCropAndResize == 1:
167            break
168        else:
169            graph_def.node.extend([node])
170            topNodes.pop()
171            numCropAndResize -= 1
172
173addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
174
175addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
176         'SecondStageBoxPredictor/Reshape_1/slice',
177         [0, 0, 1], [-1, -1, -1], graph_def)
178
179addReshape('SecondStageBoxPredictor/Reshape_1/slice',
180          'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
181
182# Replace Flatten subgraph onto a single node.
183for i in reversed(range(len(graph_def.node))):
184    if graph_def.node[i].op == 'CropAndResize':
185        graph_def.node[i].input.insert(1, 'detection_out')
186
187    if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
188        addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
189
190        graph_def.node[i].input.pop()
191        graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
192
193    if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
194                                  'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
195                                  'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
196                                  'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
197                                  'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
198                                  'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
199        del graph_def.node[i]
200
201for node in graph_def.node:
202    if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
203       node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
204        node.op = 'Flatten'
205        node.input.pop()
206
207    if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
208                     'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
209        node.addAttr('loc_pred_transposed', True)
210
211    if node.name.startswith('MaxPool2D'):
212        assert(node.op == 'MaxPool')
213        assert(len(cropAndResizeNodesNames) == 2)
214        node.input = [cropAndResizeNodesNames[0]]
215        del cropAndResizeNodesNames[0]
216
217################################################################################
218### Postprocessing
219################################################################################
220addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
221
222variance = NodeDef()
223variance.name = 'proposals/variance'
224variance.op = 'Const'
225variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
226graph_def.node.extend([variance])
227
228varianceEncoder = NodeDef()
229varianceEncoder.name = 'variance_encoded'
230varianceEncoder.op = 'Mul'
231varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
232varianceEncoder.input.append(variance.name)
233varianceEncoder.addAttr('axis', 2)
234graph_def.node.extend([varianceEncoder])
235
236addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
237addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
238
239detectionOut = NodeDef()
240detectionOut.name = 'detection_out_final'
241detectionOut.op = 'DetectionOutput'
242
243detectionOut.input.append('variance_encoded/flatten')
244detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
245detectionOut.input.append('detection_out/slice/reshape')
246
247detectionOut.addAttr('num_classes', num_classes)
248detectionOut.addAttr('share_location', False)
249detectionOut.addAttr('background_label_id', num_classes + 1)
250detectionOut.addAttr('nms_threshold', 0.6)
251detectionOut.addAttr('code_type', "CENTER_SIZE")
252detectionOut.addAttr('keep_top_k',100)
253detectionOut.addAttr('clip', True)
254detectionOut.addAttr('variance_encoded_in_target', True)
255detectionOut.addAttr('confidence_threshold', 0.3)
256detectionOut.addAttr('group_by_classes', False)
257graph_def.node.extend([detectionOut])
258
259for node in reversed(topNodes):
260    graph_def.node.extend([node])
261
262    if node.name.startswith('MaxPool2D'):
263        assert(node.op == 'MaxPool')
264        assert(len(cropAndResizeNodesNames) == 1)
265        node.input = [cropAndResizeNodesNames[0]]
266
267for i in reversed(range(len(graph_def.node))):
268    if graph_def.node[i].op == 'CropAndResize':
269        graph_def.node[i].input.insert(1, 'detection_out_final')
270        break
271
272graph_def.node[-1].name = 'detection_masks'
273graph_def.node[-1].op = 'Sigmoid'
274graph_def.node[-1].input.pop()
275
276def getUnconnectedNodes():
277    unconnected = [node.name for node in graph_def.node]
278    for node in graph_def.node:
279        for inp in node.input:
280            if inp in unconnected:
281                unconnected.remove(inp)
282    return unconnected
283
284while True:
285    unconnectedNodes = getUnconnectedNodes()
286    unconnectedNodes.remove(graph_def.node[-1].name)
287    if not unconnectedNodes:
288        break
289
290    for name in unconnectedNodes:
291        for i in range(len(graph_def.node)):
292            if graph_def.node[i].name == name:
293                del graph_def.node[i]
294                break
295
296# Save as text.
297graph_def.save(args.output)
298