1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18module TestAutoGrad
19
20using MXNet
21using Test
22
23
24function checkgradient(f, x, y, ∇)
25  ∇x = mx.attach_grad!(x)
26  y′ = mx.record(f)
27  @test copy(y′) ≈ y
28  @test copy(∇x) |> sum == 0
29  mx.backward!(y′)
30  @test copy(mx.getgrad(x)) ≈ ∇
31end  # function checkgradient
32
33
34function test_getgrad()
35  @info("AutoGrad::getgrad")
36
37  @info("AutoGrad::getgrad::unattached")
38  @test nothing == mx.getgrad(mx.zeros(10))
39
40  @info("AutoGrad::getgrad::attached")
41  x = mx.NDArray([1 2; 3 4])
42  grad = mx.attach_grad!(x)
43  @test eltype(grad) ≡ Int
44  @test copy(grad) == [0 0; 0 0]
45
46  grad[:] = 42
47  @test copy(mx.getgrad(x)) == [42 42; 42 42]
48end
49
50
51function test_mark_variables!()
52  @info("AutoGrad::mark_variables!")
53  x = mx.zeros(4)
54  ẋ = mx.zeros(4)
55  y = mx.zeros(4)
56  ẏ = mx.zeros(4)
57  mx.mark_variables!([x, y], [ẋ, ẏ], [:nop, :nop])
58  ẋ[:] = 42
59  ẏ[:] = 24
60
61  @test copy(mx.getgrad(x)) == [42, 42, 42, 42]
62  @test copy(mx.getgrad(y)) == [24, 24, 24, 24]
63
64  @info("AutoGrad::mark_variables!::invalid grad_reqs")
65  x = mx.zeros(4)
66  y = mx.zeros(4)
67  @test_throws ArgumentError mx.mark_variables!(x, y, :magic)
68  @test_throws ArgumentError mx.mark_variables!([x], [y], [:magic])
69
70  @info("AutoGrad::mark_variables!::args length mismatch")
71  x = mx.zeros(4)
72  y = mx.zeros(4)
73  z = mx.zeros(4)
74  @test_throws ArgumentError mx.mark_variables!([x], [y, z])
75  @test_throws ArgumentError mx.mark_variables!([x], [y], [:write, :nop])
76end
77
78
79function test_record()
80  let x = mx.NDArray([1 2; 3 4])
81    @info("AutoGrad::record::backward!")
82
83    y = [1 4; 9 16]
84    ∇ = [2 4; 6 8]  # gradient is 2x
85    checkgradient(x, y, ∇) do
86      mx.square(x)
87    end
88  end
89
90  let x = mx.NDArray([1 2; 3 4])
91    @info("AutoGrad::record::symbol")
92
93    mx.attach_grad!(x)
94    y = mx.record() do
95      mx.square(x)
96    end
97
98    @test copy(y) == [1 4; 9 16]
99
100    @test isa(mx.symbol(y), mx.SymbolicNode)
101  end
102
103  let x = mx.NDArray([1 2; 3 4])
104    @info("AutoGrad::record::backward!(retain_graph=true)")
105
106    mx.attach_grad!(x)
107    y = mx.record() do
108      mx.square(x)
109    end
110
111    @test copy(y) == [1 4; 9 16]
112
113    mx.backward!(y, retain_graph=true)
114    # gradient is 2x
115    @test copy(mx.getgrad(x)) == [2 4; 6 8]
116
117    @test isa(mx.symbol(y), mx.SymbolicNode)
118  end
119
120  mx._record(nothing, nothing) do  # no error with edage case
121    @test true
122  end
123end  # function test_record
124
125
126function test_is_recording()
127  @info("AutoGrad::is_recording")
128  mx.record() do
129    @test mx.is_recording()
130  end
131end  # function test_is_recording
132
133
134function test_is_training()
135  @info("AutoGrad::is_training")
136  mx.record() do
137    @test mx.is_training()
138  end
139
140  mx.record(false) do
141    @test !mx.is_training()
142  end
143end  # function test_is_training
144
145
146function test_pause()
147  @info("AutoGrad::pause")
148  let x = mx.NDArray([1 2; 3 4])
149    ∇ = mx.attach_grad!(x)
150    y = mx.record() do
151      y = mx.square(x)
152      mx.pause() do
153        z = mx.square(y)
154        @test copy(z) == [1 16; 81 256]
155      end
156      y
157    end
158
159    @test copy(y) == [1 4; 9 16]
160
161    mx.backward!(y)
162    @test copy(∇) == [2 4; 6 8]
163  end
164end  # function test_pause
165
166
167function test_train_mode()
168  @info("AutoGrad::train_mode")
169  let x = mx.NDArray(Float32[1 2; 3 4])
170    y = mx.train_mode() do
171      mx.Dropout(x, p = 1)
172    end
173
174    @test all(isnan.(copy(y)))
175  end
176end  # function test_train_mode
177
178
179function test_predict_mode()
180  @info("AutoGrad::predict_mode")
181  let x = mx.NDArray(Float32[1 2; 3 4])
182    y = mx.predict_mode() do
183      mx.Dropout(x, p = 1)
184    end
185
186    @test copy(y) ≈ Float32[1 2; 3 4]
187  end
188end  # function test_train_mode
189
190
191function test_backward!()
192  @info("AutoGrad::backward!::with head_grad")
193  let x = mx.NDArray(Float32[1 2; 3 4]), A = Float32[.2 .4; 0 .1]
194    ∇ = mx.attach_grad!(x)
195    y = mx.record() do
196      mx.square(x)
197    end
198    mx.backward!(y, mx.NDArray(A))
199    @test copy(∇) ≈ [2 4; 6 8] .* A
200  end
201
202  @info("AutoGrad::backward!::with head_grads")
203  let x = mx.NDArray(Float32[1 2; 3 4])
204    ∇ = mx.attach_grad!(x)
205    mx.record() do
206      x′ = mx.square(x)
207      y = mx.square(x)
208      z = mx.square(x) .+ 42
209      mx.backward!([x′, y, z], [nothing,
210                                mx.NDArray(Float32[.01 .01; 1 1]),
211                                mx.NDArray(Float32[1 1; .1 .1])])
212    end
213    ans = [4.02 8.04
214           12.6 16.8]
215    @test copy(∇) ≈ ans
216  end
217
218  @info("AutoGrad::backward!::ArgumentError")
219  let x = mx.NDArray([42])
220    @test_throws ArgumentError mx.backward!([x], [24])
221  end
222end  # function test_backward!
223
224
225function test_symbol()
226  @info("AutoGrad::symbol")
227
228  let x = mx.zeros(4)
229    mx.attach_grad!(x)
230    @test isa(mx.symbol(x), mx.SymbolicNode)
231  end
232end
233
234
235function test_add()
236  @info("AutoGrad::add")
237
238  @info("AutoGrad::add::x")
239  let x = mx.NDArray([1 2; 3 4])
240    y = [1 2; 3 4]
241    ∇ = [1 1; 1 1]  # gradient is 1
242    checkgradient(x, y, ∇) do
243      x
244    end
245  end
246
247  @info("AutoGrad::add::+x")
248  let x = mx.NDArray([1 2; 3 4])
249    y = [1 2; 3 4]
250    ∇ = [1 1; 1 1]  # gradient is 1
251    checkgradient(x, y, ∇) do
252      +x
253    end
254  end
255
256  @info("AutoGrad::add::x .+ 42")
257  let x = mx.NDArray([1 2; 3 4])
258    y = [43 44; 45 46]
259    ∇ = [1 1; 1 1]  # gradient is 1
260    checkgradient(x, y, ∇) do
261      x .+ 42
262    end
263  end
264
265  @info("AutoGrad::add::42 .+ x")
266  let x = mx.NDArray([1 2; 3 4])
267    y = [43 44; 45 46]
268    ∇ = [1 1; 1 1]
269    checkgradient(x, y, ∇) do
270      42 .+ x
271    end
272  end
273
274  # TODO: @info("AutoGrad::add::x .+ y")
275end  # function test_add
276
277
278function test_sub()
279  @info("AutoGrad::sub")
280
281  @info("AutoGrad::sub::-x")
282  let x = mx.NDArray([1 2; 3 4])
283    y = [-1 -2; -3 -4]
284    ∇ = [-1 -1; -1 -1]  # gradient is -1
285    checkgradient(x, y, ∇) do
286      -x
287    end
288  end
289
290  @info("AutoGrad::sub::x .- 42")
291  let x = mx.NDArray([1 2; 3 4])
292    y = [-41 -40; -39 -38]
293    ∇ = [1 1; 1 1]
294    checkgradient(x, y, ∇) do
295      x .- 42
296    end
297  end
298
299  @info("AutoGrad::sub::42 .- x")
300  let x = mx.NDArray([1 2; 3 4])
301    y = [41 40; 39 38]
302    ∇ = -[1 1; 1 1]
303    checkgradient(x, y, ∇) do
304      42 .- x
305    end
306  end
307
308  # TODO: @info("AutoGrad::sub::x .- y")
309end  # function test_sub
310
311
312function test_mul()
313  @info("AutoGrad::mul")
314
315  @info("AutoGrad::mul::2x .* x")
316  let x = mx.NDArray([1 2; 3 4])
317    y = [2 8; 18 32]
318    ∇ = [4 8; 12 16]  # 4x
319    checkgradient(x, y, ∇) do
320      2x .* x
321    end
322  end
323
324  @info("AutoGrad::mul::x * 2 .* x")
325  let x = mx.NDArray([1 2; 3 4])
326    y = [2 8; 18 32]
327    ∇ = [4 8; 12 16]  # 4x
328    checkgradient(x, y, ∇) do
329      x * 2 .* x
330    end
331  end
332end
333
334
335function test_div()
336  @info("AutoGrad::div")
337
338  @info("AutoGrad::div::x ./ 2")
339  let x = mx.NDArray(Float32[1 2; 3 4])
340    y = Float32[.5 1; 1.5 2]
341    ∇ = [.5 .5; .5 .5]
342    checkgradient(x, y, ∇) do
343      x ./ 2
344    end
345  end
346
347  @info("AutoGrad::rdiv::2 ./ x")
348  let A = Float32[1 2; 3 4], x = mx.NDArray(A)
349    y = 2 ./ A
350    ∇ = @. -2 / A^2  # -2 / x²
351    checkgradient(x, y, ∇) do
352      2 ./ x
353    end
354  end
355end  # function test_div
356
357
358function test_power()
359  @info("AutoGrad::power")
360
361  @info("AutoGrad::power::x.^3")
362  let A = Float32[1 2; 3 4]
363    x = mx.NDArray(A)
364    y = A.^3
365    ∇ = 3(A.^2)
366    checkgradient(x, y, ∇) do
367      x.^3
368    end
369  end
370
371  @info("AutoGrad::power::x.^.5")
372  let A = Float32[1 2; 3 4]
373    x = mx.NDArray(A)
374    y = A.^.5
375    ∇ = .5(A.^-.5)
376    checkgradient(x, y, ∇) do
377      x.^.5
378    end
379  end
380end
381
382
383@testset "AutoGrad Test" begin
384  test_getgrad()
385  test_mark_variables!()
386  test_record()
387  test_is_recording()
388  test_is_training()
389  test_pause()
390  test_train_mode()
391  test_predict_mode()
392  test_backward!()
393  test_symbol()
394  test_add()
395  test_sub()
396  test_mul()
397  test_div()
398  test_power()
399end
400
401
402end  # model TestAutoGrad
403