1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15# flake8: noqa: F401
16from jax._src.lax.lax import (
17  ConvDimensionNumbers,
18  ConvGeneralDilatedDimensionNumbers,
19  DotDimensionNumbers,
20  GatherDimensionNumbers,
21  Precision,
22  RoundingMethod,
23  ScatterDimensionNumbers,
24  abs,
25  abs_p,
26  acos,
27  acos_p,
28  acosh,
29  acosh_p,
30  abs,
31  abs_p,
32  acos,
33  acosh,
34  acosh_p,
35  add,
36  add_p,
37  after_all,
38  after_all_p,
39  and_p,
40  argmax,
41  argmax_p,
42  argmin,
43  argmin_p,
44  asin,
45  asin_p,
46  asinh,
47  asinh_p,
48  atan,
49  atan_p,
50  atan2,
51  atan2_p,
52  atanh,
53  atanh_p,
54  batch_matmul,
55  bessel_i0e,
56  bessel_i0e_p,
57  bessel_i1e,
58  bessel_i1e_p,
59  betainc,
60  bitcast_convert_type,
61  bitcast_convert_type_p,
62  bitwise_and,
63  bitwise_not,
64  bitwise_or,
65  bitwise_xor,
66  broadcast,
67  broadcast_p,
68  broadcast_in_dim,
69  broadcast_in_dim_p,
70  broadcast_shapes,
71  broadcast_to_rank,
72  broadcasted_iota,
73  ceil,
74  ceil_p,
75  clamp,
76  clamp_p,
77  collapse,
78  complex,
79  complex_p,
80  concatenate,
81  concatenate_p,
82  conj,
83  conj_p,
84  conv,
85  conv_dimension_numbers,
86  conv_general_dilated,
87  conv_general_dilated_p,
88  conv_general_permutations,
89  conv_general_shape_tuple,
90  conv_shape_tuple,
91  conv_transpose,
92  conv_transpose_shape_tuple,
93  conv_with_general_padding,
94  convert_element_type,
95  convert_element_type_p,
96  cos,
97  cos_p,
98  cosh,
99  cosh_p,
100  create_token,
101  create_token_p,
102  digamma,
103  digamma_p,
104  div,
105  div_p,
106  dot,
107  dot_general,
108  dot_general_p,
109  dtype,
110  dtypes,
111  dynamic_index_in_dim,
112  dynamic_slice,
113  dynamic_slice_in_dim,
114  dynamic_slice_p,
115  dynamic_update_index_in_dim,
116  dynamic_update_slice,
117  dynamic_update_slice_in_dim,
118  dynamic_update_slice_p,
119  eq,
120  eq_p,
121  erf,
122  erf_inv,
123  erf_inv_p,
124  erf_p,
125  erfc,
126  erfc_p,
127  exp,
128  exp_p,
129  expand_dims,
130  expm1,
131  expm1_p,
132  floor,
133  floor_p,
134  full,
135  full_like,
136  gather,
137  gather_p,
138  ge,
139  ge_p,
140  gt,
141  gt_p,
142  igamma,
143  igamma_grad_a,
144  igamma_grad_a_p,
145  igamma_p,
146  igammac,
147  igammac_p,
148  imag,
149  imag_p,
150  index_in_dim,
151  index_take,
152  infeed,
153  infeed_p,
154  integer_pow,
155  integer_pow_p,
156  iota,
157  iota_p,
158  is_finite,
159  is_finite_p,
160  itertools,
161  le,
162  le_p,
163  lgamma,
164  lgamma_p,
165  log,
166  log1p,
167  log1p_p,
168  log_p,
169  lt,
170  lt_p,
171  max,
172  max_p,
173  min,
174  min_p,
175  mul,
176  mul_p,
177  naryop,
178  naryop_dtype_rule,
179  ne,
180  ne_p,
181  neg,
182  neg_p,
183  nextafter,
184  nextafter_p,
185  not_p,
186  or_p,
187  outfeed,
188  outfeed_p,
189  pad,
190  pad_p,
191  padtype_to_pads,
192  partial,
193  population_count,
194  population_count_p,
195  pow,
196  pow_p,
197  prod,
198  random_gamma_grad,
199  random_gamma_grad_p,
200  real,
201  real_p,
202  reciprocal,
203  reduce,
204  reduce_and_p,
205  reduce_max_p,
206  reduce_min_p,
207  reduce_or_p,
208  reduce_p,
209  reduce_prod_p,
210  reduce_sum_p,
211  reduce_window,
212  reduce_window_max_p,
213  reduce_window_min_p,
214  reduce_window_p,
215  reduce_window_shape_tuple,
216  reduce_window_sum_p,
217  regularized_incomplete_beta_p,
218  rem,
219  rem_p,
220  reshape,
221  reshape_p,
222  rev,
223  rev_p,
224  rng_uniform,
225  rng_uniform_p,
226  round,
227  round_p,
228  rsqrt,
229  rsqrt_p,
230  scatter,
231  scatter_add,
232  scatter_add_p,
233  scatter_max,
234  scatter_max_p,
235  scatter_min,
236  scatter_min_p,
237  scatter_mul,
238  scatter_mul_p,
239  scatter_p,
240  select,
241  select_and_gather_add_p,
242  select_and_scatter_add_p,
243  select_and_scatter_p,
244  select_p,
245  shift_left,
246  shift_left_p,
247  shift_right_arithmetic,
248  shift_right_arithmetic_p,
249  shift_right_logical,
250  shift_right_logical_p,
251  sign,
252  sign_p,
253  sin,
254  sin_p,
255  sinh,
256  sinh_p,
257  slice,
258  slice_in_dim,
259  slice_p,
260  sort,
261  sort_key_val,
262  sort_p,
263  sqrt,
264  sqrt_p,
265  square,
266  squeeze,
267  squeeze_p,
268  standard_abstract_eval,
269  standard_naryop,
270  standard_primitive,
271  standard_translate,
272  standard_unop,
273  stop_gradient,
274  sub,
275  sub_p,
276  tan,
277  tan_p,
278  tanh,
279  tanh_p,
280  tie_in,
281  tie_in_p,
282  top_k,
283  top_k_p,
284  transpose,
285  transpose_p,
286  unop,
287  unop_dtype_rule,
288  xor_p,
289  zeros_like_array,
290)
291from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or,
292                  _reduce_and, _reduce_window_sum, _reduce_window_max,
293                  _reduce_window_min, _reduce_window_prod,
294                  _select_and_gather_add,
295                  _select_and_scatter_add, _float, _complex, _input_dtype,
296                  _const, _eq_meet, _broadcasting_select,
297                  _check_user_dtype_supported, _one, _zero, _const,
298                  _upcast_fp16_for_computation, _broadcasting_shape_rule,
299                  _eye, _tri, _delta, _ones, _zeros, _dilate_shape)
300from jax._src.lax.control_flow import (
301  associative_scan,
302  cond,
303  cond_p,
304  cummax,
305  cummax_p,
306  cummin,
307  cummin_p,
308  cumprod,
309  cumprod_p,
310  cumsum,
311  cumsum_p,
312  custom_linear_solve,
313  custom_root,
314  fori_loop,
315  linear_solve_p,
316  map,
317  scan,
318  scan_bind,
319  scan_p,
320  switch,
321  while_loop,
322  while_p,
323)
324from jax._src.lax.fft import (
325  fft,
326  fft_p,
327)
328from jax._src.lax.parallel import (
329  all_gather,
330  all_to_all,
331  all_to_all_p,
332  axis_index,
333  axis_index_p,
334  pmax,
335  pmax_p,
336  pmean,
337  pmin,
338  pmin_p,
339  ppermute,
340  ppermute_p,
341  pshuffle,
342  psum,
343  psum_p,
344  pswapaxes,
345  pdot,
346  xeinsum,
347)
348from jax._src.lax.other import (
349  conv_general_dilated_patches
350)
351from . import linalg
352