1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import os
19import ctypes
20import mxnet as mx
21from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array, c_str, mx_real_t
22from mxnet.symbol import Symbol
23import numpy as np
24from mxnet.test_utils import assert_almost_equal, environment
25from mxnet import gluon
26from mxnet.gluon import nn
27from mxnet import nd
28
29def network_structure_1():
30    data1 = mx.sym.var('data1', shape=(2, 3, 10, 10))
31    data2 = mx.sym.var('data2')
32    conv1 = mx.sym.Convolution(data=data1, weight=data2, no_bias=True, kernel=(2, 2), num_filter=1)
33    conv2 = mx.sym.Convolution(data=data2, no_bias=True, kernel=(1, 1), num_filter=1)
34    out = mx.sym.Group([conv1, conv2])
35    return (out, ['data1'], [(2, 3, 10, 10)])
36
37def network_structure_2():
38    # this tests whether the partitioning algorithm can deal with cycles
39    data = mx.sym.var('data', shape=(2, 3, 10, 10))
40    ret = mx.sym.exp(data)
41    ret1 = mx.sym.cos(ret)
42    ret2 = mx.sym.sin(ret)
43    ret = ret1 + ret2
44    return (ret, ['data'], [(2, 3, 10, 10)])
45
46def network_structure_3():
47    # this tests whether the partitioned sym can distinguish in_args and aux_states
48    data = mx.sym.var('data', shape=(2, 3, 10, 10))
49    ret = mx.sym.exp(data)
50    ret1 = mx.sym.cos(ret)
51    ret2 = mx.sym.sin(ret)
52    ret = ret1 + ret2
53    ret = mx.sym.BatchNorm(ret)
54    ret = mx.sym.BatchNorm(ret)
55    # Return the same and shape of 'data' and auxiliary states
56    return  (ret, ['data'] + ret.list_auxiliary_states(), [(2, 3, 10, 10), (3,), (3,), (3,), (3,)])
57
58def network_structure_4():
59    # the last op has multiple duplicate outputs
60    data = mx.sym.var('data', shape=(2, 3, 10, 10))
61    ret = mx.sym.exp(data)
62    ret = mx.sym.Group([ret, ret, ret])
63    return (ret, ['data'], [(2, 3, 10, 10)])
64
65def network_structure_5():
66    # the subgraph has two duplicate input entries
67    data = mx.sym.var('data', shape=(2, 3, 10, 10))
68    ret = data + data
69    return (ret, ['data'], [(2, 3, 10, 10)])
70
71def network_structure_6():
72    data1 = mx.sym.Variable('data1', shape=(3, 3, 10, 10), dtype=np.float32)
73    data2 = mx.sym.Variable('data2', shape=(1, 0, 2, 2))
74    data3 = mx.sym.sin(data2)
75    conv = mx.sym.Convolution(data=data1, weight=data3, kernel=(2, 2), num_filter=1)
76    return (conv, ['data1'], [(3, 3, 10, 10)])
77
78def network_structure_7():
79    # in this graph, the subgraph node and the other two external nodes form a cycle
80    data = mx.sym.Variable('data', shape=(1,))
81    ret1 = mx.sym.sin(data)
82    ret2 = mx.sym.cos(ret1)
83    for _ in range(5):
84        ret2 = mx.sym.cos(ret2)
85    ret = ret1 + ret2
86    return (ret, ['data'], [(1,)])
87
88def network_structure_8():
89    # in this graph, two nodes in the subgraph consume the same input, and
90    # and two nodes outside the subgraph consume a single output from the subgraph
91    data = mx.sym.Variable('data', shape=(1,))
92    sin1 = mx.sym.sin(data)
93    sin2 = mx.sym.sin(data)
94    plus = sin1 + sin2
95    ret1 = mx.sym.cos(plus)
96    ret2 = mx.sym.cos(plus)
97    ret = ret1 - ret2
98    return (ret, ['data'], [(1,)])
99
100def get_graphs():
101    return [
102            (network_structure_1(), ['Convolution']),
103            (network_structure_2(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
104            (network_structure_2(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
105            (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus']),
106            (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus']),
107            (network_structure_3(), ['exp', 'sin', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
108            (network_structure_3(), ['exp', 'cos', '_Plus', 'elemwise_add', '_plus', 'BatchNorm']),
109            (network_structure_3(), ['exp', 'BatchNorm']),
110            (network_structure_3(), ['BatchNorm']),
111            (network_structure_4(), ['exp']),
112            (network_structure_5(), ['_plus', '_Plus', 'elemwise_add']),
113            (network_structure_6(), []),
114            (network_structure_6(), [mx.sym.sin.__name__]),
115            (network_structure_6(), [mx.sym.Convolution.__name__]),
116            (network_structure_6(), [mx.sym.sin.__name__, mx.sym.Convolution.__name__]),
117            (network_structure_7(), ['sin', 'elemwise_add', '_plus', '_Plus']),
118            (network_structure_8(), ['sin', 'elemwise_add'])
119            ]
120
121def check_subgraph_exe1(sym, subgraph_backend, op_names):
122    """Use the partitioned sym to simple_bind an executor and compare the outputs
123    with those of the original executor"""
124    out = SymbolHandle()
125    check_call(_LIB.MXBuildSubgraphByOpNames(sym.handle, c_str(subgraph_backend), mx_uint(len(op_names)),
126                                              c_str_array(op_names), ctypes.byref(out)))
127
128    partitioned_sym = Symbol(out)
129    assert partitioned_sym.list_inputs() == sym.list_inputs()
130    assert partitioned_sym.list_arguments() == sym.list_arguments()
131    assert partitioned_sym.list_auxiliary_states() == sym.list_auxiliary_states()
132    exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
133    partitioned_exe = partitioned_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
134    input_names = sym.list_inputs()
135    for name in input_names:
136        if name in exe.arg_dict:
137            exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)
138            partitioned_exe.arg_dict[name][:] = exe.arg_dict[name]
139        else:
140            assert name in exe.aux_dict
141            exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)
142            partitioned_exe.aux_dict[name][:] = exe.aux_dict[name]
143    exe.forward()
144    partitioned_exe.forward()
145    assert len(exe.outputs) == len(partitioned_exe.outputs)
146    for i in range(len(exe.outputs)):
147        assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
148                            np.zeros(shape=(1,)))
149
150def check_subgraph_exe2(sym, subgraph_backend, op_names):
151    """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in simple_bind
152    and compare results of the partitioned sym and the original sym."""
153    def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
154        if subgraph_backend is not None:
155            with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
156                check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
157                                                         c_str_array(op_names)))
158        exe = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
159        input_names = sym.list_inputs()
160        for name in input_names:
161            if name in exe.arg_dict:
162                exe.arg_dict[name][:] = mx.nd.random.uniform(shape=exe.arg_dict[name].shape)\
163                    if original_exec is None else original_exec.arg_dict[name]
164            else:
165                assert name in exe.aux_dict
166                exe.aux_dict[name][:] = mx.nd.random.uniform(shape=exe.aux_dict[name].shape)\
167                    if original_exec is None else original_exec.aux_dict[name]
168        exe.forward()
169        return exe
170    original_exec = get_executor(sym)
171    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
172        check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
173                                                     c_str_array(op_names)))
174        partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec)
175        check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
176    outputs1 = original_exec.outputs
177    outputs2 = partitioned_exec.outputs
178    assert len(outputs1) == len(outputs2)
179    for i in range(len(outputs1)):
180        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
181
182def check_subgraph_exe3(sym, subgraph_backend, op_names):
183    """Use the partitioned sym to bind an executor and compare the outputs
184    with those of the original executor"""
185    out = SymbolHandle()
186    check_call(_LIB.MXBuildSubgraphByOpNames(sym.handle, c_str(subgraph_backend), mx_uint(len(op_names)),
187                                              c_str_array(op_names), ctypes.byref(out)))
188
189    partitioned_sym = Symbol(out)
190    input_names = sym.list_inputs()
191    arg_names = sym.list_arguments()
192    aux_names = sym.list_auxiliary_states()
193    assert partitioned_sym.list_inputs() == input_names
194    assert partitioned_sym.list_arguments() == arg_names
195    assert partitioned_sym.list_auxiliary_states() == aux_names
196    arg_shapes, _, aux_shapes = sym.infer_shape()
197    arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
198    aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
199    exe = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
200    partitioned_exe = partitioned_sym.bind(ctx=mx.current_context(), args=arg_array,
201                                           aux_states=aux_array, grad_req='null')
202    exe.forward()
203    partitioned_exe.forward()
204    assert len(exe.outputs) == len(partitioned_exe.outputs)
205    for i in range(len(exe.outputs)):
206        assert_almost_equal((exe.outputs[i] - partitioned_exe.outputs[i]).abs().sum().asnumpy(),
207                            np.zeros(shape=(1,)))
208
209def check_subgraph_exe4(sym, subgraph_backend, op_names):
210    """Use env var MXNET_SUBGRAPH_BACKEND=default to trigger graph partitioning in bind
211    and compare results of the partitioned sym and the original sym."""
212    def get_executor(sym, subgraph_backend=None, op_names=None, original_exec=None):
213        arg_shapes, _, aux_shapes = sym.infer_shape()
214        if subgraph_backend is None:
215            arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
216            aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
217        else:
218            arg_array = None
219            aux_array = None
220        exe = sym.bind(ctx=mx.current_context(),
221                       args=arg_array if subgraph_backend is None else original_exec.arg_arrays,
222                       aux_states=aux_array if subgraph_backend is None else original_exec.aux_arrays,
223                       grad_req='null')
224        exe.forward()
225        return exe
226
227    original_exec = get_executor(sym)
228    with environment('MXNET_SUBGRAPH_BACKEND', subgraph_backend):
229        check_call(_LIB.MXSetSubgraphPropertyOpNames(c_str(subgraph_backend), mx_uint(len(op_names)),
230                                                     c_str_array(op_names)))
231        partitioned_exec = get_executor(sym, subgraph_backend, op_names, original_exec)
232        check_call(_LIB.MXRemoveSubgraphPropertyOpNames(c_str(subgraph_backend)))
233    outputs1 = original_exec.outputs
234    outputs2 = partitioned_exec.outputs
235    assert len(outputs1) == len(outputs2)
236    for i in range(len(outputs1)):
237        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
238
239def set_random_inputs(exe1, input_names):
240    """Sets random values to exe1's args and auxs"""
241    for name in input_names:
242        if name in exe1.arg_dict:
243            exe1.arg_dict[name][:] = mx.nd.random.uniform(shape=exe1.arg_dict[name].shape)
244        else:
245            assert name in exe1.aux_dict
246            exe1.aux_dict[name][:] = mx.nd.random.uniform(shape=exe1.aux_dict[name].shape)
247
248def copy_inputs_between_executors(exe1, exe2, input_names):
249    """Copies values of args and auxs from exe1 to exe2"""
250    for name in input_names:
251        if name in exe2.arg_dict:
252            exe2.arg_dict[name][:] = exe1.arg_dict[name]
253        else:
254            assert name in exe2.aux_dict
255            exe2.aux_dict[name][:] = exe1.aux_dict[name]
256
257def check_subgraph_exe5(sym, subgraph_backend, op_names):
258    """Call optimize_for to trigger graph partitioning without infer shapes/types before,
259    then simple_bind and compare results of the partitioned sym and the original sym."""
260    # simple_bind
261    exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
262    input_names = sym.list_inputs()
263    set_random_inputs(exe1, input_names)
264    exe1.forward()
265
266    # partition before simple_bind
267    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
268                                                 c_str_array(op_names)))
269    part_sym = sym.optimize_for(subgraph_backend)
270    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
271
272    exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
273    copy_inputs_between_executors(exe1, exe2, input_names)
274    exe2.forward()
275
276    # compare outputs
277    outputs1 = exe1.outputs
278    outputs2 = exe2.outputs
279    assert len(outputs1) == len(outputs2)
280    for i in range(len(outputs1)):
281        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
282
283def check_subgraph_exe6(sym, subgraph_backend, op_names):
284    """Call optimize_for to trigger graph partitioning with shapes/types, then simple_bind
285    and compare results of the partitioned sym and the original sym."""
286    # simple_bind
287    exe1 = sym.simple_bind(ctx=mx.current_context(), grad_req='null')
288    input_names = sym.list_inputs()
289    set_random_inputs(exe1, input_names)
290    exe1.forward()
291
292    # infer shape/type before partition before simple_bind
293    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
294                                                 c_str_array(op_names)))
295    part_sym = sym.optimize_for(subgraph_backend, exe1.arg_dict, exe1.aux_dict)
296    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
297
298    exe2 = part_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
299    copy_inputs_between_executors(exe1, exe2, input_names)
300    exe2.forward()
301
302    # compare outputs
303    outputs1 = exe1.outputs
304    outputs2 = exe2.outputs
305    assert len(outputs1) == len(outputs2)
306    for i in range(len(outputs1)):
307        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
308
309def check_subgraph_exe7(sym, subgraph_backend, op_names):
310    """Call optimize_for to trigger graph partitioning without infer shapes/types before,
311    then bind and compare results of the partitioned sym and the original sym."""
312    # bind
313    arg_shapes, _, aux_shapes = sym.infer_shape()
314    arg_array = [mx.nd.random.uniform(shape=shape) for shape in arg_shapes]
315    aux_array = [mx.nd.random.uniform(shape=shape) for shape in aux_shapes]
316    exe1 = sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
317    exe1.forward()
318
319    # partition before bind
320    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
321                                                 c_str_array(op_names)))
322    part_sym = sym.optimize_for(subgraph_backend)
323    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
324
325    exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_array, aux_states=aux_array, grad_req='null')
326    exe2.forward()
327
328    # compare outputs
329    outputs1 = exe1.outputs
330    outputs2 = exe2.outputs
331    assert len(outputs1) == len(outputs2)
332    for i in range(len(outputs1)):
333        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
334
335def check_subgraph_exe8(sym, subgraph_backend, op_names):
336    """Call optimize_for to infer shapes, types and dtypes followed by graph partitioning,
337    then bind and compare results of the partitioned sym and the original sym."""
338    # bind
339    arg_shapes, _, aux_shapes = sym.infer_shape()
340    arg_names = sym.list_arguments()
341    aux_names = sym.list_auxiliary_states()
342    arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)}
343    aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)}
344    exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
345    exe1.forward()
346
347    # infer shape/type before partition before bind
348    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
349                                                   c_str_array(op_names)))
350    part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict)
351    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
352
353    exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
354    exe2.forward()
355
356    # compare outputs
357    outputs1 = exe1.outputs
358    outputs2 = exe2.outputs
359    assert len(outputs1) == len(outputs2)
360    for i in range(len(outputs1)):
361        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
362
363def check_subgraph_exe9(sym, subgraph_backend, op_names):
364    """Call hybridize() to partition the graph, and then compare results of the partitioned
365    sym and the original sym. Here do an inference before hybridizing with the subgraph_backend
366    which means we'll pass shapes/types"""
367    # create Gluon block for given symbol
368    inputs = [mx.sym.var(i, dtype=mx_real_t) for i in sym[1]]
369    sym_block = nn.SymbolBlock(sym[0], inputs)
370    sym_block.initialize(ctx=mx.current_context())
371    x = [mx.nd.random.uniform(shape=s,ctx=mx.current_context()) for s in sym[2]]
372    # hybridize and export to get baseline
373    sym_block.hybridize()
374    outputs1 = sym_block(*x)
375    sym_block.export('check_subgraph_exe9')
376
377    # load model and partition
378    sym_block = nn.SymbolBlock.imports('check_subgraph_exe9-symbol.json',sym[1], 'check_subgraph_exe9-0000.params',
379                                       ctx=mx.current_context())
380    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
381                                                c_str_array(op_names)))
382    sym_block.hybridize(backend=subgraph_backend)
383    outputs2 = sym_block(*x)
384    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
385
386    # compare outputs
387    assert len(outputs1) == len(outputs2)
388    for i in range(len(outputs1)):
389        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
390
391def check_subgraph_exe10(sym, subgraph_backend, op_names):
392    """Call optimize_for to infer shapes, types and dtypes followed by graph partitioning and
393    dedup subgraph, then bind and compare results of the partitioned sym and the original sym."""
394    # bind
395    arg_shapes, _, aux_shapes = sym.infer_shape()
396    arg_names = sym.list_arguments()
397    aux_names = sym.list_auxiliary_states()
398    arg_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(arg_names,arg_shapes)}
399    aux_dict = {name:mx.nd.random.uniform(shape=shape) for name,shape in zip(aux_names,aux_shapes)}
400    exe1 = sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
401    exe1.forward()
402
403    # infer shape/type before partition before bind
404    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
405                                                   c_str_array(op_names)))
406    print(sym.tojson())
407    part_sym = sym.optimize_for(subgraph_backend, arg_dict, aux_dict, dedup_subgraph=True)
408    print(part_sym.tojson())
409    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
410
411    exe2 = part_sym.bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null')
412    exe2.forward()
413
414    # compare outputs
415    outputs1 = exe1.outputs
416    outputs2 = exe2.outputs
417    assert len(outputs1) == len(outputs2)
418    for i in range(len(outputs1)):
419        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
420
421def check_subgraph(subgraph_backend):
422    for sym, op_names in get_graphs():
423        check_subgraph_exe1(sym[0], subgraph_backend, op_names)
424        check_subgraph_exe2(sym[0], subgraph_backend, op_names)
425        check_subgraph_exe3(sym[0], subgraph_backend, op_names)
426        check_subgraph_exe4(sym[0], subgraph_backend, op_names)
427
428def check_subgraph_backend_sym(subgraph_backend):
429    for sym, op_names in get_graphs():
430        check_subgraph_exe5(sym[0], subgraph_backend, op_names)
431        check_subgraph_exe6(sym[0], subgraph_backend, op_names)
432        check_subgraph_exe7(sym[0], subgraph_backend, op_names)
433        check_subgraph_exe8(sym[0], subgraph_backend, op_names)
434        check_subgraph_exe10(sym[0], subgraph_backend, op_names)
435
436def check_subgraph_backend_gluon(subgraph_backend):
437    for sym, op_names in get_graphs():
438        check_subgraph_exe9(sym, subgraph_backend, op_names)
439
440# Test graph partition for 'default' backend.
441def test_subgraph():
442    check_subgraph('default')
443
444# Test graph partition for 'default_v2' backend.
445def test_subgraph_v2():
446    check_subgraph('default_v2')
447
448# Test enhanced Python and C APIs for graph partitioning given 'default' backend.
449def test_subgraph_backend_sym():
450    check_subgraph_backend_sym('default')
451
452# Test enhanced Python and C APIs for graph partitioning given 'default_v2' backend.
453def test_subgraph_backend_sym_v2():
454    check_subgraph_backend_sym('default_v2')
455
456# Test Gluon HybridBlocks for graph partitioning given 'default' backend.
457def test_subgraph_backend_gluon():
458    check_subgraph_backend_gluon('default')
459
460# Test Gluon HybridBlocks for graph partitioning given 'default_v2' backend.
461def test_subgraph_backend_gluon_v2():
462    check_subgraph_backend_gluon('default_v2')
463
464# Test Gluon HybridBlocks for graph partitioning a network created by HybridSequential.
465def test_subgraph_backend_gluon_ext1():
466    def get_net():
467        net = nn.HybridSequential()  # Here we use the class HybridSequential.
468        net.add(nn.Dense(256, activation='relu'),
469                nn.Dense(128, activation='relu'),
470                nn.Dense(2))
471        return net
472
473    # regular inference
474    x = nd.random.normal(shape=(1, 512),ctx=mx.current_context())
475    net = get_net()
476    net.collect_params().initialize(ctx=mx.current_context())
477    outputs1 = net(x)
478    net.save_parameters('test_subgraph_backend_gluon_ext1.params')
479
480    # after partitioning
481    net = get_net()
482    net.load_parameters('test_subgraph_backend_gluon_ext1.params',ctx=mx.current_context())
483    subgraph_backend = 'default'
484    op_names = ['FullyConnected']
485    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
486                                                c_str_array(op_names)))
487    net.hybridize(backend = subgraph_backend)
488    outputs2 = net(x)
489    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
490
491    # compare outputs
492    assert len(outputs1) == len(outputs2)
493    for i in range(len(outputs1)):
494        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
495
496# Test Gluon HybridBlocks for graph partitioning a network created by HybridBlock.
497def test_subgraph_backend_gluon_ext2():
498    class Net(gluon.HybridBlock):
499        def __init__(self, **kwargs):
500            super(Net, self).__init__(**kwargs)
501            with self.name_scope():
502                self.fc1 = nn.Dense(256)
503                self.fc2 = nn.Dense(128)
504                self.fc3 = nn.Dense(2)
505
506        def hybrid_forward(self, F, x):
507            x = F.relu(self.fc1(x))
508            x = F.relu(self.fc2(x))
509            return self.fc3(x)
510    # regular inference
511    x = nd.random.normal(shape=(1, 512),ctx=mx.current_context())
512    net = Net()
513    net.collect_params().initialize(ctx=mx.current_context())
514    outputs1 = net(x)
515    net.save_parameters('test_subgraph_backend_gluon_ext2.params')
516
517    # after partitioning
518    net = Net()
519    net.load_parameters('test_subgraph_backend_gluon_ext2.params',ctx=mx.current_context())
520    subgraph_backend = 'default'
521    op_names = ['FullyConnected']
522    check_call(_LIB.MXSetSubgraphPropertyOpNamesV2(c_str(subgraph_backend), mx_uint(len(op_names)),
523                                                c_str_array(op_names)))
524    net.hybridize(backend = subgraph_backend)
525    outputs2 = net(x)
526    check_call(_LIB.MXRemoveSubgraphPropertyOpNamesV2(c_str(subgraph_backend)))
527
528    # compare outputs
529    assert len(outputs1) == len(outputs2)
530    for i in range(len(outputs1)):
531        assert_almost_equal((outputs1[i] - outputs2[i]).abs().sum().asnumpy(), np.zeros(shape=(1,)))
532
533if __name__ == '__main__':
534    import nose
535    nose.runmodule()
536