1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""References:
19
20Simonyan, Karen, and Andrew Zisserman. "Very deep convolutional networks for
21large-scale image recognition." arXiv preprint arXiv:1409.1556 (2014).
22"""
23
24import mxnet as mx
25import numpy as np
26
27
28def get_feature(internel_layer, layers, filters, batch_norm=False, **kwargs):
29    for i, num in enumerate(layers):
30        for j in range(num):
31            internel_layer = mx.sym.Convolution(
32                data=internel_layer,
33                kernel=(3, 3),
34                pad=(1, 1),
35                num_filter=filters[i],
36                name="conv%s_%s" % (i + 1, j + 1),
37            )
38            if batch_norm:
39                internel_layer = mx.symbol.BatchNorm(
40                    data=internel_layer, name="bn%s_%s" % (i + 1, j + 1)
41                )
42            internel_layer = mx.sym.Activation(
43                data=internel_layer, act_type="relu", name="relu%s_%s" % (i + 1, j + 1)
44            )
45        internel_layer = mx.sym.Pooling(
46            data=internel_layer,
47            pool_type="max",
48            kernel=(2, 2),
49            stride=(2, 2),
50            name="pool%s" % (i + 1),
51        )
52    return internel_layer
53
54
55def get_classifier(input_data, num_classes, **kwargs):
56    flatten = mx.sym.Flatten(data=input_data, name="flatten")
57    try:
58        fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6", flatten=False)
59        relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
60        drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
61        fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7", flatten=False)
62        relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
63        drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
64        fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8", flatten=False)
65    except:
66        fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
67        relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
68        drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
69        fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
70        relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
71        drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
72        fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
73    return fc8
74
75
76def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype="float32", **kwargs):
77    """
78    Parameters
79    ----------
80    num_classes : int, default 1000
81        Number of classification classes.
82    num_layers : int
83        Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
84    batch_norm : bool, default False
85        Use batch normalization.
86    dtype: str, float32 or float16
87        Data precision.
88    """
89    vgg_spec = {
90        11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
91        13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
92        16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
93        19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512]),
94    }
95    if num_layers not in vgg_spec:
96        raise ValueError(
97            "Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers)
98        )
99    layers, filters = vgg_spec[num_layers]
100    data = mx.sym.Variable(name="data")
101    if dtype == "float16":
102        data = mx.sym.Cast(data=data, dtype=np.float16)
103    feature = get_feature(data, layers, filters, batch_norm)
104    classifier = get_classifier(feature, num_classes)
105    if dtype == "float16":
106        classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
107    symbol = mx.sym.softmax(data=classifier, name="softmax")
108    return symbol
109