1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18use strict; 19use warnings; 20use Test::More tests => 18; 21use AI::MXNet qw(mx); 22use AI::MXNet::TestUtils qw(mlp2); 23 24sub _test_shapes 25{ 26 my ($sym, $arg_shapes, %expected_shapes) = @_; 27 my %arg_shape_dict; 28 @arg_shape_dict{ @{ $sym->list_arguments() } } = @{ $arg_shapes }; 29 while(my ($k, $v) = each %expected_shapes) 30 { 31 is_deeply($arg_shape_dict{$k}, $v); 32 } 33} 34 35sub test_mlp2_infer_shape 36{ 37 # Build MLP 38 my $out = mlp2(); 39 # infer shape 40 my $data_shape = [100, 100]; 41 my($arg_shapes, $out_shapes, $aux_shapes) = $out->infer_shape(data=>$data_shape); 42 ok(@$out_shapes == 1); 43 is_deeply($out_shapes->[0], [100, 10]); 44 my %true_shapes = ( 45 fc2_bias => [10], 46 fc2_weight => [10, 1000], 47 fc1_bias => [1000], 48 fc1_weight => [1000,100] 49 ); 50 _test_shapes($out, $arg_shapes, %true_shapes); 51} 52 53sub test_mlp2_infer_error 54{ 55 # Test shape inconsistent case 56 my $out = mlp2(); 57 my $weight_shape = [1, 100]; 58 my $data_shape = [100, 100]; 59 eval { $out->infer_shape(data=>$data_shape, fc1_weight=>$weight_shape) }; 60 like($@, qr/Shape inconsistent/); 61} 62 63sub test_backward_infer 64{ 65 my $w = mx->sym->Variable("weight"); 66 my $wshift = mx->sym->Variable("wshift", shape=>[1]); 67 my $data = mx->sym->Variable("data"); 68 # broadcast add here, not being able to deduce shape correctly 69 my $wt = mx->sym->broadcast_add($w, $wshift); 70 # shape constraint, this is what enables backward shape inference 71 $wt = mx->sym->_identity_with_attr_like_rhs($wt, $w); 72 my $net = mx->sym->FullyConnected(data=>$data, weight=>$wt, num_hidden=>11, no_bias=>1); 73 my $data_shape = [7, 100]; 74 my ($arg_shapes, $out_shapes, $aux_shapes) = $net->infer_shape(data=>$data_shape); 75 _test_shapes($net, $arg_shapes, weight=>[11,100]); 76} 77 78sub test_incomplete_infer_elewise 79{ 80 my $a = mx->sym->Variable('a', shape=>[0, 10]); 81 my $b = mx->sym->Variable('b', shape=>[12, 0]); 82 my $c = $a + $b; 83 my ($arg_shapes) = $c->infer_shape(); 84 _test_shapes($c, $arg_shapes, a=>[12,10], b=>[12,10]); 85} 86 87sub test_incomplete_infer_mlp 88{ 89 my $a = mx->sym->Variable('a', shape=>[0, 10]); 90 my $b = mx->sym->FullyConnected(data=>$a, num_hidden=>21); 91 my $c = mx->sym->Variable('c', shape=>[5, 0]); 92 my $d = $b + $c; 93 my ($arg_shapes) = $d->infer_shape(); 94 _test_shapes($d, $arg_shapes, a=>[5,10], c=>[5,21]); 95} 96 97sub test_incomplete_infer_slicechannel 98{ 99 my $a = mx->sym->Variable('a', shape=>[0, 10]); 100 my $b = mx->sym->SliceChannel(data=>$a, num_outputs=>10, axis=>1, squeeze_axis=>1); 101 my $c = mx->sym->Variable('c', shape=>[5]); 102 my $d = @{$b}[1] + $c; 103 my ($arg_shapes) = $d->infer_shape(); 104 _test_shapes($d, $arg_shapes, a=>[5,10]); 105 106 $a = mx->sym->Variable('a', shape=>[0, 15, 0]); 107 $b = mx->sym->SliceChannel(data=>$a, num_outputs=>3, squeeze_axis=>0); 108 $c = mx->sym->Variable('c', shape=>[3, 5, 2]); 109 $d = @{$b}[1] + $c; 110 ($arg_shapes) = $d->infer_shape(); 111 _test_shapes($d, $arg_shapes, a=>[3,15,2]); 112} 113 114sub test_incomplete_infer_convolution 115{ 116 my $a = mx->sym->Variable('a', shape=>[0, 10, 0, 0]); 117 my $b = mx->sym->Convolution(data=>$a, num_filter=>21, kernel=>[3, 3], dilate=>[1, 1], pad=>[1, 1]); 118 my $c = mx->sym->Variable('c', shape=>[5, 21, 32, 32]); 119 my $d = $b + $c; 120 my ($arg_shapes) = $d->infer_shape(); 121 _test_shapes($d, $arg_shapes, a=>[5, 10, 32, 32]); 122} 123 124sub test_incomplete_infer_concat 125{ 126 my $a = mx->sym->Variable('a', shape=>[0, 10]); 127 my $b = mx->sym->Variable('b', shape=>[0, 5]); 128 my $c = mx->sym->Concat($a, $b, num_args=>2, dim=>1); 129 my $d = mx->sym->Variable('d', shape=>[2, 0]); 130 $d = $d + $c; 131 my ($arg_shapes) = $d->infer_shape(); 132 _test_shapes($d, $arg_shapes, a=>[2,10], b=>[2,5], d=>[2,15]); 133} 134 135test_mlp2_infer_shape(); 136test_mlp2_infer_error(); 137test_backward_infer(); 138test_incomplete_infer_elewise(); 139test_incomplete_infer_mlp(); 140test_incomplete_infer_slicechannel(); 141test_incomplete_infer_convolution(); 142test_incomplete_infer_concat(); 143