1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import numpy as np
18import nnvm
19import tvm
20from tvm.contrib import graph_runtime
21from nnvm.testing.config import ctx_list
22import keras
23
24# prevent keras from using up all gpu memory
25import tensorflow as tf
26from keras.backend.tensorflow_backend import set_session
27config = tf.ConfigProto()
28config.gpu_options.per_process_gpu_memory_fraction = 0.5
29set_session(tf.Session(config=config))
30
31
32def verify_keras_frontend(keras_model, need_transpose=True):
33    # Keras frontend currently supports tensorflow backend only.
34    assert(keras.backend.backend() == 'tensorflow')
35
36    in_shapes = []
37    for layer in keras_model._input_layers:
38        in_shapes.append(tuple(dim.value if dim.value is not None else 1 for dim in layer.input.shape))
39
40    def get_keras_output(xs, dtype='float32'):
41        return keras_model.predict(xs)
42
43    def get_tvm_output(xs, target, ctx, dtype='float32'):
44        sym, params = nnvm.frontend.from_keras(keras_model)
45        shape_dict = {name: x.shape for (name, x) in zip(keras_model.input_names, xs)}
46        with nnvm.compiler.build_config(opt_level=2):
47            graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, params=params)
48        m = graph_runtime.create(graph, lib, ctx)
49        for name, x in zip(keras_model.input_names, xs):
50            m.set_input(name, tvm.nd.array(x.astype(dtype)))
51        m.set_input(**params)
52        m.run()
53
54        return [m.get_output(i).asnumpy() for i in range(m.get_num_outputs())]
55
56    def to_channels_first(arr):
57        return arr.transpose([0, -1] + list(range(1, arr.ndim - 1)))
58
59    def to_channels_last(arr):
60        return arr.transpose([0] + list(range(2, arr.ndim)) + [1])
61
62    xs = [np.random.uniform(size=shape, low=-1.0, high=1.0) for shape in in_shapes]
63    keras_out = get_keras_output(xs)
64
65    keras_out = keras_out if isinstance(keras_out, list) else [keras_out]
66    for target, ctx in ctx_list():
67        tvm_out = get_tvm_output([to_channels_first(x) for x in xs] if need_transpose else xs, target, ctx)
68        for kout, tout in zip(keras_out, tvm_out):
69            if need_transpose:
70                tout = to_channels_last(tout)
71            tvm.testing.assert_allclose(kout, tout, rtol=1e-5, atol=1e-5)
72
73def test_forward_elemwise_add():
74    r = []
75    data = keras.layers.Input(shape=(32,32,3))
76    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
77    r.append(x)
78    x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
79    r.append(x)
80    x = keras.layers.Conv2D(8, (3, 3), padding="same")(x)
81    # add two symbols
82    y = keras.layers.add([keras.layers.add([x, r[0]]), r[1]])
83    y = keras.layers.GlobalAveragePooling2D()(y)
84    keras_model = keras.models.Model(data, y)
85    verify_keras_frontend(keras_model)
86    # add three symbols
87    y = keras.layers.add([x, r[0], r[1]])
88    y = keras.layers.GlobalAveragePooling2D()(y)
89    keras_model = keras.models.Model(data, y)
90    verify_keras_frontend(keras_model)
91
92
93def _test_forward_dense():
94    data = keras.layers.Input(shape=(32,32,1))
95    x = keras.layers.Flatten()(data)
96    x = keras.layers.Dropout(0.5)(x)
97    x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(x)
98    keras_model = keras.models.Model(data, x)
99    verify_keras_frontend(keras_model)
100
101def _test_forward_dense_with_3d_inp():
102    data = keras.layers.Input(shape=(1, 20))
103    x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(data)
104    keras_model = keras.models.Model(data, x)
105    verify_keras_frontend(keras_model, need_transpose=False)
106
107def test_forward_dense():
108    _test_forward_dense()
109    _test_forward_dense_with_3d_inp()
110
111def test_forward_pool():
112    data = keras.layers.Input(shape=(32,32,1))
113    # maxpool
114    x = keras.layers.MaxPooling2D((3, 3), strides=(1, 1), padding='same')(data)
115    keras_model = keras.models.Model(data, x)
116    verify_keras_frontend(keras_model)
117    # avgpool
118    y = keras.layers.AveragePooling2D((3, 3), strides=(1, 1), padding='same')(data)
119    keras_model = keras.models.Model(data, y)
120    verify_keras_frontend(keras_model)
121
122
123def test_forward_conv():
124    data = keras.layers.Input(shape=(32,32,3))
125    conv_funcs = [keras.layers.Conv2D(filters=10, kernel_size=(3,3),
126                                      strides=(2,2), padding='same'),
127                  keras.layers.Conv2D(filters=10, kernel_size=(3,3),
128                                      dilation_rate=(2,2), padding='same'),
129                  keras.layers.DepthwiseConv2D(kernel_size=(3,3), padding='same'),
130                  keras.layers.Conv2DTranspose(filters=10, kernel_size=(3,3), padding='valid'),
131                  keras.layers.SeparableConv2D(filters=10, kernel_size=(3,3), padding='same')]
132    for conv_func in conv_funcs:
133        x = conv_func(data)
134        keras_model = keras.models.Model(data, x)
135        verify_keras_frontend(keras_model)
136
137
138def test_forward_upsample():
139    data = keras.layers.Input(shape=(32,32,3))
140    x = keras.layers.UpSampling2D(size=(3,3))(data)
141    keras_model = keras.models.Model(data, x)
142    verify_keras_frontend(keras_model)
143
144
145def test_forward_reshape():
146    data = keras.layers.Input(shape=(32,32,3))
147    x = keras.layers.Reshape(target_shape=(32,32,3))(data)
148    keras_model = keras.models.Model(data, x)
149    verify_keras_frontend(keras_model)
150
151
152def test_forward_crop():
153    data = keras.layers.Input(shape=(32,32,3))
154    x = keras.layers.Cropping2D(cropping=((1, 1), (1, 1)))(data)
155    x = keras.layers.Cropping2D(cropping=(1, 1))(x)
156    x = keras.layers.Cropping2D(cropping=1)(x)
157    x = keras.layers.Cropping2D(cropping=((0, 1), (1, 0)))(x)
158    x = keras.layers.Cropping2D(cropping=(1, 0))(x)
159    x = keras.layers.Cropping2D(cropping=0)(x)
160    x = keras.layers.Add()([x, x])
161    keras_model = keras.models.Model(data, x)
162    verify_keras_frontend(keras_model)
163
164
165def test_forward_vgg16():
166    keras_model = keras.applications.vgg16.VGG16(include_top=True, weights='imagenet',
167        input_shape=(224,224,3), classes=1000)
168    verify_keras_frontend(keras_model)
169
170
171def test_forward_xception():
172    keras_model = keras.applications.xception.Xception(include_top=True, weights='imagenet',
173        input_shape=(299,299,3), classes=1000)
174    verify_keras_frontend(keras_model)
175
176
177def test_forward_resnet50():
178    keras_model = keras.applications.resnet50.ResNet50(include_top=True, weights='imagenet',
179        input_shape=(224,224,3), classes=1000)
180    verify_keras_frontend(keras_model)
181
182
183def test_forward_mobilenet():
184    keras_model = keras.applications.mobilenet.MobileNet(include_top=True, weights='imagenet',
185        input_shape=(224,224,3), classes=1000)
186    verify_keras_frontend(keras_model)
187
188
189def test_forward_activations():
190    data = keras.layers.Input(shape=(32,32,3))
191    weights = np.random.rand(1, 32, 32, 3)
192    act_funcs = [keras.layers.Activation('softmax'),
193                 keras.layers.Activation('softplus'),
194                 keras.layers.ReLU(),
195                 keras.layers.ReLU(max_value=6.),
196                 keras.layers.LeakyReLU(alpha=0.3),
197                 keras.layers.PReLU(weights=weights, alpha_initializer="zero"),
198                 keras.layers.ELU(alpha=0.5),
199                 keras.layers.Activation('selu'),
200                 keras.layers.ThresholdedReLU(theta=0.5),
201                 keras.layers.Activation('softsign'),
202                 keras.layers.Activation('hard_sigmoid'),
203                 keras.layers.Activation('sigmoid'),
204                 keras.layers.Activation('tanh'),
205                 keras.layers.Activation('linear')]
206    for act_func in act_funcs:
207        x = act_func(data)
208        keras_model = keras.models.Model(data, x)
209        verify_keras_frontend(keras_model)
210
211
212def test_forward_multi_inputs():
213    data1 = keras.layers.Input(shape=(32,32,3))
214    data2 = keras.layers.Input(shape=(32,32,3))
215    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data1)
216    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data2)
217    z = keras.layers.add([x, y])
218    z = keras.layers.GlobalAveragePooling2D()(z)
219    keras_model = keras.models.Model([data1, data2], z)
220    verify_keras_frontend(keras_model)
221
222
223def test_forward_multi_outputs():
224    data = keras.layers.Input(shape=(32,32,3))
225    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
226    x = keras.layers.GlobalAveragePooling2D()(x)
227    y = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
228    y = keras.layers.GlobalAveragePooling2D()(y)
229    keras_model = keras.models.Model(data, [x, y])
230    verify_keras_frontend(keras_model)
231
232
233def test_forward_reuse_layers():
234    # reuse conv2d
235    data = keras.layers.Input(shape=(32,32,3))
236    conv2d = keras.layers.Conv2D(8, (3, 3), padding="same")
237    x = conv2d(data)
238    y = conv2d(data)
239    z = keras.layers.add([x, y])
240    z = keras.layers.GlobalAveragePooling2D()(z)
241    keras_model = keras.models.Model(data, z)
242    verify_keras_frontend(keras_model)
243
244    # reuse add
245    data = keras.layers.Input(shape=(32,32,3))
246    x = keras.layers.Conv2D(8, (3, 3), padding="same")(data)
247    add = keras.layers.Add()
248    x = add([x, x])
249    x = add([x, x])
250    z = keras.layers.GlobalAveragePooling2D()(x)
251    keras_model = keras.models.Model(data, z)
252    verify_keras_frontend(keras_model)
253
254def _test_LSTM(time_steps, inputs, hidden, return_state=True):
255    data = keras.layers.Input(shape=(time_steps, inputs))
256    lstm_out = keras.layers.LSTM(hidden,
257                                 return_state=return_state,
258                                 recurrent_activation='sigmoid',
259                                 activation='tanh')
260    x = lstm_out(data)
261    keras_model = keras.models.Model(data, x)
262    verify_keras_frontend(keras_model, need_transpose=False)
263
264def _test_LSTM_MultiLayer(inputs, hidden):
265    inputs = keras.layers.Input(shape=(1, inputs))
266    layer = keras.layers.LSTM(hidden, return_state=True, return_sequences=True,
267                                 recurrent_activation='sigmoid',
268                                 activation='tanh')
269    outputs = layer(inputs)
270    output, state = outputs[0], outputs[1:]
271    output = keras.layers.LSTM(hidden, recurrent_activation='sigmoid',
272                               activation='tanh')(output, initial_state=state)
273    keras_model = keras.models.Model(inputs, output)
274    verify_keras_frontend(keras_model, need_transpose=False)
275
276
277def test_forward_LSTM():
278    _test_LSTM(1, 8, 8, return_state=True)
279    _test_LSTM(1, 4, 4, return_state=False)
280    _test_LSTM(20, 16, 256, return_state=False)
281    _test_LSTM_MultiLayer(4, 4)
282
283def _test_RNN(inputs, units):
284    data = keras.layers.Input(shape=(1, inputs))
285    rnn_out = keras.layers.SimpleRNN(units, return_state=True,
286                                 activation='tanh')
287    x = rnn_out(data)
288    keras_model = keras.models.Model(data, x)
289    verify_keras_frontend(keras_model, need_transpose=False)
290
291def _test_RNN_MultiLayer(inputs, units):
292    inputs = keras.layers.Input(shape=(1, inputs))
293    layer = keras.layers.SimpleRNN(units, return_state=True, return_sequences=True,
294                                   activation='tanh')
295    outputs = layer(inputs)
296    output, state = outputs[0], outputs[1:]
297    output = keras.layers.SimpleRNN(units, activation='tanh')(output, initial_state=state)
298    keras_model = keras.models.Model(inputs, output)
299    verify_keras_frontend(keras_model, need_transpose=False)
300
301def test_forward_RNN():
302    _test_RNN(2, 4)
303    _test_RNN(4, 3)
304    _test_RNN_MultiLayer(4, 12)
305
306def _test_GRU(inputs, units):
307    data = keras.layers.Input(shape=(1, inputs))
308    gru_out = keras.layers.GRU(units,
309                               return_state=True,
310                               recurrent_activation='sigmoid',
311                               activation='tanh')
312    x = gru_out(data)
313    keras_model = keras.models.Model(data, x)
314    verify_keras_frontend(keras_model, need_transpose=False)
315
316def _test_GRU_MultiLayer(inputs, units):
317    inputs = keras.layers.Input(shape=(1, inputs))
318    layer = keras.layers.GRU(units,
319                             return_state=True,
320                             return_sequences=True,
321                             recurrent_activation='sigmoid',
322                             activation='tanh')
323    outputs = layer(inputs)
324    output, state = outputs[0], outputs[1:]
325    output = keras.layers.GRU(units, recurrent_activation='sigmoid',
326                              activation='tanh')(output, initial_state=state)
327    keras_model = keras.models.Model(inputs, output)
328    verify_keras_frontend(keras_model, need_transpose=False)
329
330def test_forward_GRU():
331    _test_GRU(2, 4)
332    _test_GRU(4, 3)
333    _test_GRU_MultiLayer(4, 4)
334
335if __name__ == '__main__':
336    test_forward_elemwise_add()
337    test_forward_activations()
338    test_forward_dense()
339    test_forward_pool()
340    test_forward_conv()
341    test_forward_upsample()
342    test_forward_reshape()
343    test_forward_crop()
344    test_forward_vgg16()
345    test_forward_xception()
346    test_forward_resnet50()
347    test_forward_mobilenet()
348
349    test_forward_multi_inputs()
350    test_forward_multi_outputs()
351    test_forward_reuse_layers()
352    test_forward_LSTM()
353    test_forward_RNN()
354    test_forward_GRU()
355