1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import pytest
19from mxnet.gluon import nn
20from gluonnlp import initializer
21
22
23def test_truncnorm_string_alias_works():
24    try:
25        layer = nn.Dense(prefix="test_layer", in_units=1, units=1, weight_initializer='truncnorm')
26        layer.initialize()
27    except RuntimeError:
28        pytest.fail('Layer couldn\'t be initialized')
29
30
31def test_truncnorm_all_values_inside_boundaries():
32    mean = 0
33    std = 0.01
34    layer = nn.Dense(prefix="test_layer", in_units=1, units=1000)
35    layer.initialize(init=initializer.TruncNorm(mean, std))
36    assert ((layer.weight.data() > 2 * std).sum() +
37            (layer.weight.data() < -2 * std).sum()).sum().asscalar() == 0
38
39
40def test_truncnorm_generates_values_with_defined_mean_and_std():
41    from scipy import stats
42
43    mean = 10
44    std = 5
45    layer = nn.Dense(prefix="test_layer", in_units=1, units=100000)
46    layer.initialize(init=initializer.TruncNorm(mean, std))
47    samples = layer.weight.data().reshape((-1, )).asnumpy()
48
49    p_value = stats.kstest(samples, 'truncnorm', args=(-2, 2, mean, std)).pvalue
50    assert p_value > 0.0001
51
52