Numpy vs Jax vs Cupy

In this blog I will describe the benefits and drawbacks of using NUMPY, JAX, or CUPY for working with multidimensional arrays in Python.

NUMPY

NUMPY is the basis for the implementation of NUMPY, JAX, and CUPY. NUMPY allows us to work with multidimensional arrays in the CPU, these arrays are known as ndarray. NUMPY operations cannot be done in the GPU directly, they need libraries like JAX and CUPY for this implementation. The benefit of using GPU implementations comes when working with big datasets (e.g., offline processing). Some of the requirements for ndarrays are:

The full documentation of NUMPY can be accessed here).

JAX

JAX is a pre-complied implementation of NUMPY that can run on CPU, GPU, and TPU systems. Jax uses XLA (Accelerated linear algebra) to compile and run NUMPY code on accelerators (i.e., GPUs and TPUs). Compilation can happen just-in-time (after the first call of the function) or using the @jit function (before the first call of the function). There is limited access to SCIPY functions in JAX. Some characteristics of JAX are:

# Relax, slower method
import jax.numpy as jnp
jnp.add(1, 1.0)
# Strict, faster method
from jax import lax
lax.add(1, 1.0)               # Will give an error
lax.add(jnp.float32(1), 1.0)  # Will run

The full JAX documentation can be found here.

CUPY

CUPY is a NUMPY/SCIPY-compatible array library for using NVIDA GPUs in Python. The main object used in CUPY is cupy.ndarray. Some characteristics of CUPY are described below:

import numpy as np
import cupy as cp
def fun(x):
    xp = cp.get_array_module(x)
    return print(f"Using {xp.__name__} module")
fun(np.ndarray(1))  # Using numpy module
fun(cp.array(1))    # Using cupy module

The full CUPY documentation can be found here.

Conclusion