1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import os 19import ctypes 20import mxnet as mx 21from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str, mx_real_t 22from mxnet.symbol import Symbol 23import numpy as np 24from mxnet.test_utils import assert_almost_equal, environment 25from mxnet import gluon 26from mxnet.gluon import nn 27from mxnet import nd 28 29def network_structure_1(): 30 data1 = mx.sym.var('data1', shape=(2, 3, 10, 10)) 31 data2 = mx.sym.var('data2') 32 conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1) 33 conv2 = mx.sym.Convolution(data=data2, no_bias=True, kernel=(1, 1), num_filter=1) 34 out = mx.sym.Group([conv1, conv2]) 35 return (out, ['data1'], [(2, 3, 10, 10)]) 36 37def network_structure_2(): 38 # this tests whether the partitioning algorithm can deal with cycles 39 data = mx.sym.var('data', shape=(2, 3, 10, 10)) 40 ret = mx.sym.exp(data) 41 ret1 = mx.sym.cos(ret) 42 ret2 = mx.sym.sin(ret) 43 ret = ret1 + ret2 44 return (ret, ['data'], [(2, 3, 10, 10)]) 45 46def network_structure_3(): 47 # this tests whether the partitioned sym can distinguish in_args and aux_states 48 data = mx.sym.var('data', shape=(2, 3, 10, 10)) 49 ret = mx.sym.exp(data) 50 ret1 = mx.sym.cos(ret) 51 ret2 = mx.sym.sin(ret) 52 ret = ret1 + ret2 53 ret = mx.sym.BatchNorm(ret) 54 ret = mx.sym.BatchNorm(ret) 55 # Return the same and shape of 'data' and auxiliary states 56 return (ret, ['data'] + ret.list_auxiliary_states(), [(2, 3, 10, 10), (3,), (3,), (3,), (3,)]) 57 58def network_structure_4(): 59 # the last op has multiple duplicate outputs 60 data = mx.sym.var('data', shape=(2, 3, 10, 10)) 61 ret = mx.sym.exp(data) 62 ret = mx.sym.Group([ret, ret, ret]) 63 return (ret, ['data'], [(2, 3, 10, 10)]) 64 65def network_structure_5(): 66 # the subgraph has two duplicate input entries 67 data = mx.sym.var('data', shape=(2, 3, 10, 10)) 68 ret = data + data 69 return (ret, ['data'], [(2, 3, 10, 10)]) 70 71def network_structure_6(): 72 data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32) 73 data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2)) 74 data3 = mx.sym.sin(data2) 75 conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1) 76 return (conv, ['data1'], [(3, 3, 10, 10)]) 77 78def network_structure_7(): 79 # in this graph, the subgraph node and the other two external nodes form a cycle 80 data = mx.sym.Variable('data', shape=(1,)) 81 ret1 = mx.sym.sin(data) 82 ret2 = mx.sym.cos(ret1) 83 for _ in range(5): 84 ret2 = mx.sym.cos(ret2) 85 ret = ret1 + ret2 86 return (ret, ['data'], [(1,)]) 87 88def network_structure_8(): 89 # in this graph, two nodes in the subgraph consume the same input, and 90 # and two nodes outside the subgraph consume a single output from the subgraph 91 data = mx.sym.Variable('data', shape=(1,)) 92 sin1 = mx.sym.sin(data) 93 sin2 = mx.sym.sin(data) 94 plus = sin1 + sin2 95 ret1 = mx.sym.cos(plus) 96 ret2 = mx.sym.cos(plus) 97 ret = ret1 - ret2 98 return (ret, ['data'], [(1,)]) 99 100def get_graphs(): 101 return [ 102 (network_structure_1(), ['Convolution']), 103 (network_structure_2(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']), 104 (network_structure_2(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']), 105 (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']), 106 (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']), 107 (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), 108 (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']), 109 (network_structure_3(), ['exp', 'BatchNorm']), 110 (network_structure_3(), ['BatchNorm']), 111 (network_structure_4(), ['exp']), 112 (network_structure_5(), ['_plus', '_Plus', 'elemwise_add']), 113 (network_structure_6(), []), 114 (network_structure_6(), [mx.sym.sin.__name__]), 115 (network_structure_6(), [mx.sym.Convolution.__name__]), 116 (network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]), 117 (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus']), 118 (network_structure_8(), ['sin', 'elemwise_add']) 119 ] 120 121def check_subgraph_exe1(sym, subgraph_backend, op_names): 122 """Use the partitioned sym to simple_bind an executor and compare the outputs 123 with those of the original executor""" 124 out = SymbolHandle() 125 check_call(_LIB.MXBuildSubgraphByOpNames(sym.handle, c_str(subgraph_backend), mx_uint(len(op_names)), 126 c_str_array(op_names), ctypes.byref(out))) 127 128 partitioned_sym = Symbol(out) 129 assert partitioned_sym.list_inputs() == sym.list_inputs() 130 assert partitioned_sym.list_arguments() == sym.list_arguments() 131 assert partitioned_sym.list_auxiliary_states() == sym.list_auxiliary_states() 132 exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null') 133 partitioned_exe = partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null') 134 input_names = sym.list_inputs() 135 for name in input_names: 136 if name in exe.arg_dict: 137 exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape) 138 partitioned_exe.arg_dict[name][:] = exe.arg_dict[name] 139 else: 140 assert name in exe.aux_dict 141 exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape) 142 partitioned_exe.aux_dict[name][:] = exe.aux_dict[name] 143 exe.forward() 144 partitioned_exe.forward() 145 assert len(exe.outputs) == len(partitioned_exe.outputs) 146 for i in range(len(exe.outputs)): 147 assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(), 148 np.zeros(shape=(1,))) 149 150def check_subgraph_exe2(sym, subgraph_backend, op_names): 151 """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind 152 and compare results of the partitioned sym and the original sym.""" 153 def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): 154 if subgraph_backend is not None: 155 with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): 156 check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), 157 c_str_array(op_names))) 158 exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null') 159 input_names = sym.list_inputs() 160 for name in input_names: 161 if name in exe.arg_dict: 162 exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\ 163 if original_exec is None else original_exec.arg_dict[name] 164 else: 165 assert name in exe.aux_dict 166 exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\ 167 if original_exec is None else original_exec.aux_dict[name] 168 exe.forward() 169 return exe 170 original_exec = get_executor(sym) 171 with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): 172 check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), 173 c_str_array(op_names))) 174 partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec) 175 check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend))) 176 outputs1 = original_exec.outputs 177 outputs2 = partitioned_exec.outputs 178 assert len(outputs1) == len(outputs2) 179 for i in range(len(outputs1)): 180 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 181 182def check_subgraph_exe3(sym, subgraph_backend, op_names): 183 """Use the partitioned sym to bind an executor and compare the outputs 184 with those of the original executor""" 185 out = SymbolHandle() 186 check_call(_LIB.MXBuildSubgraphByOpNames(sym.handle, c_str(subgraph_backend), mx_uint(len(op_names)), 187 c_str_array(op_names), ctypes.byref(out))) 188 189 partitioned_sym = Symbol(out) 190 input_names = sym.list_inputs() 191 arg_names = sym.list_arguments() 192 aux_names = sym.list_auxiliary_states() 193 assert partitioned_sym.list_inputs() == input_names 194 assert partitioned_sym.list_arguments() == arg_names 195 assert partitioned_sym.list_auxiliary_states() == aux_names 196 arg_shapes, _, aux_shapes = sym.infer_shape() 197 arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes] 198 aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] 199 exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') 200 partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), args=arg_array, 201 aux_states=aux_array, grad_req='null') 202 exe.forward() 203 partitioned_exe.forward() 204 assert len(exe.outputs) == len(partitioned_exe.outputs) 205 for i in range(len(exe.outputs)): 206 assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(), 207 np.zeros(shape=(1,))) 208 209def check_subgraph_exe4(sym, subgraph_backend, op_names): 210 """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind 211 and compare results of the partitioned sym and the original sym.""" 212 def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None): 213 arg_shapes, _, aux_shapes = sym.infer_shape() 214 if subgraph_backend is None: 215 arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes] 216 aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] 217 else: 218 arg_array = None 219 aux_array = None 220 exe = sym.bind(ctx=mx.current_context(), 221 args=arg_array if subgraph_backend is None else original_exec.arg_arrays, 222 aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays, 223 grad_req='null') 224 exe.forward() 225 return exe 226 227 original_exec = get_executor(sym) 228 with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend): 229 check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)), 230 c_str_array(op_names))) 231 partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec) 232 check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend))) 233 outputs1 = original_exec.outputs 234 outputs2 = partitioned_exec.outputs 235 assert len(outputs1) == len(outputs2) 236 for i in range(len(outputs1)): 237 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 238 239def set_random_inputs(exe1, input_names): 240 """Sets random values to exe1's args and auxs""" 241 for name in input_names: 242 if name in exe1.arg_dict: 243 exe1.arg_dict[name][:] = mx.nd.random.uniform(shape=exe1.arg_dict[name].shape) 244 else: 245 assert name in exe1.aux_dict 246 exe1.aux_dict[name][:] = mx.nd.random.uniform(shape=exe1.aux_dict[name].shape) 247 248def copy_inputs_between_executors(exe1, exe2, input_names): 249 """Copies values of args and auxs from exe1 to exe2""" 250 for name in input_names: 251 if name in exe2.arg_dict: 252 exe2.arg_dict[name][:] = exe1.arg_dict[name] 253 else: 254 assert name in exe2.aux_dict 255 exe2.aux_dict[name][:] = exe1.aux_dict[name] 256 257def check_subgraph_exe5(sym, subgraph_backend, op_names): 258 """Call optimize_for to trigger graph partitioning without infer shapes/types before, 259 then simple_bind and compare results of the partitioned sym and the original sym.""" 260 # simple_bind 261 exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null') 262 input_names = sym.list_inputs() 263 set_random_inputs(exe1, input_names) 264 exe1.forward() 265 266 # partition before simple_bind 267 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 268 c_str_array(op_names))) 269 part_sym = sym.optimize_for(subgraph_backend) 270 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 271 272 exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null') 273 copy_inputs_between_executors(exe1, exe2, input_names) 274 exe2.forward() 275 276 # compare outputs 277 outputs1 = exe1.outputs 278 outputs2 = exe2.outputs 279 assert len(outputs1) == len(outputs2) 280 for i in range(len(outputs1)): 281 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 282 283def check_subgraph_exe6(sym, subgraph_backend, op_names): 284 """Call optimize_for to trigger graph partitioning with shapes/types, then simple_bind 285 and compare results of the partitioned sym and the original sym.""" 286 # simple_bind 287 exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null') 288 input_names = sym.list_inputs() 289 set_random_inputs(exe1, input_names) 290 exe1.forward() 291 292 # infer shape/type before partition before simple_bind 293 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 294 c_str_array(op_names))) 295 part_sym = sym.optimize_for(subgraph_backend, exe1.arg_dict, exe1.aux_dict) 296 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 297 298 exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null') 299 copy_inputs_between_executors(exe1, exe2, input_names) 300 exe2.forward() 301 302 # compare outputs 303 outputs1 = exe1.outputs 304 outputs2 = exe2.outputs 305 assert len(outputs1) == len(outputs2) 306 for i in range(len(outputs1)): 307 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 308 309def check_subgraph_exe7(sym, subgraph_backend, op_names): 310 """Call optimize_for to trigger graph partitioning without infer shapes/types before, 311 then bind and compare results of the partitioned sym and the original sym.""" 312 # bind 313 arg_shapes, _, aux_shapes = sym.infer_shape() 314 arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes] 315 aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes] 316 exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') 317 exe1.forward() 318 319 # partition before bind 320 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 321 c_str_array(op_names))) 322 part_sym = sym.optimize_for(subgraph_backend) 323 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 324 325 exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null') 326 exe2.forward() 327 328 # compare outputs 329 outputs1 = exe1.outputs 330 outputs2 = exe2.outputs 331 assert len(outputs1) == len(outputs2) 332 for i in range(len(outputs1)): 333 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 334 335def check_subgraph_exe8(sym, subgraph_backend, op_names): 336 """Call optimize_for to infer shapes, types and dtypes followed by graph partitioning, 337 then bind and compare results of the partitioned sym and the original sym.""" 338 # bind 339 arg_shapes, _, aux_shapes = sym.infer_shape() 340 arg_names = sym.list_arguments() 341 aux_names = sym.list_auxiliary_states() 342 arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)} 343 aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)} 344 exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') 345 exe1.forward() 346 347 # infer shape/type before partition before bind 348 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 349 c_str_array(op_names))) 350 part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict) 351 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 352 353 exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') 354 exe2.forward() 355 356 # compare outputs 357 outputs1 = exe1.outputs 358 outputs2 = exe2.outputs 359 assert len(outputs1) == len(outputs2) 360 for i in range(len(outputs1)): 361 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 362 363def check_subgraph_exe9(sym, subgraph_backend, op_names): 364 """Call hybridize() to partition the graph, and then compare results of the partitioned 365 sym and the original sym. Here do an inference before hybridizing with the subgraph_backend 366 which means we'll pass shapes/types""" 367 # create Gluon block for given symbol 368 inputs = [mx.sym.var(i, dtype=mx_real_t) for i in sym[1]] 369 sym_block = nn.SymbolBlock(sym[0], inputs) 370 sym_block.initialize(ctx=mx.current_context()) 371 x = [mx.nd.random.uniform(shape=s,ctx=mx.current_context()) for s in sym[2]] 372 # hybridize and export to get baseline 373 sym_block.hybridize() 374 outputs1 = sym_block(*x) 375 sym_block.export('check_subgraph_exe9') 376 377 # load model and partition 378 sym_block = nn.SymbolBlock.imports('check_subgraph_exe9-symbol.json',sym[1], 'check_subgraph_exe9-0000.params', 379 ctx=mx.current_context()) 380 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 381 c_str_array(op_names))) 382 sym_block.hybridize(backend=subgraph_backend) 383 outputs2 = sym_block(*x) 384 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 385 386 # compare outputs 387 assert len(outputs1) == len(outputs2) 388 for i in range(len(outputs1)): 389 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 390 391def check_subgraph_exe10(sym, subgraph_backend, op_names): 392 """Call optimize_for to infer shapes, types and dtypes followed by graph partitioning and 393 dedup subgraph, then bind and compare results of the partitioned sym and the original sym.""" 394 # bind 395 arg_shapes, _, aux_shapes = sym.infer_shape() 396 arg_names = sym.list_arguments() 397 aux_names = sym.list_auxiliary_states() 398 arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)} 399 aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)} 400 exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') 401 exe1.forward() 402 403 # infer shape/type before partition before bind 404 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 405 c_str_array(op_names))) 406 print(sym.tojson()) 407 part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict, dedup_subgraph=True) 408 print(part_sym.tojson()) 409 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 410 411 exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') 412 exe2.forward() 413 414 # compare outputs 415 outputs1 = exe1.outputs 416 outputs2 = exe2.outputs 417 assert len(outputs1) == len(outputs2) 418 for i in range(len(outputs1)): 419 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 420 421def check_subgraph(subgraph_backend): 422 for sym, op_names in get_graphs(): 423 check_subgraph_exe1(sym[0], subgraph_backend, op_names) 424 check_subgraph_exe2(sym[0], subgraph_backend, op_names) 425 check_subgraph_exe3(sym[0], subgraph_backend, op_names) 426 check_subgraph_exe4(sym[0], subgraph_backend, op_names) 427 428def check_subgraph_backend_sym(subgraph_backend): 429 for sym, op_names in get_graphs(): 430 check_subgraph_exe5(sym[0], subgraph_backend, op_names) 431 check_subgraph_exe6(sym[0], subgraph_backend, op_names) 432 check_subgraph_exe7(sym[0], subgraph_backend, op_names) 433 check_subgraph_exe8(sym[0], subgraph_backend, op_names) 434 check_subgraph_exe10(sym[0], subgraph_backend, op_names) 435 436def check_subgraph_backend_gluon(subgraph_backend): 437 for sym, op_names in get_graphs(): 438 check_subgraph_exe9(sym, subgraph_backend, op_names) 439 440# Test graph partition for 'default' backend. 441def test_subgraph(): 442 check_subgraph('default') 443 444# Test graph partition for 'default_v2' backend. 445def test_subgraph_v2(): 446 check_subgraph('default_v2') 447 448# Test enhanced Python and C APIs for graph partitioning given 'default' backend. 449def test_subgraph_backend_sym(): 450 check_subgraph_backend_sym('default') 451 452# Test enhanced Python and C APIs for graph partitioning given 'default_v2' backend. 453def test_subgraph_backend_sym_v2(): 454 check_subgraph_backend_sym('default_v2') 455 456# Test Gluon HybridBlocks for graph partitioning given 'default' backend. 457def test_subgraph_backend_gluon(): 458 check_subgraph_backend_gluon('default') 459 460# Test Gluon HybridBlocks for graph partitioning given 'default_v2' backend. 461def test_subgraph_backend_gluon_v2(): 462 check_subgraph_backend_gluon('default_v2') 463 464# Test Gluon HybridBlocks for graph partitioning a network created by HybridSequential. 465def test_subgraph_backend_gluon_ext1(): 466 def get_net(): 467 net = nn.HybridSequential() # Here we use the class HybridSequential. 468 net.add(nn.Dense(256, activation='relu'), 469 nn.Dense(128, activation='relu'), 470 nn.Dense(2)) 471 return net 472 473 # regular inference 474 x = nd.random.normal(shape=(1, 512),ctx=mx.current_context()) 475 net = get_net() 476 net.collect_params().initialize(ctx=mx.current_context()) 477 outputs1 = net(x) 478 net.save_parameters('test_subgraph_backend_gluon_ext1.params') 479 480 # after partitioning 481 net = get_net() 482 net.load_parameters('test_subgraph_backend_gluon_ext1.params',ctx=mx.current_context()) 483 subgraph_backend = 'default' 484 op_names = ['FullyConnected'] 485 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 486 c_str_array(op_names))) 487 net.hybridize(backend = subgraph_backend) 488 outputs2 = net(x) 489 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 490 491 # compare outputs 492 assert len(outputs1) == len(outputs2) 493 for i in range(len(outputs1)): 494 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 495 496# Test Gluon HybridBlocks for graph partitioning a network created by HybridBlock. 497def test_subgraph_backend_gluon_ext2(): 498 class Net(gluon.HybridBlock): 499 def __init__(self, **kwargs): 500 super(Net, self).__init__(**kwargs) 501 with self.name_scope(): 502 self.fc1 = nn.Dense(256) 503 self.fc2 = nn.Dense(128) 504 self.fc3 = nn.Dense(2) 505 506 def hybrid_forward(self, F, x): 507 x = F.relu(self.fc1(x)) 508 x = F.relu(self.fc2(x)) 509 return self.fc3(x) 510 # regular inference 511 x = nd.random.normal(shape=(1, 512),ctx=mx.current_context()) 512 net = Net() 513 net.collect_params().initialize(ctx=mx.current_context()) 514 outputs1 = net(x) 515 net.save_parameters('test_subgraph_backend_gluon_ext2.params') 516 517 # after partitioning 518 net = Net() 519 net.load_parameters('test_subgraph_backend_gluon_ext2.params',ctx=mx.current_context()) 520 subgraph_backend = 'default' 521 op_names = ['FullyConnected'] 522 check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)), 523 c_str_array(op_names))) 524 net.hybridize(backend = subgraph_backend) 525 outputs2 = net(x) 526 check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend))) 527 528 # compare outputs 529 assert len(outputs1) == len(outputs2) 530 for i in range(len(outputs1)): 531 assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,))) 532 533if __name__ == '__main__': 534 import nose 535 nose.runmodule() 536