Home
last modified time | relevance | path

Searched refs:_CppDeviceArray (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py1023 _CppDeviceArray: DeviceArrayProtocol = xc.Buffer variable
1041 assert isinstance(device_buffer, _CppDeviceArray)
1055 return type_x is _DeviceArray or type_x is _CppDeviceArray
1169 for device_array in [_DeviceArray, _CppDeviceArray]:
1289 for device_array in [_CppDeviceArray, _DeviceArray]:
1302 xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler)
1307 device_put_handlers[_CppDeviceArray] = _device_put_device_array
H A Dpxla.py390 shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/
H A Dlax_numpy.py47 from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
5205 for device_array in [_DeviceArray, _CppDeviceArray]:
5233 setattr(_CppDeviceArray, "__array_module__", __array_module__)
5240 for device_array in [_DeviceArray, _CppDeviceArray]:
5250 setattr(_CppDeviceArray, "compress", _compress_method)
5270 setattr(_CppDeviceArray, "_multi_slice", _multi_slice)
5392 setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py1958 [xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray]):
1961 ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array