Numpy vs Jax vs Cupy
In this blog I will describe the benefits and drawbacks of using NUMPY, JAX, or CUPY for working with multidimensional arrays in Python.
NUMPY
NUMPY is the basis for the implementation of NUMPY, JAX, and CUPY. NUMPY allows us to work with multidimensional arrays in the CPU, these arrays are known as ndarray. NUMPY operations cannot be done in the GPU directly, they need libraries like JAX and CUPY for this implementation. The benefit of using GPU implementations comes when working with big datasets (e.g., offline processing). Some of the requirements for ndarrays are:
- They have a fixed size at creation, unlike Python lists (which can grow dynamically). Changing the size of a
ndarraywill create a new array and delete the original. - The elements in a NumPy array are all required to be of the same data type, and thus will be the same size in memory. The exception: one can have arrays of (Python, including NumPy) objects, thereby allowing for arrays of different sized elements.
- NumPy arrays facilitate advanced mathematical and other types of operations on large numbers of data. Typically, such operations are executed more efficiently and with less code than is possible using Python’s built-in sequences.
The full documentation of NUMPY can be accessed here).
JAX
JAX is a pre-complied implementation of NUMPY that can run on CPU, GPU, and TPU systems. Jax uses XLA (Accelerated linear algebra) to compile and run NUMPY code on accelerators (i.e., GPUs and TPUs). Compilation can happen just-in-time (after the first call of the function) or using the @jit function (before the first call of the function). There is limited access to SCIPY functions in JAX. Some characteristics of JAX are:
- JAX compilation calls will try to run on GPU or TPU first. If these processors are unavailable, the CPU will be used.
jax.numpyarrays are immutable (i.e., they cannot change their values or size). If you want to change a previously made array you can use the.at[x].set(y)call to make a copy of the array and update its values.jax.numpyis a relaxed version that allows for different data types to interact. A faster lower-level API isjax.lax; this requires casting of the same data types to work. See the example below:
# Relax, slower method
import jax.numpy as jnp
jnp.add(1, 1.0)
# Strict, faster method
from jax import lax
lax.add(1, 1.0) # Will give an error
lax.add(jnp.float32(1), 1.0) # Will run
- When using the
@jitfunction to pre-compile code the size of the input and output to a function cannot change size, array shapes must be static and known at compile time. - JAX functions are meant to work with pure functions (i.e., functions where all the input data is passed through the function parameters, and the results are output through the function results). This means that the data used in the function must be an input or be defined within the function, you cannot use variables from outside the function that are NOT an input (e.g., global variables).
The full JAX documentation can be found here.
CUPY
CUPY is a NUMPY/SCIPY-compatible array library for using NVIDA GPUs in Python. The main object used in CUPY is cupy.ndarray. Some characteristics of CUPY are described below:
- CUPY agnostic code can be defined as follows:
import numpy as np
import cupy as cp
def fun(x):
xp = cp.get_array_module(x)
return print(f"Using {xp.__name__} module")
fun(np.ndarray(1)) # Using numpy module
fun(cp.array(1)) # Using cupy module
- Data needs to be transferred from RAM to the GPU RAM (i.e.,
y = cp.asarray(x)) or vice-versa (i.e.,x = cp.ndarray(y)). - For faster implementation in GPU the
cp.float32data type MUST be used. - There is a Scikit-learn implementation in CUDA, but not in CUPY. However, this tool might be able to be used with CPU to run ML kernels in the GPU.
The full CUPY documentation can be found here.
Conclusion
- GPUs are more efficient with big datasets
- JAX can be used to perform NUMPY routines faster
- CUPY can write CPU/GPU agnostic code
- JAX nor CUPY have direct implementation with Scikit-learn. However, there is a tool that has CUDA implementation.