Searched refs:device_put_p (Results 1 – 7 of 7) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 1392 device_put_p = core.Primitive('device_put') variable 1393 device_put_p.def_impl(_device_put_impl) 1394 device_put_p.def_abstract_eval(lambda x, device=None: x) 1395 translations[device_put_p] = lambda c, x, device=None: x 1396 ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) 1397 masking.defvectorized(device_put_p)
|
H A D | batching.py | 381 defvectorized(xla.device_put_p)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | doubledouble.py | 276 _def_passthrough(xla.device_put_p)
|
H A D | jet.py | 245 deflinear(xla.device_put_p)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 2119 return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/ |
H A D | primitive_harness.py | 622 lambda x: xla.device_put_p.bind(x, device=_device_fn()),
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 849 tf_impl[xla.device_put_p] = lambda x, device=None: x
|