1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from typing import Sequence, Tuple, Union
17from jax._src.numpy import lax_numpy as jnp
18from jax._src.util import prod
19from . import lax
20
21
22def conv_general_dilated_patches(
23    lhs: lax.Array,
24    filter_shape: Sequence[int],
25    window_strides: Sequence[int],
26    padding: Union[str, Sequence[Tuple[int, int]]],
27    lhs_dilation: Sequence[int] = None,
28    rhs_dilation: Sequence[int] = None,
29    dimension_numbers: lax.ConvGeneralDilatedDimensionNumbers = None,
30    precision: lax.PrecisionType = None,
31) -> lax.Array:
32  """Extract patches subject to the receptive field of `conv_general_dilated`.
33
34  Runs the input through a convolution with given parameters. The kernel of the
35  convolution is constructed such that the output channel dimension `"C"`
36  contains flattened image patches, so instead a single `"C"` dimension
37  represents, for example, three dimensions `"chw"` collapsed. The order of
38  these dimensions is `"c" + ''.join(c for c in rhs_spec if c not in 'OI')`,
39  where `rhs_spec == dimension_numbers[1]`, and the size of this `"C"`
40  dimension is therefore the size of each patch, i.e.
41  `np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`, where
42  `lhs_spec == dimension_numbers[0]`.
43
44  Docstring below adapted from `jax.lax.conv_general_dilated`.
45
46  See Also:
47    https://www.tensorflow.org/xla/operation_semantics#conv_convolution
48
49  Args:
50    lhs: a rank `n+2` dimensional input array.
51    filter_shape: a sequence of `n` integers, representing the receptive window
52      spatial shape in the order as specified in
53      `rhs_spec = dimension_numbers[1]`.
54    window_strides: a sequence of `n` integers, representing the inter-window
55      strides.
56    padding: either the string `'SAME'`, the string `'VALID'`, or a sequence of
57      `n` `(low, high)` integer pairs that give the padding to apply before and
58      after each spatial dimension.
59    lhs_dilation: `None`, or a sequence of `n` integers, giving the
60      dilation factor to apply in each spatial dimension of `lhs`. LHS dilation
61      is also known as transposed convolution.
62    rhs_dilation: `None`, or a sequence of `n` integers, giving the
63      dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
64      is also known as atrous convolution.
65    dimension_numbers: either `None`, or a 3-tuple
66      `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
67      of length `n+2`. `None` defaults to `("NCHWD..., OIHWD..., NCHWD...")`.
68    precision: Optional. Either ``None``, which means the default precision for
69      the backend, or a ``lax.Precision`` enum value (``Precision.DEFAULT``,
70      ``Precision.HIGH`` or ``Precision.HIGHEST``).
71
72  Returns:
73    A rank `n+2` array containing the flattened image patches in the output
74    channel (`"C"`) dimension. For example if
75    `dimension_numbers = ("NcHW", "OIwh", "CNHW")`, the output has dimension
76    numbers `"CNHW" = "{cwh}NHW"`, with the size of dimension `"C"` equal to
77    the size of each patch
78    (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`).
79
80  """
81  filter_shape = tuple(filter_shape)
82  dimension_numbers = lax.conv_dimension_numbers(
83      lhs.shape, (1, 1) + filter_shape, dimension_numbers)
84
85  lhs_spec, rhs_spec, out_spec = dimension_numbers
86
87  spatial_size = prod(filter_shape)
88  n_channels = lhs.shape[lhs_spec[1]]
89
90  # Move separate `lhs` spatial locations into separate `rhs` channels.
91  rhs = jnp.eye(spatial_size, dtype=lhs.dtype).reshape(filter_shape * 2)
92
93  rhs = rhs.reshape((spatial_size, 1) + filter_shape)
94  rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1))
95  rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1]))
96
97  out = lax.conv_general_dilated(
98      lhs=lhs,
99      rhs=rhs,
100      window_strides=window_strides,
101      padding=padding,
102      lhs_dilation=lhs_dilation,
103      rhs_dilation=rhs_dilation,
104      dimension_numbers=dimension_numbers,
105      precision=None if precision is None else (precision,
106                                                lax.Precision.DEFAULT),
107      feature_group_count=n_channels
108  )
109  return out
110