1import argparse 2import numpy as np 3from tf_text_graph_common import * 4 5parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' 6 'Mask-RCNN model from TensorFlow Object Detection API. ' 7 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') 8parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') 9parser.add_argument('--output', required=True, help='Path to output text graph.') 10parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.') 11args = parser.parse_args() 12 13scopesToKeep = ('FirstStageFeatureExtractor', 'Conv', 14 'FirstStageBoxPredictor/BoxEncodingPredictor', 15 'FirstStageBoxPredictor/ClassPredictor', 16 'CropAndResize', 17 'MaxPool2D', 18 'SecondStageFeatureExtractor', 19 'SecondStageBoxPredictor', 20 'Preprocessor/sub', 21 'Preprocessor/mul', 22 'image_tensor') 23 24scopesToIgnore = ('FirstStageFeatureExtractor/Assert', 25 'FirstStageFeatureExtractor/Shape', 26 'FirstStageFeatureExtractor/strided_slice', 27 'FirstStageFeatureExtractor/GreaterEqual', 28 'FirstStageFeatureExtractor/LogicalAnd', 29 'Conv/required_space_to_batch_paddings') 30 31# Load a config file. 32config = readTextMessage(args.config) 33config = config['model'][0]['faster_rcnn'][0] 34num_classes = int(config['num_classes'][0]) 35 36grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0] 37scales = [float(s) for s in grid_anchor_generator['scales']] 38aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']] 39width_stride = float(grid_anchor_generator['width_stride'][0]) 40height_stride = float(grid_anchor_generator['height_stride'][0]) 41features_stride = float(config['feature_extractor'][0]['first_stage_features_stride'][0]) 42first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0]) 43first_stage_max_proposals = int(config['first_stage_max_proposals'][0]) 44 45print('Number of classes: %d' % num_classes) 46print('Scales: %s' % str(scales)) 47print('Aspect ratios: %s' % str(aspect_ratios)) 48print('Width stride: %f' % width_stride) 49print('Height stride: %f' % height_stride) 50print('Features stride: %f' % features_stride) 51 52# Read the graph. 53writeTextGraph(args.input, args.output, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes', 'detection_masks']) 54graph_def = parseTextGraph(args.output) 55 56removeIdentity(graph_def) 57 58nodesToKeep = [] 59def to_remove(name, op): 60 if name in nodesToKeep: 61 return False 62 return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \ 63 (name.startswith('CropAndResize') and op != 'CropAndResize') 64 65# Fuse atrous convolutions (with dilations). 66nodesMap = {node.name: node for node in graph_def.node} 67for node in reversed(graph_def.node): 68 if node.op == 'BatchToSpaceND': 69 del node.input[2] 70 conv = nodesMap[node.input[0]] 71 spaceToBatchND = nodesMap[conv.input[0]] 72 73 paddingsNode = NodeDef() 74 paddingsNode.name = conv.name + '/paddings' 75 paddingsNode.op = 'Const' 76 paddingsNode.addAttr('value', [2, 2, 2, 2]) 77 graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode) 78 nodesToKeep.append(paddingsNode.name) 79 80 spaceToBatchND.input[2] = paddingsNode.name 81 82removeUnusedNodesAndAttrs(to_remove, graph_def) 83 84 85# Connect input node to the first layer 86assert(graph_def.node[0].op == 'Placeholder') 87graph_def.node[1].input.insert(0, graph_def.node[0].name) 88 89# Temporarily remove top nodes. 90topNodes = [] 91numCropAndResize = 0 92while True: 93 node = graph_def.node.pop() 94 topNodes.append(node) 95 if node.op == 'CropAndResize': 96 numCropAndResize += 1 97 if numCropAndResize == 2: 98 break 99 100addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', 101 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def) 102 103addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', 104 'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4 105 106addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax', 107 'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def) 108 109# Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd 110addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd', 111 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def) 112 113proposals = NodeDef() 114proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) 115proposals.op = 'PriorBox' 116proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') 117proposals.input.append(graph_def.node[0].name) # image_tensor 118 119proposals.addAttr('flip', False) 120proposals.addAttr('clip', True) 121proposals.addAttr('step', features_stride) 122proposals.addAttr('offset', 0.0) 123proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2]) 124 125widths = [] 126heights = [] 127for a in aspect_ratios: 128 for s in scales: 129 ar = np.sqrt(a) 130 heights.append((height_stride**2) * s / ar) 131 widths.append((width_stride**2) * s * ar) 132 133proposals.addAttr('width', widths) 134proposals.addAttr('height', heights) 135 136graph_def.node.extend([proposals]) 137 138# Compare with Reshape_5 139detectionOut = NodeDef() 140detectionOut.name = 'detection_out' 141detectionOut.op = 'DetectionOutput' 142 143detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') 144detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten') 145detectionOut.input.append('proposals') 146 147detectionOut.addAttr('num_classes', 2) 148detectionOut.addAttr('share_location', True) 149detectionOut.addAttr('background_label_id', 0) 150detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold) 151detectionOut.addAttr('top_k', 6000) 152detectionOut.addAttr('code_type', "CENTER_SIZE") 153detectionOut.addAttr('keep_top_k', first_stage_max_proposals) 154detectionOut.addAttr('clip', True) 155 156graph_def.node.extend([detectionOut]) 157 158# Save as text. 159cropAndResizeNodesNames = [] 160for node in reversed(topNodes): 161 if node.op != 'CropAndResize': 162 graph_def.node.extend([node]) 163 topNodes.pop() 164 else: 165 cropAndResizeNodesNames.append(node.name) 166 if numCropAndResize == 1: 167 break 168 else: 169 graph_def.node.extend([node]) 170 topNodes.pop() 171 numCropAndResize -= 1 172 173addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def) 174 175addSlice('SecondStageBoxPredictor/Reshape_1/softmax', 176 'SecondStageBoxPredictor/Reshape_1/slice', 177 [0, 0, 1], [-1, -1, -1], graph_def) 178 179addReshape('SecondStageBoxPredictor/Reshape_1/slice', 180 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def) 181 182# Replace Flatten subgraph onto a single node. 183for i in reversed(range(len(graph_def.node))): 184 if graph_def.node[i].op == 'CropAndResize': 185 graph_def.node[i].input.insert(1, 'detection_out') 186 187 if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': 188 addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def) 189 190 graph_def.node[i].input.pop() 191 graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2') 192 193 if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', 194 'SecondStageBoxPredictor/Flatten/flatten/strided_slice', 195 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape', 196 'SecondStageBoxPredictor/Flatten_1/flatten/Shape', 197 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice', 198 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']: 199 del graph_def.node[i] 200 201for node in graph_def.node: 202 if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \ 203 node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape': 204 node.op = 'Flatten' 205 node.input.pop() 206 207 if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D', 208 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']: 209 node.addAttr('loc_pred_transposed', True) 210 211 if node.name.startswith('MaxPool2D'): 212 assert(node.op == 'MaxPool') 213 assert(len(cropAndResizeNodesNames) == 2) 214 node.input = [cropAndResizeNodesNames[0]] 215 del cropAndResizeNodesNames[0] 216 217################################################################################ 218### Postprocessing 219################################################################################ 220addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def) 221 222variance = NodeDef() 223variance.name = 'proposals/variance' 224variance.op = 'Const' 225variance.addAttr('value', [0.1, 0.1, 0.2, 0.2]) 226graph_def.node.extend([variance]) 227 228varianceEncoder = NodeDef() 229varianceEncoder.name = 'variance_encoded' 230varianceEncoder.op = 'Mul' 231varianceEncoder.input.append('SecondStageBoxPredictor/Reshape') 232varianceEncoder.input.append(variance.name) 233varianceEncoder.addAttr('axis', 2) 234graph_def.node.extend([varianceEncoder]) 235 236addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def) 237addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def) 238 239detectionOut = NodeDef() 240detectionOut.name = 'detection_out_final' 241detectionOut.op = 'DetectionOutput' 242 243detectionOut.input.append('variance_encoded/flatten') 244detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') 245detectionOut.input.append('detection_out/slice/reshape') 246 247detectionOut.addAttr('num_classes', num_classes) 248detectionOut.addAttr('share_location', False) 249detectionOut.addAttr('background_label_id', num_classes + 1) 250detectionOut.addAttr('nms_threshold', 0.6) 251detectionOut.addAttr('code_type', "CENTER_SIZE") 252detectionOut.addAttr('keep_top_k',100) 253detectionOut.addAttr('clip', True) 254detectionOut.addAttr('variance_encoded_in_target', True) 255detectionOut.addAttr('confidence_threshold', 0.3) 256detectionOut.addAttr('group_by_classes', False) 257graph_def.node.extend([detectionOut]) 258 259for node in reversed(topNodes): 260 graph_def.node.extend([node]) 261 262 if node.name.startswith('MaxPool2D'): 263 assert(node.op == 'MaxPool') 264 assert(len(cropAndResizeNodesNames) == 1) 265 node.input = [cropAndResizeNodesNames[0]] 266 267for i in reversed(range(len(graph_def.node))): 268 if graph_def.node[i].op == 'CropAndResize': 269 graph_def.node[i].input.insert(1, 'detection_out_final') 270 break 271 272graph_def.node[-1].name = 'detection_masks' 273graph_def.node[-1].op = 'Sigmoid' 274graph_def.node[-1].input.pop() 275 276def getUnconnectedNodes(): 277 unconnected = [node.name for node in graph_def.node] 278 for node in graph_def.node: 279 for inp in node.input: 280 if inp in unconnected: 281 unconnected.remove(inp) 282 return unconnected 283 284while True: 285 unconnectedNodes = getUnconnectedNodes() 286 unconnectedNodes.remove(graph_def.node[-1].name) 287 if not unconnectedNodes: 288 break 289 290 for name in unconnectedNodes: 291 for i in range(len(graph_def.node)): 292 if graph_def.node[i].name == name: 293 del graph_def.node[i] 294 break 295 296# Save as text. 297graph_def.save(args.output) 298