1#!/usr/bin/env python3
2"""Generate a test model for frugally-deep.
3"""
4
5import sys
6
7import numpy as np
8from tensorflow.keras.layers import BatchNormalization, Concatenate
9from tensorflow.keras.layers import Bidirectional, TimeDistributed
10from tensorflow.keras.layers import Conv1D, ZeroPadding1D, Cropping1D
11from tensorflow.keras.layers import Conv2D, ZeroPadding2D, Cropping2D
12from tensorflow.keras.layers import Embedding, Normalization
13from tensorflow.keras.layers import GlobalAveragePooling1D, GlobalMaxPooling1D
14from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D
15from tensorflow.keras.layers import Input, Dense, Dropout, Flatten, Activation
16from tensorflow.keras.layers import LSTM, GRU
17from tensorflow.keras.layers import LeakyReLU, ELU, PReLU, ReLU
18from tensorflow.keras.layers import MaxPooling1D, AveragePooling1D, UpSampling1D
19from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, UpSampling2D
20from tensorflow.keras.layers import Multiply, Add, Subtract, Average, Maximum
21from tensorflow.keras.layers import Permute, Reshape, RepeatVector
22from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
23from tensorflow.keras.models import Model, load_model, Sequential
24
25__author__ = "Tobias Hermann"
26__copyright__ = "Copyright 2017, Tobias Hermann"
27__license__ = "MIT"
28__maintainer__ = "Tobias Hermann, https://github.com/Dobiasd/frugally-deep"
29__email__ = "editgym@gmail.com"
30
31
32def replace_none_with(value, shape):
33    """Replace every None with a fixed value."""
34    return tuple(list(map(lambda x: x if x is not None else value, shape)))
35
36
37def get_shape_for_random_data(data_size, shape):
38    """Include size of data to generate into shape."""
39    if len(shape) == 5:
40        return (data_size, shape[0], shape[1], shape[2], shape[3], shape[4])
41    if len(shape) == 4:
42        return (data_size, shape[0], shape[1], shape[2], shape[3])
43    if len(shape) == 3:
44        return (data_size, shape[0], shape[1], shape[2])
45    if len(shape) == 2:
46        return (data_size, shape[0], shape[1])
47    if len(shape) == 1:
48        return (data_size, shape[0])
49    raise ValueError('can not determine shape for random data')
50
51
52def generate_random_data(data_size, shape):
53    """Random data for training."""
54    return np.random.random(
55        size=get_shape_for_random_data(data_size, replace_none_with(42, shape)))
56
57
58def generate_input_data(data_size, input_shapes):
59    """Random input data for training."""
60    return [generate_random_data(data_size, input_shape)
61            for input_shape in input_shapes]
62
63
64def generate_integer_random_data(data_size, low, high, shape):
65    """Random data for training."""
66    return np.random.randint(
67        low=low, high=high, size=get_shape_for_random_data(data_size, replace_none_with(42, shape)))
68
69
70def generate_integer_input_data(data_size, low, highs, input_shapes):
71    """Random input data for training."""
72    return [generate_integer_random_data(data_size, low, high, input_shape)
73            for high, input_shape in zip(highs, input_shapes)]
74
75
76def as_list(value_or_values):
77    """Leave lists untouched, convert non-list types to a singleton list"""
78    if isinstance(value_or_values, list):
79        return value_or_values
80    return [value_or_values]
81
82
83def generate_output_data(data_size, outputs):
84    """Random output data for training."""
85    return [generate_random_data(data_size, output.shape[1:])
86            for output in as_list(outputs)]
87
88
89def get_test_model_exhaustive():
90    """Returns a exhaustive test model."""
91    input_shapes = [
92        (2, 3, 4, 5, 6),
93        (2, 3, 4, 5, 6),
94        (7, 8, 9, 10),
95        (7, 8, 9, 10),
96        (11, 12, 13),
97        (11, 12, 13),
98        (14, 15),
99        (14, 15),
100        (16,),
101        (16,),
102        (2,),
103        (1,),
104        (2,),
105        (1,),
106        (1, 3),
107        (1, 4),
108        (1, 1, 3),
109        (1, 1, 4),
110        (1, 1, 1, 3),
111        (1, 1, 1, 4),
112        (1, 1, 1, 1, 3),
113        (1, 1, 1, 1, 4),
114        (26, 28, 3),
115        (4, 4, 3),
116        (4, 4, 3),
117        (4,),
118        (2, 3),
119        (1,),
120        (1,),
121        (1,),
122        (2, 3),
123        (9, 16, 1),
124        (1, 9, 16)
125    ]
126
127    inputs = [Input(shape=s) for s in input_shapes]
128
129    outputs = []
130
131    outputs.append(Conv1D(1, 3, padding='valid')(inputs[6]))
132    outputs.append(Conv1D(2, 1, padding='same')(inputs[6]))
133    outputs.append(Conv1D(3, 4, padding='causal', dilation_rate=2)(inputs[6]))
134    outputs.append(ZeroPadding1D(2)(inputs[6]))
135    outputs.append(Cropping1D((2, 3))(inputs[6]))
136    outputs.append(MaxPooling1D(2)(inputs[6]))
137    outputs.append(MaxPooling1D(2, strides=2, padding='same')(inputs[6]))
138    outputs.append(MaxPooling1D(2, data_format="channels_first")(inputs[6]))
139    outputs.append(AveragePooling1D(2)(inputs[6]))
140    outputs.append(AveragePooling1D(2, strides=2, padding='same')(inputs[6]))
141    outputs.append(AveragePooling1D(2, data_format="channels_first")(inputs[6]))
142    outputs.append(GlobalMaxPooling1D()(inputs[6]))
143    outputs.append(GlobalMaxPooling1D(data_format="channels_first")(inputs[6]))
144    outputs.append(GlobalAveragePooling1D()(inputs[6]))
145    outputs.append(GlobalAveragePooling1D(data_format="channels_first")(inputs[6]))
146
147    for axis in range(1, 6):
148        shape = input_shapes[0][axis - 1]
149        outputs.append(Normalization(axis=axis,
150                                     mean=np.random.rand(shape),
151                                     variance=np.random.rand(shape)
152                                     )(inputs[0]))
153    outputs.append(Normalization(axis=None, mean=2.1, variance=2.2)(inputs[4]))
154    outputs.append(Normalization(axis=-1, mean=2.1, variance=2.2)(inputs[6]))
155
156    outputs.append(Conv2D(4, (3, 3))(inputs[4]))
157    outputs.append(Conv2D(4, (3, 3), use_bias=False)(inputs[4]))
158    outputs.append(Conv2D(4, (2, 4), strides=(2, 3), padding='same')(inputs[4]))
159    outputs.append(Conv2D(4, (2, 4), padding='same', dilation_rate=(2, 3))(inputs[4]))
160
161    outputs.append(SeparableConv2D(3, (3, 3))(inputs[4]))
162    outputs.append(DepthwiseConv2D((3, 3))(inputs[4]))
163    outputs.append(DepthwiseConv2D((1, 2))(inputs[4]))
164
165    outputs.append(MaxPooling2D((2, 2))(inputs[4]))
166    # todo: check if TensorFlow >= 2.8 supports this
167    # outputs.append(MaxPooling2D((2, 2), data_format="channels_first")(inputs[4]))
168    outputs.append(MaxPooling2D((1, 3), strides=(2, 3), padding='same')(inputs[4]))
169    outputs.append(AveragePooling2D((2, 2))(inputs[4]))
170    # todo: check if TensorFlow >= 2.8 supports this
171    # outputs.append(AveragePooling2D((2, 2), data_format="channels_first")(inputs[4]))
172    outputs.append(AveragePooling2D((1, 3), strides=(2, 3), padding='same')(inputs[4]))
173
174    outputs.append(GlobalAveragePooling2D()(inputs[4]))
175    outputs.append(GlobalAveragePooling2D(data_format="channels_first")(inputs[4]))
176    outputs.append(GlobalMaxPooling2D()(inputs[4]))
177    outputs.append(GlobalMaxPooling2D(data_format="channels_first")(inputs[4]))
178
179    outputs.append(Permute((3, 4, 1, 5, 2))(inputs[0]))
180    outputs.append(Permute((1, 5, 3, 2, 4))(inputs[0]))
181    outputs.append(Permute((3, 4, 1, 2))(inputs[2]))
182    outputs.append(Permute((2, 1, 3))(inputs[4]))
183    outputs.append(Permute((2, 1))(inputs[6]))
184    outputs.append(Permute((1,))(inputs[8]))
185
186    outputs.append(Permute((3, 1, 2))(inputs[31]))
187    outputs.append(Permute((3, 1, 2))(inputs[32]))
188    outputs.append(BatchNormalization()(Permute((3, 1, 2))(inputs[31])))
189    outputs.append(BatchNormalization()(Permute((3, 1, 2))(inputs[32])))
190
191    outputs.append(BatchNormalization()(inputs[0]))
192    outputs.append(BatchNormalization(axis=1)(inputs[0]))
193    outputs.append(BatchNormalization(axis=2)(inputs[0]))
194    outputs.append(BatchNormalization(axis=3)(inputs[0]))
195    outputs.append(BatchNormalization(axis=4)(inputs[0]))
196    outputs.append(BatchNormalization(axis=5)(inputs[0]))
197    outputs.append(BatchNormalization()(inputs[2]))
198    outputs.append(BatchNormalization(axis=1)(inputs[2]))
199    outputs.append(BatchNormalization(axis=2)(inputs[2]))
200    outputs.append(BatchNormalization(axis=3)(inputs[2]))
201    outputs.append(BatchNormalization(axis=4)(inputs[2]))
202    outputs.append(BatchNormalization()(inputs[4]))
203    # todo: check if TensorFlow >= 2.1 supports this
204    # outputs.append(BatchNormalization(axis=1)(inputs[4])) # tensorflow.python.framework.errors_impl.InternalError:  The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now.
205    outputs.append(BatchNormalization(axis=2)(inputs[4]))
206    outputs.append(BatchNormalization(axis=3)(inputs[4]))
207    outputs.append(BatchNormalization()(inputs[6]))
208    outputs.append(BatchNormalization(axis=1)(inputs[6]))
209    outputs.append(BatchNormalization(axis=2)(inputs[6]))
210    outputs.append(BatchNormalization()(inputs[8]))
211    outputs.append(BatchNormalization(axis=1)(inputs[8]))
212    outputs.append(BatchNormalization()(inputs[27]))
213    outputs.append(BatchNormalization(axis=1)(inputs[27]))
214    outputs.append(BatchNormalization()(inputs[14]))
215    outputs.append(BatchNormalization(axis=1)(inputs[14]))
216    outputs.append(BatchNormalization(axis=2)(inputs[14]))
217    outputs.append(BatchNormalization()(inputs[16]))
218    # todo: check if TensorFlow >= 2.1 supports this
219    # outputs.append(BatchNormalization(axis=1)(inputs[16])) # tensorflow.python.framework.errors_impl.InternalError:  The CPU implementation of FusedBatchNorm only supports NHWC tensor format for now.
220    outputs.append(BatchNormalization(axis=2)(inputs[16]))
221    outputs.append(BatchNormalization(axis=3)(inputs[16]))
222    outputs.append(BatchNormalization()(inputs[18]))
223    outputs.append(BatchNormalization(axis=1)(inputs[18]))
224    outputs.append(BatchNormalization(axis=2)(inputs[18]))
225    outputs.append(BatchNormalization(axis=3)(inputs[18]))
226    outputs.append(BatchNormalization(axis=4)(inputs[18]))
227    outputs.append(BatchNormalization()(inputs[20]))
228    outputs.append(BatchNormalization(axis=1)(inputs[20]))
229    outputs.append(BatchNormalization(axis=2)(inputs[20]))
230    outputs.append(BatchNormalization(axis=3)(inputs[20]))
231    outputs.append(BatchNormalization(axis=4)(inputs[20]))
232    outputs.append(BatchNormalization(axis=5)(inputs[20]))
233
234    outputs.append(Dropout(0.5)(inputs[4]))
235
236    outputs.append(ZeroPadding2D(2)(inputs[4]))
237    outputs.append(ZeroPadding2D((2, 3))(inputs[4]))
238    outputs.append(ZeroPadding2D(((1, 2), (3, 4)))(inputs[4]))
239    outputs.append(Cropping2D(2)(inputs[4]))
240    outputs.append(Cropping2D((2, 3))(inputs[4]))
241    outputs.append(Cropping2D(((1, 2), (3, 4)))(inputs[4]))
242
243    outputs.append(Dense(3, use_bias=True)(inputs[13]))
244    outputs.append(Dense(3, use_bias=True)(inputs[14]))
245    outputs.append(Dense(4, use_bias=False)(inputs[16]))
246    outputs.append(Dense(4, use_bias=False, activation='tanh')(inputs[18]))
247    outputs.append(Dense(4, use_bias=False)(inputs[20]))
248
249    outputs.append(Reshape(((2 * 3 * 4 * 5 * 6),))(inputs[0]))
250    outputs.append(Reshape((2, 3 * 4 * 5 * 6))(inputs[0]))
251    outputs.append(Reshape((2, 3, 4 * 5 * 6))(inputs[0]))
252    outputs.append(Reshape((2, 3, 4, 5 * 6))(inputs[0]))
253    outputs.append(Reshape((2, 3, 4, 5, 6))(inputs[0]))
254
255    outputs.append(Reshape((16,))(inputs[8]))
256    outputs.append(Reshape((2, 8))(inputs[8]))
257    outputs.append(Reshape((2, 2, 4))(inputs[8]))
258    outputs.append(Reshape((2, 2, 2, 2))(inputs[8]))
259    outputs.append(Reshape((2, 2, 1, 2, 2))(inputs[8]))
260
261    outputs.append(RepeatVector(3)(inputs[8]))
262
263    outputs.append(UpSampling2D(size=(1, 2), interpolation='nearest')(inputs[4]))
264    outputs.append(UpSampling2D(size=(5, 3), interpolation='nearest')(inputs[4]))
265    outputs.append(UpSampling2D(size=(1, 2), interpolation='bilinear')(inputs[4]))
266    outputs.append(UpSampling2D(size=(5, 3), interpolation='bilinear')(inputs[4]))
267
268    outputs.append(ReLU()(inputs[0]))
269
270    for axis in [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5]:
271        outputs.append(Concatenate(axis=axis)([inputs[0], inputs[1]]))
272    for axis in [-4, -3, -2, -1, 1, 2, 3, 4]:
273        outputs.append(Concatenate(axis=axis)([inputs[2], inputs[3]]))
274    for axis in [-3, -2, -1, 1, 2, 3]:
275        outputs.append(Concatenate(axis=axis)([inputs[4], inputs[5]]))
276    for axis in [-2, -1, 1, 2]:
277        outputs.append(Concatenate(axis=axis)([inputs[6], inputs[7]]))
278    for axis in [-1, 1]:
279        outputs.append(Concatenate(axis=axis)([inputs[8], inputs[9]]))
280    for axis in [-1, 2]:
281        outputs.append(Concatenate(axis=axis)([inputs[14], inputs[15]]))
282    for axis in [-1, 3]:
283        outputs.append(Concatenate(axis=axis)([inputs[16], inputs[17]]))
284    for axis in [-1, 4]:
285        outputs.append(Concatenate(axis=axis)([inputs[18], inputs[19]]))
286    for axis in [-1, 5]:
287        outputs.append(Concatenate(axis=axis)([inputs[20], inputs[21]]))
288
289    outputs.append(UpSampling1D(size=2)(inputs[6]))
290    # outputs.append(UpSampling1D(size=2)(inputs[8])) # ValueError: Input 0 of layer up_sampling1d_1 is incompatible with the layer: expected ndim=3, found ndim=2. Full shape received: [None, 16]
291
292    outputs.append(Multiply()([inputs[10], inputs[11]]))
293    outputs.append(Multiply()([inputs[11], inputs[10]]))
294    outputs.append(Multiply()([inputs[11], inputs[13]]))
295    outputs.append(Multiply()([inputs[10], inputs[11], inputs[12]]))
296    outputs.append(Multiply()([inputs[11], inputs[12], inputs[13]]))
297
298    shared_conv = Conv2D(1, (1, 1),
299                         padding='valid', name='shared_conv', activation='relu')
300
301    up_scale_2 = UpSampling2D((2, 2))
302    x1 = shared_conv(up_scale_2(inputs[23]))  # (1, 8, 8)
303    x2 = shared_conv(up_scale_2(inputs[24]))  # (1, 8, 8)
304    x3 = Conv2D(1, (1, 1), padding='valid')(up_scale_2(inputs[24]))  # (1, 8, 8)
305    x = Concatenate()([x1, x2, x3])  # (3, 8, 8)
306    outputs.append(x)
307
308    x = Conv2D(3, (1, 1), padding='same', use_bias=False)(x)  # (3, 8, 8)
309    outputs.append(x)
310    x = Dropout(0.5)(x)
311    outputs.append(x)
312    x = Concatenate()([
313        MaxPooling2D((2, 2))(x),
314        AveragePooling2D((2, 2))(x)])  # (6, 4, 4)
315    outputs.append(x)
316
317    x = Flatten()(x)  # (1, 1, 96)
318    x = Dense(4, use_bias=False)(x)
319    outputs.append(x)
320    x = Dense(3)(x)  # (1, 1, 3)
321    outputs.append(x)
322
323    outputs.append(Add()([inputs[26], inputs[30], inputs[30]]))
324    outputs.append(Subtract()([inputs[26], inputs[30]]))
325    outputs.append(Multiply()([inputs[26], inputs[30], inputs[30]]))
326    outputs.append(Average()([inputs[26], inputs[30], inputs[30]]))
327    outputs.append(Maximum()([inputs[26], inputs[30], inputs[30]]))
328    outputs.append(Concatenate()([inputs[26], inputs[30], inputs[30]]))
329
330    intermediate_input_shape = (3,)
331    intermediate_in = Input(intermediate_input_shape)
332    intermediate_x = intermediate_in
333    intermediate_x = Dense(8)(intermediate_x)
334    intermediate_x = Dense(5, name='duplicate_layer_name')(intermediate_x)
335    intermediate_model = Model(
336        inputs=[intermediate_in], outputs=[intermediate_x],
337        name='intermediate_model')
338    intermediate_model.compile(loss='mse', optimizer='nadam')
339
340    x = intermediate_model(x)  # (1, 1, 5)
341
342    intermediate_model_2 = Sequential()
343    intermediate_model_2.add(Dense(7, input_shape=(5,)))
344    intermediate_model_2.add(Dense(5, name='duplicate_layer_name'))
345    intermediate_model_2.compile(optimizer='rmsprop',
346                                 loss='categorical_crossentropy')
347
348    x = intermediate_model_2(x)  # (1, 1, 5)
349
350    intermediate_model_3_nested = Sequential()
351    intermediate_model_3_nested.add(Dense(7, input_shape=(6,)))
352    intermediate_model_3_nested.compile(optimizer='rmsprop', loss='categorical_crossentropy')
353
354    intermediate_model_3 = Sequential()
355    intermediate_model_3.add(Dense(6, input_shape=(5,)))
356    intermediate_model_3.add(intermediate_model_3_nested)
357    intermediate_model_3.add(Dense(8))
358    intermediate_model_3.compile(optimizer='rmsprop', loss='categorical_crossentropy')
359
360    x = intermediate_model_3(x)  # (1, 1, 8)
361
362    x = Dense(3)(x)  # (1, 1, 3)
363
364    shared_activation = Activation('tanh')
365
366    outputs = outputs + [
367        Activation('tanh')(inputs[25]),
368        Activation('hard_sigmoid')(inputs[25]),
369        Activation('selu')(inputs[25]),
370        Activation('sigmoid')(inputs[25]),
371        Activation('softplus')(inputs[25]),
372        Activation('softmax')(inputs[25]),
373        Activation('relu')(inputs[25]),
374        Activation('relu6')(inputs[25]),
375        Activation('swish')(inputs[25]),
376        Activation('exponential')(inputs[25]),
377        Activation('gelu')(inputs[25]),
378        Activation('softsign')(inputs[25]),
379        LeakyReLU()(inputs[25]),
380        ReLU()(inputs[25]),
381        ReLU(max_value=0.4, negative_slope=1.1, threshold=0.3)(inputs[25]),
382        ELU()(inputs[25]),
383        PReLU()(inputs[24]),
384        PReLU()(inputs[25]),
385        PReLU()(inputs[26]),
386        shared_activation(inputs[25]),
387        Activation('linear')(inputs[26]),
388        Activation('linear')(inputs[23]),
389        x,
390        shared_activation(x),
391    ]
392
393    model = Model(inputs=inputs, outputs=outputs, name='test_model_exhaustive')
394    model.compile(loss='mse', optimizer='nadam')
395
396    # fit to dummy data
397    training_data_size = 2
398    data_in = generate_input_data(training_data_size, input_shapes)
399    initial_data_out = model.predict(data_in)
400    data_out = generate_output_data(training_data_size, initial_data_out)
401    model.fit(data_in, data_out, epochs=10)
402    return model
403
404
405def get_test_model_embedding():
406    """Returns a minimalistic test model for the embedding layer."""
407
408    input_dims = [
409        1023,  # maximum integer value in input data
410        255
411    ]
412    input_shapes = [
413        (100,),  # must be single-element tuple (for sequence length)
414        (1000,)
415    ]
416    assert len(input_dims) == len(input_shapes)
417    output_dims = [8, 3]  # embedding dimension
418
419    inputs = [Input(shape=s) for s in input_shapes]
420
421    outputs = []
422    for k in range(0, len(input_shapes)):
423        embedding = Embedding(input_dim=input_dims[k], output_dim=output_dims[k])(inputs[k])
424        lstm = LSTM(
425            units=4,
426            recurrent_activation='sigmoid',
427            return_sequences=False
428        )(embedding)
429
430        outputs.append(lstm)
431
432    model = Model(inputs=inputs, outputs=outputs, name='test_model_embedding')
433    model.compile(loss='mse', optimizer='adam')
434
435    # fit to dummy data
436    training_data_size = 2
437    data_in = generate_integer_input_data(training_data_size, 0, input_dims, input_shapes)
438    initial_data_out = model.predict(data_in)
439    data_out = generate_output_data(training_data_size, initial_data_out)
440    model.fit(data_in, data_out, epochs=1)
441    return model
442
443
444def get_test_model_recurrent():
445    """Returns a minimalistic test model for recurrent layers."""
446    input_shapes = [
447        (17, 4),
448        (1, 10),
449        (20, 40),
450        (6, 7, 10, 3)
451    ]
452
453    outputs = []
454
455    inputs = [Input(shape=s) for s in input_shapes]
456
457    inp = PReLU()(inputs[0])
458
459    lstm = Bidirectional(LSTM(units=4,
460                              return_sequences=True,
461                              bias_initializer='random_uniform',  # default is zero use random to test computation
462                              activation='tanh',
463                              recurrent_activation='relu'), merge_mode='concat')(inp)
464
465    lstm2 = Bidirectional(LSTM(units=6,
466                               return_sequences=True,
467                               bias_initializer='random_uniform',
468                               activation='elu',
469                               recurrent_activation='hard_sigmoid'), merge_mode='sum')(lstm)
470
471    lstm3 = LSTM(units=10,
472                 return_sequences=False,
473                 bias_initializer='random_uniform',
474                 activation='selu',
475                 recurrent_activation='sigmoid')(lstm2)
476
477    outputs.append(lstm3)
478
479    conv1 = Conv1D(2, 1, activation='sigmoid')(inputs[1])
480    lstm4 = LSTM(units=15,
481                 return_sequences=False,
482                 bias_initializer='random_uniform',
483                 activation='tanh',
484                 recurrent_activation='elu')(conv1)
485
486    dense = (Dense(23, activation='sigmoid'))(lstm4)
487    outputs.append(dense)
488
489    time_dist_1 = TimeDistributed(Conv2D(2, (3, 3), use_bias=True))(inputs[3])
490    flatten_1 = TimeDistributed(Flatten())(time_dist_1)
491
492    outputs.append(Bidirectional(LSTM(units=6,
493                                      return_sequences=True,
494                                      bias_initializer='random_uniform',
495                                      activation='tanh',
496                                      recurrent_activation='sigmoid'), merge_mode='ave')(flatten_1))
497
498    outputs.append(TimeDistributed(MaxPooling2D(2, 2))(inputs[3]))
499    outputs.append(TimeDistributed(AveragePooling2D(2, 2))(inputs[3]))
500    outputs.append(TimeDistributed(BatchNormalization())(inputs[3]))
501
502    nested_inputs = Input(shape=input_shapes[0][1:])
503    nested_x = Dense(5, activation='relu')(nested_inputs)
504    nested_predictions = Dense(3, activation='softmax')(nested_x)
505    nested_model = Model(inputs=nested_inputs, outputs=nested_predictions)
506    nested_model.compile(loss='categorical_crossentropy', optimizer='nadam')
507    outputs.append(TimeDistributed(nested_model)(inputs[0]))
508
509    nested_sequential_model = Sequential()
510    nested_sequential_model.add(Flatten(input_shape=input_shapes[0][1:]))
511    nested_sequential_model.compile(optimizer='rmsprop',
512                                    loss='categorical_crossentropy')
513    outputs.append(TimeDistributed(nested_sequential_model)(inputs[0]))
514
515    model = Model(inputs=inputs, outputs=outputs, name='test_model_recurrent')
516    model.compile(loss='mse', optimizer='nadam')
517
518    # fit to dummy data
519    training_data_size = 2
520    data_in = generate_input_data(training_data_size, input_shapes)
521    initial_data_out = model.predict(data_in)
522    data_out = generate_output_data(training_data_size, initial_data_out)
523    model.fit(data_in, data_out, epochs=10)
524    return model
525
526
527def get_test_model_lstm():
528    """Returns a test model for Long Short-Term Memory (LSTM) layers."""
529
530    input_shapes = [
531        (17, 4),
532        (1, 10),
533        (None, 4),
534        (12,),
535        (12,)
536    ]
537    inputs = [Input(shape=s) for s in input_shapes]
538    outputs = []
539
540    for inp in inputs[:2]:
541        lstm_sequences = LSTM(
542            units=8,
543            recurrent_activation='relu',
544            return_sequences=True
545        )(inp)
546        lstm_regular = LSTM(
547            units=3,
548            recurrent_activation='sigmoid',
549            return_sequences=False
550        )(lstm_sequences)
551        outputs.append(lstm_regular)
552        lstm_state, state_h, state_c = LSTM(
553            units=3,
554            recurrent_activation='sigmoid',
555            return_state=True
556        )(inp)
557        outputs.append(lstm_state)
558        outputs.append(state_h)
559        outputs.append(state_c)
560
561        lstm_bidi_sequences = Bidirectional(
562            LSTM(
563                units=4,
564                recurrent_activation='hard_sigmoid',
565                return_sequences=True
566            )
567        )(inp)
568        lstm_bidi = Bidirectional(
569            LSTM(
570                units=6,
571                recurrent_activation='linear',
572                return_sequences=False
573            )
574        )(lstm_bidi_sequences)
575        outputs.append(lstm_bidi)
576
577        lstm_gpu_regular = LSTM(
578            units=3,
579            activation='tanh',
580            recurrent_activation='sigmoid',
581            use_bias=True
582        )(inp)
583
584        lstm_gpu_bidi = Bidirectional(
585            LSTM(
586                units=3,
587                activation='tanh',
588                recurrent_activation='sigmoid',
589                use_bias=True
590            )
591        )(inp)
592    outputs.append(lstm_gpu_regular)
593    outputs.append(lstm_gpu_bidi)
594
595    outputs.extend(LSTM(units=12, return_sequences=True,
596                        return_state=True)(inputs[2], initial_state=[inputs[3], inputs[4]]))
597
598    model = Model(inputs=inputs, outputs=outputs, name='test_model_lstm')
599    model.compile(loss='mse', optimizer='nadam')
600
601    # fit to dummy data
602    training_data_size = 2
603    data_in = generate_input_data(training_data_size, input_shapes)
604    initial_data_out = model.predict(data_in)
605    data_out = generate_output_data(training_data_size, initial_data_out)
606    model.fit(data_in, data_out, epochs=10)
607    return model
608
609
610def get_test_model_gru():
611    return get_test_model_gru_stateful_optional(False)
612
613
614def get_test_model_gru_stateful():
615    return get_test_model_gru_stateful_optional(True)
616
617
618def get_test_model_gru_stateful_optional(stateful):
619    """Returns a test model for Gated Recurrent Unit (GRU) layers."""
620    input_shapes = [
621        (17, 4),
622        (1, 10)
623    ]
624    stateful_batch_size = 1
625    inputs = [Input(batch_shape=(stateful_batch_size,) + s) for s in input_shapes]
626    outputs = []
627
628    for inp in inputs:
629        gru_sequences = GRU(
630            stateful=stateful,
631            units=8,
632            recurrent_activation='relu',
633            reset_after=True,
634            return_sequences=True,
635            use_bias=True
636        )(inp)
637        gru_regular = GRU(
638            stateful=stateful,
639            units=3,
640            recurrent_activation='sigmoid',
641            reset_after=True,
642            return_sequences=False,
643            use_bias=False
644        )(gru_sequences)
645        outputs.append(gru_regular)
646
647        gru_bidi_sequences = Bidirectional(
648            GRU(
649                stateful=stateful,
650                units=4,
651                recurrent_activation='hard_sigmoid',
652                reset_after=False,
653                return_sequences=True,
654                use_bias=True
655            )
656        )(inp)
657        gru_bidi = Bidirectional(
658            GRU(
659                stateful=stateful,
660                units=6,
661                recurrent_activation='sigmoid',
662                reset_after=True,
663                return_sequences=False,
664                use_bias=False
665            )
666        )(gru_bidi_sequences)
667        outputs.append(gru_bidi)
668
669        gru_gpu_regular = GRU(
670            stateful=stateful,
671            units=3,
672            activation='tanh',
673            recurrent_activation='sigmoid',
674            reset_after=True,
675            use_bias=True
676        )(inp)
677
678        gru_gpu_bidi = Bidirectional(
679            GRU(
680                stateful=stateful,
681                units=3,
682                activation='tanh',
683                recurrent_activation='sigmoid',
684                reset_after=True,
685                use_bias=True
686            )
687        )(inp)
688        outputs.append(gru_gpu_regular)
689        outputs.append(gru_gpu_bidi)
690
691    model = Model(inputs=inputs, outputs=outputs, name='test_model_gru')
692    model.compile(loss='mse', optimizer='nadam')
693
694    # fit to dummy data
695    training_data_size = stateful_batch_size
696    data_in = generate_input_data(training_data_size, input_shapes)
697    initial_data_out = model.predict(data_in)
698    data_out = generate_output_data(training_data_size, initial_data_out)
699    model.fit(data_in, data_out, batch_size=stateful_batch_size, epochs=10)
700    return model
701
702
703def get_test_model_variable():
704    """Returns a model with variably shaped input tensors."""
705
706    input_shapes = [
707        (None, None, 1),
708        (None, None, 3),
709        (None, 4),
710    ]
711
712    inputs = [Input(shape=s) for s in input_shapes]
713
714    outputs = []
715
716    # same as axis=-1
717    outputs.append(Concatenate()([inputs[0], inputs[1]]))
718    outputs.append(Conv2D(8, (3, 3), padding='same', activation='elu')(inputs[0]))
719    outputs.append(Conv2D(8, (3, 3), padding='same', activation='relu')(inputs[1]))
720    outputs.append(GlobalMaxPooling2D()(inputs[0]))
721    outputs.append(Reshape((2, -1))(inputs[2]))
722    outputs.append(Reshape((-1, 2))(inputs[2]))
723    outputs.append(MaxPooling2D()(inputs[1]))
724    outputs.append(AveragePooling1D()(inputs[2]))
725
726    outputs.append(PReLU(shared_axes=[1, 2])(inputs[0]))
727    outputs.append(PReLU(shared_axes=[1, 2])(inputs[1]))
728    outputs.append(PReLU(shared_axes=[1, 2, 3])(inputs[1]))
729    outputs.append(PReLU(shared_axes=[1])(inputs[2]))
730
731    model = Model(inputs=inputs, outputs=outputs, name='test_model_variable')
732    model.compile(loss='mse', optimizer='nadam')
733
734    # fit to dummy data
735    training_data_size = 2
736    data_in = generate_input_data(training_data_size, input_shapes)
737    initial_data_out = model.predict(data_in)
738    data_out = generate_output_data(training_data_size, initial_data_out)
739    model.fit(data_in, data_out, epochs=10)
740    return model
741
742
743def get_test_model_sequential():
744    """Returns a typical (VGG-like) sequential test model."""
745    model = Sequential()
746    model.add(Conv2D(8, (3, 3), activation='relu', input_shape=(32, 32, 3)))
747    model.add(Conv2D(8, (3, 3), activation='relu'))
748    model.add(Permute((3, 1, 2)))
749    model.add(MaxPooling2D(pool_size=(2, 2)))
750    model.add(Permute((2, 3, 1)))
751    model.add(Dropout(0.25))
752
753    model.add(Conv2D(16, (3, 3), activation='elu'))
754    model.add(Conv2D(16, (3, 3)))
755    model.add(ELU())
756
757    model.add(MaxPooling2D(pool_size=(2, 2)))
758    model.add(Dropout(0.25))
759
760    model.add(Flatten())
761    model.add(Dense(64, activation='sigmoid'))
762    model.add(Dropout(0.5))
763    model.add(Dense(10, activation='softmax'))
764
765    model.compile(loss='categorical_crossentropy', optimizer='sgd')
766
767    # fit to dummy data
768    training_data_size = 2
769    data_in = [np.random.random(size=(training_data_size, 32, 32, 3))]
770    data_out = [np.random.random(size=(training_data_size, 10))]
771    model.fit(data_in, data_out, epochs=10)
772    return model
773
774
775def get_test_model_lstm_stateful():
776    stateful_batch_size = 1
777    input_shapes = [
778        (17, 4),
779        (1, 10),
780        (None, 4),
781        (12,),
782        (12,)
783    ]
784
785    inputs = [Input(batch_shape=(stateful_batch_size,) + s) for s in input_shapes]
786    outputs = []
787    for in_num, inp in enumerate(inputs[:2]):
788        stateful = bool((in_num + 1) % 2)
789        lstm_sequences = LSTM(
790            stateful=stateful,
791            units=8,
792            recurrent_activation='relu',
793            return_sequences=True,
794            name='lstm_sequences_' + str(in_num) + '_st-' + str(stateful)
795        )(inp)
796        stateful = bool((in_num) % 2)
797        lstm_regular = LSTM(
798            stateful=stateful,
799            units=3,
800            recurrent_activation='sigmoid',
801            return_sequences=False,
802            name='lstm_regular_' + str(in_num) + '_st-' + str(stateful)
803        )(lstm_sequences)
804        outputs.append(lstm_regular)
805        stateful = bool((in_num + 1) % 2)
806        lstm_state, state_h, state_c = LSTM(
807            stateful=stateful,
808            units=3,
809            recurrent_activation='sigmoid',
810            return_state=True,
811            name='lstm_state_return_' + str(in_num) + '_st-' + str(stateful)
812        )(inp)
813        outputs.append(lstm_state)
814        outputs.append(state_h)
815        outputs.append(state_c)
816        stateful = bool((in_num + 1) % 2)
817        lstm_bidi_sequences = Bidirectional(
818            LSTM(
819                stateful=stateful,
820                units=4,
821                recurrent_activation='hard_sigmoid',
822                return_sequences=True,
823                name='bi-lstm1_' + str(in_num) + '_st-' + str(stateful)
824            )
825        )(inp)
826        stateful = bool((in_num) % 2)
827        lstm_bidi = Bidirectional(
828            LSTM(
829                stateful=stateful,
830                units=6,
831                recurrent_activation='linear',
832                return_sequences=False,
833                name='bi-lstm2_' + str(in_num) + '_st-' + str(stateful)
834            )
835        )(lstm_bidi_sequences)
836        outputs.append(lstm_bidi)
837
838    initial_state_stateful = LSTM(units=12, return_sequences=True, stateful=True, return_state=True,
839                                  name='initial_state_stateful')(inputs[2], initial_state=[inputs[3], inputs[4]])
840    outputs.extend(initial_state_stateful)
841    initial_state_not_stateful = LSTM(units=12, return_sequences=False, stateful=False, return_state=True,
842                                      name='initial_state_not_stateful')(inputs[2],
843                                                                         initial_state=[inputs[3], inputs[4]])
844    outputs.extend(initial_state_not_stateful)
845    model = Model(inputs=inputs, outputs=outputs)
846    model.compile(loss='mean_squared_error', optimizer='nadam')
847
848    # fit to dummy data
849    training_data_size = stateful_batch_size
850    data_in = generate_input_data(training_data_size, input_shapes)
851    initial_data_out = model.predict(data_in)
852    data_out = generate_output_data(training_data_size, initial_data_out)
853
854    model.fit(data_in, data_out, batch_size=stateful_batch_size, epochs=10)
855    return model
856
857
858def main():
859    """Generate different test models and save them to the given directory."""
860    if len(sys.argv) != 3:
861        print('usage: [model name] [destination file path]')
862        sys.exit(1)
863    else:
864        model_name = sys.argv[1]
865        dest_path = sys.argv[2]
866
867        get_model_functions = {
868            'exhaustive': get_test_model_exhaustive,
869            'embedding': get_test_model_embedding,
870            'recurrent': get_test_model_recurrent,
871            'lstm': get_test_model_lstm,
872            'gru': get_test_model_gru,
873            'variable': get_test_model_variable,
874            'sequential': get_test_model_sequential,
875            'lstm_stateful': get_test_model_lstm_stateful,
876            'gru_stateful': get_test_model_gru_stateful
877        }
878
879        if not model_name in get_model_functions:
880            print('unknown model name: ', model_name)
881            sys.exit(2)
882
883        np.random.seed(0)
884
885        model_func = get_model_functions[model_name]
886        model = model_func()
887        model.save(dest_path, include_optimizer=False)
888
889        # Make sure models can be loaded again,
890        # see https://github.com/fchollet/keras/issues/7682
891        model = load_model(dest_path)
892        model.summary()
893        # plot_model(model, to_file= str(model_name) + '.png', show_shapes=True, show_layer_names=True)  #### DEBUG stateful
894
895
896if __name__ == "__main__":
897    main()
898