1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3import math
4from typing import Callable, Tuple
5
6import numpy as np
7import torch
8import torch.nn as nn
9from fvcore.nn.squeeze_excitation import SqueezeExcitation
10from pytorchvideo.layers.convolutions import Conv2plus1d
11from pytorchvideo.layers.swish import Swish
12from pytorchvideo.layers.utils import round_repeats, round_width, set_attributes
13from pytorchvideo.models.head import ResNetBasicHead
14from pytorchvideo.models.net import Net
15from pytorchvideo.models.resnet import BottleneckBlock, ResBlock, ResStage
16from pytorchvideo.models.stem import ResNetBasicStem
17
18
19def create_x3d_stem(
20    *,
21    # Conv configs.
22    in_channels: int,
23    out_channels: int,
24    conv_kernel_size: Tuple[int] = (5, 3, 3),
25    conv_stride: Tuple[int] = (1, 2, 2),
26    conv_padding: Tuple[int] = (2, 1, 1),
27    # BN configs.
28    norm: Callable = nn.BatchNorm3d,
29    norm_eps: float = 1e-5,
30    norm_momentum: float = 0.1,
31    # Activation configs.
32    activation: Callable = nn.ReLU,
33) -> nn.Module:
34    """
35    Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu.
36
37    ::
38
39                                        Conv_xy
4041                                        Conv_t
4243                                     Normalization
4445                                       Activation
46
47    Args:
48        in_channels (int): input channel size of the convolution.
49        out_channels (int): output channel size of the convolution.
50        conv_kernel_size (tuple): convolutional kernel size(s).
51        conv_stride (tuple): convolutional stride size(s).
52        conv_padding (tuple): convolutional padding size(s).
53
54        norm (callable): a callable that constructs normalization layer, options
55            include nn.BatchNorm3d, None (not performing normalization).
56        norm_eps (float): normalization epsilon.
57        norm_momentum (float): normalization momentum.
58
59        activation (callable): a callable that constructs activation layer, options
60            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
61            activation).
62
63    Returns:
64        (nn.Module): X3D stem layer.
65    """
66    conv_xy_module = nn.Conv3d(
67        in_channels=in_channels,
68        out_channels=out_channels,
69        kernel_size=(1, conv_kernel_size[1], conv_kernel_size[2]),
70        stride=(1, conv_stride[1], conv_stride[2]),
71        padding=(0, conv_padding[1], conv_padding[2]),
72        bias=False,
73    )
74    conv_t_module = nn.Conv3d(
75        in_channels=out_channels,
76        out_channels=out_channels,
77        kernel_size=(conv_kernel_size[0], 1, 1),
78        stride=(conv_stride[0], 1, 1),
79        padding=(conv_padding[0], 0, 0),
80        bias=False,
81        groups=out_channels,
82    )
83    stacked_conv_module = Conv2plus1d(
84        conv_t=conv_xy_module,
85        norm=None,
86        activation=None,
87        conv_xy=conv_t_module,
88    )
89
90    norm_module = (
91        None
92        if norm is None
93        else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum)
94    )
95    activation_module = None if activation is None else activation()
96
97    return ResNetBasicStem(
98        conv=stacked_conv_module,
99        norm=norm_module,
100        activation=activation_module,
101        pool=None,
102    )
103
104
105def create_x3d_bottleneck_block(
106    *,
107    # Convolution configs.
108    dim_in: int,
109    dim_inner: int,
110    dim_out: int,
111    conv_kernel_size: Tuple[int] = (3, 3, 3),
112    conv_stride: Tuple[int] = (1, 2, 2),
113    # Norm configs.
114    norm: Callable = nn.BatchNorm3d,
115    norm_eps: float = 1e-5,
116    norm_momentum: float = 0.1,
117    se_ratio: float = 0.0625,
118    # Activation configs.
119    activation: Callable = nn.ReLU,
120    inner_act: Callable = Swish,
121) -> nn.Module:
122    """
123    Bottleneck block for X3D: a sequence of Conv, Normalization with optional SE block,
124    and Activations repeated in the following order:
125
126    ::
127
128                                    Conv3d (conv_a)
129130                                 Normalization (norm_a)
131132                                   Activation (act_a)
133134                                    Conv3d (conv_b)
135136                                 Normalization (norm_b)
137138                                 Squeeze-and-Excitation
139140                                   Activation (act_b)
141142                                    Conv3d (conv_c)
143144                                 Normalization (norm_c)
145
146    Args:
147        dim_in (int): input channel size to the bottleneck block.
148        dim_inner (int): intermediate channel size of the bottleneck.
149        dim_out (int): output channel size of the bottleneck.
150        conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
151        conv_stride (tuple): convolutional stride size(s) for conv_b.
152
153        norm (callable): a callable that constructs normalization layer, examples
154            include nn.BatchNorm3d, None (not performing normalization).
155        norm_eps (float): normalization epsilon.
156        norm_momentum (float): normalization momentum.
157        se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
158            channel dimensionality being se_ratio times the 3x3x3 conv dim.
159
160        activation (callable): a callable that constructs activation layer, examples
161            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
162            activation).
163        inner_act (callable): whether use Swish activation for act_b or not.
164
165    Returns:
166        (nn.Module): X3D bottleneck block.
167    """
168    # 1x1x1 Conv
169    conv_a = nn.Conv3d(
170        in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False
171    )
172    norm_a = (
173        None
174        if norm is None
175        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
176    )
177    act_a = None if activation is None else activation()
178
179    # 3x3x3 Conv
180    conv_b = nn.Conv3d(
181        in_channels=dim_inner,
182        out_channels=dim_inner,
183        kernel_size=conv_kernel_size,
184        stride=conv_stride,
185        padding=[size // 2 for size in conv_kernel_size],
186        bias=False,
187        groups=dim_inner,
188        dilation=(1, 1, 1),
189    )
190    se = (
191        SqueezeExcitation(
192            num_channels=dim_inner,
193            num_channels_reduced=round_width(dim_inner, se_ratio),
194            is_3d=True,
195        )
196        if se_ratio > 0.0
197        else nn.Identity()
198    )
199    norm_b = nn.Sequential(
200        (
201            nn.Identity()
202            if norm is None
203            else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
204        ),
205        se,
206    )
207    act_b = None if inner_act is None else inner_act()
208
209    # 1x1x1 Conv
210    conv_c = nn.Conv3d(
211        in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
212    )
213    norm_c = (
214        None
215        if norm is None
216        else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
217    )
218
219    return BottleneckBlock(
220        conv_a=conv_a,
221        norm_a=norm_a,
222        act_a=act_a,
223        conv_b=conv_b,
224        norm_b=norm_b,
225        act_b=act_b,
226        conv_c=conv_c,
227        norm_c=norm_c,
228    )
229
230
231def create_x3d_res_block(
232    *,
233    # Bottleneck Block configs.
234    dim_in: int,
235    dim_inner: int,
236    dim_out: int,
237    bottleneck: Callable = create_x3d_bottleneck_block,
238    use_shortcut: bool = True,
239    # Conv configs.
240    conv_kernel_size: Tuple[int] = (3, 3, 3),
241    conv_stride: Tuple[int] = (1, 2, 2),
242    # Norm configs.
243    norm: Callable = nn.BatchNorm3d,
244    norm_eps: float = 1e-5,
245    norm_momentum: float = 0.1,
246    se_ratio: float = 0.0625,
247    # Activation configs.
248    activation: Callable = nn.ReLU,
249    inner_act: Callable = Swish,
250) -> nn.Module:
251    """
252    Residual block for X3D. Performs a summation between an identity shortcut in branch1 and a
253    main block in branch2. When the input and output dimensions are different, a
254    convolution followed by a normalization will be performed.
255
256    ::
257
258                                         Input
259                                           |-------+
260                                           ↓       |
261                                         Block     |
262                                           ↓       |
263                                       Summation ←-+
264265                                       Activation
266
267    Args:
268        dim_in (int): input channel size to the bottleneck block.
269        dim_inner (int): intermediate channel size of the bottleneck.
270        dim_out (int): output channel size of the bottleneck.
271        bottleneck (callable): a callable for create_x3d_bottleneck_block.
272
273        conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
274        conv_stride (tuple): convolutional stride size(s) for conv_b.
275
276        norm (callable): a callable that constructs normalization layer, examples
277            include nn.BatchNorm3d, None (not performing normalization).
278        norm_eps (float): normalization epsilon.
279        norm_momentum (float): normalization momentum.
280        se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
281            channel dimensionality being se_ratio times the 3x3x3 conv dim.
282
283        activation (callable): a callable that constructs activation layer, examples
284            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
285            activation).
286        inner_act (callable): whether use Swish activation for act_b or not.
287
288    Returns:
289        (nn.Module): X3D block layer.
290    """
291
292    norm_model = None
293    if norm is not None and dim_in != dim_out:
294        norm_model = norm(num_features=dim_out)
295
296    return ResBlock(
297        branch1_conv=nn.Conv3d(
298            dim_in,
299            dim_out,
300            kernel_size=(1, 1, 1),
301            stride=conv_stride,
302            bias=False,
303        )
304        if (dim_in != dim_out or np.prod(conv_stride) > 1) and use_shortcut
305        else None,
306        branch1_norm=norm_model if dim_in != dim_out and use_shortcut else None,
307        branch2=bottleneck(
308            dim_in=dim_in,
309            dim_inner=dim_inner,
310            dim_out=dim_out,
311            conv_kernel_size=conv_kernel_size,
312            conv_stride=conv_stride,
313            norm=norm,
314            norm_eps=norm_eps,
315            norm_momentum=norm_momentum,
316            se_ratio=se_ratio,
317            activation=activation,
318            inner_act=inner_act,
319        ),
320        activation=None if activation is None else activation(),
321        branch_fusion=lambda x, y: x + y,
322    )
323
324
325def create_x3d_res_stage(
326    *,
327    # Stage configs.
328    depth: int,
329    # Bottleneck Block configs.
330    dim_in: int,
331    dim_inner: int,
332    dim_out: int,
333    bottleneck: Callable = create_x3d_bottleneck_block,
334    # Conv configs.
335    conv_kernel_size: Tuple[int] = (3, 3, 3),
336    conv_stride: Tuple[int] = (1, 2, 2),
337    # Norm configs.
338    norm: Callable = nn.BatchNorm3d,
339    norm_eps: float = 1e-5,
340    norm_momentum: float = 0.1,
341    se_ratio: float = 0.0625,
342    # Activation configs.
343    activation: Callable = nn.ReLU,
344    inner_act: Callable = Swish,
345) -> nn.Module:
346    """
347    Create Residual Stage, which composes sequential blocks that make up X3D.
348
349    ::
350
351                                        Input
352353                                       ResBlock
354355                                           .
356                                           .
357                                           .
358359                                       ResBlock
360
361    Args:
362
363        depth (init): number of blocks to create.
364
365        dim_in (int): input channel size to the bottleneck block.
366        dim_inner (int): intermediate channel size of the bottleneck.
367        dim_out (int): output channel size of the bottleneck.
368        bottleneck (callable): a callable for create_x3d_bottleneck_block.
369
370        conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
371        conv_stride (tuple): convolutional stride size(s) for conv_b.
372
373        norm (callable): a callable that constructs normalization layer, examples
374            include nn.BatchNorm3d, None (not performing normalization).
375        norm_eps (float): normalization epsilon.
376        norm_momentum (float): normalization momentum.
377        se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
378            channel dimensionality being se_ratio times the 3x3x3 conv dim.
379
380        activation (callable): a callable that constructs activation layer, examples
381            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
382            activation).
383        inner_act (callable): whether use Swish activation for act_b or not.
384
385    Returns:
386        (nn.Module): X3D stage layer.
387    """
388    res_blocks = []
389    for idx in range(depth):
390        block = create_x3d_res_block(
391            dim_in=dim_in if idx == 0 else dim_out,
392            dim_inner=dim_inner,
393            dim_out=dim_out,
394            bottleneck=bottleneck,
395            conv_kernel_size=conv_kernel_size,
396            conv_stride=conv_stride if idx == 0 else (1, 1, 1),
397            norm=norm,
398            norm_eps=norm_eps,
399            norm_momentum=norm_momentum,
400            se_ratio=(se_ratio if (idx + 1) % 2 else 0.0),
401            activation=activation,
402            inner_act=inner_act,
403        )
404        res_blocks.append(block)
405
406    return ResStage(res_blocks=nn.ModuleList(res_blocks))
407
408
409def create_x3d_head(
410    *,
411    # Projection configs.
412    dim_in: int,
413    dim_inner: int,
414    dim_out: int,
415    num_classes: int,
416    # Pooling configs.
417    pool_act: Callable = nn.ReLU,
418    pool_kernel_size: Tuple[int] = (13, 5, 5),
419    # BN configs.
420    norm: Callable = nn.BatchNorm3d,
421    norm_eps: float = 1e-5,
422    norm_momentum: float = 0.1,
423    bn_lin5_on=False,
424    # Dropout configs.
425    dropout_rate: float = 0.5,
426    # Activation configs.
427    activation: Callable = nn.Softmax,
428    # Output configs.
429    output_with_global_average: bool = True,
430) -> nn.Module:
431    """
432    Creates X3D head. This layer performs an projected pooling operation followed
433    by an dropout, a fully-connected projection, an activation layer and a global
434    spatiotemporal averaging.
435
436    ::
437
438                                     ProjectedPool
439440                                        Dropout
441442                                       Projection
443444                                       Activation
445446                                       Averaging
447
448    Args:
449        dim_in (int): input channel size of the X3D head.
450        dim_inner (int): intermediate channel size of the X3D head.
451        dim_out (int): output channel size of the X3D head.
452        num_classes (int): the number of classes for the video dataset.
453
454        pool_act (callable): a callable that constructs resnet pool activation
455            layer such as nn.ReLU.
456        pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive
457            pooling.
458
459        norm (callable): a callable that constructs normalization layer, examples
460            include nn.BatchNorm3d, None (not performing normalization).
461        norm_eps (float): normalization epsilon.
462        norm_momentum (float): normalization momentum.
463        bn_lin5_on (bool): if True, perform normalization on the features
464            before the classifier.
465
466        dropout_rate (float): dropout rate.
467
468        activation (callable): a callable that constructs resnet head activation
469            layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not
470            applying activation).
471
472        output_with_global_average (bool): if True, perform global averaging on temporal
473            and spatial dimensions and reshape output to batch_size x out_features.
474
475    Returns:
476        (nn.Module): X3D head layer.
477    """
478    pre_conv_module = nn.Conv3d(
479        in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False
480    )
481
482    pre_norm_module = norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
483    pre_act_module = None if pool_act is None else pool_act()
484
485    if pool_kernel_size is None:
486        pool_module = nn.AdaptiveAvgPool3d((1, 1, 1))
487    else:
488        pool_module = nn.AvgPool3d(pool_kernel_size, stride=1)
489
490    post_conv_module = nn.Conv3d(
491        in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
492    )
493
494    if bn_lin5_on:
495        post_norm_module = norm(
496            num_features=dim_out, eps=norm_eps, momentum=norm_momentum
497        )
498    else:
499        post_norm_module = None
500    post_act_module = None if pool_act is None else pool_act()
501
502    projected_pool_module = ProjectedPool(
503        pre_conv=pre_conv_module,
504        pre_norm=pre_norm_module,
505        pre_act=pre_act_module,
506        pool=pool_module,
507        post_conv=post_conv_module,
508        post_norm=post_norm_module,
509        post_act=post_act_module,
510    )
511
512    if activation is None:
513        activation_module = None
514    elif activation == nn.Softmax:
515        activation_module = activation(dim=1)
516    elif activation == nn.Sigmoid:
517        activation_module = activation()
518    else:
519        raise NotImplementedError(
520            "{} is not supported as an activation" "function.".format(activation)
521        )
522
523    if output_with_global_average:
524        output_pool = nn.AdaptiveAvgPool3d(1)
525    else:
526        output_pool = None
527
528    return ResNetBasicHead(
529        proj=nn.Linear(dim_out, num_classes, bias=True),
530        activation=activation_module,
531        pool=projected_pool_module,
532        dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None,
533        output_pool=output_pool,
534    )
535
536
537def create_x3d(
538    *,
539    # Input clip configs.
540    input_channel: int = 3,
541    input_clip_length: int = 13,
542    input_crop_size: int = 160,
543    # Model configs.
544    model_num_class: int = 400,
545    dropout_rate: float = 0.5,
546    width_factor: float = 2.0,
547    depth_factor: float = 2.2,
548    # Normalization configs.
549    norm: Callable = nn.BatchNorm3d,
550    norm_eps: float = 1e-5,
551    norm_momentum: float = 0.1,
552    # Activation configs.
553    activation: Callable = nn.ReLU,
554    # Stem configs.
555    stem_dim_in: int = 12,
556    stem_conv_kernel_size: Tuple[int] = (5, 3, 3),
557    stem_conv_stride: Tuple[int] = (1, 2, 2),
558    # Stage configs.
559    stage_conv_kernel_size: Tuple[Tuple[int]] = (
560        (3, 3, 3),
561        (3, 3, 3),
562        (3, 3, 3),
563        (3, 3, 3),
564    ),
565    stage_spatial_stride: Tuple[int] = (2, 2, 2, 2),
566    stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
567    bottleneck: Callable = create_x3d_bottleneck_block,
568    bottleneck_factor: float = 2.25,
569    se_ratio: float = 0.0625,
570    inner_act: Callable = Swish,
571    # Head configs.
572    head_dim_out: int = 2048,
573    head_pool_act: Callable = nn.ReLU,
574    head_bn_lin5_on: bool = False,
575    head_activation: Callable = nn.Softmax,
576    head_output_with_global_average: bool = True,
577) -> nn.Module:
578    """
579    X3D model builder. It builds a X3D network backbone, which is a ResNet.
580
581    Christoph Feichtenhofer.
582    "X3D: Expanding Architectures for Efficient Video Recognition."
583    https://arxiv.org/abs/2004.04730
584
585    ::
586
587                                         Input
588589                                         Stem
590591                                         Stage 1
592593                                           .
594                                           .
595                                           .
596597                                         Stage N
598599                                         Head
600
601    Args:
602        input_channel (int): number of channels for the input video clip.
603        input_clip_length (int): length of the input video clip. Value for
604            different models: X3D-XS: 4; X3D-S: 13; X3D-M: 16; X3D-L: 16.
605        input_crop_size (int): spatial resolution of the input video clip.
606            Value for different models: X3D-XS: 160; X3D-S: 160; X3D-M: 224;
607            X3D-L: 312.
608
609        model_num_class (int): the number of classes for the video dataset.
610        dropout_rate (float): dropout rate.
611        width_factor (float): width expansion factor.
612        depth_factor (float): depth expansion factor. Value for different
613            models: X3D-XS: 2.2; X3D-S: 2.2; X3D-M: 2.2; X3D-L: 5.0.
614
615        norm (callable): a callable that constructs normalization layer.
616        norm_eps (float): normalization epsilon.
617        norm_momentum (float): normalization momentum.
618
619        activation (callable): a callable that constructs activation layer.
620
621        stem_dim_in (int): input channel size for stem before expansion.
622        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
623        stem_conv_stride (tuple): convolutional stride size(s) of stem.
624
625        stage_conv_kernel_size (tuple): convolutional kernel size(s) for conv_b.
626        stage_spatial_stride (tuple): the spatial stride for each stage.
627        stage_temporal_stride (tuple): the temporal stride for each stage.
628        bottleneck_factor (float): bottleneck expansion factor for the 3x3x3 conv.
629        se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE
630            channel dimensionality being se_ratio times the 3x3x3 conv dim.
631        inner_act (callable): whether use Swish activation for act_b or not.
632
633        head_dim_out (int): output channel size of the X3D head.
634        head_pool_act (callable): a callable that constructs resnet pool activation
635            layer such as nn.ReLU.
636        head_bn_lin5_on (bool): if True, perform normalization on the features
637            before the classifier.
638        head_activation (callable): a callable that constructs activation layer.
639        head_output_with_global_average (bool): if True, perform global averaging on
640            the head output.
641
642    Returns:
643        (nn.Module): the X3D network.
644    """
645
646    torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_x3d")
647
648    blocks = []
649    # Create stem for X3D.
650    stem_dim_out = round_width(stem_dim_in, width_factor)
651    stem = create_x3d_stem(
652        in_channels=input_channel,
653        out_channels=stem_dim_out,
654        conv_kernel_size=stem_conv_kernel_size,
655        conv_stride=stem_conv_stride,
656        conv_padding=[size // 2 for size in stem_conv_kernel_size],
657        norm=norm,
658        norm_eps=norm_eps,
659        norm_momentum=norm_momentum,
660        activation=activation,
661    )
662    blocks.append(stem)
663
664    # Compute the depth and dimension for each stage
665    stage_depths = [1, 2, 5, 3]
666    exp_stage = 2.0
667    stage_dim1 = stem_dim_in
668    stage_dim2 = round_width(stage_dim1, exp_stage, divisor=8)
669    stage_dim3 = round_width(stage_dim2, exp_stage, divisor=8)
670    stage_dim4 = round_width(stage_dim3, exp_stage, divisor=8)
671    stage_dims = [stage_dim1, stage_dim2, stage_dim3, stage_dim4]
672
673    dim_in = stem_dim_out
674    # Create each stage for X3D.
675    for idx in range(len(stage_depths)):
676        dim_out = round_width(stage_dims[idx], width_factor)
677        dim_inner = int(bottleneck_factor * dim_out)
678        depth = round_repeats(stage_depths[idx], depth_factor)
679
680        stage_conv_stride = (
681            stage_temporal_stride[idx],
682            stage_spatial_stride[idx],
683            stage_spatial_stride[idx],
684        )
685
686        stage = create_x3d_res_stage(
687            depth=depth,
688            dim_in=dim_in,
689            dim_inner=dim_inner,
690            dim_out=dim_out,
691            bottleneck=bottleneck,
692            conv_kernel_size=stage_conv_kernel_size[idx],
693            conv_stride=stage_conv_stride,
694            norm=norm,
695            norm_eps=norm_eps,
696            norm_momentum=norm_momentum,
697            se_ratio=se_ratio,
698            activation=activation,
699            inner_act=inner_act,
700        )
701        blocks.append(stage)
702        dim_in = dim_out
703
704    # Create head for X3D.
705    total_spatial_stride = stem_conv_stride[1] * np.prod(stage_spatial_stride)
706    total_temporal_stride = stem_conv_stride[0] * np.prod(stage_temporal_stride)
707
708    assert (
709        input_clip_length >= total_temporal_stride
710    ), "Clip length doesn't match temporal stride!"
711    assert (
712        input_crop_size >= total_spatial_stride
713    ), "Crop size doesn't match spatial stride!"
714
715    head_pool_kernel_size = (
716        input_clip_length // total_temporal_stride,
717        int(math.ceil(input_crop_size / total_spatial_stride)),
718        int(math.ceil(input_crop_size / total_spatial_stride)),
719    )
720
721    head = create_x3d_head(
722        dim_in=dim_out,
723        dim_inner=dim_inner,
724        dim_out=head_dim_out,
725        num_classes=model_num_class,
726        pool_act=head_pool_act,
727        pool_kernel_size=head_pool_kernel_size,
728        norm=norm,
729        norm_eps=norm_eps,
730        norm_momentum=norm_momentum,
731        bn_lin5_on=head_bn_lin5_on,
732        dropout_rate=dropout_rate,
733        activation=head_activation,
734        output_with_global_average=head_output_with_global_average,
735    )
736    blocks.append(head)
737    return Net(blocks=nn.ModuleList(blocks))
738
739
740class ProjectedPool(nn.Module):
741    """
742    A pooling module augmented with Conv, Normalization and Activation both
743    before and after pooling for the head layer of X3D.
744
745    ::
746
747                                    Conv3d (pre_conv)
748749                                 Normalization (pre_norm)
750751                                   Activation (pre_act)
752753                                        Pool3d
754755                                    Conv3d (post_conv)
756757                                 Normalization (post_norm)
758759                                   Activation (post_act)
760    """
761
762    def __init__(
763        self,
764        *,
765        pre_conv: nn.Module = None,
766        pre_norm: nn.Module = None,
767        pre_act: nn.Module = None,
768        pool: nn.Module = None,
769        post_conv: nn.Module = None,
770        post_norm: nn.Module = None,
771        post_act: nn.Module = None,
772    ) -> None:
773        """
774        Args:
775            pre_conv (torch.nn.modules): convolutional module.
776            pre_norm (torch.nn.modules): normalization module.
777            pre_act (torch.nn.modules): activation module.
778            pool (torch.nn.modules): pooling module.
779            post_conv (torch.nn.modules): convolutional module.
780            post_norm (torch.nn.modules): normalization module.
781            post_act (torch.nn.modules): activation module.
782        """
783        super().__init__()
784        set_attributes(self, locals())
785        assert self.pre_conv is not None
786        assert self.pool is not None
787        assert self.post_conv is not None
788
789    def forward(self, x: torch.Tensor) -> torch.Tensor:
790        x = self.pre_conv(x)
791
792        if self.pre_norm is not None:
793            x = self.pre_norm(x)
794        if self.pre_act is not None:
795            x = self.pre_act(x)
796
797        x = self.pool(x)
798        x = self.post_conv(x)
799
800        if self.post_norm is not None:
801            x = self.post_norm(x)
802        if self.post_act is not None:
803            x = self.post_act(x)
804        return x
805