1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3from typing import Callable, List, Tuple, Union 4 5import numpy as np 6import torch 7import torch.nn as nn 8from pytorchvideo.layers.utils import set_attributes 9from pytorchvideo.models.head import create_res_basic_head, create_res_roi_pooling_head 10from pytorchvideo.models.net import Net, DetectionBBoxNetwork 11from pytorchvideo.models.stem import ( 12 create_acoustic_res_basic_stem, 13 create_res_basic_stem, 14) 15 16 17def create_bottleneck_block( 18 *, 19 # Convolution configs. 20 dim_in: int, 21 dim_inner: int, 22 dim_out: int, 23 conv_a_kernel_size: Tuple[int] = (3, 1, 1), 24 conv_a_stride: Tuple[int] = (2, 1, 1), 25 conv_a_padding: Tuple[int] = (1, 0, 0), 26 conv_a: Callable = nn.Conv3d, 27 conv_b_kernel_size: Tuple[int] = (1, 3, 3), 28 conv_b_stride: Tuple[int] = (1, 2, 2), 29 conv_b_padding: Tuple[int] = (0, 1, 1), 30 conv_b_num_groups: int = 1, 31 conv_b_dilation: Tuple[int] = (1, 1, 1), 32 conv_b: Callable = nn.Conv3d, 33 conv_c: Callable = nn.Conv3d, 34 # Norm configs. 35 norm: Callable = nn.BatchNorm3d, 36 norm_eps: float = 1e-5, 37 norm_momentum: float = 0.1, 38 # Activation configs. 39 activation: Callable = nn.ReLU, 40) -> nn.Module: 41 """ 42 Bottleneck block: a sequence of spatiotemporal Convolution, Normalization, 43 and Activations repeated in the following order: 44 45 :: 46 47 Conv3d (conv_a) 48 ↓ 49 Normalization (norm_a) 50 ↓ 51 Activation (act_a) 52 ↓ 53 Conv3d (conv_b) 54 ↓ 55 Normalization (norm_b) 56 ↓ 57 Activation (act_b) 58 ↓ 59 Conv3d (conv_c) 60 ↓ 61 Normalization (norm_c) 62 63 Normalization examples include: BatchNorm3d and None (no normalization). 64 Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation). 65 66 Args: 67 dim_in (int): input channel size to the bottleneck block. 68 dim_inner (int): intermediate channel size of the bottleneck. 69 dim_out (int): output channel size of the bottleneck. 70 conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 71 conv_a_stride (tuple): convolutional stride size(s) for conv_a. 72 conv_a_padding (tuple): convolutional padding(s) for conv_a. 73 conv_a (callable): a callable that constructs the conv_a conv layer, examples 74 include nn.Conv3d, OctaveConv, etc 75 conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 76 conv_b_stride (tuple): convolutional stride size(s) for conv_b. 77 conv_b_padding (tuple): convolutional padding(s) for conv_b. 78 conv_b_num_groups (int): number of groups for groupwise convolution for 79 conv_b. 80 conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 81 conv_b (callable): a callable that constructs the conv_b conv layer, examples 82 include nn.Conv3d, OctaveConv, etc 83 conv_c (callable): a callable that constructs the conv_c conv layer, examples 84 include nn.Conv3d, OctaveConv, etc 85 86 norm (callable): a callable that constructs normalization layer, examples 87 include nn.BatchNorm3d, None (not performing normalization). 88 norm_eps (float): normalization epsilon. 89 norm_momentum (float): normalization momentum. 90 91 activation (callable): a callable that constructs activation layer, examples 92 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 93 activation). 94 95 Returns: 96 (nn.Module): resnet bottleneck block. 97 """ 98 conv_a = conv_a( 99 in_channels=dim_in, 100 out_channels=dim_inner, 101 kernel_size=conv_a_kernel_size, 102 stride=conv_a_stride, 103 padding=conv_a_padding, 104 bias=False, 105 ) 106 norm_a = ( 107 None 108 if norm is None 109 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 110 ) 111 act_a = None if activation is None else activation() 112 113 conv_b = conv_b( 114 in_channels=dim_inner, 115 out_channels=dim_inner, 116 kernel_size=conv_b_kernel_size, 117 stride=conv_b_stride, 118 padding=conv_b_padding, 119 bias=False, 120 groups=conv_b_num_groups, 121 dilation=conv_b_dilation, 122 ) 123 norm_b = ( 124 None 125 if norm is None 126 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 127 ) 128 act_b = None if activation is None else activation() 129 130 conv_c = conv_c( 131 in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False 132 ) 133 norm_c = ( 134 None 135 if norm is None 136 else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum) 137 ) 138 139 return BottleneckBlock( 140 conv_a=conv_a, 141 norm_a=norm_a, 142 act_a=act_a, 143 conv_b=conv_b, 144 norm_b=norm_b, 145 act_b=act_b, 146 conv_c=conv_c, 147 norm_c=norm_c, 148 ) 149 150 151def create_acoustic_bottleneck_block( 152 *, 153 # Convolution configs. 154 dim_in: int, 155 dim_inner: int, 156 dim_out: int, 157 conv_a_kernel_size: Tuple[int] = (3, 1, 1), 158 conv_a_stride: Tuple[int] = (2, 1, 1), 159 conv_a_padding: Tuple[int] = (1, 0, 0), 160 conv_a: Callable = nn.Conv3d, 161 # Conv b f configs. 162 conv_b_kernel_size: Tuple[int] = (1, 1, 1), 163 conv_b_stride: Tuple[int] = (1, 1, 1), 164 conv_b_padding: Tuple[int] = (0, 0, 0), 165 conv_b_num_groups: int = 1, 166 conv_b_dilation: Tuple[int] = (1, 1, 1), 167 conv_b: Callable = nn.Conv3d, 168 conv_c: Callable = nn.Conv3d, 169 # Norm configs. 170 norm: Callable = nn.BatchNorm3d, 171 norm_eps: float = 1e-5, 172 norm_momentum: float = 0.1, 173 # Activation configs. 174 activation: Callable = nn.ReLU, 175) -> nn.Module: 176 """ 177 Acoustic Bottleneck block: a sequence of spatiotemporal Convolution, Normalization, 178 and Activations repeated in the following order: 179 180 :: 181 182 Conv3d (conv_a) 183 ↓ 184 Normalization (norm_a) 185 ↓ 186 Activation (act_a) 187 ↓ 188 --------------------------------- 189 ↓ ↓ 190 Temporal Conv3d (conv_b) Spatial Conv3d (conv_b) 191 ↓ ↓ 192 Normalization (norm_b) Normalization (norm_b) 193 ↓ ↓ 194 Activation (act_b) Activation (act_b) 195 ↓ ↓ 196 --------------------------------- 197 ↓ 198 Conv3d (conv_c) 199 ↓ 200 Normalization (norm_c) 201 202 Normalization examples include: BatchNorm3d and None (no normalization). 203 Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation). 204 205 Args: 206 dim_in (int): input channel size to the bottleneck block. 207 dim_inner (int): intermediate channel size of the bottleneck. 208 dim_out (int): output channel size of the bottleneck. 209 conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 210 conv_a_stride (tuple): convolutional stride size(s) for conv_a. 211 conv_a_padding (tuple): convolutional padding(s) for conv_a. 212 conv_a (callable): a callable that constructs the conv_a conv layer, examples 213 include nn.Conv3d, OctaveConv, etc 214 conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 215 conv_b_stride (tuple): convolutional stride size(s) for conv_b. 216 conv_b_padding (tuple): convolutional padding(s) for conv_b. 217 conv_b_num_groups (int): number of groups for groupwise convolution for 218 conv_b. 219 conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 220 conv_b (callable): a callable that constructs the conv_b conv layer, examples 221 include nn.Conv3d, OctaveConv, etc 222 conv_c (callable): a callable that constructs the conv_c conv layer, examples 223 include nn.Conv3d, OctaveConv, etc 224 225 norm (callable): a callable that constructs normalization layer, examples 226 include nn.BatchNorm3d, None (not performing normalization). 227 norm_eps (float): normalization epsilon. 228 norm_momentum (float): normalization momentum. 229 230 activation (callable): a callable that constructs activation layer, examples 231 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 232 activation). 233 234 Returns: 235 (nn.Module): resnet acoustic bottleneck block. 236 """ 237 conv_a = conv_a( 238 in_channels=dim_in, 239 out_channels=dim_inner, 240 kernel_size=conv_a_kernel_size, 241 stride=conv_a_stride, 242 padding=conv_a_padding, 243 bias=False, 244 ) 245 norm_a = ( 246 None 247 if norm is None 248 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 249 ) 250 act_a = None if activation is None else activation() 251 252 conv_b_1_kernel_size = [conv_b_kernel_size[0], 1, 1] 253 conv_b_1_stride = conv_b_stride 254 conv_b_1_padding = [conv_b_padding[0], 0, 0] 255 256 conv_b_2_kernel_size = [1, conv_b_kernel_size[1], conv_b_kernel_size[2]] 257 conv_b_2_stride = conv_b_stride 258 conv_b_2_padding = [0, conv_b_padding[1], conv_b_padding[2]] 259 260 conv_b_1_num_groups, conv_b_2_num_groups = (conv_b_num_groups,) * 2 261 conv_b_1_dilation = [conv_b_dilation[0], 1, 1] 262 conv_b_2_dilation = [1, conv_b_dilation[1], conv_b_dilation[2]] 263 264 conv_b_1 = conv_b( 265 in_channels=dim_inner, 266 out_channels=dim_inner, 267 kernel_size=conv_b_1_kernel_size, 268 stride=conv_b_1_stride, 269 padding=conv_b_1_padding, 270 bias=False, 271 groups=conv_b_1_num_groups, 272 dilation=conv_b_1_dilation, 273 ) 274 norm_b_1 = ( 275 None 276 if norm is None 277 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 278 ) 279 act_b_1 = None if activation is None else activation() 280 281 conv_b_2 = conv_b( 282 in_channels=dim_inner, 283 out_channels=dim_inner, 284 kernel_size=conv_b_2_kernel_size, 285 stride=conv_b_2_stride, 286 padding=conv_b_2_padding, 287 bias=False, 288 groups=conv_b_2_num_groups, 289 dilation=conv_b_2_dilation, 290 ) 291 norm_b_2 = ( 292 None 293 if norm is None 294 else norm(num_features=dim_inner, eps=norm_eps, momentum=norm_momentum) 295 ) 296 act_b_2 = None if activation is None else activation() 297 298 conv_c = conv_c( 299 in_channels=dim_inner, out_channels=dim_out, kernel_size=(1, 1, 1), bias=False 300 ) 301 norm_c = ( 302 None 303 if norm is None 304 else norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum) 305 ) 306 307 return SeparableBottleneckBlock( 308 conv_a=conv_a, 309 norm_a=norm_a, 310 act_a=act_a, 311 conv_b=nn.ModuleList([conv_b_2, conv_b_1]), 312 norm_b=nn.ModuleList([norm_b_2, norm_b_1]), 313 act_b=nn.ModuleList([act_b_2, act_b_1]), 314 conv_c=conv_c, 315 norm_c=norm_c, 316 ) 317 318 319def create_res_block( 320 *, 321 # Bottleneck Block configs. 322 dim_in: int, 323 dim_inner: int, 324 dim_out: int, 325 bottleneck: Callable, 326 use_shortcut: bool = False, 327 branch_fusion: Callable = lambda x, y: x + y, 328 # Conv configs. 329 conv_a_kernel_size: Tuple[int] = (3, 1, 1), 330 conv_a_stride: Tuple[int] = (2, 1, 1), 331 conv_a_padding: Tuple[int] = (1, 0, 0), 332 conv_a: Callable = nn.Conv3d, 333 conv_b_kernel_size: Tuple[int] = (1, 3, 3), 334 conv_b_stride: Tuple[int] = (1, 2, 2), 335 conv_b_padding: Tuple[int] = (0, 1, 1), 336 conv_b_num_groups: int = 1, 337 conv_b_dilation: Tuple[int] = (1, 1, 1), 338 conv_b: Callable = nn.Conv3d, 339 conv_c: Callable = nn.Conv3d, 340 conv_skip: Callable = nn.Conv3d, 341 # Norm configs. 342 norm: Callable = nn.BatchNorm3d, 343 norm_eps: float = 1e-5, 344 norm_momentum: float = 0.1, 345 # Activation configs. 346 activation_bottleneck: Callable = nn.ReLU, 347 activation_block: Callable = nn.ReLU, 348) -> nn.Module: 349 """ 350 Residual block. Performs a summation between an identity shortcut in branch1 and a 351 main block in branch2. When the input and output dimensions are different, a 352 convolution followed by a normalization will be performed. 353 354 :: 355 356 357 Input 358 |-------+ 359 ↓ | 360 Block | 361 ↓ | 362 Summation ←-+ 363 ↓ 364 Activation 365 366 Normalization examples include: BatchNorm3d and None (no normalization). 367 Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation). 368 Transform examples include: BottleneckBlock. 369 370 Args: 371 dim_in (int): input channel size to the bottleneck block. 372 dim_inner (int): intermediate channel size of the bottleneck. 373 dim_out (int): output channel size of the bottleneck. 374 bottleneck (callable): a callable that constructs bottleneck block layer. 375 Examples include: create_bottleneck_block. 376 use_shortcut (bool): If true, use conv and norm layers in skip connection. 377 branch_fusion (callable): a callable that constructs summation layer. 378 Examples include: lambda x, y: x + y, OctaveSum. 379 380 conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 381 conv_a_stride (tuple): convolutional stride size(s) for conv_a. 382 conv_a_padding (tuple): convolutional padding(s) for conv_a. 383 conv_a (callable): a callable that constructs the conv_a conv layer, examples 384 include nn.Conv3d, OctaveConv, etc 385 conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 386 conv_b_stride (tuple): convolutional stride size(s) for conv_b. 387 conv_b_padding (tuple): convolutional padding(s) for conv_b. 388 conv_b_num_groups (int): number of groups for groupwise convolution for 389 conv_b. 390 conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 391 conv_b (callable): a callable that constructs the conv_b conv layer, examples 392 include nn.Conv3d, OctaveConv, etc 393 conv_c (callable): a callable that constructs the conv_c conv layer, examples 394 include nn.Conv3d, OctaveConv, etc 395 conv_skip (callable): a callable that constructs the conv_skip conv layer, 396 examples include nn.Conv3d, OctaveConv, etc 397 398 norm (callable): a callable that constructs normalization layer. Examples 399 include nn.BatchNorm3d, None (not performing normalization). 400 norm_eps (float): normalization epsilon. 401 norm_momentum (float): normalization momentum. 402 403 activation_bottleneck (callable): a callable that constructs activation layer in 404 bottleneck. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None 405 (not performing activation). 406 activation_block (callable): a callable that constructs activation layer used 407 at the end of the block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, 408 and None (not performing activation). 409 410 Returns: 411 (nn.Module): resnet basic block layer. 412 """ 413 branch1_conv_stride = tuple(map(np.prod, zip(conv_a_stride, conv_b_stride))) 414 norm_model = None 415 if use_shortcut or ( 416 norm is not None and (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) 417 ): 418 norm_model = norm(num_features=dim_out, eps=norm_eps, momentum=norm_momentum) 419 420 return ResBlock( 421 branch1_conv=conv_skip( 422 dim_in, 423 dim_out, 424 kernel_size=(1, 1, 1), 425 stride=branch1_conv_stride, 426 bias=False, 427 ) 428 if (dim_in != dim_out or np.prod(branch1_conv_stride) != 1) or use_shortcut 429 else None, 430 branch1_norm=norm_model, 431 branch2=bottleneck( 432 dim_in=dim_in, 433 dim_inner=dim_inner, 434 dim_out=dim_out, 435 conv_a_kernel_size=conv_a_kernel_size, 436 conv_a_stride=conv_a_stride, 437 conv_a_padding=conv_a_padding, 438 conv_a=conv_a, 439 conv_b_kernel_size=conv_b_kernel_size, 440 conv_b_stride=conv_b_stride, 441 conv_b_padding=conv_b_padding, 442 conv_b_num_groups=conv_b_num_groups, 443 conv_b_dilation=conv_b_dilation, 444 conv_b=conv_b, 445 conv_c=conv_c, 446 norm=norm, 447 norm_eps=norm_eps, 448 norm_momentum=norm_momentum, 449 activation=activation_bottleneck, 450 ), 451 activation=None if activation_block is None else activation_block(), 452 branch_fusion=branch_fusion, 453 ) 454 455 456def create_res_stage( 457 *, 458 # Stage configs. 459 depth: int, 460 # Bottleneck Block configs. 461 dim_in: int, 462 dim_inner: int, 463 dim_out: int, 464 bottleneck: Callable, 465 # Conv configs. 466 conv_a_kernel_size: Union[Tuple[int], List[Tuple[int]]] = (3, 1, 1), 467 conv_a_stride: Tuple[int] = (2, 1, 1), 468 conv_a_padding: Union[Tuple[int], List[Tuple[int]]] = (1, 0, 0), 469 conv_a: Callable = nn.Conv3d, 470 conv_b_kernel_size: Tuple[int] = (1, 3, 3), 471 conv_b_stride: Tuple[int] = (1, 2, 2), 472 conv_b_padding: Tuple[int] = (0, 1, 1), 473 conv_b_num_groups: int = 1, 474 conv_b_dilation: Tuple[int] = (1, 1, 1), 475 conv_b: Callable = nn.Conv3d, 476 conv_c: Callable = nn.Conv3d, 477 # Norm configs. 478 norm: Callable = nn.BatchNorm3d, 479 norm_eps: float = 1e-5, 480 norm_momentum: float = 0.1, 481 # Activation configs. 482 activation: Callable = nn.ReLU, 483) -> nn.Module: 484 """ 485 Create Residual Stage, which composes sequential blocks that make up a ResNet. These 486 blocks could be, for example, Residual blocks, Non-Local layers, or 487 Squeeze-Excitation layers. 488 489 :: 490 491 492 Input 493 ↓ 494 ResBlock 495 ↓ 496 . 497 . 498 . 499 ↓ 500 ResBlock 501 502 Normalization examples include: BatchNorm3d and None (no normalization). 503 Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation). 504 Bottleneck examples include: create_bottleneck_block. 505 506 Args: 507 depth (init): number of blocks to create. 508 509 dim_in (int): input channel size to the bottleneck block. 510 dim_inner (int): intermediate channel size of the bottleneck. 511 dim_out (int): output channel size of the bottleneck. 512 bottleneck (callable): a callable that constructs bottleneck block layer. 513 Examples include: create_bottleneck_block. 514 515 conv_a_kernel_size (tuple or list of tuple): convolutional kernel size(s) 516 for conv_a. If conv_a_kernel_size is a tuple, use it for all blocks in 517 the stage. If conv_a_kernel_size is a list of tuple, the kernel sizes 518 will be repeated until having same length of depth in the stage. For 519 example, for conv_a_kernel_size = [(3, 1, 1), (1, 1, 1)], the kernel 520 size for the first 6 blocks would be [(3, 1, 1), (1, 1, 1), (3, 1, 1), 521 (1, 1, 1), (3, 1, 1)]. 522 conv_a_stride (tuple): convolutional stride size(s) for conv_a. 523 conv_a_padding (tuple or list of tuple): convolutional padding(s) for 524 conv_a. If conv_a_padding is a tuple, use it for all blocks in 525 the stage. If conv_a_padding is a list of tuple, the padding sizes 526 will be repeated until having same length of depth in the stage. 527 conv_a (callable): a callable that constructs the conv_a conv layer, examples 528 include nn.Conv3d, OctaveConv, etc 529 conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 530 conv_b_stride (tuple): convolutional stride size(s) for conv_b. 531 conv_b_padding (tuple): convolutional padding(s) for conv_b. 532 conv_b_num_groups (int): number of groups for groupwise convolution for 533 conv_b. 534 conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 535 conv_b (callable): a callable that constructs the conv_b conv layer, examples 536 include nn.Conv3d, OctaveConv, etc 537 conv_c (callable): a callable that constructs the conv_c conv layer, examples 538 include nn.Conv3d, OctaveConv, etc 539 540 norm (callable): a callable that constructs normalization layer. Examples 541 include nn.BatchNorm3d, and None (not performing normalization). 542 norm_eps (float): normalization epsilon. 543 norm_momentum (float): normalization momentum. 544 545 activation (callable): a callable that constructs activation layer. Examples 546 include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing 547 activation). 548 549 Returns: 550 (nn.Module): resnet basic stage layer. 551 """ 552 res_blocks = [] 553 if isinstance(conv_a_kernel_size[0], int): 554 conv_a_kernel_size = [conv_a_kernel_size] 555 if isinstance(conv_a_padding[0], int): 556 conv_a_padding = [conv_a_padding] 557 # Repeat conv_a kernels until having same length of depth in the stage. 558 conv_a_kernel_size = (conv_a_kernel_size * depth)[:depth] 559 conv_a_padding = (conv_a_padding * depth)[:depth] 560 561 for ind in range(depth): 562 block = create_res_block( 563 dim_in=dim_in if ind == 0 else dim_out, 564 dim_inner=dim_inner, 565 dim_out=dim_out, 566 bottleneck=bottleneck, 567 conv_a_kernel_size=conv_a_kernel_size[ind], 568 conv_a_stride=conv_a_stride if ind == 0 else (1, 1, 1), 569 conv_a_padding=conv_a_padding[ind], 570 conv_a=conv_a, 571 conv_b_kernel_size=conv_b_kernel_size, 572 conv_b_stride=conv_b_stride if ind == 0 else (1, 1, 1), 573 conv_b_padding=conv_b_padding, 574 conv_b_num_groups=conv_b_num_groups, 575 conv_b_dilation=conv_b_dilation, 576 conv_b=conv_b, 577 conv_c=conv_c, 578 norm=norm, 579 norm_eps=norm_eps, 580 norm_momentum=norm_momentum, 581 activation_bottleneck=activation, 582 activation_block=activation, 583 ) 584 res_blocks.append(block) 585 return ResStage(res_blocks=nn.ModuleList(res_blocks)) 586 587 588# Number of blocks for different stages given the model depth. 589_MODEL_STAGE_DEPTH = {50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} 590 591 592def create_resnet( 593 *, 594 # Input clip configs. 595 input_channel: int = 3, 596 # Model configs. 597 model_depth: int = 50, 598 model_num_class: int = 400, 599 dropout_rate: float = 0.5, 600 # Normalization configs. 601 norm: Callable = nn.BatchNorm3d, 602 # Activation configs. 603 activation: Callable = nn.ReLU, 604 # Stem configs. 605 stem_dim_out: int = 64, 606 stem_conv_kernel_size: Tuple[int] = (3, 7, 7), 607 stem_conv_stride: Tuple[int] = (1, 2, 2), 608 stem_pool: Callable = nn.MaxPool3d, 609 stem_pool_kernel_size: Tuple[int] = (1, 3, 3), 610 stem_pool_stride: Tuple[int] = (1, 2, 2), 611 stem: Callable = create_res_basic_stem, 612 # Stage configs. 613 stage1_pool: Callable = None, 614 stage1_pool_kernel_size: Tuple[int] = (2, 1, 1), 615 stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = ( 616 (1, 1, 1), 617 (1, 1, 1), 618 (3, 1, 1), 619 (3, 1, 1), 620 ), 621 stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = ( 622 (1, 3, 3), 623 (1, 3, 3), 624 (1, 3, 3), 625 (1, 3, 3), 626 ), 627 stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1), 628 stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = ( 629 (1, 1, 1), 630 (1, 1, 1), 631 (1, 1, 1), 632 (1, 1, 1), 633 ), 634 stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 2), 635 stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2), 636 stage_temporal_stride: Tuple[int] = (1, 1, 1, 1), 637 bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block, 638 # Head configs. 639 head: Callable = create_res_basic_head, 640 head_pool: Callable = nn.AvgPool3d, 641 head_pool_kernel_size: Tuple[int] = (4, 7, 7), 642 head_output_size: Tuple[int] = (1, 1, 1), 643 head_activation: Callable = None, 644 head_output_with_global_average: bool = True, 645) -> nn.Module: 646 """ 647 Build ResNet style models for video recognition. ResNet has three parts: 648 Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an 649 optional pooling layer. Stages are grouped residual blocks. There are usually 650 multiple stages and each stage may include multiple residual blocks. Head 651 may include pooling, dropout, a fully-connected layer and global spatial 652 temporal averaging. The three parts are assembled in the following order: 653 654 :: 655 656 Input 657 ↓ 658 Stem 659 ↓ 660 Stage 1 661 ↓ 662 . 663 . 664 . 665 ↓ 666 Stage N 667 ↓ 668 Head 669 670 Args: 671 672 input_channel (int): number of channels for the input video clip. 673 674 model_depth (int): the depth of the resnet. Options include: 50, 101, 152. 675 model_num_class (int): the number of classes for the video dataset. 676 dropout_rate (float): dropout rate. 677 678 679 norm (callable): a callable that constructs normalization layer. 680 681 activation (callable): a callable that constructs activation layer. 682 683 stem_dim_out (int): output channel size to stem. 684 stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem. 685 stem_conv_stride (tuple): convolutional stride size(s) of stem. 686 stem_pool (callable): a callable that constructs resnet head pooling layer. 687 stem_pool_kernel_size (tuple): pooling kernel size(s). 688 stem_pool_stride (tuple): pooling stride size(s). 689 stem (callable): a callable that constructs stem layer. 690 Examples include: create_res_video_stem. 691 692 stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 693 stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 694 stage_conv_b_num_groups (tuple): number of groups for groupwise convolution 695 for conv_b. 1 for ResNet, and larger than 1 for ResNeXt. 696 stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 697 stage_spatial_h_stride (tuple): the spatial height stride for each stage. 698 stage_spatial_w_stride (tuple): the spatial width stride for each stage. 699 stage_temporal_stride (tuple): the temporal stride for each stage. 700 bottleneck (callable): a callable that constructs bottleneck block layer. 701 Examples include: create_bottleneck_block. 702 703 head (callable): a callable that constructs the resnet-style head. 704 Ex: create_res_basic_head 705 head_pool (callable): a callable that constructs resnet head pooling layer. 706 head_pool_kernel_size (tuple): the pooling kernel size. 707 head_output_size (tuple): the size of output tensor for head. 708 head_activation (callable): a callable that constructs activation layer. 709 head_output_with_global_average (bool): if True, perform global averaging on 710 the head output. 711 712 Returns: 713 (nn.Module): basic resnet. 714 """ 715 716 torch._C._log_api_usage_once("PYTORCHVIDEO.model.create_resnet") 717 718 # Given a model depth, get the number of blocks for each stage. 719 assert ( 720 model_depth in _MODEL_STAGE_DEPTH.keys() 721 ), f"{model_depth} is not in {_MODEL_STAGE_DEPTH.keys()}" 722 stage_depths = _MODEL_STAGE_DEPTH[model_depth] 723 724 # Broadcast single element to tuple if given. 725 if isinstance(stage_conv_a_kernel_size[0], int): 726 stage_conv_a_kernel_size = (stage_conv_a_kernel_size,) * len(stage_depths) 727 728 if isinstance(stage_conv_b_kernel_size[0], int): 729 stage_conv_b_kernel_size = (stage_conv_b_kernel_size,) * len(stage_depths) 730 731 if isinstance(stage_conv_b_dilation[0], int): 732 stage_conv_b_dilation = (stage_conv_b_dilation,) * len(stage_depths) 733 734 if isinstance(bottleneck, Callable): 735 bottleneck = [ 736 bottleneck, 737 ] * len(stage_depths) 738 739 blocks = [] 740 # Create stem for resnet. 741 stem = stem( 742 in_channels=input_channel, 743 out_channels=stem_dim_out, 744 conv_kernel_size=stem_conv_kernel_size, 745 conv_stride=stem_conv_stride, 746 conv_padding=[size // 2 for size in stem_conv_kernel_size], 747 pool=stem_pool, 748 pool_kernel_size=stem_pool_kernel_size, 749 pool_stride=stem_pool_stride, 750 pool_padding=[size // 2 for size in stem_pool_kernel_size], 751 norm=norm, 752 activation=activation, 753 ) 754 blocks.append(stem) 755 756 stage_dim_in = stem_dim_out 757 stage_dim_out = stage_dim_in * 4 758 759 # Create each stage for resnet. 760 for idx in range(len(stage_depths)): 761 stage_dim_inner = stage_dim_out // 4 762 depth = stage_depths[idx] 763 764 stage_conv_a_kernel = stage_conv_a_kernel_size[idx] 765 stage_conv_a_stride = (stage_temporal_stride[idx], 1, 1) 766 stage_conv_a_padding = ( 767 [size // 2 for size in stage_conv_a_kernel] 768 if isinstance(stage_conv_a_kernel[0], int) 769 else [[size // 2 for size in sizes] for sizes in stage_conv_a_kernel] 770 ) 771 772 stage_conv_b_stride = ( 773 1, 774 stage_spatial_h_stride[idx], 775 stage_spatial_w_stride[idx], 776 ) 777 778 stage = create_res_stage( 779 depth=depth, 780 dim_in=stage_dim_in, 781 dim_inner=stage_dim_inner, 782 dim_out=stage_dim_out, 783 bottleneck=bottleneck[idx], 784 conv_a_kernel_size=stage_conv_a_kernel, 785 conv_a_stride=stage_conv_a_stride, 786 conv_a_padding=stage_conv_a_padding, 787 conv_b_kernel_size=stage_conv_b_kernel_size[idx], 788 conv_b_stride=stage_conv_b_stride, 789 conv_b_padding=( 790 stage_conv_b_kernel_size[idx][0] // 2, 791 stage_conv_b_dilation[idx][1] 792 if stage_conv_b_dilation[idx][1] > 1 793 else stage_conv_b_kernel_size[idx][1] // 2, 794 stage_conv_b_dilation[idx][2] 795 if stage_conv_b_dilation[idx][2] > 1 796 else stage_conv_b_kernel_size[idx][2] // 2, 797 ), 798 conv_b_num_groups=stage_conv_b_num_groups[idx], 799 conv_b_dilation=stage_conv_b_dilation[idx], 800 norm=norm, 801 activation=activation, 802 ) 803 804 blocks.append(stage) 805 stage_dim_in = stage_dim_out 806 stage_dim_out = stage_dim_out * 2 807 808 if idx == 0 and stage1_pool is not None: 809 blocks.append( 810 stage1_pool( 811 kernel_size=stage1_pool_kernel_size, 812 stride=stage1_pool_kernel_size, 813 padding=(0, 0, 0), 814 ) 815 ) 816 if head is not None: 817 head = head( 818 in_features=stage_dim_in, 819 out_features=model_num_class, 820 pool=head_pool, 821 output_size=head_output_size, 822 pool_kernel_size=head_pool_kernel_size, 823 dropout_rate=dropout_rate, 824 activation=head_activation, 825 output_with_global_average=head_output_with_global_average, 826 ) 827 blocks.append(head) 828 return Net(blocks=nn.ModuleList(blocks)) 829 830 831def create_resnet_with_roi_head( 832 *, 833 # Input clip configs. 834 input_channel: int = 3, 835 # Model configs. 836 model_depth: int = 50, 837 model_num_class: int = 80, 838 dropout_rate: float = 0.5, 839 # Normalization configs. 840 norm: Callable = nn.BatchNorm3d, 841 # Activation configs. 842 activation: Callable = nn.ReLU, 843 # Stem configs. 844 stem_dim_out: int = 64, 845 stem_conv_kernel_size: Tuple[int] = (1, 7, 7), 846 stem_conv_stride: Tuple[int] = (1, 2, 2), 847 stem_pool: Callable = nn.MaxPool3d, 848 stem_pool_kernel_size: Tuple[int] = (1, 3, 3), 849 stem_pool_stride: Tuple[int] = (1, 2, 2), 850 stem: Callable = create_res_basic_stem, 851 # Stage configs. 852 stage1_pool: Callable = None, 853 stage1_pool_kernel_size: Tuple[int] = (2, 1, 1), 854 stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = ( 855 (1, 1, 1), 856 (1, 1, 1), 857 (3, 1, 1), 858 (3, 1, 1), 859 ), 860 stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = ( 861 (1, 3, 3), 862 (1, 3, 3), 863 (1, 3, 3), 864 (1, 3, 3), 865 ), 866 stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1), 867 stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = ( 868 (1, 1, 1), 869 (1, 1, 1), 870 (1, 1, 1), 871 (1, 2, 2), 872 ), 873 stage_spatial_h_stride: Tuple[int] = (1, 2, 2, 1), 874 stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 1), 875 stage_temporal_stride: Tuple[int] = (1, 1, 1, 1), 876 bottleneck: Union[Tuple[Callable], Callable] = create_bottleneck_block, 877 # Head configs. 878 head: Callable = create_res_roi_pooling_head, 879 head_pool: Callable = nn.AvgPool3d, 880 head_pool_kernel_size: Tuple[int] = (4, 1, 1), 881 head_output_size: Tuple[int] = (1, 1, 1), 882 head_activation: Callable = nn.Sigmoid, 883 head_output_with_global_average: bool = False, 884 head_spatial_resolution: Tuple[int] = (7, 7), 885 head_spatial_scale: float = 1.0 / 16.0, 886 head_sampling_ratio: int = 0, 887) -> nn.Module: 888 """ 889 Build ResNet style models for video detection. ResNet has three parts: 890 Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an 891 optional pooling layer. Stages are grouped residual blocks. There are usually 892 multiple stages and each stage may include multiple residual blocks. Head 893 may include pooling, dropout, a fully-connected layer and global spatial 894 temporal averaging. The three parts are assembled in the following order: 895 896 :: 897 898 Input Clip Input Bounding Boxes 899 ↓ ↓ 900 Stem ↓ 901 ↓ ↓ 902 Stage 1 ↓ 903 ↓ ↓ 904 . ↓ 905 . ↓ 906 . ↓ 907 ↓ ↓ 908 Stage N ↓ 909 ↓--------> Head <-------↓ 910 911 Args: 912 913 input_channel (int): number of channels for the input video clip. 914 915 model_depth (int): the depth of the resnet. Options include: 50, 101, 152. 916 model_num_class (int): the number of classes for the video dataset. 917 dropout_rate (float): dropout rate. 918 919 920 norm (callable): a callable that constructs normalization layer. 921 922 activation (callable): a callable that constructs activation layer. 923 924 stem_dim_out (int): output channel size to stem. 925 stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem. 926 stem_conv_stride (tuple): convolutional stride size(s) of stem. 927 stem_pool (callable): a callable that constructs resnet head pooling layer. 928 stem_pool_kernel_size (tuple): pooling kernel size(s). 929 stem_pool_stride (tuple): pooling stride size(s). 930 stem (callable): a callable that constructs stem layer. 931 Examples include: create_res_video_stem. 932 933 stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 934 stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 935 stage_conv_b_num_groups (tuple): number of groups for groupwise convolution 936 for conv_b. 1 for ResNet, and larger than 1 for ResNeXt. 937 stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 938 stage_spatial_h_stride (tuple): the spatial height stride for each stage. 939 stage_spatial_w_stride (tuple): the spatial width stride for each stage. 940 stage_temporal_stride (tuple): the temporal stride for each stage. 941 bottleneck (callable): a callable that constructs bottleneck block layer. 942 Examples include: create_bottleneck_block. 943 944 head (callable): a callable that constructs the detection head which can 945 take in the additional input of bounding boxes. 946 Ex: create_res_roi_pooling_head 947 head_pool (callable): a callable that constructs resnet head pooling layer. 948 head_pool_kernel_size (tuple): the pooling kernel size. 949 head_output_size (tuple): the size of output tensor for head. 950 head_activation (callable): a callable that constructs activation layer. 951 head_output_with_global_average (bool): if True, perform global averaging on 952 the head output. 953 head_spatial_resolution (tuple): h, w sizes of the RoI interpolation. 954 head_spatial_scale (float): scale the input boxes by this number. 955 head_sampling_ratio (int): number of inputs samples to take for each output 956 sample interpolation. 0 to take samples densely. 957 958 Returns: 959 (nn.Module): basic resnet. 960 """ 961 962 model = create_resnet( 963 # Input clip configs. 964 input_channel=input_channel, 965 # Model configs. 966 model_depth=model_depth, 967 model_num_class=model_num_class, 968 dropout_rate=dropout_rate, 969 # Normalization configs. 970 norm=norm, 971 # Activation configs. 972 activation=activation, 973 # Stem configs. 974 stem_dim_out=stem_dim_out, 975 stem_conv_kernel_size=stem_conv_kernel_size, 976 stem_conv_stride=stem_conv_stride, 977 stem_pool=stem_pool, 978 stem_pool_kernel_size=stem_pool_kernel_size, 979 stem_pool_stride=stem_pool_stride, 980 # Stage configs. 981 stage1_pool=stage1_pool, 982 stage_conv_a_kernel_size=stage_conv_a_kernel_size, 983 stage_conv_b_kernel_size=stage_conv_b_kernel_size, 984 stage_conv_b_num_groups=stage_conv_b_num_groups, 985 stage_conv_b_dilation=stage_conv_b_dilation, 986 stage_spatial_h_stride=stage_spatial_h_stride, 987 stage_spatial_w_stride=stage_spatial_w_stride, 988 stage_temporal_stride=stage_temporal_stride, 989 bottleneck=bottleneck, 990 # Head configs. 991 head=None, 992 ) 993 head = head( 994 in_features=stem_dim_out * 2 ** (len(_MODEL_STAGE_DEPTH[model_depth]) + 1), 995 out_features=model_num_class, 996 pool=head_pool, 997 output_size=head_output_size, 998 pool_kernel_size=head_pool_kernel_size, 999 dropout_rate=dropout_rate, 1000 activation=head_activation, 1001 output_with_global_average=head_output_with_global_average, 1002 resolution=head_spatial_resolution, 1003 spatial_scale=head_spatial_scale, 1004 sampling_ratio=head_sampling_ratio, 1005 ) 1006 return DetectionBBoxNetwork(model, head) 1007 1008 1009def create_acoustic_resnet( 1010 *, 1011 # Input clip configs. 1012 input_channel: int = 1, 1013 # Model configs. 1014 model_depth: int = 50, 1015 model_num_class: int = 400, 1016 dropout_rate: float = 0.5, 1017 # Normalization configs. 1018 norm: Callable = nn.BatchNorm3d, 1019 # Activation configs. 1020 activation: Callable = nn.ReLU, 1021 # Stem configs. 1022 stem_dim_out: int = 64, 1023 stem_conv_kernel_size: Tuple[int] = (9, 1, 9), 1024 stem_conv_stride: Tuple[int] = (1, 1, 3), 1025 stem_pool: Callable = None, 1026 stem_pool_kernel_size: Tuple[int] = (3, 1, 3), 1027 stem_pool_stride: Tuple[int] = (2, 1, 2), 1028 stem: Callable = create_acoustic_res_basic_stem, 1029 # Stage configs. 1030 stage1_pool: Callable = None, 1031 stage1_pool_kernel_size: Tuple[int] = (2, 1, 1), 1032 stage_conv_a_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 1), 1033 stage_conv_b_kernel_size: Union[Tuple[int], Tuple[Tuple[int]]] = (3, 1, 3), 1034 stage_conv_b_num_groups: Tuple[int] = (1, 1, 1, 1), 1035 stage_conv_b_dilation: Union[Tuple[int], Tuple[Tuple[int]]] = (1, 1, 1), 1036 stage_spatial_h_stride: Tuple[int] = (1, 1, 1, 1), 1037 stage_spatial_w_stride: Tuple[int] = (1, 2, 2, 2), 1038 stage_temporal_stride: Tuple[int] = (1, 2, 2, 2), 1039 bottleneck: Union[Tuple[Callable], Callable] = ( 1040 create_acoustic_bottleneck_block, 1041 create_acoustic_bottleneck_block, 1042 create_bottleneck_block, 1043 create_bottleneck_block, 1044 ), 1045 # Head configs. 1046 head_pool: Callable = nn.AvgPool3d, 1047 head_pool_kernel_size: Tuple[int] = (4, 1, 2), 1048 head_output_size: Tuple[int] = (1, 1, 1), 1049 head_activation: Callable = None, 1050 head_output_with_global_average: bool = True, 1051) -> nn.Module: 1052 """ 1053 Build ResNet style models for acoustic recognition. ResNet has three parts: 1054 Stem, Stages and Head. Stem is the first Convolution layer (Conv1) with an 1055 optional pooling layer. Stages are grouped residual blocks. There are usually 1056 multiple stages and each stage may include multiple residual blocks. Head 1057 may include pooling, dropout, a fully-connected layer and global spatial 1058 temporal averaging. The three parts are assembled in the following order: 1059 1060 :: 1061 1062 Input 1063 ↓ 1064 Stem 1065 ↓ 1066 Stage 1 1067 ↓ 1068 . 1069 . 1070 . 1071 ↓ 1072 Stage N 1073 ↓ 1074 Head 1075 1076 Args: 1077 1078 input_channel (int): number of channels for the input video clip. 1079 1080 model_depth (int): the depth of the resnet. Options include: 50, 101, 152. 1081 model_num_class (int): the number of classes for the video dataset. 1082 dropout_rate (float): dropout rate. 1083 1084 1085 norm (callable): a callable that constructs normalization layer. 1086 1087 activation (callable): a callable that constructs activation layer. 1088 1089 stem_dim_out (int): output channel size to stem. 1090 stem_conv_kernel_size (tuple): convolutional kernel size(s) of stem. 1091 stem_conv_stride (tuple): convolutional stride size(s) of stem. 1092 stem_pool (callable): a callable that constructs resnet head pooling layer. 1093 stem_pool_kernel_size (tuple): pooling kernel size(s). 1094 stem_pool_stride (tuple): pooling stride size(s). 1095 stem (callable): a callable that constructs stem layer. 1096 Examples include: create_res_video_stem. 1097 1098 stage_conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. 1099 stage_conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. 1100 stage_conv_b_num_groups (tuple): number of groups for groupwise convolution 1101 for conv_b. 1 for ResNet, and larger than 1 for ResNeXt. 1102 stage_conv_b_dilation (tuple): dilation for 3D convolution for conv_b. 1103 stage_spatial_h_stride (tuple): the spatial height stride for each stage. 1104 stage_spatial_w_stride (tuple): the spatial width stride for each stage. 1105 stage_temporal_stride (tuple): the temporal stride for each stage. 1106 bottleneck (callable): a callable that constructs bottleneck block layer. 1107 Examples include: create_bottleneck_block. 1108 1109 head_pool (callable): a callable that constructs resnet head pooling layer. 1110 head_pool_kernel_size (tuple): the pooling kernel size. 1111 head_output_size (tuple): the size of output tensor for head. 1112 head_activation (callable): a callable that constructs activation layer. 1113 head_output_with_global_average (bool): if True, perform global averaging on 1114 the head output. 1115 1116 Returns: 1117 (nn.Module): audio resnet, that takes spectragram image input with 1118 shape: (B, C, T, 1, F), where T is the time dimension and F is the 1119 frequency dimension. 1120 """ 1121 return create_resnet(**locals()) 1122 1123 1124class ResBlock(nn.Module): 1125 """ 1126 Residual block. Performs a summation between an identity shortcut in branch1 and a 1127 main block in branch2. When the input and output dimensions are different, a 1128 convolution followed by a normalization will be performed. 1129 1130 :: 1131 1132 1133 Input 1134 |-------+ 1135 ↓ | 1136 Block | 1137 ↓ | 1138 Summation ←-+ 1139 ↓ 1140 Activation 1141 1142 The builder can be found in `create_res_block`. 1143 """ 1144 1145 def __init__( 1146 self, 1147 branch1_conv: nn.Module = None, 1148 branch1_norm: nn.Module = None, 1149 branch2: nn.Module = None, 1150 activation: nn.Module = None, 1151 branch_fusion: Callable = None, 1152 ) -> nn.Module: 1153 """ 1154 Args: 1155 branch1_conv (torch.nn.modules): convolutional module in branch1. 1156 branch1_norm (torch.nn.modules): normalization module in branch1. 1157 branch2 (torch.nn.modules): bottleneck block module in branch2. 1158 activation (torch.nn.modules): activation module. 1159 branch_fusion: (Callable): A callable or layer that combines branch1 1160 and branch2. 1161 """ 1162 super().__init__() 1163 set_attributes(self, locals()) 1164 assert self.branch2 is not None 1165 1166 def forward(self, x) -> torch.Tensor: 1167 if self.branch1_conv is None: 1168 x = self.branch_fusion(x, self.branch2(x)) 1169 else: 1170 shortcut = self.branch1_conv(x) 1171 if self.branch1_norm is not None: 1172 shortcut = self.branch1_norm(shortcut) 1173 x = self.branch_fusion(shortcut, self.branch2(x)) 1174 if self.activation is not None: 1175 x = self.activation(x) 1176 return x 1177 1178 1179class SeparableBottleneckBlock(nn.Module): 1180 """ 1181 Separable Bottleneck block: a sequence of spatiotemporal Convolution, Normalization, 1182 and Activations repeated in the following order. Requires a tuple of models to be 1183 provided to conv_b, norm_b, act_b to perform Convolution, Normalization, and 1184 Activations in parallel Separably. 1185 1186 :: 1187 1188 1189 Conv3d (conv_a) 1190 ↓ 1191 Normalization (norm_a) 1192 ↓ 1193 Activation (act_a) 1194 ↓ 1195 Conv3d(s) (conv_b), ... 1196 ↓ (↓) 1197 Normalization(s) (norm_b), ... 1198 ↓ (↓) 1199 Activation(s) (act_b), ... 1200 ↓ (↓) 1201 Reduce (sum or cat) 1202 ↓ 1203 Conv3d (conv_c) 1204 ↓ 1205 Normalization (norm_c) 1206 """ 1207 1208 def __init__( 1209 self, 1210 *, 1211 conv_a: nn.Module, 1212 norm_a: nn.Module, 1213 act_a: nn.Module, 1214 conv_b: nn.ModuleList, 1215 norm_b: nn.ModuleList, 1216 act_b: nn.ModuleList, 1217 conv_c: nn.Module, 1218 norm_c: nn.Module, 1219 reduce_method: str = "sum", 1220 ) -> None: 1221 """ 1222 Args: 1223 conv_a (torch.nn.modules): convolutional module. 1224 norm_a (torch.nn.modules): normalization module. 1225 act_a (torch.nn.modules): activation module. 1226 conv_b (torch.nn.modules_list): convolutional module(s). 1227 norm_b (torch.nn.modules_list): normalization module(s). 1228 act_b (torch.nn.modules_list): activation module(s). 1229 conv_c (torch.nn.modules): convolutional module. 1230 norm_c (torch.nn.modules): normalization module. 1231 reduce_method (str): if multiple conv_b is used, reduce the output with 1232 `sum`, or `cat`. 1233 """ 1234 super().__init__() 1235 set_attributes(self, locals()) 1236 assert all( 1237 op is not None for op in (self.conv_b, self.conv_c) 1238 ), f"{self.conv_a}, {self.conv_b}, {self.conv_c} has None" 1239 assert reduce_method in ["sum", "cat"] 1240 if self.norm_c is not None: 1241 # This flag is used for weight initialization. 1242 self.norm_c.block_final_bn = True 1243 1244 def forward(self, x: torch.Tensor) -> torch.Tensor: 1245 # Explicitly forward every layer. 1246 # Branch2a, for example Tx1x1, BN, ReLU. 1247 if self.conv_a is not None: 1248 x = self.conv_a(x) 1249 if self.norm_a is not None: 1250 x = self.norm_a(x) 1251 if self.act_a is not None: 1252 x = self.act_a(x) 1253 1254 # Branch2b, for example 1xHxW, BN, ReLU. 1255 output = [] 1256 for ind in range(len(self.conv_b)): 1257 x_ = self.conv_b[ind](x) 1258 if self.norm_b[ind] is not None: 1259 x_ = self.norm_b[ind](x_) 1260 if self.act_b[ind] is not None: 1261 x_ = self.act_b[ind](x_) 1262 output.append(x_) 1263 if self.reduce_method == "sum": 1264 x = torch.stack(output, dim=0).sum(dim=0, keepdim=False) 1265 elif self.reduce_method == "cat": 1266 x = torch.cat(output, dim=1) 1267 1268 # Branch2c, for example 1x1x1, BN. 1269 x = self.conv_c(x) 1270 if self.norm_c is not None: 1271 x = self.norm_c(x) 1272 return x 1273 1274 1275class BottleneckBlock(nn.Module): 1276 """ 1277 Bottleneck block: a sequence of spatiotemporal Convolution, Normalization, 1278 and Activations repeated in the following order: 1279 1280 :: 1281 1282 1283 Conv3d (conv_a) 1284 ↓ 1285 Normalization (norm_a) 1286 ↓ 1287 Activation (act_a) 1288 ↓ 1289 Conv3d (conv_b) 1290 ↓ 1291 Normalization (norm_b) 1292 ↓ 1293 Activation (act_b) 1294 ↓ 1295 Conv3d (conv_c) 1296 ↓ 1297 Normalization (norm_c) 1298 1299 The builder can be found in `create_bottleneck_block`. 1300 """ 1301 1302 def __init__( 1303 self, 1304 *, 1305 conv_a: nn.Module = None, 1306 norm_a: nn.Module = None, 1307 act_a: nn.Module = None, 1308 conv_b: nn.Module = None, 1309 norm_b: nn.Module = None, 1310 act_b: nn.Module = None, 1311 conv_c: nn.Module = None, 1312 norm_c: nn.Module = None, 1313 ) -> None: 1314 """ 1315 Args: 1316 conv_a (torch.nn.modules): convolutional module. 1317 norm_a (torch.nn.modules): normalization module. 1318 act_a (torch.nn.modules): activation module. 1319 conv_b (torch.nn.modules): convolutional module. 1320 norm_b (torch.nn.modules): normalization module. 1321 act_b (torch.nn.modules): activation module. 1322 conv_c (torch.nn.modules): convolutional module. 1323 norm_c (torch.nn.modules): normalization module. 1324 """ 1325 super().__init__() 1326 set_attributes(self, locals()) 1327 assert all(op is not None for op in (self.conv_a, self.conv_b, self.conv_c)) 1328 if self.norm_c is not None: 1329 # This flag is used for weight initialization. 1330 self.norm_c.block_final_bn = True 1331 1332 def forward(self, x: torch.Tensor) -> torch.Tensor: 1333 # Explicitly forward every layer. 1334 # Branch2a, for example Tx1x1, BN, ReLU. 1335 x = self.conv_a(x) 1336 if self.norm_a is not None: 1337 x = self.norm_a(x) 1338 if self.act_a is not None: 1339 x = self.act_a(x) 1340 1341 # Branch2b, for example 1xHxW, BN, ReLU. 1342 x = self.conv_b(x) 1343 if self.norm_b is not None: 1344 x = self.norm_b(x) 1345 if self.act_b is not None: 1346 x = self.act_b(x) 1347 1348 # Branch2c, for example 1x1x1, BN. 1349 x = self.conv_c(x) 1350 if self.norm_c is not None: 1351 x = self.norm_c(x) 1352 return x 1353 1354 1355class ResStage(nn.Module): 1356 """ 1357 ResStage composes sequential blocks that make up a ResNet. These blocks could be, 1358 for example, Residual blocks, Non-Local layers, or Squeeze-Excitation layers. 1359 1360 :: 1361 1362 1363 Input 1364 ↓ 1365 ResBlock 1366 ↓ 1367 . 1368 . 1369 . 1370 ↓ 1371 ResBlock 1372 1373 The builder can be found in `create_res_stage`. 1374 """ 1375 1376 def __init__(self, res_blocks: nn.ModuleList) -> nn.Module: 1377 """ 1378 Args: 1379 res_blocks (torch.nn.module_list): ResBlock module(s). 1380 """ 1381 super().__init__() 1382 self.res_blocks = res_blocks 1383 1384 def forward(self, x: torch.Tensor) -> torch.Tensor: 1385 for _, res_block in enumerate(self.res_blocks): 1386 x = res_block(x) 1387 return x 1388