1""" 2#This script demonstrates the use of a convolutional LSTM network. 3 4This network is used to predict the next frame of an artificially 5generated movie which contains moving squares. 6""" 7from keras.models import Sequential 8from keras.layers.convolutional import Conv3D 9from keras.layers.convolutional_recurrent import ConvLSTM2D 10from keras.layers.normalization import BatchNormalization 11import numpy as np 12import pylab as plt 13 14# We create a layer which take as input movies of shape 15# (n_frames, width, height, channels) and returns a movie 16# of identical shape. 17 18seq = Sequential() 19seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), 20 input_shape=(None, 40, 40, 1), 21 padding='same', return_sequences=True)) 22seq.add(BatchNormalization()) 23 24seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), 25 padding='same', return_sequences=True)) 26seq.add(BatchNormalization()) 27 28seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), 29 padding='same', return_sequences=True)) 30seq.add(BatchNormalization()) 31 32seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3), 33 padding='same', return_sequences=True)) 34seq.add(BatchNormalization()) 35 36seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3), 37 activation='sigmoid', 38 padding='same', data_format='channels_last')) 39seq.compile(loss='binary_crossentropy', optimizer='adadelta') 40 41 42# Artificial data generation: 43# Generate movies with 3 to 7 moving squares inside. 44# The squares are of shape 1x1 or 2x2 pixels, 45# which move linearly over time. 46# For convenience we first create movies with bigger width and height (80x80) 47# and at the end we select a 40x40 window. 48 49def generate_movies(n_samples=1200, n_frames=15): 50 row = 80 51 col = 80 52 noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float) 53 shifted_movies = np.zeros((n_samples, n_frames, row, col, 1), 54 dtype=np.float) 55 56 for i in range(n_samples): 57 # Add 3 to 7 moving squares 58 n = np.random.randint(3, 8) 59 60 for j in range(n): 61 # Initial position 62 xstart = np.random.randint(20, 60) 63 ystart = np.random.randint(20, 60) 64 # Direction of motion 65 directionx = np.random.randint(0, 3) - 1 66 directiony = np.random.randint(0, 3) - 1 67 68 # Size of the square 69 w = np.random.randint(2, 4) 70 71 for t in range(n_frames): 72 x_shift = xstart + directionx * t 73 y_shift = ystart + directiony * t 74 noisy_movies[i, t, x_shift - w: x_shift + w, 75 y_shift - w: y_shift + w, 0] += 1 76 77 # Make it more robust by adding noise. 78 # The idea is that if during inference, 79 # the value of the pixel is not exactly one, 80 # we need to train the network to be robust and still 81 # consider it as a pixel belonging to a square. 82 if np.random.randint(0, 2): 83 noise_f = (-1)**np.random.randint(0, 2) 84 noisy_movies[i, t, 85 x_shift - w - 1: x_shift + w + 1, 86 y_shift - w - 1: y_shift + w + 1, 87 0] += noise_f * 0.1 88 89 # Shift the ground truth by 1 90 x_shift = xstart + directionx * (t + 1) 91 y_shift = ystart + directiony * (t + 1) 92 shifted_movies[i, t, x_shift - w: x_shift + w, 93 y_shift - w: y_shift + w, 0] += 1 94 95 # Cut to a 40x40 window 96 noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::] 97 shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::] 98 noisy_movies[noisy_movies >= 1] = 1 99 shifted_movies[shifted_movies >= 1] = 1 100 return noisy_movies, shifted_movies 101 102# Train the network 103noisy_movies, shifted_movies = generate_movies(n_samples=1200) 104seq.fit(noisy_movies[:1000], shifted_movies[:1000], batch_size=10, 105 epochs=300, validation_split=0.05) 106 107# Testing the network on one movie 108# feed it with the first 7 positions and then 109# predict the new positions 110which = 1004 111track = noisy_movies[which][:7, ::, ::, ::] 112 113for j in range(16): 114 new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::]) 115 new = new_pos[::, -1, ::, ::, ::] 116 track = np.concatenate((track, new), axis=0) 117 118 119# And then compare the predictions 120# to the ground truth 121track2 = noisy_movies[which][::, ::, ::, ::] 122for i in range(15): 123 fig = plt.figure(figsize=(10, 5)) 124 125 ax = fig.add_subplot(121) 126 127 if i >= 7: 128 ax.text(1, 3, 'Predictions !', fontsize=20, color='w') 129 else: 130 ax.text(1, 3, 'Initial trajectory', fontsize=20) 131 132 toplot = track[i, ::, ::, 0] 133 134 plt.imshow(toplot) 135 ax = fig.add_subplot(122) 136 plt.text(1, 3, 'Ground truth', fontsize=20) 137 138 toplot = track2[i, ::, ::, 0] 139 if i >= 2: 140 toplot = shifted_movies[which][i - 1, ::, ::, 0] 141 142 plt.imshow(toplot) 143 plt.savefig('%i_animate.png' % (i + 1)) 144