1import json
2import os
3import argparse
4
5markdown_code = str()
6
7framework_list = ['caffe', 'cntk', 'coreml', 'darknet', 'mxnet', 'pytorch', 'tensorflow']  # Haven't added 'keras' yet
8frame_model_map = {
9     'caffe': {'architecture':'prototxt', 'weights':'caffemodel'},
10     'cntk': {'architecture':'model'},
11     'coreml': {'architecture':'mlmodel'},
12     'darknet': {'architecture':'cfg', 'weights':'weights'},
13     'mxnet': {'architecture':'json', 'weights':'params'},
14     'pytorch': {'architecture':'pth'},
15     'tensorflow': {'architecture':'tgz'}
16}  # Haven't add 'keras' yet
17dataset_list = ['imagenet', 'imagenet11k', 'Pascal VOC', 'grocery100']
18
19def add_code(code):
20    global markdown_code
21    markdown_code += code
22
23def add_header(level, code):
24    add_code("#" * level + " " + code + '\n\n')
25
26def draw_line(num):
27    add_code("| " * num + "|\n")
28    add_code(("|-" * num + "|\n"))
29
30def save_code(filepath):
31    with open(filepath, 'w') as f:
32        f.write(markdown_code)
33    print("Markdown generate succeeded!")
34
35def LoadJson(json_path):
36    with open(json_path, encoding='utf-8') as f:
37        data = json.load(f)
38    return data
39
40def RegenerateJsonByDataset(data):
41    new_data = {}
42    new_data['dataset'] = {}
43    for i in range(len(dataset_list)):
44        new_data['dataset'][dataset_list[i]] = []
45    for mo in data['models']:
46        ds = mo['dataset']
47        item = {}
48        item['name'] = mo['name']
49        item['framework'] = mo['framework']
50        item['source'] = mo['source']
51        item['link'] = mo['link']
52        item['version'] = ""
53        new_data['dataset'][ds].append(item)
54
55    # with open('modelmapbydataset.json', 'w') as outfile:
56    #     json.dump(new_data, outfile)
57    return new_data
58
59def GenerateModelBlock_v2(model):
60    link = model['link']
61    framework = model['framework']
62
63    # generate makedown script
64    add_code('''|<b>{}</b><br />Framework: {}<br />Download: '''.format(
65        model['name'],
66        model['framework']
67    ))
68    for k in link.keys():
69        if link[k]:
70            add_code("[{}]({}) ".format(
71                frame_model_map[framework][k],
72                link[k]
73            ))
74    add_code("<br />Source: ")
75    if (model['source']!=""):
76        add_code("[Link]({})".format(model['source']))
77    add_code("<br />")
78
79def DrawTableBlock(data, dataset_name):
80    colnum = 3
81    add_header(3, dataset_name)
82    draw_line(colnum)
83    models = data['dataset'][dataset_name]
84    num = 0
85    for i in range(len(models)):
86        if ((models[i]['framework']!='keras') and (models[i]['link']['architecture']!="")):
87            GenerateModelBlock_v2(models[i])
88            num += 1
89            if num % colnum == 0:
90                add_code("\n")
91    add_code("\n")
92
93def GenerateModelsList_v2(data):
94
95    add_header(1, "Model Collection")
96
97    # add Image Classification
98    add_header(2, "Image Classification")
99    for ds_name in ['imagenet', 'imagenet11k']:
100        DrawTableBlock(data, ds_name)
101
102    # add Object Detection
103    add_header(2, "Object Detection")
104    for ds_name in ['Pascal VOC', 'grocery100']:
105        DrawTableBlock(data, ds_name)
106
107    add_code("\n")
108
109def GenerateIntroductionAndTutorial():
110    # MMdnn introduction
111    add_header(1, "Introduction")
112    text_intro='''This is a collection of pre-trained models in different deep learning frameworks.\n
113You can download the model you want by simply click the download link.\n
114With the download model, you can convert them to different frameworks.\n
115Next session show an example to show you how to convert pre-trained model between frameworks.\n\n'''
116    add_code(text_intro)
117
118    # steps for model conversion
119    add_header(2, "Steps to Convert Model")
120    text_example='''**Example: Convert vgg19 model from Tensorflow to CNTK**\n
1211. Install the stable version of MMdnn
122    ```bash
123    pip install mmdnn
124    ```
1252. Download Tensorflow pre-trained model
126    - [x] **Method 1:** Directly download from below model collection
127    - [x] **Method 2:** Use command line
128    ```bash
129        $ mmdownload -f tensorflow -n vgg19
130
131        Downloading file [./vgg_19_2016_08_28.tar.gz] from [http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz]
132        progress: 520592.0 KB downloaded, 100%
133        Model saved in file: ./imagenet_vgg19.ckpt
134    ```
135    **NOTICE:** _the model name after the **'-n'** argument must be the models appearence in the below model collection._
136
1373. Convert model architecture(*.ckpt.meta) and weights(.ckpt) from Tensorflow to IR
138    ```bash
139    $ mmtoir -f tensorflow -d vgg19 -n imagenet_vgg19.ckpt.meta -w imagenet_vgg19.ckpt  --dstNodeName MMdnn_Output
140
141    Parse file [imagenet_vgg19.ckpt.meta] with binary format successfully.
142    Tensorflow model file [imagenet_vgg19.ckpt.meta] loaded successfully.
143    Tensorflow checkpoint file [imagenet_vgg19.ckpt] loaded successfully. [38] variables loaded.
144    IR network structure is saved as [vgg19.json].
145    IR network structure is saved as [vgg19.pb].
146    IR weights are saved as [vgg19.npy].
147    ```
1484. Convert models from IR to PyTorch code snippet and weights
149    ```bash
150    $ mmtocode -f pytorch -n vgg19.pb --IRWeightPath vgg19.npy --dstModelPath pytorch_vgg19.py -dw pytorch_vgg19.npy
151
152    Parse file [vgg19.pb] with binary format successfully.
153    Target network code snippet is saved as [pytorch_vgg19.py].
154    Target weights are saved as [pytorch_vgg19.npy].
155    ```
1565. Generate PyTorch model from code snippet file and weight file
157    ```bash
158    $ mmtomodel -f pytorch -in pytorch_vgg19.py -iw pytorch_vgg19.npy --o pytorch_vgg19.pth
159
160    PyTorch model file is saved as [pytorch_vgg19.pth], generated by [pytorch_vgg19.py] and [pytorch_vgg19.npy].
161    Notice that you may need [pytorch_vgg19.py] to load the model back.
162    ```
163'''
164    add_code(text_example)
165    add_code("\n\n")
166
167def main():
168    parser = argparse.ArgumentParser()
169    parser.add_argument('-f', '--file', type=str, default="modelmap2.json", help="the path of json file")
170    parser.add_argument('-d', '--distFile', type=str, default="Collection_v2.md", help="the path of the readme file")
171    args = parser.parse_args()
172
173    # Generate model converter description
174    GenerateIntroductionAndTutorial()
175
176    # Generate models list
177    data = LoadJson(args.file)
178    new_data = RegenerateJsonByDataset(data)
179    GenerateModelsList_v2(new_data)
180    save_code(args.distFile)
181
182if __name__ == "__main__":
183    main()
184