1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15# flake8: noqa: F401 16from jax._src.lax.lax import ( 17 ConvDimensionNumbers, 18 ConvGeneralDilatedDimensionNumbers, 19 DotDimensionNumbers, 20 GatherDimensionNumbers, 21 Precision, 22 RoundingMethod, 23 ScatterDimensionNumbers, 24 abs, 25 abs_p, 26 acos, 27 acos_p, 28 acosh, 29 acosh_p, 30 abs, 31 abs_p, 32 acos, 33 acosh, 34 acosh_p, 35 add, 36 add_p, 37 after_all, 38 after_all_p, 39 and_p, 40 argmax, 41 argmax_p, 42 argmin, 43 argmin_p, 44 asin, 45 asin_p, 46 asinh, 47 asinh_p, 48 atan, 49 atan_p, 50 atan2, 51 atan2_p, 52 atanh, 53 atanh_p, 54 batch_matmul, 55 bessel_i0e, 56 bessel_i0e_p, 57 bessel_i1e, 58 bessel_i1e_p, 59 betainc, 60 bitcast_convert_type, 61 bitcast_convert_type_p, 62 bitwise_and, 63 bitwise_not, 64 bitwise_or, 65 bitwise_xor, 66 broadcast, 67 broadcast_p, 68 broadcast_in_dim, 69 broadcast_in_dim_p, 70 broadcast_shapes, 71 broadcast_to_rank, 72 broadcasted_iota, 73 ceil, 74 ceil_p, 75 clamp, 76 clamp_p, 77 collapse, 78 complex, 79 complex_p, 80 concatenate, 81 concatenate_p, 82 conj, 83 conj_p, 84 conv, 85 conv_dimension_numbers, 86 conv_general_dilated, 87 conv_general_dilated_p, 88 conv_general_permutations, 89 conv_general_shape_tuple, 90 conv_shape_tuple, 91 conv_transpose, 92 conv_transpose_shape_tuple, 93 conv_with_general_padding, 94 convert_element_type, 95 convert_element_type_p, 96 cos, 97 cos_p, 98 cosh, 99 cosh_p, 100 create_token, 101 create_token_p, 102 digamma, 103 digamma_p, 104 div, 105 div_p, 106 dot, 107 dot_general, 108 dot_general_p, 109 dtype, 110 dtypes, 111 dynamic_index_in_dim, 112 dynamic_slice, 113 dynamic_slice_in_dim, 114 dynamic_slice_p, 115 dynamic_update_index_in_dim, 116 dynamic_update_slice, 117 dynamic_update_slice_in_dim, 118 dynamic_update_slice_p, 119 eq, 120 eq_p, 121 erf, 122 erf_inv, 123 erf_inv_p, 124 erf_p, 125 erfc, 126 erfc_p, 127 exp, 128 exp_p, 129 expand_dims, 130 expm1, 131 expm1_p, 132 floor, 133 floor_p, 134 full, 135 full_like, 136 gather, 137 gather_p, 138 ge, 139 ge_p, 140 gt, 141 gt_p, 142 igamma, 143 igamma_grad_a, 144 igamma_grad_a_p, 145 igamma_p, 146 igammac, 147 igammac_p, 148 imag, 149 imag_p, 150 index_in_dim, 151 index_take, 152 infeed, 153 infeed_p, 154 integer_pow, 155 integer_pow_p, 156 iota, 157 iota_p, 158 is_finite, 159 is_finite_p, 160 itertools, 161 le, 162 le_p, 163 lgamma, 164 lgamma_p, 165 log, 166 log1p, 167 log1p_p, 168 log_p, 169 lt, 170 lt_p, 171 max, 172 max_p, 173 min, 174 min_p, 175 mul, 176 mul_p, 177 naryop, 178 naryop_dtype_rule, 179 ne, 180 ne_p, 181 neg, 182 neg_p, 183 nextafter, 184 nextafter_p, 185 not_p, 186 or_p, 187 outfeed, 188 outfeed_p, 189 pad, 190 pad_p, 191 padtype_to_pads, 192 partial, 193 population_count, 194 population_count_p, 195 pow, 196 pow_p, 197 prod, 198 random_gamma_grad, 199 random_gamma_grad_p, 200 real, 201 real_p, 202 reciprocal, 203 reduce, 204 reduce_and_p, 205 reduce_max_p, 206 reduce_min_p, 207 reduce_or_p, 208 reduce_p, 209 reduce_prod_p, 210 reduce_sum_p, 211 reduce_window, 212 reduce_window_max_p, 213 reduce_window_min_p, 214 reduce_window_p, 215 reduce_window_shape_tuple, 216 reduce_window_sum_p, 217 regularized_incomplete_beta_p, 218 rem, 219 rem_p, 220 reshape, 221 reshape_p, 222 rev, 223 rev_p, 224 rng_uniform, 225 rng_uniform_p, 226 round, 227 round_p, 228 rsqrt, 229 rsqrt_p, 230 scatter, 231 scatter_add, 232 scatter_add_p, 233 scatter_max, 234 scatter_max_p, 235 scatter_min, 236 scatter_min_p, 237 scatter_mul, 238 scatter_mul_p, 239 scatter_p, 240 select, 241 select_and_gather_add_p, 242 select_and_scatter_add_p, 243 select_and_scatter_p, 244 select_p, 245 shift_left, 246 shift_left_p, 247 shift_right_arithmetic, 248 shift_right_arithmetic_p, 249 shift_right_logical, 250 shift_right_logical_p, 251 sign, 252 sign_p, 253 sin, 254 sin_p, 255 sinh, 256 sinh_p, 257 slice, 258 slice_in_dim, 259 slice_p, 260 sort, 261 sort_key_val, 262 sort_p, 263 sqrt, 264 sqrt_p, 265 square, 266 squeeze, 267 squeeze_p, 268 standard_abstract_eval, 269 standard_naryop, 270 standard_primitive, 271 standard_translate, 272 standard_unop, 273 stop_gradient, 274 sub, 275 sub_p, 276 tan, 277 tan_p, 278 tanh, 279 tanh_p, 280 tie_in, 281 tie_in_p, 282 top_k, 283 top_k_p, 284 transpose, 285 transpose_p, 286 unop, 287 unop_dtype_rule, 288 xor_p, 289 zeros_like_array, 290) 291from jax._src.lax.lax import (_reduce_sum, _reduce_max, _reduce_min, _reduce_or, 292 _reduce_and, _reduce_window_sum, _reduce_window_max, 293 _reduce_window_min, _reduce_window_prod, 294 _select_and_gather_add, 295 _select_and_scatter_add, _float, _complex, _input_dtype, 296 _const, _eq_meet, _broadcasting_select, 297 _check_user_dtype_supported, _one, _zero, _const, 298 _upcast_fp16_for_computation, _broadcasting_shape_rule, 299 _eye, _tri, _delta, _ones, _zeros, _dilate_shape) 300from jax._src.lax.control_flow import ( 301 associative_scan, 302 cond, 303 cond_p, 304 cummax, 305 cummax_p, 306 cummin, 307 cummin_p, 308 cumprod, 309 cumprod_p, 310 cumsum, 311 cumsum_p, 312 custom_linear_solve, 313 custom_root, 314 fori_loop, 315 linear_solve_p, 316 map, 317 scan, 318 scan_bind, 319 scan_p, 320 switch, 321 while_loop, 322 while_p, 323) 324from jax._src.lax.fft import ( 325 fft, 326 fft_p, 327) 328from jax._src.lax.parallel import ( 329 all_gather, 330 all_to_all, 331 all_to_all_p, 332 axis_index, 333 axis_index_p, 334 pmax, 335 pmax_p, 336 pmean, 337 pmin, 338 pmin_p, 339 ppermute, 340 ppermute_p, 341 pshuffle, 342 psum, 343 psum_p, 344 pswapaxes, 345 pdot, 346 xeinsum, 347) 348from jax._src.lax.other import ( 349 conv_general_dilated_patches 350) 351from . import linalg 352