/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 188 out_axes, argument 202 if isinstance(out_axes, list): 203 out_axes = tuple(out_axes) 206 out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes") 252 closure=out_axes) 287 out_axes = out_axes_thunk() 292 f = plan.vectorize(f, in_axes, out_axes) 386 out_axes = params['out_axes_thunk']() 395 new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, 472 f = plan.vectorize(f, in_axes, out_axes) [all …]
|
H A D | host_callback.py | 1244 out_axes=eqn.params["out_axes"] + (0, 0)
|
/dports/math/py-flax/flax-0.3.3/flax/core/ |
H A D | axes_scan.py | 40 out_axes: Any, 116 out_axes, ys) 120 out_axes, ys) 143 ys = jax.tree_multimap(transpose_from_front, out_axes, ys) 145 out_axes, constants_out, ys)
|
H A D | lift.py | 302 out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)} 303 return in_axes, out_axes 313 in_axes=0, out_axes=0, argument 397 out_axes=(out_axes, variable_out_axes), 421 in_axes=0, out_axes=0, argument 522 out_axes=(out_axes, variable_out_axes),
|
/dports/math/py-flax/flax-0.3.3/flax/linen/ |
H A D | transforms.py | 215 in_axes=0, out_axes=0, argument 271 in_axes=in_axes, out_axes=out_axes, 370 in_axes=0, out_axes=0, argument 445 in_axes=in_axes, out_axes=out_axes,
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | pxla.py | 739 out_axes = out_axes_thunk() 741 assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes)) 743 assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes)) 744 assert all(out_axis == 0 for out_axis in out_axes) 1399 for name, axis in out_axes.items(): 1408 out_axes: Sequence[ArrayMapping], 1438 assert len(out_axes) == len(out_jaxpr_avals) 1615 out_axes = out_axes_thunk() 1616 assert all(out_axis == 0 for out_axis in out_axes) 1710 out_vals, out_axes = [out_vals], [out_axes] [all …]
|
H A D | partial_eval.py | 186 out_axes = out_axes_thunk() 187 return out_axes + (0,) * (num_outputs() - len(out_axes)) 196 out_axes = params['out_axes_thunk']() 200 for pval, out_axis in zip(out_pvals, out_axes)] 254 out_axes = params['out_axes_thunk']() 257 for pv, ax in zip(out_pvs, out_axes)] 272 new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes) 285 def out_axes_transform(out_axes): argument 286 return out_axes + (0,) * nconsts 1099 out_axes = params['out_axes_thunk']() [all …]
|
H A D | ad.py | 310 out_axes = out_axes_thunk() 311 return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) 333 def out_axes_transform(out_axes): argument 334 return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz)) 599 in_axes, out_axes = params['in_axes'], params['out_axes'] 602 *[axis for axis, x in zip(out_axes, ct) 606 assert all(out_axis is not None for out_axis in out_axes), out_axes
|
H A D | batching.py | 226 def out_axes_transform(out_axes): argument 228 for out_axis, d in zip(out_axes, dims))
|
/dports/science/py-chainer/chainer-7.8.0/chainerx_cc/chainerx/native/native_device/ |
H A D | conv.cc | 120 Axes out_axes{0}; in Call() local 123 out_axes.emplace_back(int64_t{2 + i}); in Call() 126 return TensorDot(gy, col, out_axes, col_axes, w_dtype); in Call()
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 902 y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args)) 1073 def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F: argument 1205 in_axes_, out_axes_ = tree_leaves(in_axes), tree_leaves(out_axes) 1224 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes) 1283 out_axes=0, argument 1546 if not config.omnistaging_enabled and out_axes != 0: 1550 if any(out_axis is None for out_axis in tree_flatten(out_axes)): 1555 if out_axes == 0: 1562 closure=out_axes) 1565 lambda: tuple(flatten_axes("pmap out_axes", out_tree(), out_axes)), [all …]
|
H A D | core.py | 1254 out_axes = out_axes_thunk() 1256 out_axes = t(out_axes) 1257 return out_axes 1517 out_axes = params["out_axes"] 1534 for aval, out_axis in zip(mapped_out_avals, out_axes)] 1742 out_axes = out_axes_thunk() 1744 out_axes = t(out_axes) 1745 return out_axes
|
H A D | lax_reference.py | 323 view, view_axes, rhs_axes, out_axes = _conv_view( 326 view, view_axes, rhs, rhs_axes, out_axes, use_blas=True) 365 out_axes = [0, view.ndim] + list(range(1, dim+1)) 367 return view, view_axes, rhs_axes, out_axes
|
/dports/science/py-chainer/chainer-7.8.0/chainer/functions/connection/ |
H A D | convolution_nd.py | 295 out_axes = (0,) + tuple(moves.range(2, self.ndim + 2)) 315 gW = xp.tensordot(gy, col, (out_axes, col_axes)).astype(
|
/dports/math/py-numpy/numpy-1.20.3/numpy/core/src/umath/ |
H A D | ufunc_object.c | 1029 PyObject **out_axes, /* type: List[Tuple[T]] */ in get_ufunc_arguments() argument 1048 if (out_axes != NULL) { in get_ufunc_arguments() 1049 *out_axes = NULL; in get_ufunc_arguments() 1122 _new_reference, out_axes, in get_ufunc_arguments() 1183 if (out_axes != NULL && out_axis != NULL && in get_ufunc_arguments() 1184 *out_axes != NULL && *out_axis != NULL) { in get_ufunc_arguments() 1217 if (out_axes != NULL) { in get_ufunc_arguments() 1218 Py_XDECREF(*out_axes); in get_ufunc_arguments()
|
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | lax_numpy.py | 3461 func = jax.vmap(func, in_axes=i, out_axes=-1) 3463 func = jax.vmap(func, in_axes=0, out_axes=0)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 3127 out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) 3129 tuple(out_axes))
|