Interactive online version: Open In Colab

Improve performance

Even though jax.jit can eliminate almost all performance penalties related to PAX, there is a small cost of calling tree_flatten and tree_unflatten for the inputs and outputs of a jitted function.

In this tutorial, we will measure PAX’s performance. We also introduce practices that help to improve performance.

Note

PAX’s performance penalties are usually less than 1% of the training time. Most of the time, we can ignore it.

Let us start with a simple code for training a ResNet50 classifier.

[1]:
# uncomment the following line to install pax
# !pip install -q git+https://github.com/NTT123/pax.git
[2]:
import pax, jax, opax
import jax.numpy as jnp
from pax.nets import ResNet50

pax.seed_rng_key(42)


def loss_fn(model: ResNet50, inputs):
    images, labels = inputs
    model, logits = pax.purecall(model, images)
    log_pr = jax.nn.log_softmax(logits, axis=-1)
    loss = jnp.mean(jnp.sum(jax.nn.one_hot(labels, num_classes=10) * log_pr, axis=-1))
    return loss, (loss, model)


@jax.jit
def update(model, optimizer, inputs):
    grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(model, inputs)
    model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)
    return model, optimizer, loss


net = ResNet50(3, 10)
optimizer = opax.adam(1e-4)(net.parameters())

rng_key = jax.random.PRNGKey(42)
img = jax.random.normal(rng_key, (1, 3, 64, 64))
label = jax.random.randint(rng_key, (1,), 0, 10)
# net, optimizer, loss = update(net, optimizer, (img, label))
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[3]:
import time

start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_util.tree_flatten((net, optimizer))
    net, optimizer = jax.tree_util.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end - start)
Duration: 40.00314049300505

It takes 40.0 seconds to execute 10,000 iterations of tree_flatten and tree_unflatten.

This is approximately the extra time, which we have to wait, when training a ResNet50 network with an opax.adam optimizer for 10,000 iterations.

Flatten optimizer

One easy way to reduce the time is to use the flatten mode supported by opax optimizers.

[4]:
optimizer = opax.adam(1e-4)(net.parameters(), flatten=True)

In this mode, the optimizer will automatically flatten the parameters and gradients to a list of leaves instead of dealing with the full tree structure. This reduces the flatten and unflatten time of the optimizer to almost zero.

However, we are no longer able to access the optimizer’s pytree objects. Fortunately, we rarely need to access the optimizer’s pytree objects, and one can easily convert the flatten list back to the pytree object using jax.tree_util.tree_unflatten function.

[5]:
import time

start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_util.tree_flatten((net, optimizer))
    net, optimizer = jax.tree_util.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end - start)
Duration: 8.336228522995953
[6]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_util.tree_flatten(optimizer)
    optimizer = jax.tree_util.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end - start)
Duration: 0.3621630070047104

With flatten=True we reduce the time to only 8.3 seconds. And the time to flatten/unflatten the optimizer alone is close to zero (0.36 seconds).

Multi-step update function

Another solution to reduce the time for flatten/unflatten is to execute multiple update steps inside a jitted function.

[7]:
num_steps = 10


@jax.jit
def multistep_update(model, optimizer, inputs):
    def _step(m_o, i):
        m, o, aux = update(*m_o, i)
        return (m, o), aux

    (model, optimizer), losses = pax.scan(_step, (model, optimizer), inputs)
    return model, optimizer, jnp.mean(losses)


multistep_img = jax.random.normal(rng_key, (num_steps, 1, 3, 64, 64))
multistep_label = jax.random.randint(rng_key, (num_steps, 1), 0, 10)
# net, optimizer, loss = multistep_update(net, optimizer, (multistep_img, multistep_label))

The multistep_update function will execute multiple update steps in a single call. If num_steps=10, we can reduce the time by a factor of 10.

Note

The practice of executing multiple update steps inside a jitted function is also very useful for TPU training. It reduces the communication cost between CPU host and TPU cores, significantly.

Flatten model

We have reduced the time to flatten/unflatten the optimizer to almost zero. We can do the same thing for the model too.

The idea is simple: we want to put flatten and unflatten operations inside the update function.

[8]:
from functools import partial


@partial(jax.jit, static_argnums=0)
def flatten_update(model_def, model_leaves, optimizer, inputs):
    model = jax.tree_util.tree_unflatten(model_def, model_leaves)
    params = model.parameters()
    grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(params, model, inputs)
    model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)
    return jax.tree_util.tree_leaves(model), optimizer, loss
[9]:
net_leaves, net_def = jax.tree_util.tree_flatten(net)
# net_leaves, optimizer, loss = flatten_update(net_def, net_leaves, optimizer, (img, label))
[10]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_util.tree_flatten((net_leaves, optimizer))
    (net_leaves, optimizer) = jax.tree_util.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end - start)
Duration: 0.62999291500455

We now only wait an extra time of 0.63 seconds when training a ResNet50 for 10,000 steps.

However, we have to manually recreate the model from its leaves and tree_def.

[11]:
net = jax.tree_util.tree_unflatten(net_def, net_leaves)

PAX provides a similar functionality with pax.experimental.Flattener. It creates a new module with all parameters and states are flatten.

[12]:
@jax.jit
def flat_update_wrapper(flat_mods: pax.experimental.Flattener, inputs):
    model, optimizer = flat_mods.model, flat_mods.optimizer
    model, optimizer, loss = update(model, optimizer, inputs)
    return flat_mods.update(model=model, optimizer=optimizer), loss


flat_mods = pax.experimental.Flattener(model=net, optimizer=optimizer)
# flat_mods, loss = flat_update_wrapper(flat_mods, (img, label))
[13]:
start = time.perf_counter()
for i in range(10_000):
    a, b = jax.tree_util.tree_flatten(flat_mods)
    flat_mods = jax.tree_util.tree_unflatten(b, a)
end = time.perf_counter()
print("Duration:", end - start)
Duration: 0.6704415560016059