1"""C3D, implemented in Gluon. https://arxiv.org/abs/1412.0767""" 2# pylint: disable=arguments-differ,unused-argument 3 4__all__ = ['C3D', 'c3d_kinetics400'] 5 6from mxnet import init 7from mxnet.context import cpu 8from mxnet.gluon.block import HybridBlock 9from mxnet.gluon import nn 10 11class C3D(HybridBlock): 12 r""" 13 The Convolutional 3D network (C3D). 14 Learning Spatiotemporal Features with 3D Convolutional Networks. 15 ICCV, 2015. https://arxiv.org/abs/1412.0767 16 17 Parameters 18 ---------- 19 nclass : int 20 Number of classes in the training dataset. 21 num_segments : int, default is 1. 22 Number of segments used to evenly divide a video. 23 num_crop : int, default is 1. 24 Number of crops used during evaluation, choices are 1, 3 or 10. 25 feat_ext : bool. 26 Whether to extract features before dense classification layer or 27 do a complete forward pass. 28 dropout_ratio : float 29 Dropout value used in the dropout layers after dense layers to avoid overfitting. 30 init_std : float 31 Default standard deviation value for initializing dense layers. 32 ctx : str 33 Context, default CPU. The context in which to load the pretrained weights. 34 """ 35 36 def __init__(self, nclass, dropout_ratio=0.5, 37 num_segments=1, num_crop=1, feat_ext=False, 38 init_std=0.001, ctx=None, **kwargs): 39 super(C3D, self).__init__() 40 self.num_segments = num_segments 41 self.num_crop = num_crop 42 self.feat_ext = feat_ext 43 self.feat_dim = 8192 44 45 with self.name_scope(): 46 self.conv1 = nn.Conv3D(in_channels=3, channels=64, 47 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 48 self.pool1 = nn.MaxPool3D(pool_size=(1, 2, 2), strides=(1, 2, 2)) 49 50 self.conv2 = nn.Conv3D(in_channels=64, channels=128, 51 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 52 self.pool2 = nn.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2)) 53 54 self.conv3a = nn.Conv3D(in_channels=128, channels=256, 55 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 56 self.conv3b = nn.Conv3D(in_channels=256, channels=256, 57 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 58 self.pool3 = nn.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2)) 59 60 self.conv4a = nn.Conv3D(in_channels=256, channels=512, 61 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 62 self.conv4b = nn.Conv3D(in_channels=512, channels=512, 63 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 64 self.pool4 = nn.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2)) 65 66 self.conv5a = nn.Conv3D(in_channels=512, channels=512, 67 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 68 self.conv5b = nn.Conv3D(in_channels=512, channels=512, 69 kernel_size=(3, 3, 3), padding=(1, 1, 1)) 70 self.pool5 = nn.MaxPool3D(pool_size=(2, 2, 2), strides=(2, 2, 2), padding=(0, 1, 1)) 71 72 self.fc6 = nn.Dense(in_units=8192, units=4096, 73 weight_initializer=init.Normal(sigma=init_std)) 74 self.fc7 = nn.Dense(in_units=4096, units=4096, 75 weight_initializer=init.Normal(sigma=init_std)) 76 self.fc8 = nn.Dense(in_units=4096, units=nclass, 77 weight_initializer=init.Normal(sigma=init_std)) 78 self.dropout = nn.Dropout(rate=dropout_ratio) 79 self.relu = nn.Activation('relu') 80 81 def hybrid_forward(self, F, x): 82 """Hybrid forward of C3D net""" 83 x = self.relu(self.conv1(x)) 84 x = self.pool1(x) 85 86 x = self.relu(self.conv2(x)) 87 x = self.pool2(x) 88 89 x = self.relu(self.conv3a(x)) 90 x = self.relu(self.conv3b(x)) 91 x = self.pool3(x) 92 93 x = self.relu(self.conv4a(x)) 94 x = self.relu(self.conv4b(x)) 95 x = self.pool4(x) 96 97 x = self.relu(self.conv5a(x)) 98 x = self.relu(self.conv5b(x)) 99 x = self.pool5(x) 100 101 # segmental consensus 102 x = F.reshape(x, shape=(-1, self.num_segments * self.num_crop, self.feat_dim)) 103 x = F.mean(x, axis=1) 104 105 x = self.relu(self.fc6(x)) 106 x = self.dropout(x) 107 108 if self.feat_ext: 109 return x 110 111 x = self.relu(self.fc7(x)) 112 x = self.dropout(x) 113 x = self.fc8(x) 114 return x 115 116def c3d_kinetics400(nclass=400, pretrained=False, ctx=cpu(), 117 root='~/.mxnet/models', num_segments=1, num_crop=1, 118 feat_ext=False, **kwargs): 119 r"""The Convolutional 3D network (C3D) trained on Kinetics400 dataset. 120 Learning Spatiotemporal Features with 3D Convolutional Networks. 121 ICCV, 2015. https://arxiv.org/abs/1412.0767 122 123 Parameters 124 ---------- 125 nclass : int. 126 Number of categories in the dataset. 127 pretrained : bool or str. 128 Boolean value controls whether to load the default pretrained weights for model. 129 String value represents the hashtag for a certain version of pretrained weights. 130 ctx : Context, default CPU. 131 The context in which to load the pretrained weights. 132 root : str, default $MXNET_HOME/models 133 Location for keeping the model parameters. 134 num_segments : int, default is 1. 135 Number of segments used to evenly divide a video. 136 num_crop : int, default is 1. 137 Number of crops used during evaluation, choices are 1, 3 or 10. 138 feat_ext : bool. 139 Whether to extract features before dense classification layer or 140 do a complete forward pass. 141 """ 142 143 model = C3D(nclass=nclass, ctx=ctx, num_segments=num_segments, 144 num_crop=num_crop, feat_ext=feat_ext, **kwargs) 145 model.initialize(init.MSRAPrelu(), ctx=ctx) 146 147 if pretrained: 148 from ..model_store import get_model_file 149 model.load_parameters(get_model_file('c3d_kinetics400', 150 tag=pretrained, root=root), ctx=ctx) 151 from ...data import Kinetics400Attr 152 attrib = Kinetics400Attr() 153 model.classes = attrib.classes 154 model.collect_params().reset_ctx(ctx) 155 156 return model 157