1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3from typing import Callable, List, Tuple, Union
4
5import numpy as np
6import torch
7import torch.nn as nn
8from pytorchvideo.layers.utils import set_attributes
9from pytorchvideo.models.head import create_res_basic_head, create_res_roi_pooling_head
10from pytorchvideo.models.net import Net, DetectionBBoxNetwork
11from pytorchvideo.models.stem import (
12    create_acoustic_res_basic_stem,
13    create_res_basic_stem,
14)
15
16
17def create_bottleneck_block(
18    *,
19    # Convolution configs.
20    dim_in: int,
21    dim_inner: int,
22    dim_out: int,
23    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
24    conv_a_stride: Tuple[int] = (2, 1, 1),
25    conv_a_padding: Tuple[int] = (1, 0, 0),
26    conv_a: Callable = nn.Conv3d,
27    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
28    conv_b_stride: Tuple[int] = (1, 2, 2),
29    conv_b_padding: Tuple[int] = (0, 1, 1),
30    conv_b_num_groups: int = 1,
31    conv_b_dilation: Tuple[int] = (1, 1, 1),
32    conv_b: Callable = nn.Conv3d,
33    conv_c: Callable = nn.Conv3d,
34    # Norm configs.
35    norm: Callable = nn.BatchNorm3d,
36    norm_eps: float = 1e-5,
37    norm_momentum: float = 0.1,
38    # Activation configs.
39    activation: Callable = nn.ReLU,
40) -> nn.Module:
41    """
42    Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
43    and Activations repeated in the following order:
44
45    ::
46
47                                    Conv3d (conv_a)
4849                                 Normalization (norm_a)
5051                                   Activation (act_a)
5253                                    Conv3d (conv_b)
5455                                 Normalization (norm_b)
5657                                   Activation (act_b)
5859                                    Conv3d (conv_c)
6061                                 Normalization (norm_c)
62
63    Normalization examples include: BatchNorm3d and None (no normalization).
64    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
65
66    Args:
67        dim_in (int): input channel size to the bottleneck block.
68        dim_inner (int): intermediate channel size of the bottleneck.
69        dim_out (int): output channel size of the bottleneck.
70        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
71        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
72        conv_a_padding (tuple): convolutional padding(s) for conv_a.
73        conv_a (callable): a callable that constructs the conv_a conv layer, examples
74            include nn.Conv3d, OctaveConv, etc
75        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
76        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
77        conv_b_padding (tuple): convolutional padding(s) for conv_b.
78        conv_b_num_groups (int): number of groups for groupwise convolution for
79            conv_b.
80        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
81        conv_b (callable): a callable that constructs the conv_b conv layer, examples
82            include nn.Conv3d, OctaveConv, etc
83        conv_c (callable): a callable that constructs the conv_c conv layer, examples
84            include nn.Conv3d, OctaveConv, etc
85
86        norm (callable): a callable that constructs normalization layer, examples
87            include nn.BatchNorm3d, None (not performing normalization).
88        norm_eps (float): normalization epsilon.
89        norm_momentum (float): normalization momentum.
90
91        activation (callable): a callable that constructs activation layer, examples
92            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
93            activation).
94
95    Returns:
96        (nn.Module): resnet bottleneck block.
97    """
98    conv_a = conv_a(
99        in_channels=dim_in,
100        out_channels=dim_inner,
101        kernel_size=conv_a_kernel_size,
102        stride=conv_a_stride,
103        padding=conv_a_padding,
104        bias=False,
105    )
106    norm_a = (
107        None
108        if norm is None
109        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
110    )
111    act_a = None if activation is None else activation()
112
113    conv_b = conv_b(
114        in_channels=dim_inner,
115        out_channels=dim_inner,
116        kernel_size=conv_b_kernel_size,
117        stride=conv_b_stride,
118        padding=conv_b_padding,
119        bias=False,
120        groups=conv_b_num_groups,
121        dilation=conv_b_dilation,
122    )
123    norm_b = (
124        None
125        if norm is None
126        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
127    )
128    act_b = None if activation is None else activation()
129
130    conv_c = conv_c(
131        in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
132    )
133    norm_c = (
134        None
135        if norm is None
136        else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
137    )
138
139    return BottleneckBlock(
140        conv_a=conv_a,
141        norm_a=norm_a,
142        act_a=act_a,
143        conv_b=conv_b,
144        norm_b=norm_b,
145        act_b=act_b,
146        conv_c=conv_c,
147        norm_c=norm_c,
148    )
149
150
151def create_acoustic_bottleneck_block(
152    *,
153    # Convolution configs.
154    dim_in: int,
155    dim_inner: int,
156    dim_out: int,
157    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
158    conv_a_stride: Tuple[int] = (2, 1, 1),
159    conv_a_padding: Tuple[int] = (1, 0, 0),
160    conv_a: Callable = nn.Conv3d,
161    # Conv b f configs.
162    conv_b_kernel_size: Tuple[int] = (1, 1, 1),
163    conv_b_stride: Tuple[int] = (1, 1, 1),
164    conv_b_padding: Tuple[int] = (0, 0, 0),
165    conv_b_num_groups: int = 1,
166    conv_b_dilation: Tuple[int] = (1, 1, 1),
167    conv_b: Callable = nn.Conv3d,
168    conv_c: Callable = nn.Conv3d,
169    # Norm configs.
170    norm: Callable = nn.BatchNorm3d,
171    norm_eps: float = 1e-5,
172    norm_momentum: float = 0.1,
173    # Activation configs.
174    activation: Callable = nn.ReLU,
175) -> nn.Module:
176    """
177    Acoustic Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
178    and Activations repeated in the following order:
179
180    ::
181
182                                    Conv3d (conv_a)
183184                                 Normalization (norm_a)
185186                                   Activation (act_a)
187188                           ---------------------------------
189                           ↓                               ↓
190                Temporal Conv3d (conv_b)        Spatial Conv3d (conv_b)
191                           ↓                               ↓
192                 Normalization (norm_b)         Normalization (norm_b)
193                           ↓                               ↓
194                   Activation (act_b)              Activation (act_b)
195                           ↓                               ↓
196                           ---------------------------------
197198                                    Conv3d (conv_c)
199200                                 Normalization (norm_c)
201
202    Normalization examples include: BatchNorm3d and None (no normalization).
203    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
204
205    Args:
206        dim_in (int): input channel size to the bottleneck block.
207        dim_inner (int): intermediate channel size of the bottleneck.
208        dim_out (int): output channel size of the bottleneck.
209        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
210        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
211        conv_a_padding (tuple): convolutional padding(s) for conv_a.
212        conv_a (callable): a callable that constructs the conv_a conv layer, examples
213            include nn.Conv3d, OctaveConv, etc
214        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
215        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
216        conv_b_padding (tuple): convolutional padding(s) for conv_b.
217        conv_b_num_groups (int): number of groups for groupwise convolution for
218            conv_b.
219        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
220        conv_b (callable): a callable that constructs the conv_b conv layer, examples
221            include nn.Conv3d, OctaveConv, etc
222        conv_c (callable): a callable that constructs the conv_c conv layer, examples
223            include nn.Conv3d, OctaveConv, etc
224
225        norm (callable): a callable that constructs normalization layer, examples
226            include nn.BatchNorm3d, None (not performing normalization).
227        norm_eps (float): normalization epsilon.
228        norm_momentum (float): normalization momentum.
229
230        activation (callable): a callable that constructs activation layer, examples
231            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
232            activation).
233
234    Returns:
235        (nn.Module): resnet acoustic bottleneck block.
236    """
237    conv_a = conv_a(
238        in_channels=dim_in,
239        out_channels=dim_inner,
240        kernel_size=conv_a_kernel_size,
241        stride=conv_a_stride,
242        padding=conv_a_padding,
243        bias=False,
244    )
245    norm_a = (
246        None
247        if norm is None
248        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
249    )
250    act_a = None if activation is None else activation()
251
252    conv_b_1_kernel_size = [conv_b_kernel_size[0], 1, 1]
253    conv_b_1_stride = conv_b_stride
254    conv_b_1_padding = [conv_b_padding[0], 0, 0]
255
256    conv_b_2_kernel_size = [1, conv_b_kernel_size[1], conv_b_kernel_size[2]]
257    conv_b_2_stride = conv_b_stride
258    conv_b_2_padding = [0, conv_b_padding[1], conv_b_padding[2]]
259
260    conv_b_1_num_groups, conv_b_2_num_groups = (conv_b_num_groups,) * 2
261    conv_b_1_dilation = [conv_b_dilation[0], 1, 1]
262    conv_b_2_dilation = [1, conv_b_dilation[1], conv_b_dilation[2]]
263
264    conv_b_1 = conv_b(
265        in_channels=dim_inner,
266        out_channels=dim_inner,
267        kernel_size=conv_b_1_kernel_size,
268        stride=conv_b_1_stride,
269        padding=conv_b_1_padding,
270        bias=False,
271        groups=conv_b_1_num_groups,
272        dilation=conv_b_1_dilation,
273    )
274    norm_b_1 = (
275        None
276        if norm is None
277        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
278    )
279    act_b_1 = None if activation is None else activation()
280
281    conv_b_2 = conv_b(
282        in_channels=dim_inner,
283        out_channels=dim_inner,
284        kernel_size=conv_b_2_kernel_size,
285        stride=conv_b_2_stride,
286        padding=conv_b_2_padding,
287        bias=False,
288        groups=conv_b_2_num_groups,
289        dilation=conv_b_2_dilation,
290    )
291    norm_b_2 = (
292        None
293        if norm is None
294        else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum)
295    )
296    act_b_2 = None if activation is None else activation()
297
298    conv_c = conv_c(
299        in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False
300    )
301    norm_c = (
302        None
303        if norm is None
304        else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
305    )
306
307    return SeparableBottleneckBlock(
308        conv_a=conv_a,
309        norm_a=norm_a,
310        act_a=act_a,
311        conv_b=nn.ModuleList([conv_b_2, conv_b_1]),
312        norm_b=nn.ModuleList([norm_b_2, norm_b_1]),
313        act_b=nn.ModuleList([act_b_2, act_b_1]),
314        conv_c=conv_c,
315        norm_c=norm_c,
316    )
317
318
319def create_res_block(
320    *,
321    # Bottleneck Block configs.
322    dim_in: int,
323    dim_inner: int,
324    dim_out: int,
325    bottleneck: Callable,
326    use_shortcut: bool = False,
327    branch_fusion: Callable = lambda x, y: x + y,
328    # Conv configs.
329    conv_a_kernel_size: Tuple[int] = (3, 1, 1),
330    conv_a_stride: Tuple[int] = (2, 1, 1),
331    conv_a_padding: Tuple[int] = (1, 0, 0),
332    conv_a: Callable = nn.Conv3d,
333    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
334    conv_b_stride: Tuple[int] = (1, 2, 2),
335    conv_b_padding: Tuple[int] = (0, 1, 1),
336    conv_b_num_groups: int = 1,
337    conv_b_dilation: Tuple[int] = (1, 1, 1),
338    conv_b: Callable = nn.Conv3d,
339    conv_c: Callable = nn.Conv3d,
340    conv_skip: Callable = nn.Conv3d,
341    # Norm configs.
342    norm: Callable = nn.BatchNorm3d,
343    norm_eps: float = 1e-5,
344    norm_momentum: float = 0.1,
345    # Activation configs.
346    activation_bottleneck: Callable = nn.ReLU,
347    activation_block: Callable = nn.ReLU,
348) -> nn.Module:
349    """
350    Residual block. Performs a summation between an identity shortcut in branch1 and a
351    main block in branch2. When the input and output dimensions are different, a
352    convolution followed by a normalization will be performed.
353
354    ::
355
356
357                                         Input
358                                           |-------+
359                                           ↓       |
360                                         Block     |
361                                           ↓       |
362                                       Summation ←-+
363364                                       Activation
365
366    Normalization examples include: BatchNorm3d and None (no normalization).
367    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
368    Transform examples include: BottleneckBlock.
369
370    Args:
371        dim_in (int): input channel size to the bottleneck block.
372        dim_inner (int): intermediate channel size of the bottleneck.
373        dim_out (int): output channel size of the bottleneck.
374        bottleneck (callable): a callable that constructs bottleneck block layer.
375            Examples include: create_bottleneck_block.
376        use_shortcut (bool): If true, use conv and norm layers in skip connection.
377        branch_fusion (callable): a callable that constructs summation layer.
378            Examples include: lambda x, y: x + y, OctaveSum.
379
380        conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
381        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
382        conv_a_padding (tuple): convolutional padding(s) for conv_a.
383        conv_a (callable): a callable that constructs the conv_a conv layer, examples
384            include nn.Conv3d, OctaveConv, etc
385        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
386        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
387        conv_b_padding (tuple): convolutional padding(s) for conv_b.
388        conv_b_num_groups (int): number of groups for groupwise convolution for
389            conv_b.
390        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
391        conv_b (callable): a callable that constructs the conv_b conv layer, examples
392            include nn.Conv3d, OctaveConv, etc
393        conv_c (callable): a callable that constructs the conv_c conv layer, examples
394            include nn.Conv3d, OctaveConv, etc
395        conv_skip (callable): a callable that constructs the conv_skip conv layer,
396        examples include nn.Conv3d, OctaveConv, etc
397
398        norm (callable): a callable that constructs normalization layer. Examples
399            include nn.BatchNorm3d, None (not performing normalization).
400        norm_eps (float): normalization epsilon.
401        norm_momentum (float): normalization momentum.
402
403        activation_bottleneck (callable): a callable that constructs activation layer in
404            bottleneck. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None
405            (not performing activation).
406        activation_block (callable): a callable that constructs activation layer used
407            at the end of the block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid,
408            and None (not performing activation).
409
410    Returns:
411        (nn.Module): resnet basic block layer.
412    """
413    branch1_conv_stride = tuple(map(np.prod, zip(conv_a_stride, conv_b_stride)))
414    norm_model = None
415    if use_shortcut or (
416        norm is not None and (dim_in != dim_out or np.prod(branch1_conv_stride) != 1)
417    ):
418        norm_model = norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum)
419
420    return ResBlock(
421        branch1_conv=conv_skip(
422            dim_in,
423            dim_out,
424            kernel_size=(1, 1, 1),
425            stride=branch1_conv_stride,
426            bias=False,
427        )
428        if (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) or use_shortcut
429        else None,
430        branch1_norm=norm_model,
431        branch2=bottleneck(
432            dim_in=dim_in,
433            dim_inner=dim_inner,
434            dim_out=dim_out,
435            conv_a_kernel_size=conv_a_kernel_size,
436            conv_a_stride=conv_a_stride,
437            conv_a_padding=conv_a_padding,
438            conv_a=conv_a,
439            conv_b_kernel_size=conv_b_kernel_size,
440            conv_b_stride=conv_b_stride,
441            conv_b_padding=conv_b_padding,
442            conv_b_num_groups=conv_b_num_groups,
443            conv_b_dilation=conv_b_dilation,
444            conv_b=conv_b,
445            conv_c=conv_c,
446            norm=norm,
447            norm_eps=norm_eps,
448            norm_momentum=norm_momentum,
449            activation=activation_bottleneck,
450        ),
451        activation=None if activation_block is None else activation_block(),
452        branch_fusion=branch_fusion,
453    )
454
455
456def create_res_stage(
457    *,
458    # Stage configs.
459    depth: int,
460    # Bottleneck Block configs.
461    dim_in: int,
462    dim_inner: int,
463    dim_out: int,
464    bottleneck: Callable,
465    # Conv configs.
466    conv_a_kernel_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1, 1),
467    conv_a_stride: Tuple[int] = (2, 1, 1),
468    conv_a_padding: Union[Tuple[int], List[Tuple[int]]] = (1, 0, 0),
469    conv_a: Callable = nn.Conv3d,
470    conv_b_kernel_size: Tuple[int] = (1, 3, 3),
471    conv_b_stride: Tuple[int] = (1, 2, 2),
472    conv_b_padding: Tuple[int] = (0, 1, 1),
473    conv_b_num_groups: int = 1,
474    conv_b_dilation: Tuple[int] = (1, 1, 1),
475    conv_b: Callable = nn.Conv3d,
476    conv_c: Callable = nn.Conv3d,
477    # Norm configs.
478    norm: Callable = nn.BatchNorm3d,
479    norm_eps: float = 1e-5,
480    norm_momentum: float = 0.1,
481    # Activation configs.
482    activation: Callable = nn.ReLU,
483) -> nn.Module:
484    """
485    Create Residual Stage, which composes sequential blocks that make up a ResNet. These
486    blocks could be, for example, Residual blocks, Non-Local layers, or
487    Squeeze-Excitation layers.
488
489    ::
490
491
492                                        Input
493494                                       ResBlock
495496                                           .
497                                           .
498                                           .
499500                                       ResBlock
501
502    Normalization examples include: BatchNorm3d and None (no normalization).
503    Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation).
504    Bottleneck examples include: create_bottleneck_block.
505
506    Args:
507        depth (init): number of blocks to create.
508
509        dim_in (int): input channel size to the bottleneck block.
510        dim_inner (int): intermediate channel size of the bottleneck.
511        dim_out (int): output channel size of the bottleneck.
512        bottleneck (callable): a callable that constructs bottleneck block layer.
513            Examples include: create_bottleneck_block.
514
515        conv_a_kernel_size (tuple or list of tuple): convolutional kernel size(s)
516            for conv_a. If conv_a_kernel_size is a tuple, use it for all blocks in
517            the stage. If conv_a_kernel_size is a list of tuple, the kernel sizes
518            will be repeated until having same length of depth in the stage. For
519            example, for conv_a_kernel_size = [(3, 1, 1), (1, 1, 1)], the kernel
520            size for the first 6 blocks would be [(3, 1, 1), (1, 1, 1), (3, 1, 1),
521            (1, 1, 1), (3, 1, 1)].
522        conv_a_stride (tuple): convolutional stride size(s) for conv_a.
523        conv_a_padding (tuple or list of tuple): convolutional padding(s) for
524            conv_a. If conv_a_padding is a tuple, use it for all blocks in
525            the stage. If conv_a_padding is a list of tuple, the padding sizes
526            will be repeated until having same length of depth in the stage.
527        conv_a (callable): a callable that constructs the conv_a conv layer, examples
528            include nn.Conv3d, OctaveConv, etc
529        conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
530        conv_b_stride (tuple): convolutional stride size(s) for conv_b.
531        conv_b_padding (tuple): convolutional padding(s) for conv_b.
532        conv_b_num_groups (int): number of groups for groupwise convolution for
533            conv_b.
534        conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
535        conv_b (callable): a callable that constructs the conv_b conv layer, examples
536            include nn.Conv3d, OctaveConv, etc
537        conv_c (callable): a callable that constructs the conv_c conv layer, examples
538            include nn.Conv3d, OctaveConv, etc
539
540        norm (callable): a callable that constructs normalization layer. Examples
541            include nn.BatchNorm3d, and None (not performing normalization).
542        norm_eps (float): normalization epsilon.
543        norm_momentum (float): normalization momentum.
544
545        activation (callable): a callable that constructs activation layer. Examples
546            include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing
547            activation).
548
549    Returns:
550        (nn.Module): resnet basic stage layer.
551    """
552    res_blocks = []
553    if isinstance(conv_a_kernel_size[0], int):
554        conv_a_kernel_size = [conv_a_kernel_size]
555    if isinstance(conv_a_padding[0], int):
556        conv_a_padding = [conv_a_padding]
557    # Repeat conv_a kernels until having same length of depth in the stage.
558    conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth]
559    conv_a_padding = (conv_a_padding * depth)[:depth]
560
561    for ind in range(depth):
562        block = create_res_block(
563            dim_in=dim_in if ind == 0 else dim_out,
564            dim_inner=dim_inner,
565            dim_out=dim_out,
566            bottleneck=bottleneck,
567            conv_a_kernel_size=conv_a_kernel_size[ind],
568            conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1),
569            conv_a_padding=conv_a_padding[ind],
570            conv_a=conv_a,
571            conv_b_kernel_size=conv_b_kernel_size,
572            conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1),
573            conv_b_padding=conv_b_padding,
574            conv_b_num_groups=conv_b_num_groups,
575            conv_b_dilation=conv_b_dilation,
576            conv_b=conv_b,
577            conv_c=conv_c,
578            norm=norm,
579            norm_eps=norm_eps,
580            norm_momentum=norm_momentum,
581            activation_bottleneck=activation,
582            activation_block=activation,
583        )
584        res_blocks.append(block)
585    return ResStage(res_blocks=nn.ModuleList(res_blocks))
586
587
588# Number of blocks for different stages given the model depth.
589_MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)}
590
591
592def create_resnet(
593    *,
594    # Input clip configs.
595    input_channel: int = 3,
596    # Model configs.
597    model_depth: int = 50,
598    model_num_class: int = 400,
599    dropout_rate: float = 0.5,
600    # Normalization configs.
601    norm: Callable = nn.BatchNorm3d,
602    # Activation configs.
603    activation: Callable = nn.ReLU,
604    # Stem configs.
605    stem_dim_out: int = 64,
606    stem_conv_kernel_size: Tuple[int] = (3, 7, 7),
607    stem_conv_stride: Tuple[int] = (1, 2, 2),
608    stem_pool: Callable = nn.MaxPool3d,
609    stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
610    stem_pool_stride: Tuple[int] = (1, 2, 2),
611    stem: Callable = create_res_basic_stem,
612    # Stage configs.
613    stage1_pool: Callable = None,
614    stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
615    stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
616        (1, 1, 1),
617        (1, 1, 1),
618        (3, 1, 1),
619        (3, 1, 1),
620    ),
621    stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
622        (1, 3, 3),
623        (1, 3, 3),
624        (1, 3, 3),
625        (1, 3, 3),
626    ),
627    stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
628    stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (
629        (1, 1, 1),
630        (1, 1, 1),
631        (1, 1, 1),
632        (1, 1, 1),
633    ),
634    stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 2),
635    stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2),
636    stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
637    bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block,
638    # Head configs.
639    head: Callable = create_res_basic_head,
640    head_pool: Callable = nn.AvgPool3d,
641    head_pool_kernel_size: Tuple[int] = (4, 7, 7),
642    head_output_size: Tuple[int] = (1, 1, 1),
643    head_activation: Callable = None,
644    head_output_with_global_average: bool = True,
645) -> nn.Module:
646    """
647    Build ResNet style models for video recognition. ResNet has three parts:
648    Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
649    optional pooling layer. Stages are grouped residual blocks. There are usually
650    multiple stages and each stage may include multiple residual blocks. Head
651    may include pooling, dropout, a fully-connected layer and global spatial
652    temporal averaging. The three parts are assembled in the following order:
653
654    ::
655
656                                         Input
657658                                         Stem
659660                                         Stage 1
661662                                           .
663                                           .
664                                           .
665666                                         Stage N
667668                                         Head
669
670    Args:
671
672        input_channel (int): number of channels for the input video clip.
673
674        model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
675        model_num_class (int): the number of classes for the video dataset.
676        dropout_rate (float): dropout rate.
677
678
679        norm (callable): a callable that constructs normalization layer.
680
681        activation (callable): a callable that constructs activation layer.
682
683        stem_dim_out (int): output channel size to stem.
684        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
685        stem_conv_stride (tuple): convolutional stride size(s) of stem.
686        stem_pool (callable): a callable that constructs resnet head pooling layer.
687        stem_pool_kernel_size (tuple): pooling kernel size(s).
688        stem_pool_stride (tuple): pooling stride size(s).
689        stem (callable): a callable that constructs stem layer.
690            Examples include: create_res_video_stem.
691
692        stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
693        stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
694        stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
695            for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
696        stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
697        stage_spatial_h_stride (tuple): the spatial height stride for each stage.
698        stage_spatial_w_stride (tuple): the spatial width stride for each stage.
699        stage_temporal_stride (tuple): the temporal stride for each stage.
700        bottleneck (callable): a callable that constructs bottleneck block layer.
701            Examples include: create_bottleneck_block.
702
703        head (callable): a callable that constructs the resnet-style head.
704            Ex: create_res_basic_head
705        head_pool (callable): a callable that constructs resnet head pooling layer.
706        head_pool_kernel_size (tuple): the pooling kernel size.
707        head_output_size (tuple): the size of output tensor for head.
708        head_activation (callable): a callable that constructs activation layer.
709        head_output_with_global_average (bool): if True, perform global averaging on
710            the head output.
711
712    Returns:
713        (nn.Module): basic resnet.
714    """
715
716    torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_resnet")
717
718    # Given a model depth, get the number of blocks for each stage.
719    assert (
720        model_depth in _MODEL_STAGE_DEPTH.keys()
721    ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}"
722    stage_depths = _MODEL_STAGE_DEPTH[model_depth]
723
724    # Broadcast single element to tuple if given.
725    if isinstance(stage_conv_a_kernel_size[0], int):
726        stage_conv_a_kernel_size = (stage_conv_a_kernel_size,) * len(stage_depths)
727
728    if isinstance(stage_conv_b_kernel_size[0], int):
729        stage_conv_b_kernel_size = (stage_conv_b_kernel_size,) * len(stage_depths)
730
731    if isinstance(stage_conv_b_dilation[0], int):
732        stage_conv_b_dilation = (stage_conv_b_dilation,) * len(stage_depths)
733
734    if isinstance(bottleneck, Callable):
735        bottleneck = [
736            bottleneck,
737        ] * len(stage_depths)
738
739    blocks = []
740    # Create stem for resnet.
741    stem = stem(
742        in_channels=input_channel,
743        out_channels=stem_dim_out,
744        conv_kernel_size=stem_conv_kernel_size,
745        conv_stride=stem_conv_stride,
746        conv_padding=[size // 2 for size in stem_conv_kernel_size],
747        pool=stem_pool,
748        pool_kernel_size=stem_pool_kernel_size,
749        pool_stride=stem_pool_stride,
750        pool_padding=[size // 2 for size in stem_pool_kernel_size],
751        norm=norm,
752        activation=activation,
753    )
754    blocks.append(stem)
755
756    stage_dim_in = stem_dim_out
757    stage_dim_out = stage_dim_in * 4
758
759    # Create each stage for resnet.
760    for idx in range(len(stage_depths)):
761        stage_dim_inner = stage_dim_out // 4
762        depth = stage_depths[idx]
763
764        stage_conv_a_kernel = stage_conv_a_kernel_size[idx]
765        stage_conv_a_stride = (stage_temporal_stride[idx], 1, 1)
766        stage_conv_a_padding = (
767            [size // 2 for size in stage_conv_a_kernel]
768            if isinstance(stage_conv_a_kernel[0], int)
769            else [[size // 2 for size in sizes] for sizes in stage_conv_a_kernel]
770        )
771
772        stage_conv_b_stride = (
773            1,
774            stage_spatial_h_stride[idx],
775            stage_spatial_w_stride[idx],
776        )
777
778        stage = create_res_stage(
779            depth=depth,
780            dim_in=stage_dim_in,
781            dim_inner=stage_dim_inner,
782            dim_out=stage_dim_out,
783            bottleneck=bottleneck[idx],
784            conv_a_kernel_size=stage_conv_a_kernel,
785            conv_a_stride=stage_conv_a_stride,
786            conv_a_padding=stage_conv_a_padding,
787            conv_b_kernel_size=stage_conv_b_kernel_size[idx],
788            conv_b_stride=stage_conv_b_stride,
789            conv_b_padding=(
790                stage_conv_b_kernel_size[idx][0] // 2,
791                stage_conv_b_dilation[idx][1]
792                if stage_conv_b_dilation[idx][1] > 1
793                else stage_conv_b_kernel_size[idx][1] // 2,
794                stage_conv_b_dilation[idx][2]
795                if stage_conv_b_dilation[idx][2] > 1
796                else stage_conv_b_kernel_size[idx][2] // 2,
797            ),
798            conv_b_num_groups=stage_conv_b_num_groups[idx],
799            conv_b_dilation=stage_conv_b_dilation[idx],
800            norm=norm,
801            activation=activation,
802        )
803
804        blocks.append(stage)
805        stage_dim_in = stage_dim_out
806        stage_dim_out = stage_dim_out * 2
807
808        if idx == 0 and stage1_pool is not None:
809            blocks.append(
810                stage1_pool(
811                    kernel_size=stage1_pool_kernel_size,
812                    stride=stage1_pool_kernel_size,
813                    padding=(0, 0, 0),
814                )
815            )
816    if head is not None:
817        head = head(
818            in_features=stage_dim_in,
819            out_features=model_num_class,
820            pool=head_pool,
821            output_size=head_output_size,
822            pool_kernel_size=head_pool_kernel_size,
823            dropout_rate=dropout_rate,
824            activation=head_activation,
825            output_with_global_average=head_output_with_global_average,
826        )
827        blocks.append(head)
828    return Net(blocks=nn.ModuleList(blocks))
829
830
831def create_resnet_with_roi_head(
832    *,
833    # Input clip configs.
834    input_channel: int = 3,
835    # Model configs.
836    model_depth: int = 50,
837    model_num_class: int = 80,
838    dropout_rate: float = 0.5,
839    # Normalization configs.
840    norm: Callable = nn.BatchNorm3d,
841    # Activation configs.
842    activation: Callable = nn.ReLU,
843    # Stem configs.
844    stem_dim_out: int = 64,
845    stem_conv_kernel_size: Tuple[int] = (1, 7, 7),
846    stem_conv_stride: Tuple[int] = (1, 2, 2),
847    stem_pool: Callable = nn.MaxPool3d,
848    stem_pool_kernel_size: Tuple[int] = (1, 3, 3),
849    stem_pool_stride: Tuple[int] = (1, 2, 2),
850    stem: Callable = create_res_basic_stem,
851    # Stage configs.
852    stage1_pool: Callable = None,
853    stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
854    stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
855        (1, 1, 1),
856        (1, 1, 1),
857        (3, 1, 1),
858        (3, 1, 1),
859    ),
860    stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (
861        (1, 3, 3),
862        (1, 3, 3),
863        (1, 3, 3),
864        (1, 3, 3),
865    ),
866    stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
867    stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (
868        (1, 1, 1),
869        (1, 1, 1),
870        (1, 1, 1),
871        (1, 2, 2),
872    ),
873    stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 1),
874    stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 1),
875    stage_temporal_stride: Tuple[int] = (1, 1, 1, 1),
876    bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block,
877    # Head configs.
878    head: Callable = create_res_roi_pooling_head,
879    head_pool: Callable = nn.AvgPool3d,
880    head_pool_kernel_size: Tuple[int] = (4, 1, 1),
881    head_output_size: Tuple[int] = (1, 1, 1),
882    head_activation: Callable = nn.Sigmoid,
883    head_output_with_global_average: bool = False,
884    head_spatial_resolution: Tuple[int] = (7, 7),
885    head_spatial_scale: float = 1.0 / 16.0,
886    head_sampling_ratio: int = 0,
887) -> nn.Module:
888    """
889    Build ResNet style models for video detection. ResNet has three parts:
890    Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
891    optional pooling layer. Stages are grouped residual blocks. There are usually
892    multiple stages and each stage may include multiple residual blocks. Head
893    may include pooling, dropout, a fully-connected layer and global spatial
894    temporal averaging. The three parts are assembled in the following order:
895
896    ::
897
898                            Input Clip    Input Bounding Boxes
899                              ↓                       ↓
900                            Stem                      ↓
901                              ↓                       ↓
902                            Stage 1                   ↓
903                              ↓                       ↓
904                              .                       ↓
905                              .                       ↓
906                              .                       ↓
907                              ↓                       ↓
908                            Stage N                   ↓
909                              ↓--------> Head <-------↓
910
911    Args:
912
913        input_channel (int): number of channels for the input video clip.
914
915        model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
916        model_num_class (int): the number of classes for the video dataset.
917        dropout_rate (float): dropout rate.
918
919
920        norm (callable): a callable that constructs normalization layer.
921
922        activation (callable): a callable that constructs activation layer.
923
924        stem_dim_out (int): output channel size to stem.
925        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
926        stem_conv_stride (tuple): convolutional stride size(s) of stem.
927        stem_pool (callable): a callable that constructs resnet head pooling layer.
928        stem_pool_kernel_size (tuple): pooling kernel size(s).
929        stem_pool_stride (tuple): pooling stride size(s).
930        stem (callable): a callable that constructs stem layer.
931            Examples include: create_res_video_stem.
932
933        stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
934        stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
935        stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
936            for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
937        stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
938        stage_spatial_h_stride (tuple): the spatial height stride for each stage.
939        stage_spatial_w_stride (tuple): the spatial width stride for each stage.
940        stage_temporal_stride (tuple): the temporal stride for each stage.
941        bottleneck (callable): a callable that constructs bottleneck block layer.
942            Examples include: create_bottleneck_block.
943
944        head (callable): a callable that constructs the detection head which can
945            take in the additional input of bounding boxes.
946            Ex: create_res_roi_pooling_head
947        head_pool (callable): a callable that constructs resnet head pooling layer.
948        head_pool_kernel_size (tuple): the pooling kernel size.
949        head_output_size (tuple): the size of output tensor for head.
950        head_activation (callable): a callable that constructs activation layer.
951        head_output_with_global_average (bool): if True, perform global averaging on
952            the head output.
953        head_spatial_resolution (tuple): h, w sizes of the RoI interpolation.
954        head_spatial_scale (float): scale the input boxes by this number.
955        head_sampling_ratio (int): number of inputs samples to take for each output
956                sample interpolation. 0 to take samples densely.
957
958    Returns:
959        (nn.Module): basic resnet.
960    """
961
962    model = create_resnet(
963        # Input clip configs.
964        input_channel=input_channel,
965        # Model configs.
966        model_depth=model_depth,
967        model_num_class=model_num_class,
968        dropout_rate=dropout_rate,
969        # Normalization configs.
970        norm=norm,
971        # Activation configs.
972        activation=activation,
973        # Stem configs.
974        stem_dim_out=stem_dim_out,
975        stem_conv_kernel_size=stem_conv_kernel_size,
976        stem_conv_stride=stem_conv_stride,
977        stem_pool=stem_pool,
978        stem_pool_kernel_size=stem_pool_kernel_size,
979        stem_pool_stride=stem_pool_stride,
980        # Stage configs.
981        stage1_pool=stage1_pool,
982        stage_conv_a_kernel_size=stage_conv_a_kernel_size,
983        stage_conv_b_kernel_size=stage_conv_b_kernel_size,
984        stage_conv_b_num_groups=stage_conv_b_num_groups,
985        stage_conv_b_dilation=stage_conv_b_dilation,
986        stage_spatial_h_stride=stage_spatial_h_stride,
987        stage_spatial_w_stride=stage_spatial_w_stride,
988        stage_temporal_stride=stage_temporal_stride,
989        bottleneck=bottleneck,
990        # Head configs.
991        head=None,
992    )
993    head = head(
994        in_features=stem_dim_out * 2 ** (len(_MODEL_STAGE_DEPTH[model_depth]) + 1),
995        out_features=model_num_class,
996        pool=head_pool,
997        output_size=head_output_size,
998        pool_kernel_size=head_pool_kernel_size,
999        dropout_rate=dropout_rate,
1000        activation=head_activation,
1001        output_with_global_average=head_output_with_global_average,
1002        resolution=head_spatial_resolution,
1003        spatial_scale=head_spatial_scale,
1004        sampling_ratio=head_sampling_ratio,
1005    )
1006    return DetectionBBoxNetwork(model, head)
1007
1008
1009def create_acoustic_resnet(
1010    *,
1011    # Input clip configs.
1012    input_channel: int = 1,
1013    # Model configs.
1014    model_depth: int = 50,
1015    model_num_class: int = 400,
1016    dropout_rate: float = 0.5,
1017    # Normalization configs.
1018    norm: Callable = nn.BatchNorm3d,
1019    # Activation configs.
1020    activation: Callable = nn.ReLU,
1021    # Stem configs.
1022    stem_dim_out: int = 64,
1023    stem_conv_kernel_size: Tuple[int] = (9, 1, 9),
1024    stem_conv_stride: Tuple[int] = (1, 1, 3),
1025    stem_pool: Callable = None,
1026    stem_pool_kernel_size: Tuple[int] = (3, 1, 3),
1027    stem_pool_stride: Tuple[int] = (2, 1, 2),
1028    stem: Callable = create_acoustic_res_basic_stem,
1029    # Stage configs.
1030    stage1_pool: Callable = None,
1031    stage1_pool_kernel_size: Tuple[int] = (2, 1, 1),
1032    stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 1),
1033    stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 3),
1034    stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1),
1035    stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (1, 1, 1),
1036    stage_spatial_h_stride: Tuple[int] = (1, 1, 1, 1),
1037    stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2),
1038    stage_temporal_stride: Tuple[int] = (1, 2, 2, 2),
1039    bottleneck: Union[Tuple[Callable], Callable] = (
1040        create_acoustic_bottleneck_block,
1041        create_acoustic_bottleneck_block,
1042        create_bottleneck_block,
1043        create_bottleneck_block,
1044    ),
1045    # Head configs.
1046    head_pool: Callable = nn.AvgPool3d,
1047    head_pool_kernel_size: Tuple[int] = (4, 1, 2),
1048    head_output_size: Tuple[int] = (1, 1, 1),
1049    head_activation: Callable = None,
1050    head_output_with_global_average: bool = True,
1051) -> nn.Module:
1052    """
1053    Build ResNet style models for acoustic recognition. ResNet has three parts:
1054    Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an
1055    optional pooling layer. Stages are grouped residual blocks. There are usually
1056    multiple stages and each stage may include multiple residual blocks. Head
1057    may include pooling, dropout, a fully-connected layer and global spatial
1058    temporal averaging. The three parts are assembled in the following order:
1059
1060    ::
1061
1062                                         Input
10631064                                         Stem
10651066                                         Stage 1
10671068                                           .
1069                                           .
1070                                           .
10711072                                         Stage N
10731074                                         Head
1075
1076    Args:
1077
1078        input_channel (int): number of channels for the input video clip.
1079
1080        model_depth (int): the depth of the resnet. Options include: 50, 101, 152.
1081        model_num_class (int): the number of classes for the video dataset.
1082        dropout_rate (float): dropout rate.
1083
1084
1085        norm (callable): a callable that constructs normalization layer.
1086
1087        activation (callable): a callable that constructs activation layer.
1088
1089        stem_dim_out (int): output channel size to stem.
1090        stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem.
1091        stem_conv_stride (tuple): convolutional stride size(s) of stem.
1092        stem_pool (callable): a callable that constructs resnet head pooling layer.
1093        stem_pool_kernel_size (tuple): pooling kernel size(s).
1094        stem_pool_stride (tuple): pooling stride size(s).
1095        stem (callable): a callable that constructs stem layer.
1096            Examples include: create_res_video_stem.
1097
1098        stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a.
1099        stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b.
1100        stage_conv_b_num_groups (tuple): number of groups for groupwise convolution
1101            for conv_b. 1 for ResNet, and larger than 1 for ResNeXt.
1102        stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b.
1103        stage_spatial_h_stride (tuple): the spatial height stride for each stage.
1104        stage_spatial_w_stride (tuple): the spatial width stride for each stage.
1105        stage_temporal_stride (tuple): the temporal stride for each stage.
1106        bottleneck (callable): a callable that constructs bottleneck block layer.
1107            Examples include: create_bottleneck_block.
1108
1109        head_pool (callable): a callable that constructs resnet head pooling layer.
1110        head_pool_kernel_size (tuple): the pooling kernel size.
1111        head_output_size (tuple): the size of output tensor for head.
1112        head_activation (callable): a callable that constructs activation layer.
1113        head_output_with_global_average (bool): if True, perform global averaging on
1114            the head output.
1115
1116    Returns:
1117        (nn.Module): audio resnet, that takes spectragram image input with
1118            shape: (B, C, T, 1, F), where T is the time dimension and F is the
1119            frequency dimension.
1120    """
1121    return create_resnet(**locals())
1122
1123
1124class ResBlock(nn.Module):
1125    """
1126    Residual block. Performs a summation between an identity shortcut in branch1 and a
1127    main block in branch2. When the input and output dimensions are different, a
1128    convolution followed by a normalization will be performed.
1129
1130    ::
1131
1132
1133                                         Input
1134                                           |-------+
1135                                           ↓       |
1136                                         Block     |
1137                                           ↓       |
1138                                       Summation ←-+
11391140                                       Activation
1141
1142    The builder can be found in `create_res_block`.
1143    """
1144
1145    def __init__(
1146        self,
1147        branch1_conv: nn.Module = None,
1148        branch1_norm: nn.Module = None,
1149        branch2: nn.Module = None,
1150        activation: nn.Module = None,
1151        branch_fusion: Callable = None,
1152    ) -> nn.Module:
1153        """
1154        Args:
1155            branch1_conv (torch.nn.modules): convolutional module in branch1.
1156            branch1_norm (torch.nn.modules): normalization module in branch1.
1157            branch2 (torch.nn.modules): bottleneck block module in branch2.
1158            activation (torch.nn.modules): activation module.
1159            branch_fusion: (Callable): A callable or layer that combines branch1
1160                and branch2.
1161        """
1162        super().__init__()
1163        set_attributes(self, locals())
1164        assert self.branch2 is not None
1165
1166    def forward(self, x) -> torch.Tensor:
1167        if self.branch1_conv is None:
1168            x = self.branch_fusion(x, self.branch2(x))
1169        else:
1170            shortcut = self.branch1_conv(x)
1171            if self.branch1_norm is not None:
1172                shortcut = self.branch1_norm(shortcut)
1173            x = self.branch_fusion(shortcut, self.branch2(x))
1174        if self.activation is not None:
1175            x = self.activation(x)
1176        return x
1177
1178
1179class SeparableBottleneckBlock(nn.Module):
1180    """
1181    Separable Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
1182    and Activations repeated in the following order. Requires a tuple of models to be
1183    provided to conv_b, norm_b, act_b to perform Convolution, Normalization, and
1184    Activations in parallel Separably.
1185
1186    ::
1187
1188
1189                                    Conv3d (conv_a)
11901191                                 Normalization (norm_a)
11921193                                   Activation (act_a)
11941195                                 Conv3d(s) (conv_b), ...
1196                                         ↓ (↓)
1197                              Normalization(s) (norm_b), ...
1198                                         ↓ (↓)
1199                                 Activation(s) (act_b), ...
1200                                         ↓ (↓)
1201                                  Reduce (sum or cat)
12021203                                    Conv3d (conv_c)
12041205                                 Normalization (norm_c)
1206    """
1207
1208    def __init__(
1209        self,
1210        *,
1211        conv_a: nn.Module,
1212        norm_a: nn.Module,
1213        act_a: nn.Module,
1214        conv_b: nn.ModuleList,
1215        norm_b: nn.ModuleList,
1216        act_b: nn.ModuleList,
1217        conv_c: nn.Module,
1218        norm_c: nn.Module,
1219        reduce_method: str = "sum",
1220    ) -> None:
1221        """
1222        Args:
1223            conv_a (torch.nn.modules): convolutional module.
1224            norm_a (torch.nn.modules): normalization module.
1225            act_a (torch.nn.modules): activation module.
1226            conv_b (torch.nn.modules_list): convolutional module(s).
1227            norm_b (torch.nn.modules_list): normalization module(s).
1228            act_b (torch.nn.modules_list): activation module(s).
1229            conv_c (torch.nn.modules): convolutional module.
1230            norm_c (torch.nn.modules): normalization module.
1231            reduce_method (str): if multiple conv_b is used, reduce the output with
1232                `sum`, or `cat`.
1233        """
1234        super().__init__()
1235        set_attributes(self, locals())
1236        assert all(
1237            op is not None for op in (self.conv_b, self.conv_c)
1238        ), f"{self.conv_a}, {self.conv_b}, {self.conv_c} has None"
1239        assert reduce_method in ["sum", "cat"]
1240        if self.norm_c is not None:
1241            # This flag is used for weight initialization.
1242            self.norm_c.block_final_bn = True
1243
1244    def forward(self, x: torch.Tensor) -> torch.Tensor:
1245        # Explicitly forward every layer.
1246        # Branch2a, for example Tx1x1, BN, ReLU.
1247        if self.conv_a is not None:
1248            x = self.conv_a(x)
1249        if self.norm_a is not None:
1250            x = self.norm_a(x)
1251        if self.act_a is not None:
1252            x = self.act_a(x)
1253
1254        # Branch2b, for example 1xHxW, BN, ReLU.
1255        output = []
1256        for ind in range(len(self.conv_b)):
1257            x_ = self.conv_b[ind](x)
1258            if self.norm_b[ind] is not None:
1259                x_ = self.norm_b[ind](x_)
1260            if self.act_b[ind] is not None:
1261                x_ = self.act_b[ind](x_)
1262            output.append(x_)
1263        if self.reduce_method == "sum":
1264            x = torch.stack(output, dim=0).sum(dim=0, keepdim=False)
1265        elif self.reduce_method == "cat":
1266            x = torch.cat(output, dim=1)
1267
1268        # Branch2c, for example 1x1x1, BN.
1269        x = self.conv_c(x)
1270        if self.norm_c is not None:
1271            x = self.norm_c(x)
1272        return x
1273
1274
1275class BottleneckBlock(nn.Module):
1276    """
1277    Bottleneck block: a sequence of spatiotemporal Convolution, Normalization,
1278    and Activations repeated in the following order:
1279
1280    ::
1281
1282
1283                                    Conv3d (conv_a)
12841285                                 Normalization (norm_a)
12861287                                   Activation (act_a)
12881289                                    Conv3d (conv_b)
12901291                                 Normalization (norm_b)
12921293                                   Activation (act_b)
12941295                                    Conv3d (conv_c)
12961297                                 Normalization (norm_c)
1298
1299    The builder can be found in `create_bottleneck_block`.
1300    """
1301
1302    def __init__(
1303        self,
1304        *,
1305        conv_a: nn.Module = None,
1306        norm_a: nn.Module = None,
1307        act_a: nn.Module = None,
1308        conv_b: nn.Module = None,
1309        norm_b: nn.Module = None,
1310        act_b: nn.Module = None,
1311        conv_c: nn.Module = None,
1312        norm_c: nn.Module = None,
1313    ) -> None:
1314        """
1315        Args:
1316            conv_a (torch.nn.modules): convolutional module.
1317            norm_a (torch.nn.modules): normalization module.
1318            act_a (torch.nn.modules): activation module.
1319            conv_b (torch.nn.modules): convolutional module.
1320            norm_b (torch.nn.modules): normalization module.
1321            act_b (torch.nn.modules): activation module.
1322            conv_c (torch.nn.modules): convolutional module.
1323            norm_c (torch.nn.modules): normalization module.
1324        """
1325        super().__init__()
1326        set_attributes(self, locals())
1327        assert all(op is not None for op in (self.conv_a, self.conv_b, self.conv_c))
1328        if self.norm_c is not None:
1329            # This flag is used for weight initialization.
1330            self.norm_c.block_final_bn = True
1331
1332    def forward(self, x: torch.Tensor) -> torch.Tensor:
1333        # Explicitly forward every layer.
1334        # Branch2a, for example Tx1x1, BN, ReLU.
1335        x = self.conv_a(x)
1336        if self.norm_a is not None:
1337            x = self.norm_a(x)
1338        if self.act_a is not None:
1339            x = self.act_a(x)
1340
1341        # Branch2b, for example 1xHxW, BN, ReLU.
1342        x = self.conv_b(x)
1343        if self.norm_b is not None:
1344            x = self.norm_b(x)
1345        if self.act_b is not None:
1346            x = self.act_b(x)
1347
1348        # Branch2c, for example 1x1x1, BN.
1349        x = self.conv_c(x)
1350        if self.norm_c is not None:
1351            x = self.norm_c(x)
1352        return x
1353
1354
1355class ResStage(nn.Module):
1356    """
1357    ResStage composes sequential blocks that make up a ResNet. These blocks could be,
1358    for example, Residual blocks, Non-Local layers, or Squeeze-Excitation layers.
1359
1360    ::
1361
1362
1363                                        Input
13641365                                       ResBlock
13661367                                           .
1368                                           .
1369                                           .
13701371                                       ResBlock
1372
1373    The builder can be found in `create_res_stage`.
1374    """
1375
1376    def __init__(self, res_blocks: nn.ModuleList) -> nn.Module:
1377        """
1378        Args:
1379            res_blocks (torch.nn.module_list): ResBlock module(s).
1380        """
1381        super().__init__()
1382        self.res_blocks = res_blocks
1383
1384    def forward(self, x: torch.Tensor) -> torch.Tensor:
1385        for _, res_block in enumerate(self.res_blocks):
1386            x = res_block(x)
1387        return x
1388