1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18 19import pytest 20import mxnet as mx 21from gluonnlp.model import BiLMEncoder 22 23 24@pytest.mark.parametrize('hybridize', [False, True]) 25def test_bilm_encoder_output_shape_lstm(hybridize): 26 num_layers = 2 27 seq_len = 7 28 hidden_size = 100 29 input_size = 100 30 batch_size = 2 31 32 encoder = BiLMEncoder(mode='lstm', 33 num_layers=num_layers, 34 input_size=input_size, 35 hidden_size=hidden_size, 36 dropout=0.1, 37 skip_connection=False) 38 39 output = run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize) 40 assert output.shape == (num_layers, seq_len, batch_size, 2 * hidden_size), output.shape 41 42 43@pytest.mark.parametrize('hybridize', [False, True]) 44def test_bilm_encoder_output_shape_lstmpc(hybridize): 45 num_layers = 2 46 seq_len = 7 47 hidden_size = 100 48 input_size = 100 49 batch_size = 2 50 proj_size = 15 51 52 encoder = BiLMEncoder(mode='lstmpc', 53 num_layers=num_layers, 54 input_size=input_size, 55 hidden_size=hidden_size, 56 dropout=0.1, 57 skip_connection=False, 58 proj_size=proj_size) 59 60 output = run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize) 61 assert output.shape == (num_layers, seq_len, batch_size, 2 * proj_size), output.shape 62 63 64def run_bi_lm_encoding(encoder, batch_size, input_size, seq_len, hybridize): 65 encoder.initialize() 66 67 if hybridize: 68 encoder.hybridize() 69 70 inputs = mx.random.uniform(shape=(seq_len, batch_size, input_size)) 71 inputs_mask = mx.random.uniform(-1, 1, shape=(batch_size, seq_len)) > 1 72 73 state = encoder.begin_state(batch_size=batch_size, func=mx.ndarray.zeros) 74 output, _ = encoder(inputs, state, inputs_mask) 75 return output 76