1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16from typing import Sequence, Tuple, Union 17from jax._src.numpy import lax_numpy as jnp 18from jax._src.util import prod 19from . import lax 20 21 22def conv_general_dilated_patches( 23 lhs: lax.Array, 24 filter_shape: Sequence[int], 25 window_strides: Sequence[int], 26 padding: Union[str, Sequence[Tuple[int, int]]], 27 lhs_dilation: Sequence[int] = None, 28 rhs_dilation: Sequence[int] = None, 29 dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None, 30 precision: lax.PrecisionType = None, 31) -> lax.Array: 32 """Extract patches subject to the receptive field of `conv_general_dilated`. 33 34 Runs the input through a convolution with given parameters. The kernel of the 35 convolution is constructed such that the output channel dimension `"C"` 36 contains flattened image patches, so instead a single `"C"` dimension 37 represents, for example, three dimensions `"chw"` collapsed. The order of 38 these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`, 39 where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"` 40 dimension is therefore the size of each patch, i.e. 41 `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where 42 `lhs_spec == dimension_numbers[0]`. 43 44 Docstring below adapted from `jax.lax.conv_general_dilated`. 45 46 See Also: 47 https://www.tensorflow.org/xla/operation_semantics#conv_convolution 48 49 Args: 50 lhs: a rank `n+2` dimensional input array. 51 filter_shape: a sequence of `n` integers, representing the receptive window 52 spatial shape in the order as specified in 53 `rhs_spec = dimension_numbers[1]`. 54 window_strides: a sequence of `n` integers, representing the inter-window 55 strides. 56 padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of 57 `n` `(low, high)` integer pairs that give the padding to apply before and 58 after each spatial dimension. 59 lhs_dilation: `None`, or a sequence of `n` integers, giving the 60 dilation factor to apply in each spatial dimension of `lhs`. LHS dilation 61 is also known as transposed convolution. 62 rhs_dilation: `None`, or a sequence of `n` integers, giving the 63 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation 64 is also known as atrous convolution. 65 dimension_numbers: either `None`, or a 3-tuple 66 `(lhs_spec, rhs_spec, out_spec)`, where each element is a string 67 of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`. 68 precision: Optional. Either ``None``, which means the default precision for 69 the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``, 70 ``Precision.HIGH`` or ``Precision.HIGHEST``). 71 72 Returns: 73 A rank `n+2` array containing the flattened image patches in the output 74 channel (`"C"`) dimension. For example if 75 `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension 76 numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to 77 the size of each patch 78 (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). 79 80 """ 81 filter_shape = tuple(filter_shape) 82 dimension_numbers = lax.conv_dimension_numbers( 83 lhs.shape, (1, 1) + filter_shape, dimension_numbers) 84 85 lhs_spec, rhs_spec, out_spec = dimension_numbers 86 87 spatial_size = prod(filter_shape) 88 n_channels = lhs.shape[lhs_spec[1]] 89 90 # Move separate `lhs` spatial locations into separate `rhs` channels. 91 rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2) 92 93 rhs = rhs.reshape((spatial_size, 1) + filter_shape) 94 rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1)) 95 rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) 96 97 out = lax.conv_general_dilated( 98 lhs=lhs, 99 rhs=rhs, 100 window_strides=window_strides, 101 padding=padding, 102 lhs_dilation=lhs_dilation, 103 rhs_dilation=rhs_dilation, 104 dimension_numbers=dimension_numbers, 105 precision=None if precision is None else (precision, 106 lax.Precision.DEFAULT), 107 feature_group_count=n_channels 108 ) 109 return out 110