Article · Wikipedia archive · Last revised May 31, 2026

JAX (software)

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.

Last revised
May 31, 2026
Read time
≈ 2 min
Length
413 w
Citations
10
Source
JAX
Original authorGoogle
DevelopersGoogle and JAX developers
Written inPython, C++, CUDA
Operating systemLinux, macOS, Windows
Platformx86-64, ARM, GPU, TPU
TypeNumerical computing, machine learning
LicenseApache 2.0
Websitejax.dev
Repositorygithub.com/jax-ml/jax

JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by Google with contributions from Nvidia and other community contributors.123

It is described as bringing together a modified version of the automatic differentiation system autograd4 and OpenXLA's XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.56 The primary features of JAX are:7

  1. Providing a unified NumPy-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
  2. Built-in Just-In-Time (JIT) compilation via OpenXLA, an open-source machine learning compiler ecosystem.
  3. Efficient evaluation of gradients via its automatic differentiation transformations.
  4. Automatic vectorization to efficiently map functions over arrays representing batches of inputs.

Libraries using Jax

  • Flax 8
  • Equinox 9
  • Optax 10
See also

See also

External links
References

References

  1. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao (2022-06-18), "JAX: Autograd and XLA", Astrophysics Source Code Library, Google, Bibcode:2021ascl.soft11002B, archived from the original on 2022-06-18, retrieved 2022-06-18
  2. Frostig, Roy; Johnson, Matthew James; Leary, Chris (2018-02-02). "Compiling machine learning programs via high-level tracing" (PDF). MLsys: 1–3. Archived (PDF) from the original on 2022-06-21.
  3. "Using JAX to accelerate our research". www.deepmind.com. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  4. HIPS/autograd, Formerly: Harvard Intelligent Probabilistic Systems Group -- Now at Princeton, 2026-02-17, retrieved 2026-02-17
  5. Lynley, Matthew. "Google is quietly replacing the backbone of its AI product strategy after its last big push for dominance got overshadowed by Meta". Business Insider. Archived from the original on 2022-06-21. Retrieved 2022-06-21.
  6. "Why is Google's JAX so popular?". Analytics India Magazine. 2022-04-25. Archived from the original on 2022-06-18. Retrieved 2022-06-18.
  7. "Quickstart — JAX documentation".
  8. Flax: A neural network library and ecosystem for JAX designed for flexibility, Google, 2022-07-29, archived from the original on 2022-09-03, retrieved 2022-07-29
  9. Kidger, Patrick (2022-07-29), Equinox, archived from the original on 2023-09-19, retrieved 2022-07-29
  10. Optax, DeepMind, 2022-07-28, archived from the original on 2023-06-07, retrieved 2022-07-29