/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | linalg.py | 39 return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ 104 return jnp.any(M != 0).astype(jnp.int32) 135 jnp.sum(jnp.log(jnp.abs(diag)), axis=-1)) 144 if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating): 227 d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1) 307 s = jnp.where(s > cutoff, s, jnp.inf) 308 res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis])) 349 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims)) 359 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, 383 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis, [all …]
|
H A D | polynomial.py | 18 from . import lax_numpy as jnp unknown 27 return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_ 40 A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1) 50 start = jnp.argmin(is_zero) 51 end = is_zero.size - jnp.argmin(is_zero[::-1]) 76 p = jnp.atleast_1d(p) 85 return jnp.array([]) 87 if jnp.all(p == 0): 88 return jnp.array([]) 99 return jnp.zeros(trailing_zeros, p.dtype) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/optimize/ |
H A D | bfgs.py | 19 import jax.numpy as jnp namespace 54 x_k: jnp.ndarray 55 f_k: jnp.ndarray 56 g_k: jnp.ndarray 57 H_k: jnp.ndarray 68 x0: jnp.ndarray, 70 norm=jnp.inf, 143 sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :] 146 + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :]) 147 H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k) [all …]
|
H A D | line_search.py | 17 import jax.numpy as jnp namespace 29 A, B = jnp.dot(d1, jnp.array([fb - fa - C * db, fc - fa - C * dc])) / denom 58 j: Union[int, jnp.ndarray] 206 i: Union[int, jnp.ndarray] 215 g_star: jnp.ndarray 239 f_k: jnp.ndarray 240 g_k: jnp.ndarray 264 dphi = jnp.dot(g, pk) 386 status = jnp.where( 389 jnp.where( [all …]
|
H A D | minimize.py | 17 import jax.numpy as jnp namespace 35 x: jnp.ndarray 36 success: Union[bool, jnp.ndarray] 37 status: Union[int, jnp.ndarray] 39 fun: jnp.ndarray 40 jac: jnp.ndarray 41 hess_inv: jnp.ndarray 42 nfev: Union[int, jnp.ndarray] 43 njev: Union[int, jnp.ndarray] 44 nit: Union[int, jnp.ndarray] [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/image/ |
H A D | scale.py | 26 y = radius * jnp.sin(np.pi * x) * jnp.sin(np.pi * x / radius) 44 return jnp.maximum(0, 1 - jnp.abs(x)) 60 jnp.abs(sample_f[jnp.newaxis, :] - 61 jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) / 67 weights = jnp.where( 73 return jnp.where( 210 if not jnp.issubdtype(image.dtype, jnp.inexact): 212 if not jnp.issubdtype(scale.dtype, jnp.inexact): 214 if not jnp.issubdtype(translation.dtype, jnp.inexact): 216 translation, jnp.result_type(translation, jnp.float32)) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | ode.py | 33 import jax.numpy as jnp namespace 58 dps_c_mid = jnp.array([ 77 scale = atol + jnp.abs(y0) * rtol 78 d0 = jnp.linalg.norm(y0 / scale) 79 d1 = jnp.linalg.norm(f0 / scale) 96 beta = jnp.array([ 118 y1 = dt * jnp.dot(c_sol, k) + y0 124 if jnp.iscomplexobj(x): 130 err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1)) 202 return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old) [all …]
|
H A D | optimizers.py | 89 import jax.numpy as jnp namespace 244 v0 = jnp.zeros_like(x0) 271 v0 = jnp.zeros_like(x0) 303 m = jnp.zeros_like(x0) 308 g_sq += jnp.square(g) 309 g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0) 402 m0 = jnp.zeros_like(x0) 403 v0 = jnp.zeros_like(x0) 487 accum = functools.reduce(jnp.minimum, vs) + jnp.square(g) 488 accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/nn/ |
H A D | functions.py | 27 import jax.numpy as jnp namespace 42 return jnp.maximum(x, 0) 53 return jnp.logaddexp(x, 0) 63 return x / (jnp.abs(x) + 1) 109 return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) 138 return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x)) 154 return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha)) 233 return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True)) 265 variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean) 295 x = jnp.asarray(x) [all …]
|
H A D | initializers.py | 25 import jax.numpy as jnp namespace 31 def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype) 32 def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype) 34 def uniform(scale=1e-2, dtype=jnp.float32): 39 def normal(stddev=1e-2, dtype=jnp.float32): 59 variance = jnp.array(scale / denominator, dtype=dtype) 62 stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype) 92 Q, R = jnp.linalg.qr(A) 93 diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim) 97 Q = jnp.moveaxis(Q, -1, column_axis) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/ |
H A D | signal.py | 21 from jax._src.numpy import lax_numpy as jnp unknown 69 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating… 79 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating… 81 if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: 91 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating… 101 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating… 103 if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2: 114 data, = _promote_dtypes_inexact(jnp.asarray(data)) 123 data = jnp.moveaxis(data, axis, 0) 128 A = jnp.vstack([ [all …]
|
H A D | special.py | 113 sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) 133 return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x)) 161 res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - 199 s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1) 209 T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max) 218 assert jnp.issubdtype(lax.dtype(n), jnp.integer) 330 x = jnp.asarray(x) 332 if dtype not in (jnp.float32, jnp.float64): 372 if dtype not in (jnp.float32, jnp.float64): 556 x = jnp.asarray(x) [all …]
|
H A D | linalg.py | 48 c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b)) 126 m, n = jnp.shape(a) 127 p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype)) 129 l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype) 170 a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) 206 a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b)) 209 b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1 263 return jnp.full_like(A, jnp.nan) 276 A = jnp.asarray(A) 287 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm))) [all …]
|
/dports/math/py-flax/flax-0.3.3/flax/core/nn/ |
H A D | normalization.py | 19 import jax.numpy as jnp namespace 31 dtype=jnp.float32, 37 x = jnp.asarray(x, jnp.float32) 43 m = jnp.mean(x, redux, keepdims=True) 50 mean2 = pmean(jnp.square(x)) 51 var = mean2 - jnp.square(mean) 76 return jnp.asarray(y, dtype) 83 dtype=jnp.float32, 125 dtype=jnp.float32, 155 x = jnp.asarray(x, jnp.float32) [all …]
|
H A D | linear.py | 27 import jax.numpy as jnp namespace 46 dtype=jnp.float32, 66 inputs = jnp.asarray(inputs, dtype) 94 return jnp.reshape(kernel, shape) 99 kernel = jnp.asarray(kernel, dtype) 122 bias = jnp.asarray(bias, dtype) 131 dtype=jnp.float32, 157 bias = jnp.asarray(bias, dtype) 181 dtype=jnp.float32, 241 bias = jnp.asarray(bias, dtype) [all …]
|
/dports/math/py-flax/flax-0.3.3/flax/nn/ |
H A D | linear.py | 28 import jax.numpy as jnp namespace 53 dtype=jnp.float32, 73 inputs = jnp.asarray(inputs, dtype) 129 bias = jnp.asarray(bias, dtype) 145 dtype=jnp.float32, 171 bias = jnp.asarray(bias, dtype) 206 dtype=jnp.float32, 274 y = jnp.squeeze(y, axis=0) 297 dtype=jnp.float32, 348 y = jnp.squeeze(y, axis=0) [all …]
|
H A D | normalization.py | 21 import jax.numpy as jnp namespace 45 dtype=jnp.float32, 81 x = jnp.asarray(x, jnp.float32) 106 mean, mean2 = jnp.split( 126 return jnp.asarray(y, dtype) 142 dtype=jnp.float32, 169 x = jnp.asarray(x, jnp.float32) 181 return jnp.asarray(y, dtype) 196 dtype=jnp.float32, 231 x = jnp.asarray(x, jnp.float32) [all …]
|
/dports/math/py-flax/flax-0.3.3/flax/optim/ |
H A D | adafactor.py | 27 import jax.numpy as jnp namespace 111 t = jnp.array(i, jnp.float32) + 1.0 142 state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32) 143 state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32) 145 state['v'] = jnp.zeros(param.shape, dtype=jnp.float32) 161 grad = grad.astype(jnp.float32) 167 update_scale *= jnp.maximum( 168 jnp.sqrt(jnp.mean(param * param)), epsilon2) 186 jnp.expand_dims(row_factor, axis=d0) * 187 jnp.expand_dims(col_factor, axis=d1)) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | jax2tf_test.py | 25 from jax import numpy as jnp unknown 39 f_jax = lambda x: jnp.sin(jnp.cos(x)) 48 x = (jnp.float_(.7), {"a": jnp.float_(.8), "b": jnp.float_(.9)}) 52 f_jax = lambda x: jnp.sin(jnp.cos(x)) 59 f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) 63 f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x))) 115 f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) 319 x2 = jnp.sin(x1) 320 x3 = jnp.sin(x2) 321 x4 = jnp.sin(x3) [all …]
|
/dports/math/py-flax/flax-0.3.3/flax/linen/ |
H A D | attention.py | 23 import jax.numpy as jnp namespace 45 dtype: Dtype = jnp.float32, 87 query = query / jnp.sqrt(depth).astype(dtype) 109 jnp.asarray(keep_prob, dtype=dtype)) 142 dtype: Dtype = jnp.float32 210 lambda: jnp.array(0, dtype=jnp.int32)) 234 jnp.broadcast_to(jnp.arange(max_length) <= cur_index, 312 mask = jnp.expand_dims(mask, axis=-3) 319 dtype: Dtype = jnp.float32): 335 idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) [all …]
|
H A D | normalization.py | 21 import jax.numpy as jnp namespace 67 dtype: Dtype = jnp.float32 89 x = jnp.asarray(x, jnp.float32) 100 lambda s: jnp.zeros(s, jnp.float32), 103 lambda s: jnp.ones(s, jnp.float32), 160 dtype: Any = jnp.float32 176 x = jnp.asarray(x, jnp.float32) 188 y = y + jnp.asarray( 220 dtype: Any = jnp.float32 238 x = jnp.asarray(x, jnp.float32) [all …]
|
/dports/games/diaspora/Diaspora_R1_Linux/Diaspora/fs2_open/code/jumpnode/ |
H A D | jumpnode.cpp | 313 SCP_list<CJumpNode>::iterator jnp; in jumpnode_get_by_name() local 315 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_get_by_name() 316 if(!stricmp(jnp->GetName(), name)) in jumpnode_get_by_name() 317 return &(*jnp); in jumpnode_get_by_name() 332 SCP_list<CJumpNode>::iterator jnp; in jumpnode_get_which_in() local 335 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_get_which_in() 336 if(jnp->GetModelNumber() < 0) in jumpnode_get_which_in() 342 return &(*jnp); in jumpnode_get_which_in() 356 SCP_list<CJumpNode>::iterator jnp; in jumpnode_render_all() local 358 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_render_all() [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/sparse/ |
H A D | linalg.py | 19 import jax.numpy as jnp namespace 44 if jnp.iscomplexobj(x) or jnp.iscomplexobj(y): 95 atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol)) 184 x0 = tree_map(jnp.zeros_like, b) 283 xnorm_scaled = xnorm / jnp.sqrt(2) 326 eps = jnp.finfo(jnp.result_type(*tree_leaves(V))).eps 356 t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a) 358 cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r)) 359 sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t)) 461 return jnp.logical_and(k < restart, jnp.logical_not(breakdown)) [all …]
|
/dports/games/fs2open/fs2open.github.com-release_21_4_1/qtfred/src/mission/dialogs/ |
H A D | WaypointEditorDialogModel.cpp | 48 SCP_list<CJumpNode>::iterator jnp; in apply() local 145 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in apply() 158 _currentName = jnp->GetName(); in apply() 172 _currentName = jnp->GetName(); in apply() 189 _currentName = jnp->GetName(); in apply() 200 _currentName = jnp->GetName(); in apply() 210 _currentName = jnp->GetName(); in apply() 221 _currentName = jnp->GetName(); in apply() 260 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in initializeData() 266 _currentName = jnp->GetName(); in initializeData() [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/third_party/numpy/ |
H A D | linalg.py | 3 from jax._src.numpy import lax_numpy as jnp unknown 58 orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any()) 59 nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1))) 60 r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r) 66 a = jnp.asarray(a) 82 a = jnp.asarray(a) 83 b = jnp.asarray(b) 102 res = jnp.asarray(la.solve(a, b)) 159 return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision) 161 return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision) [all …]
|