1#!/usr/bin/env python3
2"""Save application models mentioned in Keras documentation
3"""
4
5import convert_model
6import tensorflow as tf
7
8__author__ = "Tobias Hermann"
9__copyright__ = "Copyright 2017, Tobias Hermann"
10__license__ = "MIT"
11__maintainer__ = "Tobias Hermann, https://github.com/Dobiasd/frugally-deep"
12__email__ = "editgym@gmail.com"
13
14
15def save_model(file_name_base, model):
16    """Save and convert Keras model"""
17    keras_file = f'{file_name_base}.h5'
18    fdeep_file = f'{file_name_base}.json'
19    print(f'Saving {keras_file}')
20    model.save(keras_file, include_optimizer=False)
21    print(f'Converting {keras_file} to {fdeep_file}.')
22    convert_model.convert(keras_file, fdeep_file)
23    print(f'Conversion of model {keras_file} to {fdeep_file} done.')
24
25
26def main():
27    """Save famous example models in Keras-h5 and fdeep-json format."""
28    print('Saving application examples')
29    save_model('densenet121', tf.keras.applications.densenet.DenseNet121())
30    save_model('densenet169', tf.keras.applications.densenet.DenseNet169())
31    save_model('densenet201', tf.keras.applications.densenet.DenseNet201())
32    # save_model('inceptionresnetv2', tf.keras.applications.inception_resnet_v2.InceptionResNetV2(input_shape=(299, 299, 3)))  # lambda
33    save_model('inceptionv3', tf.keras.applications.inception_v3.InceptionV3(input_shape=(299, 299, 3)))
34    save_model('mobilenet', tf.keras.applications.mobilenet.MobileNet())
35    save_model('mobilenetv2', tf.keras.applications.mobilenet_v2.MobileNetV2())
36    save_model('nasnetlarge', tf.keras.applications.nasnet.NASNetLarge(input_shape=(331, 331, 3)))
37    save_model('nasnetmobile', tf.keras.applications.nasnet.NASNetMobile(input_shape=(224, 224, 3)))
38    save_model('resnet101', tf.keras.applications.ResNet101())
39    save_model('resnet101v2', tf.keras.applications.ResNet101V2())
40    save_model('resnet152', tf.keras.applications.ResNet152())
41    save_model('resnet152v2', tf.keras.applications.ResNet152V2())
42    save_model('resnet50', tf.keras.applications.ResNet50())
43    save_model('resnet50v2', tf.keras.applications.ResNet50V2())
44    save_model('vgg16', tf.keras.applications.vgg16.VGG16())
45    save_model('vgg19', tf.keras.applications.vgg19.VGG19())
46    save_model('xception', tf.keras.applications.xception.Xception(input_shape=(299, 299, 3)))
47
48
49if __name__ == "__main__":
50    main()
51