1'''
2#Train a simple deep CNN on the CIFAR10 small images dataset.
3
4It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs.
5(it's still underfitting at that point, though).
6'''
7
8from __future__ import print_function
9import keras
10from keras.datasets import cifar10
11from keras.preprocessing.image import ImageDataGenerator
12from keras.models import Sequential
13from keras.layers import Dense, Dropout, Activation, Flatten
14from keras.layers import Conv2D, MaxPooling2D
15import os
16
17batch_size = 32
18num_classes = 10
19epochs = 100
20data_augmentation = True
21num_predictions = 20
22save_dir = os.path.join(os.getcwd(), 'saved_models')
23model_name = 'keras_cifar10_trained_model.h5'
24
25# The data, split between train and test sets:
26(x_train, y_train), (x_test, y_test) = cifar10.load_data()
27print('x_train shape:', x_train.shape)
28print(x_train.shape[0], 'train samples')
29print(x_test.shape[0], 'test samples')
30
31# Convert class vectors to binary class matrices.
32y_train = keras.utils.to_categorical(y_train, num_classes)
33y_test = keras.utils.to_categorical(y_test, num_classes)
34
35model = Sequential()
36model.add(Conv2D(32, (3, 3), padding='same',
37                 input_shape=x_train.shape[1:]))
38model.add(Activation('relu'))
39model.add(Conv2D(32, (3, 3)))
40model.add(Activation('relu'))
41model.add(MaxPooling2D(pool_size=(2, 2)))
42model.add(Dropout(0.25))
43
44model.add(Conv2D(64, (3, 3), padding='same'))
45model.add(Activation('relu'))
46model.add(Conv2D(64, (3, 3)))
47model.add(Activation('relu'))
48model.add(MaxPooling2D(pool_size=(2, 2)))
49model.add(Dropout(0.25))
50
51model.add(Flatten())
52model.add(Dense(512))
53model.add(Activation('relu'))
54model.add(Dropout(0.5))
55model.add(Dense(num_classes))
56model.add(Activation('softmax'))
57
58# initiate RMSprop optimizer
59opt = keras.optimizers.RMSprop(learning_rate=0.0001, decay=1e-6)
60
61# Let's train the model using RMSprop
62model.compile(loss='categorical_crossentropy',
63              optimizer=opt,
64              metrics=['accuracy'])
65
66x_train = x_train.astype('float32')
67x_test = x_test.astype('float32')
68x_train /= 255
69x_test /= 255
70
71if not data_augmentation:
72    print('Not using data augmentation.')
73    model.fit(x_train, y_train,
74              batch_size=batch_size,
75              epochs=epochs,
76              validation_data=(x_test, y_test),
77              shuffle=True)
78else:
79    print('Using real-time data augmentation.')
80    # This will do preprocessing and realtime data augmentation:
81    datagen = ImageDataGenerator(
82        featurewise_center=False,  # set input mean to 0 over the dataset
83        samplewise_center=False,  # set each sample mean to 0
84        featurewise_std_normalization=False,  # divide inputs by std of the dataset
85        samplewise_std_normalization=False,  # divide each input by its std
86        zca_whitening=False,  # apply ZCA whitening
87        zca_epsilon=1e-06,  # epsilon for ZCA whitening
88        rotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)
89        # randomly shift images horizontally (fraction of total width)
90        width_shift_range=0.1,
91        # randomly shift images vertically (fraction of total height)
92        height_shift_range=0.1,
93        shear_range=0.,  # set range for random shear
94        zoom_range=0.,  # set range for random zoom
95        channel_shift_range=0.,  # set range for random channel shifts
96        # set mode for filling points outside the input boundaries
97        fill_mode='nearest',
98        cval=0.,  # value used for fill_mode = "constant"
99        horizontal_flip=True,  # randomly flip images
100        vertical_flip=False,  # randomly flip images
101        # set rescaling factor (applied before any other transformation)
102        rescale=None,
103        # set function that will be applied on each input
104        preprocessing_function=None,
105        # image data format, either "channels_first" or "channels_last"
106        data_format=None,
107        # fraction of images reserved for validation (strictly between 0 and 1)
108        validation_split=0.0)
109
110    # Compute quantities required for feature-wise normalization
111    # (std, mean, and principal components if ZCA whitening is applied).
112    datagen.fit(x_train)
113
114    # Fit the model on the batches generated by datagen.flow().
115    model.fit_generator(datagen.flow(x_train, y_train,
116                                     batch_size=batch_size),
117                        epochs=epochs,
118                        validation_data=(x_test, y_test),
119                        workers=4)
120
121# Save model and weights
122if not os.path.isdir(save_dir):
123    os.makedirs(save_dir)
124model_path = os.path.join(save_dir, model_name)
125model.save(model_path)
126print('Saved trained model at %s ' % model_path)
127
128# Score trained model.
129scores = model.evaluate(x_test, y_test, verbose=1)
130print('Test loss:', scores[0])
131print('Test accuracy:', scores[1])
132