1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18
19import pytest
20import mxnet as mx
21from gluonnlp.model import BiLMEncoder
22
23
24@pytest.mark.parametrize('hybridize', [False, True])
25def test_bilm_encoder_output_shape_lstm(hybridize):
26    num_layers = 2
27    seq_len = 7
28    hidden_size = 100
29    input_size = 100
30    batch_size = 2
31
32    encoder = BiLMEncoder(mode='lstm',
33                          num_layers=num_layers,
34                          input_size=input_size,
35                          hidden_size=hidden_size,
36                          dropout=0.1,
37                          skip_connection=False)
38
39    output = run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize)
40    assert output.shape == (num_layers, seq_len, batch_size, 2 * hidden_size), output.shape
41
42
43@pytest.mark.parametrize('hybridize', [False, True])
44def test_bilm_encoder_output_shape_lstmpc(hybridize):
45    num_layers = 2
46    seq_len = 7
47    hidden_size = 100
48    input_size = 100
49    batch_size = 2
50    proj_size = 15
51
52    encoder = BiLMEncoder(mode='lstmpc',
53                          num_layers=num_layers,
54                          input_size=input_size,
55                          hidden_size=hidden_size,
56                          dropout=0.1,
57                          skip_connection=False,
58                          proj_size=proj_size)
59
60    output = run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize)
61    assert output.shape == (num_layers, seq_len, batch_size, 2 * proj_size), output.shape
62
63
64def run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize):
65    encoder.initialize()
66
67    if hybridize:
68        encoder.hybridize()
69
70    inputs = mx.random.uniform(shape=(seq_len, batch_size, input_size))
71    inputs_mask = mx.random.uniform(-1, 1, shape=(batch_size, seq_len)) > 1
72
73    state = encoder.begin_state(batch_size=batch_size, func=mx.ndarray.zeros)
74    output, _ = encoder(inputs, state, inputs_mask)
75    return output
76