1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16from functools import partial 17 18import numpy as np 19 20from jax.api import jit, linear_transpose, ShapeDtypeStruct 21from jax.core import Primitive, ShapedArray 22from jax.interpreters import xla 23from jax._src.util import prod 24from jax import dtypes, lax 25from jax.lib import xla_client 26from jax.interpreters import ad 27from jax.interpreters import batching 28from jax.lib import pocketfft 29 30xops = xla_client.ops 31 32__all__ = [ 33 "fft", 34 "fft_p", 35] 36 37def _promote_to_complex(arg): 38 dtype = dtypes.result_type(arg, np.complex64) 39 return lax.convert_element_type(arg, dtype) 40 41def _promote_to_real(arg): 42 dtype = dtypes.result_type(arg, np.float32) 43 return lax.convert_element_type(arg, dtype) 44 45def fft(x, fft_type, fft_lengths): 46 if fft_type == xla_client.FftType.RFFT: 47 if np.iscomplexobj(x): 48 raise ValueError("only real valued inputs supported for rfft") 49 x = _promote_to_real(x) 50 else: 51 x = _promote_to_complex(x) 52 if len(fft_lengths) == 0: 53 # XLA FFT doesn't support 0-rank. 54 return x 55 fft_lengths = tuple(fft_lengths) 56 return fft_p.bind(x, fft_type=fft_type, fft_lengths=fft_lengths) 57 58def fft_impl(x, fft_type, fft_lengths): 59 return xla.apply_primitive(fft_p, x, fft_type=fft_type, fft_lengths=fft_lengths) 60 61_complex_dtype = lambda dtype: (np.zeros((), dtype) + np.zeros((), np.complex64)).dtype 62_real_dtype = lambda dtype: np.zeros((), dtype).real.dtype 63_is_even = lambda x: x % 2 == 0 64 65def fft_abstract_eval(x, fft_type, fft_lengths): 66 if fft_type == xla_client.FftType.RFFT: 67 shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1] 68 + (fft_lengths[-1] // 2 + 1,)) 69 dtype = _complex_dtype(x.dtype) 70 elif fft_type == xla_client.FftType.IRFFT: 71 shape = x.shape[:-len(fft_lengths)] + fft_lengths 72 dtype = _real_dtype(x.dtype) 73 else: 74 shape = x.shape 75 dtype = x.dtype 76 return ShapedArray(shape, dtype) 77 78def fft_translation_rule(c, x, fft_type, fft_lengths): 79 return xops.Fft(x, fft_type, fft_lengths) 80 81def _naive_rfft(x, fft_lengths): 82 y = fft(x, xla_client.FftType.FFT, fft_lengths) 83 n = fft_lengths[-1] 84 return y[..., : n//2 + 1] 85 86@partial(jit, static_argnums=1) 87def _rfft_transpose(t, fft_lengths): 88 # The transpose of RFFT can't be expressed only in terms of irfft. Instead of 89 # manually building up larger twiddle matrices (which would increase the 90 # asymptotic complexity and is also rather complicated), we rely JAX to 91 # transpose a naive RFFT implementation. 92 dummy_shape = t.shape[:-len(fft_lengths)] + fft_lengths 93 dummy_primal = ShapeDtypeStruct(dummy_shape, _real_dtype(t.dtype)) 94 transpose = linear_transpose( 95 partial(_naive_rfft, fft_lengths=fft_lengths), dummy_primal) 96 result, = transpose(t) 97 assert result.dtype == _real_dtype(t.dtype), (result.dtype, t.dtype) 98 return result 99 100def _irfft_transpose(t, fft_lengths): 101 # The transpose of IRFFT is the RFFT of the cotangent times a scaling 102 # factor and a mask. The mask scales the cotangent for the Hermitian 103 # symmetric components of the RFFT by a factor of two, since these components 104 # are de-duplicated in the RFFT. 105 x = fft(t, xla_client.FftType.RFFT, fft_lengths) 106 n = x.shape[-1] 107 is_odd = fft_lengths[-1] % 2 108 full = partial(lax.full_like, t, dtype=t.dtype) 109 mask = lax.concatenate( 110 [full(1.0, shape=(1,)), 111 full(2.0, shape=(n - 2 + is_odd,)), 112 full(1.0, shape=(1 - is_odd,))], 113 dimension=0) 114 scale = 1 / prod(fft_lengths) 115 out = scale * mask * x 116 assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) 117 return out 118 119def fft_transpose_rule(t, operand, fft_type, fft_lengths): 120 if fft_type == xla_client.FftType.RFFT: 121 result = _rfft_transpose(t, fft_lengths) 122 elif fft_type == xla_client.FftType.IRFFT: 123 result = _irfft_transpose(t, fft_lengths) 124 else: 125 result = fft(t, fft_type, fft_lengths) 126 return result, 127 128def fft_batching_rule(batched_args, batch_dims, fft_type, fft_lengths): 129 x, = batched_args 130 bd, = batch_dims 131 x = batching.moveaxis(x, bd, 0) 132 return fft(x, fft_type, fft_lengths), 0 133 134fft_p = Primitive('fft') 135fft_p.def_impl(fft_impl) 136fft_p.def_abstract_eval(fft_abstract_eval) 137xla.translations[fft_p] = fft_translation_rule 138ad.deflinear2(fft_p, fft_transpose_rule) 139batching.primitive_batchers[fft_p] = fft_batching_rule 140if pocketfft: 141 xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft 142