1"""
2#Trains a ResNet on the CIFAR10 dataset.
3
4ResNet v1:
5[Deep Residual Learning for Image Recognition
6](https://arxiv.org/pdf/1512.03385.pdf)
7
8ResNet v2:
9[Identity Mappings in Deep Residual Networks
10](https://arxiv.org/pdf/1603.05027.pdf)
11
12
13Model|n|200-epoch accuracy|Original paper accuracy |sec/epoch GTX1080Ti
14:------------|--:|-------:|-----------------------:|---:
15ResNet20   v1|  3| 92.16 %|                 91.25 %|35
16ResNet32   v1|  5| 92.46 %|                 92.49 %|50
17ResNet44   v1|  7| 92.50 %|                 92.83 %|70
18ResNet56   v1|  9| 92.71 %|                 93.03 %|90
19ResNet110  v1| 18| 92.65 %|            93.39+-.16 %|165
20ResNet164  v1| 27|     - %|                 94.07 %|  -
21ResNet1001 v1|N/A|     - %|                 92.39 %|  -
22
23 
24
25Model|n|200-epoch accuracy|Original paper accuracy |sec/epoch GTX1080Ti
26:------------|--:|-------:|-----------------------:|---:
27ResNet20   v2|  2|     - %|                     - %|---
28ResNet32   v2|N/A| NA    %|            NA         %| NA
29ResNet44   v2|N/A| NA    %|            NA         %| NA
30ResNet56   v2|  6| 93.01 %|            NA         %|100
31ResNet110  v2| 12| 93.15 %|            93.63      %|180
32ResNet164  v2| 18|     - %|            94.54      %|  -
33ResNet1001 v2|111|     - %|            95.08+-.14 %|  -
34"""
35
36from __future__ import print_function
37import keras
38from keras.layers import Dense, Conv2D, BatchNormalization, Activation
39from keras.layers import AveragePooling2D, Input, Flatten
40from keras.optimizers import Adam
41from keras.callbacks import ModelCheckpoint, LearningRateScheduler
42from keras.callbacks import ReduceLROnPlateau
43from keras.preprocessing.image import ImageDataGenerator
44from keras.regularizers import l2
45from keras import backend as K
46from keras.models import Model
47from keras.datasets import cifar10
48import numpy as np
49import os
50
51# Training parameters
52batch_size = 32  # orig paper trained all networks with batch_size=128
53epochs = 200
54data_augmentation = True
55num_classes = 10
56
57# Subtracting pixel mean improves accuracy
58subtract_pixel_mean = True
59
60# Model parameter
61# ----------------------------------------------------------------------------
62#           |      | 200-epoch | Orig Paper| 200-epoch | Orig Paper| sec/epoch
63# Model     |  n   | ResNet v1 | ResNet v1 | ResNet v2 | ResNet v2 | GTX1080Ti
64#           |v1(v2)| %Accuracy | %Accuracy | %Accuracy | %Accuracy | v1 (v2)
65# ----------------------------------------------------------------------------
66# ResNet20  | 3 (2)| 92.16     | 91.25     | -----     | -----     | 35 (---)
67# ResNet32  | 5(NA)| 92.46     | 92.49     | NA        | NA        | 50 ( NA)
68# ResNet44  | 7(NA)| 92.50     | 92.83     | NA        | NA        | 70 ( NA)
69# ResNet56  | 9 (6)| 92.71     | 93.03     | 93.01     | NA        | 90 (100)
70# ResNet110 |18(12)| 92.65     | 93.39+-.16| 93.15     | 93.63     | 165(180)
71# ResNet164 |27(18)| -----     | 94.07     | -----     | 94.54     | ---(---)
72# ResNet1001| (111)| -----     | 92.39     | -----     | 95.08+-.14| ---(---)
73# ---------------------------------------------------------------------------
74n = 3
75
76# Model version
77# Orig paper: version = 1 (ResNet v1), Improved ResNet: version = 2 (ResNet v2)
78version = 1
79
80# Computed depth from supplied model parameter n
81if version == 1:
82    depth = n * 6 + 2
83elif version == 2:
84    depth = n * 9 + 2
85
86# Model name, depth and version
87model_type = 'ResNet%dv%d' % (depth, version)
88
89# Load the CIFAR10 data.
90(x_train, y_train), (x_test, y_test) = cifar10.load_data()
91
92# Input image dimensions.
93input_shape = x_train.shape[1:]
94
95# Normalize data.
96x_train = x_train.astype('float32') / 255
97x_test = x_test.astype('float32') / 255
98
99# If subtract pixel mean is enabled
100if subtract_pixel_mean:
101    x_train_mean = np.mean(x_train, axis=0)
102    x_train -= x_train_mean
103    x_test -= x_train_mean
104
105print('x_train shape:', x_train.shape)
106print(x_train.shape[0], 'train samples')
107print(x_test.shape[0], 'test samples')
108print('y_train shape:', y_train.shape)
109
110# Convert class vectors to binary class matrices.
111y_train = keras.utils.to_categorical(y_train, num_classes)
112y_test = keras.utils.to_categorical(y_test, num_classes)
113
114
115def lr_schedule(epoch):
116    """Learning Rate Schedule
117
118    Learning rate is scheduled to be reduced after 80, 120, 160, 180 epochs.
119    Called automatically every epoch as part of callbacks during training.
120
121    # Arguments
122        epoch (int): The number of epochs
123
124    # Returns
125        lr (float32): learning rate
126    """
127    lr = 1e-3
128    if epoch > 180:
129        lr *= 0.5e-3
130    elif epoch > 160:
131        lr *= 1e-3
132    elif epoch > 120:
133        lr *= 1e-2
134    elif epoch > 80:
135        lr *= 1e-1
136    print('Learning rate: ', lr)
137    return lr
138
139
140def resnet_layer(inputs,
141                 num_filters=16,
142                 kernel_size=3,
143                 strides=1,
144                 activation='relu',
145                 batch_normalization=True,
146                 conv_first=True):
147    """2D Convolution-Batch Normalization-Activation stack builder
148
149    # Arguments
150        inputs (tensor): input tensor from input image or previous layer
151        num_filters (int): Conv2D number of filters
152        kernel_size (int): Conv2D square kernel dimensions
153        strides (int): Conv2D square stride dimensions
154        activation (string): activation name
155        batch_normalization (bool): whether to include batch normalization
156        conv_first (bool): conv-bn-activation (True) or
157            bn-activation-conv (False)
158
159    # Returns
160        x (tensor): tensor as input to the next layer
161    """
162    conv = Conv2D(num_filters,
163                  kernel_size=kernel_size,
164                  strides=strides,
165                  padding='same',
166                  kernel_initializer='he_normal',
167                  kernel_regularizer=l2(1e-4))
168
169    x = inputs
170    if conv_first:
171        x = conv(x)
172        if batch_normalization:
173            x = BatchNormalization()(x)
174        if activation is not None:
175            x = Activation(activation)(x)
176    else:
177        if batch_normalization:
178            x = BatchNormalization()(x)
179        if activation is not None:
180            x = Activation(activation)(x)
181        x = conv(x)
182    return x
183
184
185def resnet_v1(input_shape, depth, num_classes=10):
186    """ResNet Version 1 Model builder [a]
187
188    Stacks of 2 x (3 x 3) Conv2D-BN-ReLU
189    Last ReLU is after the shortcut connection.
190    At the beginning of each stage, the feature map size is halved (downsampled)
191    by a convolutional layer with strides=2, while the number of filters is
192    doubled. Within each stage, the layers have the same number filters and the
193    same number of filters.
194    Features maps sizes:
195    stage 0: 32x32, 16
196    stage 1: 16x16, 32
197    stage 2:  8x8,  64
198    The Number of parameters is approx the same as Table 6 of [a]:
199    ResNet20 0.27M
200    ResNet32 0.46M
201    ResNet44 0.66M
202    ResNet56 0.85M
203    ResNet110 1.7M
204
205    # Arguments
206        input_shape (tensor): shape of input image tensor
207        depth (int): number of core convolutional layers
208        num_classes (int): number of classes (CIFAR10 has 10)
209
210    # Returns
211        model (Model): Keras model instance
212    """
213    if (depth - 2) % 6 != 0:
214        raise ValueError('depth should be 6n+2 (eg 20, 32, 44 in [a])')
215    # Start model definition.
216    num_filters = 16
217    num_res_blocks = int((depth - 2) / 6)
218
219    inputs = Input(shape=input_shape)
220    x = resnet_layer(inputs=inputs)
221    # Instantiate the stack of residual units
222    for stack in range(3):
223        for res_block in range(num_res_blocks):
224            strides = 1
225            if stack > 0 and res_block == 0:  # first layer but not first stack
226                strides = 2  # downsample
227            y = resnet_layer(inputs=x,
228                             num_filters=num_filters,
229                             strides=strides)
230            y = resnet_layer(inputs=y,
231                             num_filters=num_filters,
232                             activation=None)
233            if stack > 0 and res_block == 0:  # first layer but not first stack
234                # linear projection residual shortcut connection to match
235                # changed dims
236                x = resnet_layer(inputs=x,
237                                 num_filters=num_filters,
238                                 kernel_size=1,
239                                 strides=strides,
240                                 activation=None,
241                                 batch_normalization=False)
242            x = keras.layers.add([x, y])
243            x = Activation('relu')(x)
244        num_filters *= 2
245
246    # Add classifier on top.
247    # v1 does not use BN after last shortcut connection-ReLU
248    x = AveragePooling2D(pool_size=8)(x)
249    y = Flatten()(x)
250    outputs = Dense(num_classes,
251                    activation='softmax',
252                    kernel_initializer='he_normal')(y)
253
254    # Instantiate model.
255    model = Model(inputs=inputs, outputs=outputs)
256    return model
257
258
259def resnet_v2(input_shape, depth, num_classes=10):
260    """ResNet Version 2 Model builder [b]
261
262    Stacks of (1 x 1)-(3 x 3)-(1 x 1) BN-ReLU-Conv2D or also known as
263    bottleneck layer
264    First shortcut connection per layer is 1 x 1 Conv2D.
265    Second and onwards shortcut connection is identity.
266    At the beginning of each stage, the feature map size is halved (downsampled)
267    by a convolutional layer with strides=2, while the number of filter maps is
268    doubled. Within each stage, the layers have the same number filters and the
269    same filter map sizes.
270    Features maps sizes:
271    conv1  : 32x32,  16
272    stage 0: 32x32,  64
273    stage 1: 16x16, 128
274    stage 2:  8x8,  256
275
276    # Arguments
277        input_shape (tensor): shape of input image tensor
278        depth (int): number of core convolutional layers
279        num_classes (int): number of classes (CIFAR10 has 10)
280
281    # Returns
282        model (Model): Keras model instance
283    """
284    if (depth - 2) % 9 != 0:
285        raise ValueError('depth should be 9n+2 (eg 56 or 110 in [b])')
286    # Start model definition.
287    num_filters_in = 16
288    num_res_blocks = int((depth - 2) / 9)
289
290    inputs = Input(shape=input_shape)
291    # v2 performs Conv2D with BN-ReLU on input before splitting into 2 paths
292    x = resnet_layer(inputs=inputs,
293                     num_filters=num_filters_in,
294                     conv_first=True)
295
296    # Instantiate the stack of residual units
297    for stage in range(3):
298        for res_block in range(num_res_blocks):
299            activation = 'relu'
300            batch_normalization = True
301            strides = 1
302            if stage == 0:
303                num_filters_out = num_filters_in * 4
304                if res_block == 0:  # first layer and first stage
305                    activation = None
306                    batch_normalization = False
307            else:
308                num_filters_out = num_filters_in * 2
309                if res_block == 0:  # first layer but not first stage
310                    strides = 2    # downsample
311
312            # bottleneck residual unit
313            y = resnet_layer(inputs=x,
314                             num_filters=num_filters_in,
315                             kernel_size=1,
316                             strides=strides,
317                             activation=activation,
318                             batch_normalization=batch_normalization,
319                             conv_first=False)
320            y = resnet_layer(inputs=y,
321                             num_filters=num_filters_in,
322                             conv_first=False)
323            y = resnet_layer(inputs=y,
324                             num_filters=num_filters_out,
325                             kernel_size=1,
326                             conv_first=False)
327            if res_block == 0:
328                # linear projection residual shortcut connection to match
329                # changed dims
330                x = resnet_layer(inputs=x,
331                                 num_filters=num_filters_out,
332                                 kernel_size=1,
333                                 strides=strides,
334                                 activation=None,
335                                 batch_normalization=False)
336            x = keras.layers.add([x, y])
337
338        num_filters_in = num_filters_out
339
340    # Add classifier on top.
341    # v2 has BN-ReLU before Pooling
342    x = BatchNormalization()(x)
343    x = Activation('relu')(x)
344    x = AveragePooling2D(pool_size=8)(x)
345    y = Flatten()(x)
346    outputs = Dense(num_classes,
347                    activation='softmax',
348                    kernel_initializer='he_normal')(y)
349
350    # Instantiate model.
351    model = Model(inputs=inputs, outputs=outputs)
352    return model
353
354
355if version == 2:
356    model = resnet_v2(input_shape=input_shape, depth=depth)
357else:
358    model = resnet_v1(input_shape=input_shape, depth=depth)
359
360model.compile(loss='categorical_crossentropy',
361              optimizer=Adam(learning_rate=lr_schedule(0)),
362              metrics=['accuracy'])
363model.summary()
364print(model_type)
365
366# Prepare model model saving directory.
367save_dir = os.path.join(os.getcwd(), 'saved_models')
368model_name = 'cifar10_%s_model.{epoch:03d}.h5' % model_type
369if not os.path.isdir(save_dir):
370    os.makedirs(save_dir)
371filepath = os.path.join(save_dir, model_name)
372
373# Prepare callbacks for model saving and for learning rate adjustment.
374checkpoint = ModelCheckpoint(filepath=filepath,
375                             monitor='val_acc',
376                             verbose=1,
377                             save_best_only=True)
378
379lr_scheduler = LearningRateScheduler(lr_schedule)
380
381lr_reducer = ReduceLROnPlateau(factor=np.sqrt(0.1),
382                               cooldown=0,
383                               patience=5,
384                               min_lr=0.5e-6)
385
386callbacks = [checkpoint, lr_reducer, lr_scheduler]
387
388# Run training, with or without data augmentation.
389if not data_augmentation:
390    print('Not using data augmentation.')
391    model.fit(x_train, y_train,
392              batch_size=batch_size,
393              epochs=epochs,
394              validation_data=(x_test, y_test),
395              shuffle=True,
396              callbacks=callbacks)
397else:
398    print('Using real-time data augmentation.')
399    # This will do preprocessing and realtime data augmentation:
400    datagen = ImageDataGenerator(
401        # set input mean to 0 over the dataset
402        featurewise_center=False,
403        # set each sample mean to 0
404        samplewise_center=False,
405        # divide inputs by std of dataset
406        featurewise_std_normalization=False,
407        # divide each input by its std
408        samplewise_std_normalization=False,
409        # apply ZCA whitening
410        zca_whitening=False,
411        # epsilon for ZCA whitening
412        zca_epsilon=1e-06,
413        # randomly rotate images in the range (deg 0 to 180)
414        rotation_range=0,
415        # randomly shift images horizontally
416        width_shift_range=0.1,
417        # randomly shift images vertically
418        height_shift_range=0.1,
419        # set range for random shear
420        shear_range=0.,
421        # set range for random zoom
422        zoom_range=0.,
423        # set range for random channel shifts
424        channel_shift_range=0.,
425        # set mode for filling points outside the input boundaries
426        fill_mode='nearest',
427        # value used for fill_mode = "constant"
428        cval=0.,
429        # randomly flip images
430        horizontal_flip=True,
431        # randomly flip images
432        vertical_flip=False,
433        # set rescaling factor (applied before any other transformation)
434        rescale=None,
435        # set function that will be applied on each input
436        preprocessing_function=None,
437        # image data format, either "channels_first" or "channels_last"
438        data_format=None,
439        # fraction of images reserved for validation (strictly between 0 and 1)
440        validation_split=0.0)
441
442    # Compute quantities required for featurewise normalization
443    # (std, mean, and principal components if ZCA whitening is applied).
444    datagen.fit(x_train)
445
446    # Fit the model on the batches generated by datagen.flow().
447    model.fit_generator(datagen.flow(x_train, y_train, batch_size=batch_size),
448                        validation_data=(x_test, y_test),
449                        epochs=epochs, verbose=1, workers=4,
450                        callbacks=callbacks)
451
452# Score trained model.
453scores = model.evaluate(x_test, y_test, verbose=1)
454print('Test loss:', scores[0])
455print('Test accuracy:', scores[1])
456