1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16from functools import partial
17
18import numpy as np
19
20from jax.api import jit, linear_transpose, ShapeDtypeStruct
21from jax.core import Primitive, ShapedArray
22from jax.interpreters import xla
23from jax._src.util import prod
24from jax import dtypes, lax
25from jax.lib import xla_client
26from jax.interpreters import ad
27from jax.interpreters import batching
28from jax.lib import pocketfft
29
30xops = xla_client.ops
31
32__all__ = [
33  "fft",
34  "fft_p",
35]
36
37def _promote_to_complex(arg):
38  dtype = dtypes.result_type(arg, np.complex64)
39  return lax.convert_element_type(arg, dtype)
40
41def _promote_to_real(arg):
42  dtype = dtypes.result_type(arg, np.float32)
43  return lax.convert_element_type(arg, dtype)
44
45def fft(x, fft_type, fft_lengths):
46  if fft_type == xla_client.FftType.RFFT:
47    if np.iscomplexobj(x):
48      raise ValueError("only real valued inputs supported for rfft")
49    x = _promote_to_real(x)
50  else:
51    x = _promote_to_complex(x)
52  if len(fft_lengths) == 0:
53    # XLA FFT doesn't support 0-rank.
54    return x
55  fft_lengths = tuple(fft_lengths)
56  return fft_p.bind(x, fft_type=fft_type, fft_lengths=fft_lengths)
57
58def fft_impl(x, fft_type, fft_lengths):
59  return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths)
60
61_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype
62_real_dtype = lambda dtype: np.zeros((), dtype).real.dtype
63_is_even = lambda x: x % 2 == 0
64
65def fft_abstract_eval(x, fft_type, fft_lengths):
66  if fft_type == xla_client.FftType.RFFT:
67    shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1]
68             + (fft_lengths[-1] // 2 + 1,))
69    dtype = _complex_dtype(x.dtype)
70  elif fft_type == xla_client.FftType.IRFFT:
71    shape = x.shape[:-len(fft_lengths)] + fft_lengths
72    dtype = _real_dtype(x.dtype)
73  else:
74    shape = x.shape
75    dtype = x.dtype
76  return ShapedArray(shape, dtype)
77
78def fft_translation_rule(c, x, fft_type, fft_lengths):
79  return xops.Fft(x, fft_type, fft_lengths)
80
81def _naive_rfft(x, fft_lengths):
82  y = fft(x, xla_client.FftType.FFT, fft_lengths)
83  n = fft_lengths[-1]
84  return y[..., : n//2 + 1]
85
86@partial(jit, static_argnums=1)
87def _rfft_transpose(t, fft_lengths):
88  # The transpose of RFFT can't be expressed only in terms of irfft. Instead of
89  # manually building up larger twiddle matrices (which would increase the
90  # asymptotic complexity and is also rather complicated), we rely JAX to
91  # transpose a naive RFFT implementation.
92  dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths
93  dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype))
94  transpose = linear_transpose(
95      partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal)
96  result, = transpose(t)
97  assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype)
98  return result
99
100def _irfft_transpose(t, fft_lengths):
101  # The transpose of IRFFT is the RFFT of the cotangent times a scaling
102  # factor and a mask. The mask scales the cotangent for the Hermitian
103  # symmetric components of the RFFT by a factor of two, since these components
104  # are de-duplicated in the RFFT.
105  x = fft(t, xla_client.FftType.RFFT, fft_lengths)
106  n = x.shape[-1]
107  is_odd = fft_lengths[-1] % 2
108  full = partial(lax.full_like, t, dtype=t.dtype)
109  mask = lax.concatenate(
110      [full(1.0, shape=(1,)),
111       full(2.0, shape=(n - 2 + is_odd,)),
112       full(1.0, shape=(1 - is_odd,))],
113      dimension=0)
114  scale = 1 / prod(fft_lengths)
115  out = scale * mask * x
116  assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype)
117  return out
118
119def fft_transpose_rule(t, operand, fft_type, fft_lengths):
120  if fft_type == xla_client.FftType.RFFT:
121    result = _rfft_transpose(t, fft_lengths)
122  elif fft_type == xla_client.FftType.IRFFT:
123    result = _irfft_transpose(t, fft_lengths)
124  else:
125    result = fft(t, fft_type, fft_lengths)
126  return result,
127
128def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths):
129  x, = batched_args
130  bd, = batch_dims
131  x = batching.moveaxis(x, bd, 0)
132  return fft(x, fft_type, fft_lengths), 0
133
134fft_p = Primitive('fft')
135fft_p.def_impl(fft_impl)
136fft_p.def_abstract_eval(fft_abstract_eval)
137xla.translations[fft_p] = fft_translation_rule
138ad.deflinear2(fft_p, fft_transpose_rule)
139batching.primitive_batchers[fft_p] = fft_batching_rule
140if pocketfft:
141  xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
142