1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20use Test::More tests => 18;
21use AI::MXNet qw(mx);
22use AI::MXNet::TestUtils qw(mlp2);
23
24sub _test_shapes
25{
26    my ($sym, $arg_shapes, %expected_shapes) = @_;
27    my %arg_shape_dict;
28    @arg_shape_dict{ @{ $sym->list_arguments() } } = @{ $arg_shapes };
29    while(my ($k, $v) = each %expected_shapes)
30    {
31        is_deeply($arg_shape_dict{$k}, $v);
32    }
33}
34
35sub test_mlp2_infer_shape
36{
37    # Build MLP
38    my $out = mlp2();
39    # infer shape
40    my $data_shape = [100, 100];
41    my($arg_shapes, $out_shapes, $aux_shapes) = $out->infer_shape(data=>$data_shape);
42    ok(@$out_shapes == 1);
43    is_deeply($out_shapes->[0], [100, 10]);
44    my %true_shapes = (
45        fc2_bias   => [10],
46        fc2_weight => [10, 1000],
47        fc1_bias   => [1000],
48        fc1_weight => [1000,100]
49    );
50    _test_shapes($out, $arg_shapes, %true_shapes);
51}
52
53sub test_mlp2_infer_error
54{
55    # Test shape inconsistent case
56    my $out = mlp2();
57    my $weight_shape = [1, 100];
58    my $data_shape   = [100, 100];
59    eval { $out->infer_shape(data=>$data_shape, fc1_weight=>$weight_shape) };
60    like($@, qr/Shape inconsistent/);
61}
62
63sub test_backward_infer
64{
65    my $w = mx->sym->Variable("weight");
66    my $wshift = mx->sym->Variable("wshift", shape=>[1]);
67    my $data = mx->sym->Variable("data");
68    # broadcast add here, not being able to deduce shape correctly
69    my $wt = mx->sym->broadcast_add($w, $wshift);
70    # shape constraint, this is what enables backward shape inference
71    $wt = mx->sym->_identity_with_attr_like_rhs($wt, $w);
72    my $net = mx->sym->FullyConnected(data=>$data, weight=>$wt, num_hidden=>11, no_bias=>1);
73    my $data_shape = [7, 100];
74    my ($arg_shapes, $out_shapes, $aux_shapes) = $net->infer_shape(data=>$data_shape);
75    _test_shapes($net, $arg_shapes, weight=>[11,100]);
76}
77
78sub test_incomplete_infer_elewise
79{
80    my $a = mx->sym->Variable('a', shape=>[0, 10]);
81    my $b = mx->sym->Variable('b', shape=>[12, 0]);
82    my $c = $a + $b;
83    my ($arg_shapes) = $c->infer_shape();
84    _test_shapes($c, $arg_shapes, a=>[12,10], b=>[12,10]);
85}
86
87sub test_incomplete_infer_mlp
88{
89    my $a = mx->sym->Variable('a', shape=>[0, 10]);
90    my $b = mx->sym->FullyConnected(data=>$a, num_hidden=>21);
91    my $c = mx->sym->Variable('c', shape=>[5, 0]);
92    my $d = $b + $c;
93    my ($arg_shapes) = $d->infer_shape();
94    _test_shapes($d, $arg_shapes, a=>[5,10], c=>[5,21]);
95}
96
97sub test_incomplete_infer_slicechannel
98{
99    my $a = mx->sym->Variable('a', shape=>[0, 10]);
100    my $b = mx->sym->SliceChannel(data=>$a, num_outputs=>10, axis=>1, squeeze_axis=>1);
101    my $c = mx->sym->Variable('c', shape=>[5]);
102    my $d = @{$b}[1] + $c;
103    my ($arg_shapes) = $d->infer_shape();
104    _test_shapes($d, $arg_shapes, a=>[5,10]);
105
106    $a = mx->sym->Variable('a', shape=>[0, 15, 0]);
107    $b = mx->sym->SliceChannel(data=>$a, num_outputs=>3, squeeze_axis=>0);
108    $c = mx->sym->Variable('c', shape=>[3, 5, 2]);
109    $d = @{$b}[1] + $c;
110    ($arg_shapes) = $d->infer_shape();
111    _test_shapes($d, $arg_shapes, a=>[3,15,2]);
112}
113
114sub test_incomplete_infer_convolution
115{
116    my $a = mx->sym->Variable('a', shape=>[0, 10, 0, 0]);
117    my $b = mx->sym->Convolution(data=>$a, num_filter=>21, kernel=>[3, 3], dilate=>[1, 1], pad=>[1, 1]);
118    my $c = mx->sym->Variable('c', shape=>[5, 21, 32, 32]);
119    my $d = $b + $c;
120    my ($arg_shapes) = $d->infer_shape();
121    _test_shapes($d, $arg_shapes, a=>[5, 10, 32, 32]);
122}
123
124sub test_incomplete_infer_concat
125{
126    my $a = mx->sym->Variable('a', shape=>[0, 10]);
127    my $b = mx->sym->Variable('b', shape=>[0, 5]);
128    my $c = mx->sym->Concat($a, $b, num_args=>2, dim=>1);
129    my $d = mx->sym->Variable('d', shape=>[2, 0]);
130    $d = $d + $c;
131    my ($arg_shapes) = $d->infer_shape();
132    _test_shapes($d, $arg_shapes, a=>[2,10], b=>[2,5], d=>[2,15]);
133}
134
135test_mlp2_infer_shape();
136test_mlp2_infer_error();
137test_backward_infer();
138test_incomplete_infer_elewise();
139test_incomplete_infer_mlp();
140test_incomplete_infer_slicechannel();
141test_incomplete_infer_convolution();
142test_incomplete_infer_concat();
143