• Home
  • History
  • Annotate
Name Date Size #Lines LOC

..03-May-2022-

jax/H07-May-2022-58,09544,169

jax.egg-info/H03-May-2022-1514

PKG-INFOH A D27-Jan-2021429 1514

README.mdH A D27-Jan-202120.8 KiB511403

setup.cfgH A D27-Jan-2021304 1916

setup.pyH A D27-Jan-20211.3 KiB4426

README.md

1<div align="center">
2<img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="logo"></img>
3</div>
4
5# JAX: Autograd and XLA
6
7![Continuous integration](https://github.com/google/jax/workflows/Continuous%20integration/badge.svg)
8![PyPI version](https://img.shields.io/pypi/v/jax)
9
10[**Quickstart**](#quickstart-colab-in-the-cloud)
11| [**Transformations**](#transformations)
12| [**Install guide**](#installation)
13| [**Neural net libraries**](#neural-network-libraries)
14| [**Change logs**](https://jax.readthedocs.io/en/latest/CHANGELOG.html)
15| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
16| [**Code search**](https://cs.opensource.google/jax/jax)
17
18
19**News:** [JAX tops largest-scale MLPerf Training 0.7 benchmarks!](https://cloud.google.com/blog/products/ai-machine-learning/google-breaks-ai-performance-records-in-mlperf-with-worlds-fastest-training-supercomputer)
20
21## What is JAX?
22
23JAX is [Autograd](https://github.com/hips/autograd) and
24[XLA](https://www.tensorflow.org/xla),
25brought together for high-performance machine learning research.
26
27With its updated version of [Autograd](https://github.com/hips/autograd),
28JAX can automatically differentiate native
29Python and NumPy functions. It can differentiate through loops, branches,
30recursion, and closures, and it can take derivatives of derivatives of
31derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
32via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
33and the two can be composed arbitrarily to any order.
34
35What’s new is that JAX uses
36[XLA](https://www.tensorflow.org/xla)
37to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
38under the hood by default, with library calls getting just-in-time compiled and
39executed. But JAX also lets you just-in-time compile your own Python functions
40into XLA-optimized kernels using a one-function API,
41[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
42composed arbitrarily, so you can express sophisticated algorithms and get
43maximal performance without leaving Python. You can even program multiple GPUs
44or TPU cores at once using [`pmap`](#spmd-programming-with-pmap), and
45differentiate through the whole thing.
46
47Dig a little deeper, and you'll see that JAX is really an extensible system for
48[composable function transformations](#transformations). Both
49[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
50are instances of such transformations. Others are
51[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
52[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
53parallel programming of multiple accelerators, with more to come.
54
55This is a research project, not an official Google product. Expect bugs and
56[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
57Please help by trying it out, [reporting
58bugs](https://github.com/google/jax/issues), and letting us know what you
59think!
60
61```python
62import jax.numpy as jnp
63from jax import grad, jit, vmap
64
65def predict(params, inputs):
66  for W, b in params:
67    outputs = jnp.dot(inputs, W) + b
68    inputs = jnp.tanh(outputs)
69  return outputs
70
71def logprob_fun(params, inputs, targets):
72  preds = predict(params, inputs)
73  return jnp.sum((preds - targets)**2)
74
75grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
76perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads
77```
78
79### Contents
80* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
81* [Transformations](#transformations)
82* [Current gotchas](#current-gotchas)
83* [Installation](#installation)
84* [Neural net libraries](#neural-network-libraries)
85* [Citing JAX](#citing-jax)
86* [Reference documentation](#reference-documentation)
87
88## Quickstart: Colab in the Cloud
89Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
90Here are some starter notebooks:
91- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
92- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
93
94**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
95Colabs](https://github.com/google/jax/tree/master/cloud_tpu_colabs).
96
97For a deeper dive into JAX:
98- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
99- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
100- See the [full list of
101notebooks](https://github.com/google/jax/tree/master/docs/notebooks).
102
103You can also take a look at [the mini-libraries in
104`jax.experimental`](https://github.com/google/jax/tree/master/jax/experimental/README.md),
105like [`stax` for building neural
106networks](https://github.com/google/jax/tree/master/jax/experimental/README.md#neural-net-building-with-stax)
107and [`optimizers` for first-order stochastic
108optimization](https://github.com/google/jax/tree/master/jax/experimental/README.md#first-order-optimization),
109or the [examples](https://github.com/google/jax/tree/master/examples).
110
111## Transformations
112
113At its core, JAX is an extensible system for transforming numerical functions.
114Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
115
116### Automatic differentiation with `grad`
117
118JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
119The most popular function is
120[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
121for reverse-mode gradients:
122
123```python
124from jax import grad
125import jax.numpy as jnp
126
127def tanh(x):  # Define a function
128  y = jnp.exp(-2.0 * x)
129  return (1.0 - y) / (1.0 + y)
130
131grad_tanh = grad(tanh)  # Obtain its gradient function
132print(grad_tanh(1.0))   # Evaluate it at x = 1.0
133# prints 0.4199743
134```
135
136You can differentiate to any order with `grad`.
137
138```python
139print(grad(grad(grad(tanh)))(1.0))
140# prints 0.62162673
141```
142
143For more advanced autodiff, you can use
144[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
145reverse-mode vector-Jacobian products and
146[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp) for
147forward-mode Jacobian-vector products. The two can be composed arbitrarily with
148one another, and with other JAX transformations. Here's one way to compose those
149to make a function that efficiently computes [full Hessian
150matrices](https://jax.readthedocs.io/en/latest/jax.html#jax.hessian):
151
152```python
153from jax import jit, jacfwd, jacrev
154
155def hessian(fun):
156  return jit(jacfwd(jacrev(fun)))
157```
158
159As with [Autograd](https://github.com/hips/autograd), you're free to use
160differentiation with Python control structures:
161
162```python
163def abs_val(x):
164  if x > 0:
165    return x
166  else:
167    return -x
168
169abs_val_grad = grad(abs_val)
170print(abs_val_grad(1.0))   # prints 1.0
171print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)
172```
173
174See the [reference docs on automatic
175differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
176and the [JAX Autodiff
177Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
178for more.
179
180### Compilation with `jit`
181
182You can use XLA to compile your functions end-to-end with
183[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
184used either as an `@jit` decorator or as a higher-order function.
185
186```python
187import jax.numpy as jnp
188from jax import jit
189
190def slow_f(x):
191  # Element-wise ops see a large benefit from fusion
192  return x * x + x * 2.0
193
194x = jnp.ones((5000, 5000))
195fast_f = jit(slow_f)
196%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
197%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)
198```
199
200You can mix `jit` and `grad` and any other JAX transformation however you like.
201
202Using `jit` puts constraints on the kind of Python control flow
203the function can use; see
204the [Gotchas
205Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
206for more.
207
208### Auto-vectorization with `vmap`
209
210[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
211the vectorizing map.
212It has the familiar semantics of mapping a function along array axes, but
213instead of keeping the loop on the outside, it pushes the loop down into a
214function’s primitive operations for better performance.
215
216Using `vmap` can save you from having to carry around batch dimensions in your
217code. For example, consider this simple *unbatched* neural network prediction
218function:
219
220```python
221def predict(params, input_vec):
222  assert input_vec.ndim == 1
223  activations = inputs
224  for W, b in params:
225    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
226    activations = jnp.tanh(outputs)
227  return outputs
228```
229
230We often instead write `jnp.dot(inputs, W)` to allow for a batch dimension on the
231left side of `inputs`, but we’ve written this particular prediction function to
232apply only to single input vectors. If we wanted to apply this function to a
233batch of inputs at once, semantically we could just write
234
235```python
236from functools import partial
237predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
238```
239
240But pushing one example through the network at a time would be slow! It’s better
241to vectorize the computation, so that at every layer we’re doing matrix-matrix
242multiplication rather than matrix-vector multiplication.
243
244The `vmap` function does that transformation for us. That is, if we write
245
246```python
247from jax import vmap
248predictions = vmap(partial(predict, params))(input_batch)
249# or, alternatively
250predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
251```
252
253then the `vmap` function will push the outer loop inside the function, and our
254machine will end up executing matrix-matrix multiplications exactly as if we’d
255done the batching by hand.
256
257It’s easy enough to manually batch a simple neural network without `vmap`, but
258in other cases manual vectorization can be impractical or impossible. Take the
259problem of efficiently computing per-example gradients: that is, for a fixed set
260of parameters, we want to compute the gradient of our loss function evaluated
261separately at each example in a batch. With `vmap`, it’s easy:
262
263```python
264per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
265```
266
267Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
268JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
269differentiation for fast Jacobian and Hessian matrix calculations in
270`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
271
272### SPMD programming with `pmap`
273
274For parallel programming of multiple accelerators, like multiple GPUs, use
275[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
276With `pmap` you write single-program multiple-data (SPMD) programs, including
277fast parallel collective communication operations. Applying `pmap` will mean
278that the function you write is compiled by XLA (similarly to `jit`), then
279replicated and executed in parallel across devices.
280
281Here's an example on an 8-GPU machine:
282
283```python
284from jax import random, pmap
285import jax.numpy as jnp
286
287# Create 8 random 5000 x 6000 matrices, one per GPU
288keys = random.split(random.PRNGKey(0), 8)
289mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
290
291# Run a local matmul on each device in parallel (no data transfer)
292result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)
293
294# Compute the mean on each device in parallel and print the result
295print(pmap(jnp.mean)(result))
296# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
297```
298
299In addition to expressing pure maps, you can use fast [collective communication
300operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
301between devices:
302
303```python
304from functools import partial
305from jax import lax
306
307@partial(pmap, axis_name='i')
308def normalize(x):
309  return x / lax.psum(x, 'i')
310
311print(normalize(jnp.arange(4.)))
312# prints [0.         0.16666667 0.33333334 0.5       ]
313```
314
315You can even [nest `pmap` functions](https://colab.sandbox.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more
316sophisticated communication patterns.
317
318It all composes, so you're free to differentiate through parallel computations:
319
320```python
321from jax import grad
322
323@pmap
324def f(x):
325  y = jnp.sin(x)
326  @pmap
327  def g(z):
328    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
329  return grad(lambda w: jnp.sum(g(w)))(x)
330
331print(f(x))
332# [[ 0.        , -0.7170853 ],
333#  [-3.1085174 , -0.4824318 ],
334#  [10.366636  , 13.135289  ],
335#  [ 0.22163185, -0.52112055]]
336
337print(grad(lambda x: jnp.sum(f(x)))(x))
338# [[ -3.2369726,  -1.6356447],
339#  [  4.7572474,  11.606951 ],
340#  [-98.524414 ,  42.76499  ],
341#  [ -1.6007166,  -1.2568436]]
342```
343
344When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
345backward pass of the computation is parallelized just like the forward pass.
346
347See the [SPMD
348Cookbook](https://colab.sandbox.google.com/github/google/jax/blob/master/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
349and the [SPMD MNIST classifier from scratch
350example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py)
351for more.
352
353## Current gotchas
354
355For a more thorough survey of current gotchas, with examples and explanations,
356we highly recommend reading the [Gotchas
357Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
358Some standouts:
359
3601. JAX transformations only work on [pure functions](https://en.wikipedia.org/wiki/Pure_function), which don't have side-effects and respect [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...`  or `Exception: Different traces at same level`.
3611. [In-place mutating updates of
362   arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
3631. [Random numbers are
364   different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers), but for [good reasons](https://github.com/google/jax/blob/master/design_notes/prng.md).
3651. If you're looking for [convolution
366   operators](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions),
367   they're in the `jax.lax` package.
3681. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
369   [to enable
370   double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
371   (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
372   startup (or set the environment variable `JAX_ENABLE_X64=True`).
3731. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
374   and NumPy types aren't preserved, namely `np.add(1, np.array([2],
375   np.float32)).dtype` is `float64` rather than `float32`.
3761. Some transformations, like `jit`, [constrain how you can use Python control
377   flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow).
378   You'll always get loud errors if something goes wrong. You might have to use
379   [`jit`'s `static_argnums`
380   parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
381   [structured control flow
382   primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
383   like
384   [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
385   or just use `jit` on smaller subfunctions.
386
387## Installation
388
389JAX is written in pure Python, but it depends on XLA, which needs to be
390installed as the `jaxlib` package. Use the following instructions to install a
391binary package with `pip`, or to build JAX from source.
392
393We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and
394macOS (10.12 or later) platforms. Windows users can use JAX on CPU and GPU via
395the
396[Windows Subsystem for Linux](https://docs.microsoft.com/en-us/windows/wsl/about).
397There is some initial native Windows support, but since it is still somewhat
398immature, there are no binary releases and it must be
399[built from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-jaxlib-from-source-on-windows).
400
401### pip installation
402
403To install a CPU-only version, which might be useful for doing local
404development on a laptop, you can run
405
406```bash
407pip install --upgrade pip
408pip install --upgrade jax jaxlib  # CPU-only version
409```
410
411On Linux, it is often necessary to first update `pip` to a version that supports
412`manylinux2010` wheels.
413
414If you want to install JAX with both CPU and NVidia GPU support, you must first
415install [CUDA](https://developer.nvidia.com/cuda-downloads) and
416[CuDNN](https://developer.nvidia.com/CUDNN),
417if they have not already been installed. Unlike some other popular deep
418learning systems, JAX does not bundle CUDA or CuDNN as part of the `pip`
419package. The CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of
420JAX require CuDNN 8. Other combinations of CUDA and CuDNN are possible but
421require building from source.
422
423Next, run
424
425```bash
426pip install --upgrade pip
427pip install --upgrade jax jaxlib==0.1.59+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
428```
429
430The jaxlib version must correspond to the version of the existing CUDA
431installation you want to use, with `cuda111` for CUDA 11.1, `cuda110` for CUDA
43211.0, `cuda102` for CUDA 10.2, and `cuda101` for CUDA 10.1. You can find your
433CUDA version with the command:
434
435```bash
436nvcc --version
437```
438
439Note that some GPU functionality expects the CUDA installation to be at
440`/usr/local/cuda-X.X`, where X.X should be replaced with the CUDA version number
441(e.g. `cuda-10.2`). If CUDA is installed elsewhere on your system, you can either
442create a symlink:
443
444```bash
445sudo ln -s /path/to/cuda /usr/local/cuda-X.X
446```
447
448Alternatively, you can set the following environment variable before importing
449JAX:
450
451```bash
452XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda
453```
454
455Please let us know on [the issue tracker](https://github.com/google/jax/issues)
456if you run into any errors or problems with the prebuilt wheels.
457
458### Building JAX from source
459See [Building JAX from
460source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
461
462## Neural network libraries
463
464Multiple Google research groups develop and share libraries for training neural
465networks in JAX. If you want a fully featured library for neural network
466training with examples and how-to guides, try
467[Flax](https://github.com/google/flax). Another option is
468[Trax](https://github.com/google/trax), a combinator-based framework focused on
469ease-of-use and end-to-end single-command examples, especially for sequence
470models and reinforcement learning. Finally,
471[Objax](https://github.com/google/objax) is a minimalist object-oriented
472framework with a PyTorch-like interface.
473
474DeepMind has open-sourced an ecosystem of libraries around JAX including
475[Haiku](https://github.com/deepmind/dm-haiku) for neural network modules,
476[Optax](https://github.com/deepmind/optax) for gradient processing and
477optimization, [RLax](https://github.com/deepmind/rlax) for RL algorithms, and
478[chex](https://github.com/deepmind/chex) for reliable code and testing.
479
480## Citing JAX
481
482To cite this repository:
483
484```
485@software{jax2018github,
486  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
487  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
488  url = {http://github.com/google/jax},
489  version = {0.2.5},
490  year = {2018},
491}
492```
493
494In the above bibtex entry, names are in alphabetical order, the version number
495is intended to be that from [jax/version.py](../master/jax/version.py), and
496the year corresponds to the project's open-source release.
497
498A nascent version of JAX, supporting only automatic differentiation and
499compilation to XLA, was described in a [paper that appeared at SysML
5002018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
501covering JAX's ideas and capabilities in a more comprehensive and up-to-date
502paper.
503
504## Reference documentation
505
506For details about the JAX API, see the
507[reference documentation](https://jax.readthedocs.io/).
508
509For getting started as a JAX developer, see the
510[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).
511