Searched refs:_CppDeviceArray (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 1023 _CppDeviceArray: DeviceArrayProtocol = xc.Buffer variable 1041 assert isinstance(device_buffer, _CppDeviceArray) 1055 return type_x is _DeviceArray or type_x is _CppDeviceArray 1169 for device_array in [_DeviceArray, _CppDeviceArray]: 1289 for device_array in [_CppDeviceArray, _DeviceArray]: 1302 xb.register_constant_handler(_CppDeviceArray, _device_array_constant_handler) 1307 device_put_handlers[_CppDeviceArray] = _device_put_device_array
|
H A D | pxla.py | 390 shard_arg_handlers[xla._CppDeviceArray] = _shard_device_array
|
/dports/math/py-jax/jax-0.2.9/jax/_src/numpy/ |
H A D | lax_numpy.py | 47 from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray 5205 for device_array in [_DeviceArray, _CppDeviceArray]: 5233 setattr(_CppDeviceArray, "__array_module__", __array_module__) 5240 for device_array in [_DeviceArray, _CppDeviceArray]: 5250 setattr(_CppDeviceArray, "compress", _compress_method) 5270 setattr(_CppDeviceArray, "_multi_slice", _multi_slice) 5392 setattr(_CppDeviceArray, "at", property(_IndexUpdateHelper))
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | lax.py | 1958 [xla._CppDeviceArray, xla._DeviceArray, pxla.ShardedDeviceArray]): 1961 ad_util.jaxval_zeros_likers[xla._CppDeviceArray] = zeros_like_array
|