Google JAX is a machine learning framework for transforming numerical functions.
It is described as bringing together a modified version o
autograd(automatic obtaining of the gradient function through differentiation of a function) and
TensorFlow
TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks. "It is machine learning ...
'
XLA(Accelerated Linear Algebra). It is designed to follow the structure and workflow of
NumPy as closely as possible and works with various existing frameworks such as
TensorFlow
TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks. "It is machine learning ...
and
PyTorch
PyTorch is a machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is free and op ...
. The primary functions of JAX are:
# grad: automatic differentiation
# jit: compilation
# vmap: auto-vectorization
# pmap: SPMD programming
grad
The below code demonstrates the grad function's automatic differentiation.
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
The final line should outputː
0.19661194
jit
The below code demonstrates the jit function's optimization through fusion.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for spreed comoparion
cube(x)
jit_cube(x)
The computation time for jit_cube (line no.17) should be noticeably shorter than that for cube (line no.16). Increasing the values on line no. 7, will further exacerbate the difference.
vmap
The below code demonstrates the vmap function's vectorization.
# imports
from functools import partial
from jax import vmap
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim 2 and flat_grads.shape inputs.shape return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The below code demonstrates the pmap function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
The final line should print the valuesː
.1566595 1.1805978
Libraries using Jax
Several python libraries use Jax as a backend, including:
* Flax, a high level
neural network
A neural network is a network or neural circuit, circuit of biological neurons, or, in a modern sense, an artificial neural network, composed of artificial neurons or nodes. Thus, a neural network is either a biological neural network, made up ...
library initially developed by
Google Brain
Google Brain is a deep learning artificial intelligence research team under the umbrella of Google AI, a research division at Google dedicated to artificial intelligence. Formed in 2011, Google Brain combines open-ended machine learning research ...
.
* Haiku, an
object-oriented
Object-oriented programming (OOP) is a programming paradigm based on the concept of " objects", which can contain data and code. The data is in the form of fields (often known as attributes or ''properties''), and the code is in the form of ...
library for
neural networks
A neural network is a network or circuit of biological neurons, or, in a modern sense, an artificial neural network, composed of artificial neurons or nodes. Thus, a neural network is either a biological neural network, made up of biological ...
developed by
DeepMind
DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014 and became a wholly owned subsidiary of Alphabet Inc, after Google's restru ...
.
* Equinox, a library that revolves around the idea of representing parameterised functions (including
neural networks
A neural network is a network or circuit of biological neurons, or, in a modern sense, an artificial neural network, composed of artificial neurons or nodes. Thus, a neural network is either a biological neural network, made up of biological ...
) as PyTrees. It was created by Patrick Kidger.
* Optax, a library for gradient processing and
optimisation
Mathematical optimization (alternatively spelled ''optimisation'') or mathematical programming is the selection of a best element, with regard to some criterion, from some set of available alternatives. It is generally divided into two subfi ...
developed by
DeepMind
DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014 and became a wholly owned subsidiary of Alphabet Inc, after Google's restru ...
.
* RLax, a library for developing
reinforcement learning
Reinforcement learning (RL) is an area of machine learning concerned with how intelligent agents ought to take actions in an environment in order to maximize the notion of cumulative reward. Reinforcement learning is one of three basic machine ...
agents developed by
DeepMind
DeepMind Technologies is a British artificial intelligence subsidiary of Alphabet Inc. and research laboratory founded in 2010. DeepMind was acquired by Google in 2014 and became a wholly owned subsidiary of Alphabet Inc, after Google's restru ...
.
See also
*
NumPy
*
TensorFlow
TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks. "It is machine learning ...
*
PyTorch
PyTorch is a machine learning framework based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is free and op ...
*
CUDA
CUDA (or Compute Unified Device Architecture) is a parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for general purpose processing, an approach ...
*
Automatic differentiation
In mathematics and computer algebra, automatic differentiation (AD), also called algorithmic differentiation, computational differentiation, auto-differentiation, or simply autodiff, is a set of techniques to evaluate the derivative of a function ...
*
Just-in-time compilation
In computing, just-in-time (JIT) compilation (also dynamic translation or run-time compilations) is a way of executing computer code that involves compiler, compilation during execution of a program (at run time (program lifecycle phase), run tim ...
*
Vectorization
*
Automatic parallelization
Automatic may refer to:
Music Bands
* Automatic (band), Australian rock band
* Automatic (American band), American rock band
* The Automatic, a Welsh alternative rock band
Albums
* ''Automatic'' (Jack Bruce album), a 1983 electronic rock ...
External links
* Documentationː
* Colab (
Jupyter
Project Jupyter () is a project with goals to develop open-source software, open standards, and services for interactive computing across multiple programming languages. It was spun off from IPython in 2014 by Fernando Pérez and Brian Granger. ...
/iPython) Quickstart Guideː
*
TensorFlow
TensorFlow is a free and open-source software library for machine learning and artificial intelligence. It can be used across a range of tasks but has a particular focus on training and inference of deep neural networks. "It is machine learning ...
's XLAː (Accelerated Linear Algebra)
*
* Original paperː
References
{{Google LLC
Machine learning
Google