Home
last modified time | relevance | path

Searched refs:device_put_p (Results 1 – 7 of 7) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py1392 device_put_p = core.Primitive('device_put') variable
1393 device_put_p.def_impl(_device_put_impl)
1394 device_put_p.def_abstract_eval(lambda x, device=None: x)
1395 translations[device_put_p] = lambda c, x, device=None: x
1396 ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
1397 masking.defvectorized(device_put_p)
H A Dbatching.py381 defvectorized(xla.device_put_p)
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Ddoubledouble.py276 _def_passthrough(xla.device_put_p)
H A Djet.py245 deflinear(xla.device_put_p)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py2119 return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dprimitive_harness.py622 lambda x: xla.device_put_p.bind(x, device=_device_fn()),
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py849 tf_impl[xla.device_put_p] = lambda x, device=None: x