Home
last modified time | relevance | path

Searched refs:jnp (Results 1 – 25 of 927) sorted by relevance

12345678910>>...38

/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlinalg.py39 return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
104 return jnp.any(M != 0).astype(jnp.int32)
135 jnp.sum(jnp.log(jnp.abs(diag)), axis=-1))
144 if jnp.issubdtype(jnp._dtype(x), jnp.complexfloating):
227 d = jnp.any(jnp.logical_or(jnp.isnan(d), jnp.isinf(d)), axis=-1)
307 s = jnp.where(s > cutoff, s, jnp.inf)
308 res = jnp.matmul(_T(v), jnp.divide(_T(u), s[..., jnp.newaxis]))
349 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), keepdims=keepdims))
359 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
383 return jnp.sqrt(jnp.sum(jnp.real(x * jnp.conj(x)), axis=axis,
[all …]
H A Dpolynomial.py18 from . import lax_numpy as jnp unknown
27 return type if jnp.issubdtype(type, jnp.inexact) else jnp.float_
40 A = jnp.diag(jnp.ones((p.size - 2,), p.dtype), -1)
50 start = jnp.argmin(is_zero)
51 end = is_zero.size - jnp.argmin(is_zero[::-1])
76 p = jnp.atleast_1d(p)
85 return jnp.array([])
87 if jnp.all(p == 0):
88 return jnp.array([])
99 return jnp.zeros(trailing_zeros, p.dtype)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/optimize/
H A Dbfgs.py19 import jax.numpy as jnp namespace
54 x_k: jnp.ndarray
55 f_k: jnp.ndarray
56 g_k: jnp.ndarray
57 H_k: jnp.ndarray
68 x0: jnp.ndarray,
70 norm=jnp.inf,
143 sy_k = s_k[:, jnp.newaxis] * y_k[jnp.newaxis, :]
146 + rho_k * s_k[:, jnp.newaxis] * s_k[jnp.newaxis, :])
147 H_kp1 = jnp.where(jnp.isfinite(rho_k), H_kp1, state.H_k)
[all …]
H A Dline_search.py17 import jax.numpy as jnp namespace
29 A, B = jnp.dot(d1, jnp.array([fb - fa - C * db, fc - fa - C * dc])) / denom
58 j: Union[int, jnp.ndarray]
206 i: Union[int, jnp.ndarray]
215 g_star: jnp.ndarray
239 f_k: jnp.ndarray
240 g_k: jnp.ndarray
264 dphi = jnp.dot(g, pk)
386 status = jnp.where(
389 jnp.where(
[all …]
H A Dminimize.py17 import jax.numpy as jnp namespace
35 x: jnp.ndarray
36 success: Union[bool, jnp.ndarray]
37 status: Union[int, jnp.ndarray]
39 fun: jnp.ndarray
40 jac: jnp.ndarray
41 hess_inv: jnp.ndarray
42 nfev: Union[int, jnp.ndarray]
43 njev: Union[int, jnp.ndarray]
44 nit: Union[int, jnp.ndarray]
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/image/
H A Dscale.py26 y = radius * jnp.sin(np.pi * x) * jnp.sin(np.pi * x / radius)
44 return jnp.maximum(0, 1 - jnp.abs(x))
60 jnp.abs(sample_f[jnp.newaxis, :] -
61 jnp.arange(input_size, dtype=sample_f.dtype)[:, jnp.newaxis]) /
67 weights = jnp.where(
73 return jnp.where(
210 if not jnp.issubdtype(image.dtype, jnp.inexact):
212 if not jnp.issubdtype(scale.dtype, jnp.inexact):
214 if not jnp.issubdtype(translation.dtype, jnp.inexact):
216 translation, jnp.result_type(translation, jnp.float32))
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dode.py33 import jax.numpy as jnp namespace
58 dps_c_mid = jnp.array([
77 scale = atol + jnp.abs(y0) * rtol
78 d0 = jnp.linalg.norm(y0 / scale)
79 d1 = jnp.linalg.norm(f0 / scale)
96 beta = jnp.array([
118 y1 = dt * jnp.dot(c_sol, k) + y0
124 if jnp.iscomplexobj(x):
130 err_tol = atol + rtol * jnp.maximum(jnp.abs(y0), jnp.abs(y1))
202 return map(partial(jnp.where, jnp.all(error_ratios <= 1.)), new, old)
[all …]
H A Doptimizers.py89 import jax.numpy as jnp namespace
244 v0 = jnp.zeros_like(x0)
271 v0 = jnp.zeros_like(x0)
303 m = jnp.zeros_like(x0)
308 g_sq += jnp.square(g)
309 g_sq_inv_sqrt = jnp.where(g_sq > 0, 1. / jnp.sqrt(g_sq), 0.0)
402 m0 = jnp.zeros_like(x0)
403 v0 = jnp.zeros_like(x0)
487 accum = functools.reduce(jnp.minimum, vs) + jnp.square(g)
488 accum_inv_sqrt = jnp.where(accum > 0, 1. / jnp.sqrt(accum), 0)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/nn/
H A Dfunctions.py27 import jax.numpy as jnp namespace
42 return jnp.maximum(x, 0)
53 return jnp.logaddexp(x, 0)
63 return x / (jnp.abs(x) + 1)
109 return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
138 return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
154 return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))
233 return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True))
265 variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
295 x = jnp.asarray(x)
[all …]
H A Dinitializers.py25 import jax.numpy as jnp namespace
31 def zeros(key, shape, dtype=jnp.float32): return jnp.zeros(shape, dtype)
32 def ones(key, shape, dtype=jnp.float32): return jnp.ones(shape, dtype)
34 def uniform(scale=1e-2, dtype=jnp.float32):
39 def normal(stddev=1e-2, dtype=jnp.float32):
59 variance = jnp.array(scale / denominator, dtype=dtype)
62 stddev = jnp.sqrt(variance) / jnp.array(.87962566103423978, dtype)
92 Q, R = jnp.linalg.qr(A)
93 diag_sign = lax.broadcast_to_rank(jnp.sign(jnp.diag(R)), rank=Q.ndim)
97 Q = jnp.moveaxis(Q, -1, column_axis)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/
H A Dsignal.py21 from jax._src.numpy import lax_numpy as jnp unknown
69 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating…
79 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating…
81 if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
91 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating…
101 …if jnp.issubdtype(in1.dtype, jnp.complexfloating) or jnp.issubdtype(in2.dtype, jnp.complexfloating…
103 if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
114 data, = _promote_dtypes_inexact(jnp.asarray(data))
123 data = jnp.moveaxis(data, axis, 0)
128 A = jnp.vstack([
[all …]
H A Dspecial.py113 sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
133 return jnp.where(x_ok, lax.mul(safe_x, lax.log(safe_y)), jnp.zeros_like(x))
161 res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) -
199 s_, a_ = jnp.expand_dims(s, -1), jnp.expand_dims(a, -1)
209 T1 = jnp.clip(T1, a_max=jnp.finfo(dtype).max)
218 assert jnp.issubdtype(lax.dtype(n), jnp.integer)
330 x = jnp.asarray(x)
332 if dtype not in (jnp.float32, jnp.float64):
372 if dtype not in (jnp.float32, jnp.float64):
556 x = jnp.asarray(x)
[all …]
H A Dlinalg.py48 c, b = np_linalg._promote_arg_dtypes(jnp.asarray(c), jnp.asarray(b))
126 m, n = jnp.shape(a)
127 p = jnp.real(jnp.array(permutation == jnp.arange(m)[:, None], dtype=dtype))
129 l = jnp.tril(lu, -1)[:, :k] + jnp.eye(m, k, dtype=dtype)
170 a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
206 a, b = np_linalg._promote_arg_dtypes(jnp.asarray(a), jnp.asarray(b))
209 b_is_vector = jnp.ndim(a) == jnp.ndim(b) + 1
263 return jnp.full_like(A, jnp.nan)
276 A = jnp.asarray(A)
287 n_squarings = jnp.maximum(0, jnp.floor(jnp.log2(A_L1 / maxnorm)))
[all …]
/dports/math/py-flax/flax-0.3.3/flax/core/nn/
H A Dnormalization.py19 import jax.numpy as jnp namespace
31 dtype=jnp.float32,
37 x = jnp.asarray(x, jnp.float32)
43 m = jnp.mean(x, redux, keepdims=True)
50 mean2 = pmean(jnp.square(x))
51 var = mean2 - jnp.square(mean)
76 return jnp.asarray(y, dtype)
83 dtype=jnp.float32,
125 dtype=jnp.float32,
155 x = jnp.asarray(x, jnp.float32)
[all …]
H A Dlinear.py27 import jax.numpy as jnp namespace
46 dtype=jnp.float32,
66 inputs = jnp.asarray(inputs, dtype)
94 return jnp.reshape(kernel, shape)
99 kernel = jnp.asarray(kernel, dtype)
122 bias = jnp.asarray(bias, dtype)
131 dtype=jnp.float32,
157 bias = jnp.asarray(bias, dtype)
181 dtype=jnp.float32,
241 bias = jnp.asarray(bias, dtype)
[all …]
/dports/math/py-flax/flax-0.3.3/flax/nn/
H A Dlinear.py28 import jax.numpy as jnp namespace
53 dtype=jnp.float32,
73 inputs = jnp.asarray(inputs, dtype)
129 bias = jnp.asarray(bias, dtype)
145 dtype=jnp.float32,
171 bias = jnp.asarray(bias, dtype)
206 dtype=jnp.float32,
274 y = jnp.squeeze(y, axis=0)
297 dtype=jnp.float32,
348 y = jnp.squeeze(y, axis=0)
[all …]
H A Dnormalization.py21 import jax.numpy as jnp namespace
45 dtype=jnp.float32,
81 x = jnp.asarray(x, jnp.float32)
106 mean, mean2 = jnp.split(
126 return jnp.asarray(y, dtype)
142 dtype=jnp.float32,
169 x = jnp.asarray(x, jnp.float32)
181 return jnp.asarray(y, dtype)
196 dtype=jnp.float32,
231 x = jnp.asarray(x, jnp.float32)
[all …]
/dports/math/py-flax/flax-0.3.3/flax/optim/
H A Dadafactor.py27 import jax.numpy as jnp namespace
111 t = jnp.array(i, jnp.float32) + 1.0
142 state['v_row'] = jnp.zeros(vr_shape, dtype=jnp.float32)
143 state['v_col'] = jnp.zeros(vc_shape, dtype=jnp.float32)
145 state['v'] = jnp.zeros(param.shape, dtype=jnp.float32)
161 grad = grad.astype(jnp.float32)
167 update_scale *= jnp.maximum(
168 jnp.sqrt(jnp.mean(param * param)), epsilon2)
186 jnp.expand_dims(row_factor, axis=d0) *
187 jnp.expand_dims(col_factor, axis=d1))
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Djax2tf_test.py25 from jax import numpy as jnp unknown
39 f_jax = lambda x: jnp.sin(jnp.cos(x))
48 x = (jnp.float_(.7), {"a": jnp.float_(.8), "b": jnp.float_(.9)})
52 f_jax = lambda x: jnp.sin(jnp.cos(x))
59 f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
63 f_jax = jax.jit(lambda x: jnp.sin(jax.jit(jnp.cos)(x)))
115 f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x)))
319 x2 = jnp.sin(x1)
320 x3 = jnp.sin(x2)
321 x4 = jnp.sin(x3)
[all …]
/dports/math/py-flax/flax-0.3.3/flax/linen/
H A Dattention.py23 import jax.numpy as jnp namespace
45 dtype: Dtype = jnp.float32,
87 query = query / jnp.sqrt(depth).astype(dtype)
109 jnp.asarray(keep_prob, dtype=dtype))
142 dtype: Dtype = jnp.float32
210 lambda: jnp.array(0, dtype=jnp.int32))
234 jnp.broadcast_to(jnp.arange(max_length) <= cur_index,
312 mask = jnp.expand_dims(mask, axis=-3)
319 dtype: Dtype = jnp.float32):
335 idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
[all …]
H A Dnormalization.py21 import jax.numpy as jnp namespace
67 dtype: Dtype = jnp.float32
89 x = jnp.asarray(x, jnp.float32)
100 lambda s: jnp.zeros(s, jnp.float32),
103 lambda s: jnp.ones(s, jnp.float32),
160 dtype: Any = jnp.float32
176 x = jnp.asarray(x, jnp.float32)
188 y = y + jnp.asarray(
220 dtype: Any = jnp.float32
238 x = jnp.asarray(x, jnp.float32)
[all …]
/dports/games/diaspora/Diaspora_R1_Linux/Diaspora/fs2_open/code/jumpnode/
H A Djumpnode.cpp313 SCP_list<CJumpNode>::iterator jnp; in jumpnode_get_by_name() local
315 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_get_by_name()
316 if(!stricmp(jnp->GetName(), name)) in jumpnode_get_by_name()
317 return &(*jnp); in jumpnode_get_by_name()
332 SCP_list<CJumpNode>::iterator jnp; in jumpnode_get_which_in() local
335 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_get_which_in()
336 if(jnp->GetModelNumber() < 0) in jumpnode_get_which_in()
342 return &(*jnp); in jumpnode_get_which_in()
356 SCP_list<CJumpNode>::iterator jnp; in jumpnode_render_all() local
358 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in jumpnode_render_all()
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/sparse/
H A Dlinalg.py19 import jax.numpy as jnp namespace
44 if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
95 atol2 = jnp.maximum(jnp.square(tol) * bs, jnp.square(atol))
184 x0 = tree_map(jnp.zeros_like, b)
283 xnorm_scaled = xnorm / jnp.sqrt(2)
326 eps = jnp.finfo(jnp.result_type(*tree_leaves(V))).eps
356 t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
358 cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
359 sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
461 return jnp.logical_and(k < restart, jnp.logical_not(breakdown))
[all …]
/dports/games/fs2open/fs2open.github.com-release_21_4_1/qtfred/src/mission/dialogs/
H A DWaypointEditorDialogModel.cpp48 SCP_list<CJumpNode>::iterator jnp; in apply() local
145 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in apply()
158 _currentName = jnp->GetName(); in apply()
172 _currentName = jnp->GetName(); in apply()
189 _currentName = jnp->GetName(); in apply()
200 _currentName = jnp->GetName(); in apply()
210 _currentName = jnp->GetName(); in apply()
221 _currentName = jnp->GetName(); in apply()
260 for (jnp = Jump_nodes.begin(); jnp != Jump_nodes.end(); ++jnp) { in initializeData()
266 _currentName = jnp->GetName(); in initializeData()
[all …]
/dports/math/py-jax/jax-0.2.9/jax/_src/third_party/numpy/
H A Dlinalg.py3 from jax._src.numpy import lax_numpy as jnp unknown
58 orig_nan_check = jnp.full_like(r, ~jnp.isnan(r).any())
59 nan_mask = jnp.logical_and(jnp.isnan(r), ~jnp.isnan(x).any(axis=(-2, -1)))
60 r = jnp.where(orig_nan_check, jnp.where(nan_mask, jnp.inf, r), r)
66 a = jnp.asarray(a)
82 a = jnp.asarray(a)
83 b = jnp.asarray(b)
102 res = jnp.asarray(la.solve(a, b))
159 return jnp.dot(jnp.dot(A, B, precision=precision), C, precision=precision)
161 return jnp.dot(A, jnp.dot(B, C, precision=precision), precision=precision)
[all …]

12345678910>>...38