1"""
2#This script demonstrates the use of a convolutional LSTM network.
3
4This network is used to predict the next frame of an artificially
5generated movie which contains moving squares.
6"""
7from keras.models import Sequential
8from keras.layers.convolutional import Conv3D
9from keras.layers.convolutional_recurrent import ConvLSTM2D
10from keras.layers.normalization import BatchNormalization
11import numpy as np
12import pylab as plt
13
14# We create a layer which take as input movies of shape
15# (n_frames, width, height, channels) and returns a movie
16# of identical shape.
17
18seq = Sequential()
19seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
20                   input_shape=(None, 40, 40, 1),
21                   padding='same', return_sequences=True))
22seq.add(BatchNormalization())
23
24seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
25                   padding='same', return_sequences=True))
26seq.add(BatchNormalization())
27
28seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
29                   padding='same', return_sequences=True))
30seq.add(BatchNormalization())
31
32seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
33                   padding='same', return_sequences=True))
34seq.add(BatchNormalization())
35
36seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
37               activation='sigmoid',
38               padding='same', data_format='channels_last'))
39seq.compile(loss='binary_crossentropy', optimizer='adadelta')
40
41
42# Artificial data generation:
43# Generate movies with 3 to 7 moving squares inside.
44# The squares are of shape 1x1 or 2x2 pixels,
45# which move linearly over time.
46# For convenience we first create movies with bigger width and height (80x80)
47# and at the end we select a 40x40 window.
48
49def generate_movies(n_samples=1200, n_frames=15):
50    row = 80
51    col = 80
52    noisy_movies = np.zeros((n_samples, n_frames, row, col, 1), dtype=np.float)
53    shifted_movies = np.zeros((n_samples, n_frames, row, col, 1),
54                              dtype=np.float)
55
56    for i in range(n_samples):
57        # Add 3 to 7 moving squares
58        n = np.random.randint(3, 8)
59
60        for j in range(n):
61            # Initial position
62            xstart = np.random.randint(20, 60)
63            ystart = np.random.randint(20, 60)
64            # Direction of motion
65            directionx = np.random.randint(0, 3) - 1
66            directiony = np.random.randint(0, 3) - 1
67
68            # Size of the square
69            w = np.random.randint(2, 4)
70
71            for t in range(n_frames):
72                x_shift = xstart + directionx * t
73                y_shift = ystart + directiony * t
74                noisy_movies[i, t, x_shift - w: x_shift + w,
75                             y_shift - w: y_shift + w, 0] += 1
76
77                # Make it more robust by adding noise.
78                # The idea is that if during inference,
79                # the value of the pixel is not exactly one,
80                # we need to train the network to be robust and still
81                # consider it as a pixel belonging to a square.
82                if np.random.randint(0, 2):
83                    noise_f = (-1)**np.random.randint(0, 2)
84                    noisy_movies[i, t,
85                                 x_shift - w - 1: x_shift + w + 1,
86                                 y_shift - w - 1: y_shift + w + 1,
87                                 0] += noise_f * 0.1
88
89                # Shift the ground truth by 1
90                x_shift = xstart + directionx * (t + 1)
91                y_shift = ystart + directiony * (t + 1)
92                shifted_movies[i, t, x_shift - w: x_shift + w,
93                               y_shift - w: y_shift + w, 0] += 1
94
95    # Cut to a 40x40 window
96    noisy_movies = noisy_movies[::, ::, 20:60, 20:60, ::]
97    shifted_movies = shifted_movies[::, ::, 20:60, 20:60, ::]
98    noisy_movies[noisy_movies >= 1] = 1
99    shifted_movies[shifted_movies >= 1] = 1
100    return noisy_movies, shifted_movies
101
102# Train the network
103noisy_movies, shifted_movies = generate_movies(n_samples=1200)
104seq.fit(noisy_movies[:1000], shifted_movies[:1000], batch_size=10,
105        epochs=300, validation_split=0.05)
106
107# Testing the network on one movie
108# feed it with the first 7 positions and then
109# predict the new positions
110which = 1004
111track = noisy_movies[which][:7, ::, ::, ::]
112
113for j in range(16):
114    new_pos = seq.predict(track[np.newaxis, ::, ::, ::, ::])
115    new = new_pos[::, -1, ::, ::, ::]
116    track = np.concatenate((track, new), axis=0)
117
118
119# And then compare the predictions
120# to the ground truth
121track2 = noisy_movies[which][::, ::, ::, ::]
122for i in range(15):
123    fig = plt.figure(figsize=(10, 5))
124
125    ax = fig.add_subplot(121)
126
127    if i >= 7:
128        ax.text(1, 3, 'Predictions !', fontsize=20, color='w')
129    else:
130        ax.text(1, 3, 'Initial trajectory', fontsize=20)
131
132    toplot = track[i, ::, ::, 0]
133
134    plt.imshow(toplot)
135    ax = fig.add_subplot(122)
136    plt.text(1, 3, 'Ground truth', fontsize=20)
137
138    toplot = track2[i, ::, ::, 0]
139    if i >= 2:
140        toplot = shifted_movies[which][i - 1, ::, ::, 0]
141
142    plt.imshow(toplot)
143    plt.savefig('%i_animate.png' % (i + 1))
144