Interactive online version: Open In Colab

Getting started

This tutorial introduces basic concepts and practices of PAX.

[1]:
import jax
import jax.numpy as jnp
import pax
import matplotlib.pyplot as plt
from absl import logging
[2]:
logging.set_verbosity(logging.FATAL)
plt.rcParams["figure.figsize"] = (3, 2)

To demonstrate the basics of pax.Module, we will define a simple Linear module with an additional forward pass counter.

[3]:
class Linear(pax.Module):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)
        self.counter = jnp.array(0)

    parameters = pax.parameters_method("weight", "bias")

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        self.counter = self.counter + 1
        return self.weight * x + self.bias

    def __repr__(self):
        return f"Linear(weight={self.weight:.3f}, bias={self.bias:.3f}, counter={self.counter})"

The function call pax.parameters_method("weight", "bias") returns a parameters method that considers weight and bias as trainable parameters.

[4]:
net = Linear()
print(net)
Linear(weight=1.000, bias=0.000, counter=0)

Next, we will create a toy dataset for training purposes.

[5]:
def create_data(a=-3.0, b=1.5):
    x = jax.random.uniform(pax.next_rng_key(), (128, 1))
    noise = jax.random.normal(pax.next_rng_key(), x.shape) * 0.2
    y = a * x + b + noise
    return x, y
[6]:
pax.seed_rng_key(42)
x, y = create_data()

PAX keeps a global random key for each thread. We use pax.seed_rng_key(42) to seed that global random key. The function pax.next_rng_key() returns a random key and also renews the global random key.

[7]:
def plot_data(x, y):
    plt.scatter(x, y, label="data")
    plt.legend()
    plt.show()
[8]:
plot_data(x, y)
../_images/notebooks_basics_12_0.png

OK, let’s plot the initial predictions of our Linear model.

[9]:
def plot_prediction(x, y, y_hat):
    plt.scatter(x, y, label="data")
    plt.scatter(x, y_hat, label="prediction")
    plt.legend()
[10]:
y_hat = net(x)
plot_prediction(x, y, y_hat)
print(net)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-236420d1cd7e> in <module>
----> 1 y_hat = net(x)
      2 plot_prediction(x, y, y_hat)
      3 print(net)

<ipython-input-3-26b95aa36ade> in __call__(self, x)
     10
     11     def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
---> 12         self.counter = self.counter + 1
     13         return self.weight * x + self.bias
     14

/workspaces/pax/pax/_src/core/base.py in __setattr__(self, name, value)
    172     def __setattr__(self, name: str, value: Any) -> None:
    173         """Setting `_pax` attribute is forbidden."""
--> 174         self._assert_mutability()
    175
    176         if name == "_pax":

/workspaces/pax/pax/_src/core/base.py in _assert_mutability(self)
    154     def _assert_mutability(self):
    155         if not is_mutable(self):
--> 156             raise ValueError(
    157                 "Cannot modify a module in immutable mode.\n"
    158                 "Please do this computation inside a function decorated by `pax.pure`."

ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

Oops! PAX prevents us to update self.counter. Indeed, PAX modules are immutable by default. The only way to modify a module is to pass it as an argument to a function decorated by pax.pure. Below is a working implementation:

[11]:
@pax.pure
def run(net, x):
    y_hat = net(x)
    return net, y_hat

Note that pax.pure only allows run to access to a copy of its input module. This is PAX’s mechanism to purify a function. As a consequence, any modification on the copy will not affect the original module. Therefore, we have to return the updated net in the output.

[12]:
net, y_hat = run(net, x)
plot_prediction(x, y, y_hat)
print(net)
Linear(weight=1.000, bias=0.000, counter=1)
../_images/notebooks_basics_19_1.png

It is inconvenient that we have to manually create a function and decorate it with pax.pure every time we want to call a module. PAX provides the utility function pax.purecall that does the job for us. Below is a more convenient implementation:

[13]:
net, y_hat = pax.purecall(net, x)

To train our model, we need a loss function. In this case, we will define a mean squared error (MSE) loss function as follows:

[14]:
def mse_loss(model: Linear, x, y):
    model, y_hat = pax.purecall(model, x)
    loss = jnp.mean(jnp.square(y - y_hat))
    return loss, model

It is important that we return the updated model in the output of mse_loss. This is required to make the updated model available outside of the loss function.

Next, we will implement a simple stochastic gradient descent (SGD) optimizer to update trainable parameters of our model.

[15]:
def sgd(params: Linear, gradients: Linear, lr: float = 1e-1):
    updated_params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, gradients)
    return updated_params

Lastly, we need an update function that represents a single training step.

[16]:
@jax.jit
def update_fn(net: Linear, x, y):
    (loss, net), grads = pax.value_and_grad(mse_loss, has_aux=True)(net, x, y)
    params = net.parameters()
    params = sgd(params, grads)
    net = net.update_parameters(params)
    return net, loss

The function pax.value_and_grad is a thin wrapper of jax.value_and_grad. It computes gradients with respect to trainable parameters of the model. We use it to transform the mse_loss function into a function that returns both (loss, net) and the gradients grads.

Also, net.parameters() returns trainable parameters of the model and net.update_parameters(params) returns a new model with parameters updated.

OK, let’s train the model for 500 steps:

[17]:
losses = []

for step in range(500):
    net, loss = update_fn(net, x, y)
    losses.append(loss)

plt.plot(losses)
plt.xlabel("step")
plt.ylabel("loss")
plt.show()
../_images/notebooks_basics_29_0.png
[18]:
net
[18]:
Linear(weight=-3.023, bias=1.496, counter=502)

Finally, let’s plot the predictions of our trained model.

[19]:
net, y_hat = pax.purecall(net, x)
plot_prediction(x, y, y_hat)
../_images/notebooks_basics_32_0.png