1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18module TestSymbolicNode
19
20using MXNet
21using Test
22
23using ..Main: mlp2, mlpchain, exec
24
25################################################################################
26# Test Implementations
27################################################################################
28function test_basic()
29  @info("SymbolicNode::basic")
30
31  model = mlp2()
32  @test mx.list_arguments(model) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias]
33  @test mx.list_outputs(model) == [:fc2_output]
34  @test mx.list_auxiliary_states(model) == Symbol[]
35end
36
37function test_chain()
38  @info("SymbolicNode::chain")
39
40  model = mlpchain()
41  @test mx.list_arguments(model) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias]
42  @test mx.list_outputs(model) == [:fc2_output]
43  @test mx.list_auxiliary_states(model) == Symbol[]
44
45  let layerconfig = [20, 10, 6]
46    model = @mx.chain mx.Variable(:data) =>
47      mx.MLP(layerconfig, prefix=:magic_) =>
48      mx.LinearRegressionOutput(mx.Variable(:label))
49
50    @test mx.list_arguments(model) == [
51      :data,
52      :magic_fc1_weight, :magic_fc1_bias,
53      :magic_fc2_weight, :magic_fc2_bias,
54      :magic_fc3_weight, :magic_fc3_bias,
55      :label]
56  end
57end
58
59function test_internal()
60  @info("SymbolicNode::internal")
61
62  data  = mx.Variable(:data)
63  oldfc = mx.FullyConnected(data, name=:fc1, num_hidden=10)
64  net1  = mx.FullyConnected(oldfc, name=:fc2, num_hidden=100)
65
66  @test mx.list_arguments(net1) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias]
67
68  internal = mx.get_internals(net1)
69  fc1      = internal[:fc1_output]
70  @test mx.list_arguments(fc1) == mx.list_arguments(oldfc)
71end
72
73function test_get_children()
74  @info("SymbolicNode::get_children")
75
76  let x = mx.Variable(:x), y = mx.Variable(:y)
77    z = x + y
78    @test length(mx.list_outputs(z)) == 1
79    @test length(mx.list_outputs(mx.get_children(z))) == 2
80    @test mx.list_outputs(mx.get_children(z)) == [:x, :y]
81  end
82
83  @info("SymbolicNode::get_children::on leaf")
84  let x = mx.Variable(:x)
85    @test mx.get_children(x) == nothing
86  end
87end  # test_get_children
88
89
90function test_compose()
91  @info("SymbolicNode::compose")
92
93  data = mx.Variable(:data)
94  net1 = mx.FullyConnected(data, name=:fc1, num_hidden=10)
95  net1 = mx.FullyConnected(net1, name=:fc2, num_hidden=100)
96
97  net2 = mx.FullyConnected(mx.SymbolicNode, name=:fc3, num_hidden=10)
98  net2 = mx.Activation(net2, act_type=:relu)
99  net2 = mx.FullyConnected(net2, name=:fc4, num_hidden=20)
100
101  composed  = net2(fc3_data=net1, name=:composed)
102  multi_out = mx.Group(composed, net1)
103  @test mx.list_outputs(multi_out) == [:composed_output, :fc2_output]
104end
105
106function test_infer_shape()
107  @info("SymbolicNode::infer_shape::mlp2")
108
109  model = mlp2()
110  data_shape = (100, 100)
111  arg_shapes, out_shapes, aux_shapes = mx.infer_shape(model, data=data_shape)
112  arg_shape_dict = Dict{Symbol,Tuple}(zip(mx.list_arguments(model), arg_shapes))
113  @test arg_shape_dict == Dict{Symbol,Tuple}(:fc2_bias => (10,),:fc2_weight => (1000,10),
114                                             :fc1_bias => (1000,), :fc1_weight => (100, 1000),
115                                             :data => data_shape)
116  @test length(out_shapes) == 1
117  @test out_shapes[1] == (10, 100)
118end
119
120function test_infer_shape_error()
121  @info("SymbolicNode::infer_shape::throws")
122
123  model = mlp2()
124  weight_shape = (100, 1)
125  data_shape   = (100, 100)
126  @test_throws mx.MXError mx.infer_shape(model, data=data_shape, fc1_weight=weight_shape)
127end
128
129function test_saveload()
130  @info("SymbolicNode::saveload::mlp2")
131
132  model = mlp2()
133  fname = tempname()
134  mx.save(fname, model)
135  model2 = mx.load(fname, mx.SymbolicNode)
136  @test mx.to_json(model) == mx.to_json(model2)
137
138  rm(fname)
139end
140
141function test_attrs()
142  @info("SymbolicNode::Attributes")
143
144  data = mx.Variable(:data)
145
146  @test mx.get_name(data) == :data
147  result = mx.get_attr(data, :test)
148  @test ismissing(result)
149  mx.set_attr(data, :test, "1.0")
150  result = mx.get_attr(data, :test)
151  @test !ismissing(result)
152  @test result == "1.0"
153
154  data2 = mx.Variable(:data2, attrs = Dict(:test => "hallo!"))
155  @test mx.get_attr(data2, :test) == "hallo!"
156
157  conv = mx.Convolution(data2, kernel = (1,1), num_filter = 1)
158  @test ismissing(mx.get_attr(conv, :b))
159  @test mx.get_name(conv) isa Symbol
160
161  @test_throws MethodError mx.Variable(:data3, attrs = Dict(:test => "1.0", :test2 => 1.0))
162  @test_throws MethodError mx.Convolution(data2, kernel = (1,1), num_filter = 1, attrs = Dict(:test => "1.0", :test2 => 1.0))
163end
164
165function test_functions()
166  @info("SymbolicNode::Functions")
167  data = mx.Variable(:data)
168  typeof(mx.sum(data)) == mx.SymbolicNode
169end
170
171function test_reshape()
172  @info("SymbolicNode::reshape(sym, dim...)")
173
174  A = mx.NDArray(collect(1:24))
175  x = mx.Variable(:x)
176  y = mx.reshape(x, 2, 3, 4)
177  e = mx.bind(y, mx.cpu(), Dict(:x => A))
178  mx.forward(e)
179  out = e.outputs[1]
180
181  @test size(out) == (2, 3, 4)
182  @test copy(out) == reshape(1:24, 2, 3, 4)
183
184  @info("SymbolicNode::reshape(sym, dim)")
185
186  A = mx.NDArray(collect(1:24))
187  x = mx.Variable(:x)
188  y = mx.reshape(x, (2, 3, 4))
189  e = mx.bind(y, mx.cpu(), Dict(:x => A))
190  mx.forward(e)
191  out = e.outputs[1]
192
193  @test size(out) == (2, 3, 4)
194  @test copy(out) == reshape(1:24, 2, 3, 4)
195
196  @info("SymbolicNode::reshape::reverse")
197
198  A = mx.zeros(10, 5, 4)
199  x = mx.Variable(:x)
200  y = mx.reshape(x, -1, 0, reverse = true)
201  e = mx.bind(y, mx.cpu(), Dict(:x => A))
202  mx.forward(e)
203  out = e.outputs[1]
204
205  @test size(out) == (50, 4)
206
207  @info("SymbolicNode::reshape::0")
208
209  A = mx.zeros(2, 3, 4)
210  x = mx.Variable(:x)
211  y = mx.reshape(x, 4, 0, 2)
212  e = mx.bind(y, mx.cpu(), Dict(:x => A))
213  mx.forward(e)
214  out = e.outputs[1]
215
216  @test size(out) == (4, 3, 2)
217
218  @info("SymbolicNode::reshape::-1")
219
220  A = mx.zeros(2, 3, 4)
221  x = mx.Variable(:x)
222  y = mx.reshape(x, 6, 1, -1)
223  e = mx.bind(y, mx.cpu(), Dict(:x => A))
224  mx.forward(e)
225  out = e.outputs[1]
226
227  @test size(out) == (6, 1, 4)
228
229  @info("SymbolicNode::reshape::-2")
230
231  A = mx.zeros(2, 3, 4, 2)
232  x = mx.Variable(:x)
233  y = mx.reshape(x, 3, 2, -2)
234  e = mx.bind(y, mx.cpu(), Dict(:x => A))
235  mx.forward(e)
236  out = e.outputs[1]
237
238  @test size(out) == (3, 2, 4, 2)
239
240  @info("SymbolicNode::reshape::-3")
241
242  A = mx.zeros(2, 3, 4, 5)
243  x = mx.Variable(:x)
244  y = mx.reshape(x, -3, -3)
245  e = mx.bind(y, mx.cpu(), Dict(:x => A))
246  mx.forward(e)
247  out = e.outputs[1]
248
249  @test size(out) == (6, 20)
250
251  @info("SymbolicNode::reshape::-4")
252
253  A = mx.zeros(2, 3, 4)
254  x = mx.Variable(:x)
255  y = mx.reshape(x, 0, 0, -4, 2, 2)
256  e = mx.bind(y, mx.cpu(), Dict(:x => A))
257  mx.forward(e)
258  out = e.outputs[1]
259
260  @test size(out) == (2, 3, 2, 2)
261end
262
263function test_dot()
264  @info("SymbolicNode::dot")
265  x = mx.Variable(:x)
266  y = mx.Variable(:y)
267  z = mx.dot(x, y)
268  z_exec = mx.bind(z, context = mx.cpu(),
269                   args = Dict(:x => mx.ones((100, 2)), :y => mx.ones((2, 200))))
270  mx.forward(z_exec)
271
272  ret = copy(z_exec.outputs[1])
273  @test size(ret) == (100, 200)
274  @test ret ≈ 2*ones(100, 200)
275end
276
277function test_print()
278  @info("SymbolicNode::print")
279  io = IOBuffer()
280  print(io, mx.Variable(:x))
281  @test !isempty(String(take!(io)))
282end
283
284function test_misc()
285  @info("SymbolicNode::Miscellaneous")
286  # Test for #189
287  a = mx.Variable("a")
288  b = mx.Variable("b")
289  symb = mx.ElementWiseSum(a, b)
290end
291
292function test_add()
293  @info("SymbolicNode::elementwise add")
294  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
295    let y = exec(x .+ 42; :x => A)[]
296      @test size(y) == size(A)
297      @test copy(y) == A .+ 42
298    end
299
300    let y = exec(42 .+ x; :x => A)[]
301      @test size(y) == size(A)
302      @test copy(y) == 42 .+ A
303    end
304
305    let y = exec(-1 .+ x .+ 42; :x => A)[]
306      @test size(y) == size(A)
307      @test copy(y) == -1 .+ A .+ 42
308    end
309  end
310
311  let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8]
312    x = mx.Variable(:x)
313    y = mx.Variable(:y)
314
315    let z = x .+ y
316      z = exec(z; :x => A, :y => B)[]
317
318      @test size(z) == size(A)
319      @test copy(z) == A .+ B
320    end
321
322    let z = y .+ x
323      z = exec(z; :x => A, :y => B)[]
324
325      @test size(z) == size(A)
326      @test copy(z) == B .+ A
327    end
328  end
329end  # function test_add
330
331function test_minus()
332  @info("SymbolicNode::elementwise minus")
333  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
334    let y = exec(x .- 42; :x => A)[]
335      @test size(y) == size(A)
336      @test copy(y) == A .- 42
337    end
338
339    let y = exec(42 .- x; :x => A)[]
340      @test size(y) == size(A)
341      @test copy(y) == 42 .- A
342    end
343
344    let y = exec(-1 .- x .- 42; :x => A)[]
345      @test size(y) == size(A)
346      @test copy(y) == -1 .- A .- 42
347    end
348
349    let y = exec(-x; :x => A)[]
350      @test size(y) == size(A)
351      @test copy(y) == -A
352    end
353  end
354
355  let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8]
356    x = mx.Variable(:x)
357    y = mx.Variable(:y)
358
359    let z = x .- y
360      z = exec(z; :x => A, :y => B)[]
361
362      @test size(z) == size(A)
363      @test copy(z) == A .- B
364    end
365
366    let z = y .- x
367      z = exec(z; :x => A, :y => B)[]
368
369      @test size(z) == size(A)
370      @test copy(z) == B .- A
371    end
372  end
373end  # function test_minus
374
375function test_mul()
376  @info("SymbolicNode::elementwise mul")
377  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
378    let y = exec(x .* 42; :x => A)[]
379      @test size(y) == size(A)
380      @test copy(y) == A .* 42
381    end
382
383    let y = exec(42 .* x; :x => A)[]
384      @test size(y) == size(A)
385      @test copy(y) == 42 .* A
386    end
387
388    let y = exec(-1 .* x .* 42; :x => A)[]
389      @test size(y) == size(A)
390      @test copy(y) == -1 .* A .* 42
391    end
392  end
393
394  let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8]
395    x = mx.Variable(:x)
396    y = mx.Variable(:y)
397
398    let z = x .* y
399      z = exec(z; :x => A, :y => B)[]
400
401      @test size(z) == size(A)
402      @test copy(z) == A .* B
403    end
404
405    let z = y .* x
406      z = exec(z; :x => A, :y => B)[]
407
408      @test size(z) == size(A)
409      @test copy(z) == B .* A
410    end
411  end
412end  # function test_mul
413
414function test_div()
415  @info("SymbolicNode::elementwise div")
416  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
417    let y = exec(x ./ 42; :x => A)[]
418      @test size(y) == size(A)
419      @test copy(y) ≈ A ./ 42
420    end
421
422    let y = exec(42 ./ x; :x => A)[]
423      @test size(y) == size(A)
424      @test copy(y) ≈ 42 ./ A
425    end
426
427    let y = exec(-1 ./ x ./ 42; :x => A)[]
428      @test size(y) == size(A)
429      @test copy(y) ≈ -1 ./ A ./ 42
430    end
431  end
432
433  let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8]
434    x = mx.Variable(:x)
435    y = mx.Variable(:y)
436
437    let z = x ./ y
438      z = exec(z; :x => A, :y => B)[]
439
440      @test size(z) == size(A)
441      @test copy(z) ≈ A ./ B
442    end
443
444    let z = y ./ x
445      z = exec(z; :x => A, :y => B)[]
446
447      @test size(z) == size(A)
448      @test copy(z) ≈ B ./ A
449    end
450  end
451end  # function test_div
452
453function test_power()
454  @info("SymbolicNode::elementwise power")
455  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
456    let y = exec(x .^ 42; :x => A)[]
457      @test size(y) == size(A)
458      @test copy(y) ≈ A .^ 42
459    end
460
461    let y = exec(42 .^ x; :x => A)[]
462      @test size(y) == size(A)
463      @test copy(y) ≈ 42 .^ A
464    end
465  end
466
467  let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8]
468    x = mx.Variable(:x)
469    y = mx.Variable(:y)
470
471    let z = x .^ y
472      z = exec(z; :x => A, :y => B)[]
473
474      @test size(z) == size(A)
475      @test copy(z) ≈ A .^ B
476    end
477
478    let z = y .^ x
479      z = exec(z; :x => A, :y => B)[]
480
481      @test size(z) == size(A)
482      @test copy(z) ≈ B .^ A
483    end
484  end
485
486  @info("SymbolicNode::power::e .^ x::x .^ e")
487  let x = mx.Variable(:x), A = [0 0 0; 0 0 0]
488    y = exec(ℯ .^ x; :x => A)[]
489    @test copy(y) ≈ fill(1, size(A))
490  end
491
492  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
493    let y = ℯ .^ x
494      z = exec(y; :x => A)[]
495      @test copy(z) ≈ ℯ .^ A
496    end
497
498    let y = x .^ ℯ
499      z = exec(y; :x => A)[]
500      @test copy(z) ≈ A .^ ℯ
501    end
502  end
503
504  @info("SymbolicNode::power::π .^ x::x .^ π")
505  let x = mx.Variable(:x), A = Float32[1 2; 3 4]
506    let y = π .^ x
507      z = exec(y; :x => A)[]
508      @test copy(z) ≈ π .^ A
509    end
510
511    let y = x .^ π
512      z = exec(y; :x => A)[]
513      @test copy(z) ≈ A .^ π
514    end
515  end
516end  # function test_power
517
518function test_get_name()
519  @info("SymbolicNode::get_name::with get_internals")
520  name = mx.get_name(mx.get_internals(mlp2()))  # no error
521  @test occursin("Ptr", name)
522end  # function test_get_name
523
524function test_var()
525  @info("SymbolicNode::var")
526  x = @mx.var x
527  @test x isa mx.SymbolicNode
528
529  x′ = @mx.var x
530  @test x.handle != x′.handle
531
532  x, y, z = @mx.var x y z
533  @test x isa mx.SymbolicNode
534  @test y isa mx.SymbolicNode
535  @test z isa mx.SymbolicNode
536end  # test_var
537
538
539################################################################################
540# Run tests
541################################################################################
542@testset "SymbolicNode Test" begin
543  test_basic()
544  test_chain()
545  test_internal()
546  test_compose()
547  test_infer_shape()
548  test_infer_shape_error()
549  test_saveload()
550  test_attrs()
551  test_functions()
552  test_reshape()
553  test_dot()
554  test_print()
555  test_misc()
556  test_add()
557  test_minus()
558  test_mul()
559  test_div()
560  test_power()
561  test_get_name()
562  test_var()
563end
564
565end
566