Home
last modified time | relevance | path

Searched refs:out_axes (Results 1 – 17 of 17) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py188 out_axes, argument
202 if isinstance(out_axes, list):
203 out_axes = tuple(out_axes)
206 out_axes, out_axes_entries = _prepare_axes(out_axes, "out_axes")
252 closure=out_axes)
287 out_axes = out_axes_thunk()
292 f = plan.vectorize(f, in_axes, out_axes)
386 out_axes = params['out_axes_thunk']()
395 new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes,
472 f = plan.vectorize(f, in_axes, out_axes)
[all …]
H A Dhost_callback.py1244 out_axes=eqn.params["out_axes"] + (0, 0)
/dports/math/py-flax/flax-0.3.3/flax/core/
H A Daxes_scan.py40 out_axes: Any,
116 out_axes, ys)
120 out_axes, ys)
143 ys = jax.tree_multimap(transpose_from_front, out_axes, ys)
145 out_axes, constants_out, ys)
H A Dlift.py302 out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)}
303 return in_axes, out_axes
313 in_axes=0, out_axes=0, argument
397 out_axes=(out_axes, variable_out_axes),
421 in_axes=0, out_axes=0, argument
522 out_axes=(out_axes, variable_out_axes),
/dports/math/py-flax/flax-0.3.3/flax/linen/
H A Dtransforms.py215 in_axes=0, out_axes=0, argument
271 in_axes=in_axes, out_axes=out_axes,
370 in_axes=0, out_axes=0, argument
445 in_axes=in_axes, out_axes=out_axes,
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpxla.py739 out_axes = out_axes_thunk()
741 assert len(out_sharded_avals) == len(out_axes), (len(out_sharded_avals), len(out_axes))
743 assert len(out_pvals) == len(out_axes), (len(out_pvals), len(out_axes))
744 assert all(out_axis == 0 for out_axis in out_axes)
1399 for name, axis in out_axes.items():
1408 out_axes: Sequence[ArrayMapping],
1438 assert len(out_axes) == len(out_jaxpr_avals)
1615 out_axes = out_axes_thunk()
1616 assert all(out_axis == 0 for out_axis in out_axes)
1710 out_vals, out_axes = [out_vals], [out_axes]
[all …]
H A Dpartial_eval.py186 out_axes = out_axes_thunk()
187 return out_axes + (0,) * (num_outputs() - len(out_axes))
196 out_axes = params['out_axes_thunk']()
200 for pval, out_axis in zip(out_pvals, out_axes)]
254 out_axes = params['out_axes_thunk']()
257 for pv, ax in zip(out_pvs, out_axes)]
272 new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes)
285 def out_axes_transform(out_axes): argument
286 return out_axes + (0,) * nconsts
1099 out_axes = params['out_axes_thunk']()
[all …]
H A Dad.py310 out_axes = out_axes_thunk()
311 return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz))
333 def out_axes_transform(out_axes): argument
334 return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz))
599 in_axes, out_axes = params['in_axes'], params['out_axes']
602 *[axis for axis, x in zip(out_axes, ct)
606 assert all(out_axis is not None for out_axis in out_axes), out_axes
H A Dbatching.py226 def out_axes_transform(out_axes): argument
228 for out_axis, d in zip(out_axes, dims))
/dports/science/py-chainer/chainer-7.8.0/chainerx_cc/chainerx/native/native_device/
H A Dconv.cc120 Axes out_axes{0}; in Call() local
123 out_axes.emplace_back(int64_t{2 + i}); in Call()
126 return TensorDot(gy, col, out_axes, col_axes, w_dtype); in Call()
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py902 y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
1073 def vmap(fun: F, in_axes=0, out_axes=0, axis_name=None) -> F: argument
1205 in_axes_, out_axes_ = tree_leaves(in_axes), tree_leaves(out_axes)
1224 lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
1283 out_axes=0, argument
1546 if not config.omnistaging_enabled and out_axes != 0:
1550 if any(out_axis is None for out_axis in tree_flatten(out_axes)):
1555 if out_axes == 0:
1562 closure=out_axes)
1565 lambda: tuple(flatten_axes("pmap out_axes", out_tree(), out_axes)),
[all …]
H A Dcore.py1254 out_axes = out_axes_thunk()
1256 out_axes = t(out_axes)
1257 return out_axes
1517 out_axes = params["out_axes"]
1534 for aval, out_axis in zip(mapped_out_avals, out_axes)]
1742 out_axes = out_axes_thunk()
1744 out_axes = t(out_axes)
1745 return out_axes
H A Dlax_reference.py323 view, view_axes, rhs_axes, out_axes = _conv_view(
326 view, view_axes, rhs, rhs_axes, out_axes, use_blas=True)
365 out_axes = [0, view.ndim] + list(range(1, dim+1))
367 return view, view_axes, rhs_axes, out_axes
/dports/science/py-chainer/chainer-7.8.0/chainer/functions/connection/
H A Dconvolution_nd.py295 out_axes = (0,) + tuple(moves.range(2, self.ndim + 2))
315 gW = xp.tensordot(gy, col, (out_axes, col_axes)).astype(
/dports/math/py-numpy/numpy-1.20.3/numpy/core/src/umath/
H A Dufunc_object.c1029 PyObject **out_axes, /* type: List[Tuple[T]] */ in get_ufunc_arguments() argument
1048 if (out_axes != NULL) { in get_ufunc_arguments()
1049 *out_axes = NULL; in get_ufunc_arguments()
1122 _new_reference, out_axes, in get_ufunc_arguments()
1183 if (out_axes != NULL && out_axis != NULL && in get_ufunc_arguments()
1184 *out_axes != NULL && *out_axis != NULL) { in get_ufunc_arguments()
1217 if (out_axes != NULL) { in get_ufunc_arguments()
1218 Py_XDECREF(*out_axes); in get_ufunc_arguments()
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlax_numpy.py3461 func = jax.vmap(func, in_axes=i, out_axes=-1)
3463 func = jax.vmap(func, in_axes=0, out_axes=0)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py3127 out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
3129 tuple(out_axes))