Searched refs:device_put (Results 1 – 12 of 12) sorted by relevance
/dports/math/py-flax/flax-0.3.3/flax/ |
H A D | jax_utils.py | 61 buffers = [xla.device_put(x, device=d) for d in devices] 165 buffers = [xla.device_put(x, devices[i])
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | __init__.py | 39 device_put,
|
H A D | api.py | 2101 def device_put(x, device: Optional[xc.Device] = None): function 2184 buffers = [buf for x, d in zip(xs, devices) for buf in xla.device_put(x, d)] 2226 buf, = xla.device_put(x, devices[0])
|
/dports/math/py-jax/jax-0.2.9/jax/_src/scipy/sparse/ |
H A D | linalg.py | 21 from jax import lax, device_put 186 b, x0 = device_put((b, x0)) 593 b, x0 = device_put((b, x0))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 127 def device_put(x, device: Optional[Device] = None) -> Tuple[Any]: function 356 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 363 list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 840 input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 847 list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token)) 860 else h(*device_put(x, device)) for h, x in zip(handlers, outs)] 1390 return aval_to_result_handler(device, a)(*device_put(x, device))
|
H A D | pxla.py | 378 lambda x, devices, _: device_put(core.unit, devices, replicate=True) 380 return device_put([x[i] for i in indices], devices) 388 return device_put(shards, devices) 1148 device_buffers = device_put(val, devices, replicate=True) 1903 def device_put(x, devices: Sequence[xb.xla_client.Device], replicate: bool=False) -> List[xb.xla_cl… function 1906 return list(it.chain.from_iterable(xla.device_put(x, device) for device in devices)) 1908 …return list(it.chain.from_iterable(xla.device_put(val, device) for val, device in safe_zip(x, devi…
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1580 x_on_dev = api.device_put(d_idx, device=d)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 1463 return xla.array_result_handler(None, aval)(*xla.device_put(x))
|
/dports/textproc/wiggle/wiggle-1.3/tests/contrib/series/ |
H A D | merge | 24860 patches.suse/nvme-core-Fix-extra-device_put-call-on-error-path.patch
|
H A D | orig | 24859 patches.suse/nvme-core-Fix-extra-device_put-call-on-error-path.patch
|
H A D | ldiff | 24870 - patches.suse/nvme-core-Fix-extra-device_put-call-on-error-path.patch
|
H A D | diff | 24863 | patches.suse/nvme-core-Fix-extra-device_put-call-on-error-path.patch
|