1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18module TestSymbolicNode 19 20using MXNet 21using Test 22 23using ..Main: mlp2, mlpchain, exec 24 25################################################################################ 26# Test Implementations 27################################################################################ 28function test_basic() 29 @info("SymbolicNode::basic") 30 31 model = mlp2() 32 @test mx.list_arguments(model) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias] 33 @test mx.list_outputs(model) == [:fc2_output] 34 @test mx.list_auxiliary_states(model) == Symbol[] 35end 36 37function test_chain() 38 @info("SymbolicNode::chain") 39 40 model = mlpchain() 41 @test mx.list_arguments(model) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias] 42 @test mx.list_outputs(model) == [:fc2_output] 43 @test mx.list_auxiliary_states(model) == Symbol[] 44 45 let layerconfig = [20, 10, 6] 46 model = @mx.chain mx.Variable(:data) => 47 mx.MLP(layerconfig, prefix=:magic_) => 48 mx.LinearRegressionOutput(mx.Variable(:label)) 49 50 @test mx.list_arguments(model) == [ 51 :data, 52 :magic_fc1_weight, :magic_fc1_bias, 53 :magic_fc2_weight, :magic_fc2_bias, 54 :magic_fc3_weight, :magic_fc3_bias, 55 :label] 56 end 57end 58 59function test_internal() 60 @info("SymbolicNode::internal") 61 62 data = mx.Variable(:data) 63 oldfc = mx.FullyConnected(data, name=:fc1, num_hidden=10) 64 net1 = mx.FullyConnected(oldfc, name=:fc2, num_hidden=100) 65 66 @test mx.list_arguments(net1) == [:data,:fc1_weight,:fc1_bias,:fc2_weight,:fc2_bias] 67 68 internal = mx.get_internals(net1) 69 fc1 = internal[:fc1_output] 70 @test mx.list_arguments(fc1) == mx.list_arguments(oldfc) 71end 72 73function test_get_children() 74 @info("SymbolicNode::get_children") 75 76 let x = mx.Variable(:x), y = mx.Variable(:y) 77 z = x + y 78 @test length(mx.list_outputs(z)) == 1 79 @test length(mx.list_outputs(mx.get_children(z))) == 2 80 @test mx.list_outputs(mx.get_children(z)) == [:x, :y] 81 end 82 83 @info("SymbolicNode::get_children::on leaf") 84 let x = mx.Variable(:x) 85 @test mx.get_children(x) == nothing 86 end 87end # test_get_children 88 89 90function test_compose() 91 @info("SymbolicNode::compose") 92 93 data = mx.Variable(:data) 94 net1 = mx.FullyConnected(data, name=:fc1, num_hidden=10) 95 net1 = mx.FullyConnected(net1, name=:fc2, num_hidden=100) 96 97 net2 = mx.FullyConnected(mx.SymbolicNode, name=:fc3, num_hidden=10) 98 net2 = mx.Activation(net2, act_type=:relu) 99 net2 = mx.FullyConnected(net2, name=:fc4, num_hidden=20) 100 101 composed = net2(fc3_data=net1, name=:composed) 102 multi_out = mx.Group(composed, net1) 103 @test mx.list_outputs(multi_out) == [:composed_output, :fc2_output] 104end 105 106function test_infer_shape() 107 @info("SymbolicNode::infer_shape::mlp2") 108 109 model = mlp2() 110 data_shape = (100, 100) 111 arg_shapes, out_shapes, aux_shapes = mx.infer_shape(model, data=data_shape) 112 arg_shape_dict = Dict{Symbol,Tuple}(zip(mx.list_arguments(model), arg_shapes)) 113 @test arg_shape_dict == Dict{Symbol,Tuple}(:fc2_bias => (10,),:fc2_weight => (1000,10), 114 :fc1_bias => (1000,), :fc1_weight => (100, 1000), 115 :data => data_shape) 116 @test length(out_shapes) == 1 117 @test out_shapes[1] == (10, 100) 118end 119 120function test_infer_shape_error() 121 @info("SymbolicNode::infer_shape::throws") 122 123 model = mlp2() 124 weight_shape = (100, 1) 125 data_shape = (100, 100) 126 @test_throws mx.MXError mx.infer_shape(model, data=data_shape, fc1_weight=weight_shape) 127end 128 129function test_saveload() 130 @info("SymbolicNode::saveload::mlp2") 131 132 model = mlp2() 133 fname = tempname() 134 mx.save(fname, model) 135 model2 = mx.load(fname, mx.SymbolicNode) 136 @test mx.to_json(model) == mx.to_json(model2) 137 138 rm(fname) 139end 140 141function test_attrs() 142 @info("SymbolicNode::Attributes") 143 144 data = mx.Variable(:data) 145 146 @test mx.get_name(data) == :data 147 result = mx.get_attr(data, :test) 148 @test ismissing(result) 149 mx.set_attr(data, :test, "1.0") 150 result = mx.get_attr(data, :test) 151 @test !ismissing(result) 152 @test result == "1.0" 153 154 data2 = mx.Variable(:data2, attrs = Dict(:test => "hallo!")) 155 @test mx.get_attr(data2, :test) == "hallo!" 156 157 conv = mx.Convolution(data2, kernel = (1,1), num_filter = 1) 158 @test ismissing(mx.get_attr(conv, :b)) 159 @test mx.get_name(conv) isa Symbol 160 161 @test_throws MethodError mx.Variable(:data3, attrs = Dict(:test => "1.0", :test2 => 1.0)) 162 @test_throws MethodError mx.Convolution(data2, kernel = (1,1), num_filter = 1, attrs = Dict(:test => "1.0", :test2 => 1.0)) 163end 164 165function test_functions() 166 @info("SymbolicNode::Functions") 167 data = mx.Variable(:data) 168 typeof(mx.sum(data)) == mx.SymbolicNode 169end 170 171function test_reshape() 172 @info("SymbolicNode::reshape(sym, dim...)") 173 174 A = mx.NDArray(collect(1:24)) 175 x = mx.Variable(:x) 176 y = mx.reshape(x, 2, 3, 4) 177 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 178 mx.forward(e) 179 out = e.outputs[1] 180 181 @test size(out) == (2, 3, 4) 182 @test copy(out) == reshape(1:24, 2, 3, 4) 183 184 @info("SymbolicNode::reshape(sym, dim)") 185 186 A = mx.NDArray(collect(1:24)) 187 x = mx.Variable(:x) 188 y = mx.reshape(x, (2, 3, 4)) 189 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 190 mx.forward(e) 191 out = e.outputs[1] 192 193 @test size(out) == (2, 3, 4) 194 @test copy(out) == reshape(1:24, 2, 3, 4) 195 196 @info("SymbolicNode::reshape::reverse") 197 198 A = mx.zeros(10, 5, 4) 199 x = mx.Variable(:x) 200 y = mx.reshape(x, -1, 0, reverse = true) 201 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 202 mx.forward(e) 203 out = e.outputs[1] 204 205 @test size(out) == (50, 4) 206 207 @info("SymbolicNode::reshape::0") 208 209 A = mx.zeros(2, 3, 4) 210 x = mx.Variable(:x) 211 y = mx.reshape(x, 4, 0, 2) 212 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 213 mx.forward(e) 214 out = e.outputs[1] 215 216 @test size(out) == (4, 3, 2) 217 218 @info("SymbolicNode::reshape::-1") 219 220 A = mx.zeros(2, 3, 4) 221 x = mx.Variable(:x) 222 y = mx.reshape(x, 6, 1, -1) 223 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 224 mx.forward(e) 225 out = e.outputs[1] 226 227 @test size(out) == (6, 1, 4) 228 229 @info("SymbolicNode::reshape::-2") 230 231 A = mx.zeros(2, 3, 4, 2) 232 x = mx.Variable(:x) 233 y = mx.reshape(x, 3, 2, -2) 234 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 235 mx.forward(e) 236 out = e.outputs[1] 237 238 @test size(out) == (3, 2, 4, 2) 239 240 @info("SymbolicNode::reshape::-3") 241 242 A = mx.zeros(2, 3, 4, 5) 243 x = mx.Variable(:x) 244 y = mx.reshape(x, -3, -3) 245 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 246 mx.forward(e) 247 out = e.outputs[1] 248 249 @test size(out) == (6, 20) 250 251 @info("SymbolicNode::reshape::-4") 252 253 A = mx.zeros(2, 3, 4) 254 x = mx.Variable(:x) 255 y = mx.reshape(x, 0, 0, -4, 2, 2) 256 e = mx.bind(y, mx.cpu(), Dict(:x => A)) 257 mx.forward(e) 258 out = e.outputs[1] 259 260 @test size(out) == (2, 3, 2, 2) 261end 262 263function test_dot() 264 @info("SymbolicNode::dot") 265 x = mx.Variable(:x) 266 y = mx.Variable(:y) 267 z = mx.dot(x, y) 268 z_exec = mx.bind(z, context = mx.cpu(), 269 args = Dict(:x => mx.ones((100, 2)), :y => mx.ones((2, 200)))) 270 mx.forward(z_exec) 271 272 ret = copy(z_exec.outputs[1]) 273 @test size(ret) == (100, 200) 274 @test ret ≈ 2*ones(100, 200) 275end 276 277function test_print() 278 @info("SymbolicNode::print") 279 io = IOBuffer() 280 print(io, mx.Variable(:x)) 281 @test !isempty(String(take!(io))) 282end 283 284function test_misc() 285 @info("SymbolicNode::Miscellaneous") 286 # Test for #189 287 a = mx.Variable("a") 288 b = mx.Variable("b") 289 symb = mx.ElementWiseSum(a, b) 290end 291 292function test_add() 293 @info("SymbolicNode::elementwise add") 294 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 295 let y = exec(x .+ 42; :x => A)[] 296 @test size(y) == size(A) 297 @test copy(y) == A .+ 42 298 end 299 300 let y = exec(42 .+ x; :x => A)[] 301 @test size(y) == size(A) 302 @test copy(y) == 42 .+ A 303 end 304 305 let y = exec(-1 .+ x .+ 42; :x => A)[] 306 @test size(y) == size(A) 307 @test copy(y) == -1 .+ A .+ 42 308 end 309 end 310 311 let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8] 312 x = mx.Variable(:x) 313 y = mx.Variable(:y) 314 315 let z = x .+ y 316 z = exec(z; :x => A, :y => B)[] 317 318 @test size(z) == size(A) 319 @test copy(z) == A .+ B 320 end 321 322 let z = y .+ x 323 z = exec(z; :x => A, :y => B)[] 324 325 @test size(z) == size(A) 326 @test copy(z) == B .+ A 327 end 328 end 329end # function test_add 330 331function test_minus() 332 @info("SymbolicNode::elementwise minus") 333 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 334 let y = exec(x .- 42; :x => A)[] 335 @test size(y) == size(A) 336 @test copy(y) == A .- 42 337 end 338 339 let y = exec(42 .- x; :x => A)[] 340 @test size(y) == size(A) 341 @test copy(y) == 42 .- A 342 end 343 344 let y = exec(-1 .- x .- 42; :x => A)[] 345 @test size(y) == size(A) 346 @test copy(y) == -1 .- A .- 42 347 end 348 349 let y = exec(-x; :x => A)[] 350 @test size(y) == size(A) 351 @test copy(y) == -A 352 end 353 end 354 355 let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8] 356 x = mx.Variable(:x) 357 y = mx.Variable(:y) 358 359 let z = x .- y 360 z = exec(z; :x => A, :y => B)[] 361 362 @test size(z) == size(A) 363 @test copy(z) == A .- B 364 end 365 366 let z = y .- x 367 z = exec(z; :x => A, :y => B)[] 368 369 @test size(z) == size(A) 370 @test copy(z) == B .- A 371 end 372 end 373end # function test_minus 374 375function test_mul() 376 @info("SymbolicNode::elementwise mul") 377 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 378 let y = exec(x .* 42; :x => A)[] 379 @test size(y) == size(A) 380 @test copy(y) == A .* 42 381 end 382 383 let y = exec(42 .* x; :x => A)[] 384 @test size(y) == size(A) 385 @test copy(y) == 42 .* A 386 end 387 388 let y = exec(-1 .* x .* 42; :x => A)[] 389 @test size(y) == size(A) 390 @test copy(y) == -1 .* A .* 42 391 end 392 end 393 394 let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8] 395 x = mx.Variable(:x) 396 y = mx.Variable(:y) 397 398 let z = x .* y 399 z = exec(z; :x => A, :y => B)[] 400 401 @test size(z) == size(A) 402 @test copy(z) == A .* B 403 end 404 405 let z = y .* x 406 z = exec(z; :x => A, :y => B)[] 407 408 @test size(z) == size(A) 409 @test copy(z) == B .* A 410 end 411 end 412end # function test_mul 413 414function test_div() 415 @info("SymbolicNode::elementwise div") 416 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 417 let y = exec(x ./ 42; :x => A)[] 418 @test size(y) == size(A) 419 @test copy(y) ≈ A ./ 42 420 end 421 422 let y = exec(42 ./ x; :x => A)[] 423 @test size(y) == size(A) 424 @test copy(y) ≈ 42 ./ A 425 end 426 427 let y = exec(-1 ./ x ./ 42; :x => A)[] 428 @test size(y) == size(A) 429 @test copy(y) ≈ -1 ./ A ./ 42 430 end 431 end 432 433 let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8] 434 x = mx.Variable(:x) 435 y = mx.Variable(:y) 436 437 let z = x ./ y 438 z = exec(z; :x => A, :y => B)[] 439 440 @test size(z) == size(A) 441 @test copy(z) ≈ A ./ B 442 end 443 444 let z = y ./ x 445 z = exec(z; :x => A, :y => B)[] 446 447 @test size(z) == size(A) 448 @test copy(z) ≈ B ./ A 449 end 450 end 451end # function test_div 452 453function test_power() 454 @info("SymbolicNode::elementwise power") 455 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 456 let y = exec(x .^ 42; :x => A)[] 457 @test size(y) == size(A) 458 @test copy(y) ≈ A .^ 42 459 end 460 461 let y = exec(42 .^ x; :x => A)[] 462 @test size(y) == size(A) 463 @test copy(y) ≈ 42 .^ A 464 end 465 end 466 467 let A = Float32[1 2; 3 4], B = Float32[2 4; 6 8] 468 x = mx.Variable(:x) 469 y = mx.Variable(:y) 470 471 let z = x .^ y 472 z = exec(z; :x => A, :y => B)[] 473 474 @test size(z) == size(A) 475 @test copy(z) ≈ A .^ B 476 end 477 478 let z = y .^ x 479 z = exec(z; :x => A, :y => B)[] 480 481 @test size(z) == size(A) 482 @test copy(z) ≈ B .^ A 483 end 484 end 485 486 @info("SymbolicNode::power::e .^ x::x .^ e") 487 let x = mx.Variable(:x), A = [0 0 0; 0 0 0] 488 y = exec(ℯ .^ x; :x => A)[] 489 @test copy(y) ≈ fill(1, size(A)) 490 end 491 492 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 493 let y = ℯ .^ x 494 z = exec(y; :x => A)[] 495 @test copy(z) ≈ ℯ .^ A 496 end 497 498 let y = x .^ ℯ 499 z = exec(y; :x => A)[] 500 @test copy(z) ≈ A .^ ℯ 501 end 502 end 503 504 @info("SymbolicNode::power::π .^ x::x .^ π") 505 let x = mx.Variable(:x), A = Float32[1 2; 3 4] 506 let y = π .^ x 507 z = exec(y; :x => A)[] 508 @test copy(z) ≈ π .^ A 509 end 510 511 let y = x .^ π 512 z = exec(y; :x => A)[] 513 @test copy(z) ≈ A .^ π 514 end 515 end 516end # function test_power 517 518function test_get_name() 519 @info("SymbolicNode::get_name::with get_internals") 520 name = mx.get_name(mx.get_internals(mlp2())) # no error 521 @test occursin("Ptr", name) 522end # function test_get_name 523 524function test_var() 525 @info("SymbolicNode::var") 526 x = @mx.var x 527 @test x isa mx.SymbolicNode 528 529 x′ = @mx.var x 530 @test x.handle != x′.handle 531 532 x, y, z = @mx.var x y z 533 @test x isa mx.SymbolicNode 534 @test y isa mx.SymbolicNode 535 @test z isa mx.SymbolicNode 536end # test_var 537 538 539################################################################################ 540# Run tests 541################################################################################ 542@testset "SymbolicNode Test" begin 543 test_basic() 544 test_chain() 545 test_internal() 546 test_compose() 547 test_infer_shape() 548 test_infer_shape_error() 549 test_saveload() 550 test_attrs() 551 test_functions() 552 test_reshape() 553 test_dot() 554 test_print() 555 test_misc() 556 test_add() 557 test_minus() 558 test_mul() 559 test_div() 560 test_power() 561 test_get_name() 562 test_var() 563end 564 565end 566