1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3import math 4from typing import Callable, Tuple 5 6import numpy as np 7import torch 8import torch.nn as nn 9from fvcore.nn.squeeze_excitation import SqueezeExcitation 10from pytorchvideo.layers.convolutions import Conv2plus1d 11from pytorchvideo.layers.swish import Swish 12from pytorchvideo.layers.utils import round_repeats, round_width, set_attributes 13from pytorchvideo.models.head import ResNetBasicHead 14from pytorchvideo.models.net import Net 15from pytorchvideo.models.resnet import BottleneckBlock, ResBlock, ResStage 16from pytorchvideo.models.stem import ResNetBasicStem 17 18 19def create_x3d_stem( 20 *, 21 # Conv configs. 22 in_channels: int, 23 out_channels: int, 24 conv_kernel_size: Tuple[int] = (5, 3, 3), 25 conv_stride: Tuple[int] = (1, 2, 2), 26 conv_padding: Tuple[int] = (2, 1, 1), 27 # BN configs. 28 norm: Callable = nn.BatchNorm3d, 29 norm_eps: float = 1e-5, 30 norm_momentum: float = 0.1, 31 # Activation configs. 32 activation: Callable = nn.ReLU, 33) -> nn.Module: 34 """ 35 Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu. 36 37 :: 38 39 Conv_xy 40 ↓ 41 Conv_t 42 ↓ 43 Normalization 44 ↓ 45 Activation 46 47 Args: 48 in_channels (int): input channel size of the convolution. 49 out_channels (int): output channel size of the convolution. 50 conv_kernel_size (tuple): convolutional kernel size(s). 51 conv_stride (tuple): convolutional stride size(s). 52 conv_padding (tuple): convolutional padding size(s). 53 54 norm (callable): a callable that constructs normalization layer, options 55 include nn.BatchNorm3d, None (not performing normalization). 56 norm_eps (float): normalization epsilon. 57 norm_momentum (float): normalization momentum. 58 59 activation (callable): a callable that constructs activation layer, options 60 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 61 activation). 62 63 Returns: 64 (nn.Module): X3D stem layer. 65 """ 66 conv_xy_module = nn.Conv3d( 67 in_channels=in_channels, 68 out_channels=out_channels, 69 kernel_size=(1, conv_kernel_size[1], conv_kernel_size[2]), 70 stride=(1, conv_stride[1], conv_stride[2]), 71 padding=(0, conv_padding[1], conv_padding[2]), 72 bias=False, 73 ) 74 conv_t_module = nn.Conv3d( 75 in_channels=out_channels, 76 out_channels=out_channels, 77 kernel_size=(conv_kernel_size[0], 1, 1), 78 stride=(conv_stride[0], 1, 1), 79 padding=(conv_padding[0], 0, 0), 80 bias=False, 81 groups=out_channels, 82 ) 83 stacked_conv_module = Conv2plus1d( 84 conv_t=conv_xy_module, 85 norm=None, 86 activation=None, 87 conv_xy=conv_t_module, 88 ) 89 90 norm_module = ( 91 None 92 if norm is None 93 else norm(num_features=out_channels, eps=norm_eps, momentum=norm_momentum) 94 ) 95 activation_module = None if activation is None else activation() 96 97 return ResNetBasicStem( 98 conv=stacked_conv_module, 99 norm=norm_module, 100 activation=activation_module, 101 pool=None, 102 ) 103 104 105def create_x3d_bottleneck_block( 106 *, 107 # Convolution configs. 108 dim_in: int, 109 dim_inner: int, 110 dim_out: int, 111 conv_kernel_size: Tuple[int] = (3, 3, 3), 112 conv_stride: Tuple[int] = (1, 2, 2), 113 # Norm configs. 114 norm: Callable = nn.BatchNorm3d, 115 norm_eps: float = 1e-5, 116 norm_momentum: float = 0.1, 117 se_ratio: float = 0.0625, 118 # Activation configs. 119 activation: Callable = nn.ReLU, 120 inner_act: Callable = Swish, 121) -> nn.Module: 122 """ 123 Bottleneck block for X3D: a sequence of Conv, Normalization with optional SE block, 124 and Activations repeated in the following order: 125 126 :: 127 128 Conv3d (conv_a) 129 ↓ 130 Normalization (norm_a) 131 ↓ 132 Activation (act_a) 133 ↓ 134 Conv3d (conv_b) 135 ↓ 136 Normalization (norm_b) 137 ↓ 138 Squeeze-and-Excitation 139 ↓ 140 Activation (act_b) 141 ↓ 142 Conv3d (conv_c) 143 ↓ 144 Normalization (norm_c) 145 146 Args: 147 dim_in (int): input channel size to the bottleneck block. 148 dim_inner (int): intermediate channel size of the bottleneck. 149 dim_out (int): output channel size of the bottleneck. 150 conv_kernel_size (tuple): convolutional kernel size(s) for conv_b. 151 conv_stride (tuple): convolutional stride size(s) for conv_b. 152 153 norm (callable): a callable that constructs normalization layer, examples 154 include nn.BatchNorm3d, None (not performing normalization). 155 norm_eps (float): normalization epsilon. 156 norm_momentum (float): normalization momentum. 157 se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE 158 channel dimensionality being se_ratio times the 3x3x3 conv dim. 159 160 activation (callable): a callable that constructs activation layer, examples 161 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 162 activation). 163 inner_act (callable): whether use Swish activation for act_b or not. 164 165 Returns: 166 (nn.Module): X3D bottleneck block. 167 """ 168 # 1x1x1 Conv 169 conv_a = nn.Conv3d( 170 in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False 171 ) 172 norm_a = ( 173 None 174 if norm is None 175 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 176 ) 177 act_a = None if activation is None else activation() 178 179 # 3x3x3 Conv 180 conv_b = nn.Conv3d( 181 in_channels=dim_inner, 182 out_channels=dim_inner, 183 kernel_size=conv_kernel_size, 184 stride=conv_stride, 185 padding=[size // 2 for size in conv_kernel_size], 186 bias=False, 187 groups=dim_inner, 188 dilation=(1, 1, 1), 189 ) 190 se = ( 191 SqueezeExcitation( 192 num_channels=dim_inner, 193 num_channels_reduced=round_width(dim_inner, se_ratio), 194 is_3d=True, 195 ) 196 if se_ratio > 0.0 197 else nn.Identity() 198 ) 199 norm_b = nn.Sequential( 200 ( 201 nn.Identity() 202 if norm is None 203 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 204 ), 205 se, 206 ) 207 act_b = None if inner_act is None else inner_act() 208 209 # 1x1x1 Conv 210 conv_c = nn.Conv3d( 211 in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False 212 ) 213 norm_c = ( 214 None 215 if norm is None 216 else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum) 217 ) 218 219 return BottleneckBlock( 220 conv_a=conv_a, 221 norm_a=norm_a, 222 act_a=act_a, 223 conv_b=conv_b, 224 norm_b=norm_b, 225 act_b=act_b, 226 conv_c=conv_c, 227 norm_c=norm_c, 228 ) 229 230 231def create_x3d_res_block( 232 *, 233 # Bottleneck Block configs. 234 dim_in: int, 235 dim_inner: int, 236 dim_out: int, 237 bottleneck: Callable = create_x3d_bottleneck_block, 238 use_shortcut: bool = True, 239 # Conv configs. 240 conv_kernel_size: Tuple[int] = (3, 3, 3), 241 conv_stride: Tuple[int] = (1, 2, 2), 242 # Norm configs. 243 norm: Callable = nn.BatchNorm3d, 244 norm_eps: float = 1e-5, 245 norm_momentum: float = 0.1, 246 se_ratio: float = 0.0625, 247 # Activation configs. 248 activation: Callable = nn.ReLU, 249 inner_act: Callable = Swish, 250) -> nn.Module: 251 """ 252 Residual block for X3D. Performs a summation between an identity shortcut in branch1 and a 253 main block in branch2. When the input and output dimensions are different, a 254 convolution followed by a normalization will be performed. 255 256 :: 257 258 Input 259 |-------+ 260 ↓ | 261 Block | 262 ↓ | 263 Summation ←-+ 264 ↓ 265 Activation 266 267 Args: 268 dim_in (int): input channel size to the bottleneck block. 269 dim_inner (int): intermediate channel size of the bottleneck. 270 dim_out (int): output channel size of the bottleneck. 271 bottleneck (callable): a callable for create_x3d_bottleneck_block. 272 273 conv_kernel_size (tuple): convolutional kernel size(s) for conv_b. 274 conv_stride (tuple): convolutional stride size(s) for conv_b. 275 276 norm (callable): a callable that constructs normalization layer, examples 277 include nn.BatchNorm3d, None (not performing normalization). 278 norm_eps (float): normalization epsilon. 279 norm_momentum (float): normalization momentum. 280 se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE 281 channel dimensionality being se_ratio times the 3x3x3 conv dim. 282 283 activation (callable): a callable that constructs activation layer, examples 284 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 285 activation). 286 inner_act (callable): whether use Swish activation for act_b or not. 287 288 Returns: 289 (nn.Module): X3D block layer. 290 """ 291 292 norm_model = None 293 if norm is not None and dim_in != dim_out: 294 norm_model = norm(num_features=dim_out) 295 296 return ResBlock( 297 branch1_conv=nn.Conv3d( 298 dim_in, 299 dim_out, 300 kernel_size=(1, 1, 1), 301 stride=conv_stride, 302 bias=False, 303 ) 304 if (dim_in != dim_out or np.prod(conv_stride) > 1) and use_shortcut 305 else None, 306 branch1_norm=norm_model if dim_in != dim_out and use_shortcut else None, 307 branch2=bottleneck( 308 dim_in=dim_in, 309 dim_inner=dim_inner, 310 dim_out=dim_out, 311 conv_kernel_size=conv_kernel_size, 312 conv_stride=conv_stride, 313 norm=norm, 314 norm_eps=norm_eps, 315 norm_momentum=norm_momentum, 316 se_ratio=se_ratio, 317 activation=activation, 318 inner_act=inner_act, 319 ), 320 activation=None if activation is None else activation(), 321 branch_fusion=lambda x, y: x + y, 322 ) 323 324 325def create_x3d_res_stage( 326 *, 327 # Stage configs. 328 depth: int, 329 # Bottleneck Block configs. 330 dim_in: int, 331 dim_inner: int, 332 dim_out: int, 333 bottleneck: Callable = create_x3d_bottleneck_block, 334 # Conv configs. 335 conv_kernel_size: Tuple[int] = (3, 3, 3), 336 conv_stride: Tuple[int] = (1, 2, 2), 337 # Norm configs. 338 norm: Callable = nn.BatchNorm3d, 339 norm_eps: float = 1e-5, 340 norm_momentum: float = 0.1, 341 se_ratio: float = 0.0625, 342 # Activation configs. 343 activation: Callable = nn.ReLU, 344 inner_act: Callable = Swish, 345) -> nn.Module: 346 """ 347 Create Residual Stage, which composes sequential blocks that make up X3D. 348 349 :: 350 351 Input 352 ↓ 353 ResBlock 354 ↓ 355 . 356 . 357 . 358 ↓ 359 ResBlock 360 361 Args: 362 363 depth (init): number of blocks to create. 364 365 dim_in (int): input channel size to the bottleneck block. 366 dim_inner (int): intermediate channel size of the bottleneck. 367 dim_out (int): output channel size of the bottleneck. 368 bottleneck (callable): a callable for create_x3d_bottleneck_block. 369 370 conv_kernel_size (tuple): convolutional kernel size(s) for conv_b. 371 conv_stride (tuple): convolutional stride size(s) for conv_b. 372 373 norm (callable): a callable that constructs normalization layer, examples 374 include nn.BatchNorm3d, None (not performing normalization). 375 norm_eps (float): normalization epsilon. 376 norm_momentum (float): normalization momentum. 377 se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE 378 channel dimensionality being se_ratio times the 3x3x3 conv dim. 379 380 activation (callable): a callable that constructs activation layer, examples 381 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 382 activation). 383 inner_act (callable): whether use Swish activation for act_b or not. 384 385 Returns: 386 (nn.Module): X3D stage layer. 387 """ 388 res_blocks = [] 389 for idx in range(depth): 390 block = create_x3d_res_block( 391 dim_in=dim_in if idx == 0 else dim_out, 392 dim_inner=dim_inner, 393 dim_out=dim_out, 394 bottleneck=bottleneck, 395 conv_kernel_size=conv_kernel_size, 396 conv_stride=conv_stride if idx == 0 else (1, 1, 1), 397 norm=norm, 398 norm_eps=norm_eps, 399 norm_momentum=norm_momentum, 400 se_ratio=(se_ratio if (idx + 1) % 2 else 0.0), 401 activation=activation, 402 inner_act=inner_act, 403 ) 404 res_blocks.append(block) 405 406 return ResStage(res_blocks=nn.ModuleList(res_blocks)) 407 408 409def create_x3d_head( 410 *, 411 # Projection configs. 412 dim_in: int, 413 dim_inner: int, 414 dim_out: int, 415 num_classes: int, 416 # Pooling configs. 417 pool_act: Callable = nn.ReLU, 418 pool_kernel_size: Tuple[int] = (13, 5, 5), 419 # BN configs. 420 norm: Callable = nn.BatchNorm3d, 421 norm_eps: float = 1e-5, 422 norm_momentum: float = 0.1, 423 bn_lin5_on=False, 424 # Dropout configs. 425 dropout_rate: float = 0.5, 426 # Activation configs. 427 activation: Callable = nn.Softmax, 428 # Output configs. 429 output_with_global_average: bool = True, 430) -> nn.Module: 431 """ 432 Creates X3D head. This layer performs an projected pooling operation followed 433 by an dropout, a fully-connected projection, an activation layer and a global 434 spatiotemporal averaging. 435 436 :: 437 438 ProjectedPool 439 ↓ 440 Dropout 441 ↓ 442 Projection 443 ↓ 444 Activation 445 ↓ 446 Averaging 447 448 Args: 449 dim_in (int): input channel size of the X3D head. 450 dim_inner (int): intermediate channel size of the X3D head. 451 dim_out (int): output channel size of the X3D head. 452 num_classes (int): the number of classes for the video dataset. 453 454 pool_act (callable): a callable that constructs resnet pool activation 455 layer such as nn.ReLU. 456 pool_kernel_size (tuple): pooling kernel size(s) when not using adaptive 457 pooling. 458 459 norm (callable): a callable that constructs normalization layer, examples 460 include nn.BatchNorm3d, None (not performing normalization). 461 norm_eps (float): normalization epsilon. 462 norm_momentum (float): normalization momentum. 463 bn_lin5_on (bool): if True, perform normalization on the features 464 before the classifier. 465 466 dropout_rate (float): dropout rate. 467 468 activation (callable): a callable that constructs resnet head activation 469 layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not 470 applying activation). 471 472 output_with_global_average (bool): if True, perform global averaging on temporal 473 and spatial dimensions and reshape output to batch_size x out_features. 474 475 Returns: 476 (nn.Module): X3D head layer. 477 """ 478 pre_conv_module = nn.Conv3d( 479 in_channels=dim_in, out_channels=dim_inner, kernel_size=(1, 1, 1), bias=False 480 ) 481 482 pre_norm_module = norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 483 pre_act_module = None if pool_act is None else pool_act() 484 485 if pool_kernel_size is None: 486 pool_module = nn.AdaptiveAvgPool3d((1, 1, 1)) 487 else: 488 pool_module = nn.AvgPool3d(pool_kernel_size, stride=1) 489 490 post_conv_module = nn.Conv3d( 491 in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False 492 ) 493 494 if bn_lin5_on: 495 post_norm_module = norm( 496 num_features=dim_out, eps=norm_eps, momentum=norm_momentum 497 ) 498 else: 499 post_norm_module = None 500 post_act_module = None if pool_act is None else pool_act() 501 502 projected_pool_module = ProjectedPool( 503 pre_conv=pre_conv_module, 504 pre_norm=pre_norm_module, 505 pre_act=pre_act_module, 506 pool=pool_module, 507 post_conv=post_conv_module, 508 post_norm=post_norm_module, 509 post_act=post_act_module, 510 ) 511 512 if activation is None: 513 activation_module = None 514 elif activation == nn.Softmax: 515 activation_module = activation(dim=1) 516 elif activation == nn.Sigmoid: 517 activation_module = activation() 518 else: 519 raise NotImplementedError( 520 "{} is not supported as an activation" "function.".format(activation) 521 ) 522 523 if output_with_global_average: 524 output_pool = nn.AdaptiveAvgPool3d(1) 525 else: 526 output_pool = None 527 528 return ResNetBasicHead( 529 proj=nn.Linear(dim_out, num_classes, bias=True), 530 activation=activation_module, 531 pool=projected_pool_module, 532 dropout=nn.Dropout(dropout_rate) if dropout_rate > 0 else None, 533 output_pool=output_pool, 534 ) 535 536 537def create_x3d( 538 *, 539 # Input clip configs. 540 input_channel: int = 3, 541 input_clip_length: int = 13, 542 input_crop_size: int = 160, 543 # Model configs. 544 model_num_class: int = 400, 545 dropout_rate: float = 0.5, 546 width_factor: float = 2.0, 547 depth_factor: float = 2.2, 548 # Normalization configs. 549 norm: Callable = nn.BatchNorm3d, 550 norm_eps: float = 1e-5, 551 norm_momentum: float = 0.1, 552 # Activation configs. 553 activation: Callable = nn.ReLU, 554 # Stem configs. 555 stem_dim_in: int = 12, 556 stem_conv_kernel_size: Tuple[int] = (5, 3, 3), 557 stem_conv_stride: Tuple[int] = (1, 2, 2), 558 # Stage configs. 559 stage_conv_kernel_size: Tuple[Tuple[int]] = ( 560 (3, 3, 3), 561 (3, 3, 3), 562 (3, 3, 3), 563 (3, 3, 3), 564 ), 565 stage_spatial_stride: Tuple[int] = (2, 2, 2, 2), 566 stage_temporal_stride: Tuple[int] = (1, 1, 1, 1), 567 bottleneck: Callable = create_x3d_bottleneck_block, 568 bottleneck_factor: float = 2.25, 569 se_ratio: float = 0.0625, 570 inner_act: Callable = Swish, 571 # Head configs. 572 head_dim_out: int = 2048, 573 head_pool_act: Callable = nn.ReLU, 574 head_bn_lin5_on: bool = False, 575 head_activation: Callable = nn.Softmax, 576 head_output_with_global_average: bool = True, 577) -> nn.Module: 578 """ 579 X3D model builder. It builds a X3D network backbone, which is a ResNet. 580 581 Christoph Feichtenhofer. 582 "X3D: Expanding Architectures for Efficient Video Recognition." 583 https://arxiv.org/abs/2004.04730 584 585 :: 586 587 Input 588 ↓ 589 Stem 590 ↓ 591 Stage 1 592 ↓ 593 . 594 . 595 . 596 ↓ 597 Stage N 598 ↓ 599 Head 600 601 Args: 602 input_channel (int): number of channels for the input video clip. 603 input_clip_length (int): length of the input video clip. Value for 604 different models: X3D-XS: 4; X3D-S: 13; X3D-M: 16; X3D-L: 16. 605 input_crop_size (int): spatial resolution of the input video clip. 606 Value for different models: X3D-XS: 160; X3D-S: 160; X3D-M: 224; 607 X3D-L: 312. 608 609 model_num_class (int): the number of classes for the video dataset. 610 dropout_rate (float): dropout rate. 611 width_factor (float): width expansion factor. 612 depth_factor (float): depth expansion factor. Value for different 613 models: X3D-XS: 2.2; X3D-S: 2.2; X3D-M: 2.2; X3D-L: 5.0. 614 615 norm (callable): a callable that constructs normalization layer. 616 norm_eps (float): normalization epsilon. 617 norm_momentum (float): normalization momentum. 618 619 activation (callable): a callable that constructs activation layer. 620 621 stem_dim_in (int): input channel size for stem before expansion. 622 stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem. 623 stem_conv_stride (tuple): convolutional stride size(s) of stem. 624 625 stage_conv_kernel_size (tuple): convolutional kernel size(s) for conv_b. 626 stage_spatial_stride (tuple): the spatial stride for each stage. 627 stage_temporal_stride (tuple): the temporal stride for each stage. 628 bottleneck_factor (float): bottleneck expansion factor for the 3x3x3 conv. 629 se_ratio (float): if > 0, apply SE to the 3x3x3 conv, with the SE 630 channel dimensionality being se_ratio times the 3x3x3 conv dim. 631 inner_act (callable): whether use Swish activation for act_b or not. 632 633 head_dim_out (int): output channel size of the X3D head. 634 head_pool_act (callable): a callable that constructs resnet pool activation 635 layer such as nn.ReLU. 636 head_bn_lin5_on (bool): if True, perform normalization on the features 637 before the classifier. 638 head_activation (callable): a callable that constructs activation layer. 639 head_output_with_global_average (bool): if True, perform global averaging on 640 the head output. 641 642 Returns: 643 (nn.Module): the X3D network. 644 """ 645 646 torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_x3d") 647 648 blocks = [] 649 # Create stem for X3D. 650 stem_dim_out = round_width(stem_dim_in, width_factor) 651 stem = create_x3d_stem( 652 in_channels=input_channel, 653 out_channels=stem_dim_out, 654 conv_kernel_size=stem_conv_kernel_size, 655 conv_stride=stem_conv_stride, 656 conv_padding=[size // 2 for size in stem_conv_kernel_size], 657 norm=norm, 658 norm_eps=norm_eps, 659 norm_momentum=norm_momentum, 660 activation=activation, 661 ) 662 blocks.append(stem) 663 664 # Compute the depth and dimension for each stage 665 stage_depths = [1, 2, 5, 3] 666 exp_stage = 2.0 667 stage_dim1 = stem_dim_in 668 stage_dim2 = round_width(stage_dim1, exp_stage, divisor=8) 669 stage_dim3 = round_width(stage_dim2, exp_stage, divisor=8) 670 stage_dim4 = round_width(stage_dim3, exp_stage, divisor=8) 671 stage_dims = [stage_dim1, stage_dim2, stage_dim3, stage_dim4] 672 673 dim_in = stem_dim_out 674 # Create each stage for X3D. 675 for idx in range(len(stage_depths)): 676 dim_out = round_width(stage_dims[idx], width_factor) 677 dim_inner = int(bottleneck_factor * dim_out) 678 depth = round_repeats(stage_depths[idx], depth_factor) 679 680 stage_conv_stride = ( 681 stage_temporal_stride[idx], 682 stage_spatial_stride[idx], 683 stage_spatial_stride[idx], 684 ) 685 686 stage = create_x3d_res_stage( 687 depth=depth, 688 dim_in=dim_in, 689 dim_inner=dim_inner, 690 dim_out=dim_out, 691 bottleneck=bottleneck, 692 conv_kernel_size=stage_conv_kernel_size[idx], 693 conv_stride=stage_conv_stride, 694 norm=norm, 695 norm_eps=norm_eps, 696 norm_momentum=norm_momentum, 697 se_ratio=se_ratio, 698 activation=activation, 699 inner_act=inner_act, 700 ) 701 blocks.append(stage) 702 dim_in = dim_out 703 704 # Create head for X3D. 705 total_spatial_stride = stem_conv_stride[1] * np.prod(stage_spatial_stride) 706 total_temporal_stride = stem_conv_stride[0] * np.prod(stage_temporal_stride) 707 708 assert ( 709 input_clip_length >= total_temporal_stride 710 ), "Clip length doesn't match temporal stride!" 711 assert ( 712 input_crop_size >= total_spatial_stride 713 ), "Crop size doesn't match spatial stride!" 714 715 head_pool_kernel_size = ( 716 input_clip_length // total_temporal_stride, 717 int(math.ceil(input_crop_size / total_spatial_stride)), 718 int(math.ceil(input_crop_size / total_spatial_stride)), 719 ) 720 721 head = create_x3d_head( 722 dim_in=dim_out, 723 dim_inner=dim_inner, 724 dim_out=head_dim_out, 725 num_classes=model_num_class, 726 pool_act=head_pool_act, 727 pool_kernel_size=head_pool_kernel_size, 728 norm=norm, 729 norm_eps=norm_eps, 730 norm_momentum=norm_momentum, 731 bn_lin5_on=head_bn_lin5_on, 732 dropout_rate=dropout_rate, 733 activation=head_activation, 734 output_with_global_average=head_output_with_global_average, 735 ) 736 blocks.append(head) 737 return Net(blocks=nn.ModuleList(blocks)) 738 739 740class ProjectedPool(nn.Module): 741 """ 742 A pooling module augmented with Conv, Normalization and Activation both 743 before and after pooling for the head layer of X3D. 744 745 :: 746 747 Conv3d (pre_conv) 748 ↓ 749 Normalization (pre_norm) 750 ↓ 751 Activation (pre_act) 752 ↓ 753 Pool3d 754 ↓ 755 Conv3d (post_conv) 756 ↓ 757 Normalization (post_norm) 758 ↓ 759 Activation (post_act) 760 """ 761 762 def __init__( 763 self, 764 *, 765 pre_conv: nn.Module = None, 766 pre_norm: nn.Module = None, 767 pre_act: nn.Module = None, 768 pool: nn.Module = None, 769 post_conv: nn.Module = None, 770 post_norm: nn.Module = None, 771 post_act: nn.Module = None, 772 ) -> None: 773 """ 774 Args: 775 pre_conv (torch.nn.modules): convolutional module. 776 pre_norm (torch.nn.modules): normalization module. 777 pre_act (torch.nn.modules): activation module. 778 pool (torch.nn.modules): pooling module. 779 post_conv (torch.nn.modules): convolutional module. 780 post_norm (torch.nn.modules): normalization module. 781 post_act (torch.nn.modules): activation module. 782 """ 783 super().__init__() 784 set_attributes(self, locals()) 785 assert self.pre_conv is not None 786 assert self.pool is not None 787 assert self.post_conv is not None 788 789 def forward(self, x: torch.Tensor) -> torch.Tensor: 790 x = self.pre_conv(x) 791 792 if self.pre_norm is not None: 793 x = self.pre_norm(x) 794 if self.pre_act is not None: 795 x = self.pre_act(x) 796 797 x = self.pool(x) 798 x = self.post_conv(x) 799 800 if self.post_norm is not None: 801 x = self.post_norm(x) 802 if self.post_act is not None: 803 x = self.post_act(x) 804 return x 805