Interactive online version: Open In Colab

JAX transformations

In this tutorial, we provide pieces of advice on mixing PAX and JAX transformations.

[1]:
import jax
import jax.numpy as jnp
import pax
from typing import Dict
from absl import logging
[2]:
%xmode Minimal
logging.set_verbosity(logging.FATAL)
Exception reporting mode: Minimal

JAX transformations have a similar effect on a function as pax.pure does. We can only access a copy of the inputs. Any modification on the copy will not affect the original inputs.

Let’s try with a simple example:

[3]:
def print_id_and_value(c: Dict[str, int], msg=""):
    print(f'({msg}) id {id(c)}  counter {c["count"]}')
[4]:
@jax.jit
def increase_counter(c):
    c["count"] += 1  # increase counter
    print_id_and_value(c, "inside")
[5]:
c = {"count": 1}
print_id_and_value(c, "before")
increase_counter(c)
print_id_and_value(c, "after ")
(before) id 140438671691264  counter 1
(inside) id 140437229406400  counter Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
(after ) id 140438671691264  counter 1

Note that, inside the jitted function increase_counter, the counter c is a different object (different id) compared to the counter c outside of the function. Therefore, modifications of c inside increase_counter will not affect the c outside.

This behavior is very similar to pax.pure. In fact, pax.pure mimics this behavior from JAX transformations.

Now, things get complicated when we use JAX transformations inside a function decorated by pax.pure.

In the following toy example, we have a RNN module that uses jax.lax.scan to scan the inputs with the function scan_fn. Besides, the scan_fn function also updates the internal state of a Counter module.

[6]:
class Counter(pax.StateModule):
    def __init__(self):
        self.count = jnp.array(0)

    def __call__(self, x):
        self.count = self.count + 1
        return x
[7]:
class RNN(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, xs):
        def scan_fn(c: Counter, x):
            y = c(x)
            return c, y

        _, y = jax.lax.scan(scan_fn, init=self.counter, xs=xs)
        return y
[8]:
rnn = RNN()
xs = jnp.arange(0, 10)
rnn, ys = pax.purecall(rnn, xs)
UnfilteredStackTrace: ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

Oops! PAX prevents us to update the counter even though we did run rnn with pax.purecall.

This is because jax.lax.scan, similar to jax.jit, executes the function scan_fn on a copy of its input modules. Moreover, this copy is immutable in our case.

Note

Only input modules of functions decorated by pax.pure are mutable. A copy of an input module is still immutable.

In this case, we have to use pax.purecall inside scan_fn. Below is a working implementation:

[9]:
class RNN(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, xs):
        def scan_fn(c: Counter, x):
            c, y = pax.purecall(c, x)
            return c, y

        self.counter, y = jax.lax.scan(scan_fn, init=self.counter, xs=xs)
        return y
[10]:
rnn = RNN()
xs = jnp.arange(0, 10)
rnn, ys = pax.purecall(rnn, xs)
print(f"Count = {rnn.counter.count}")
Count = 10

Now, let’s try another example. In the following, we have a jitted function fn trying to call self.counter.

[11]:
class BadModule(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, x):
        @jax.jit
        def fn(x):
            y = self.counter(x)
            return y

        y = fn(x)
        return y
[12]:
mod = BadModule()
x = jnp.array(0.0)
mod, y = pax.purecall(mod, x)
UnfilteredStackTrace: ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------


The above exception was the direct cause of the following exception:

ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.

In this example, PAX also prevents fn to modify self.counter. This is PAX’s mechanism to prevent leaks when a traced function at a higher level of abstraction trying to modify a module that is created at a lower level of abstraction.

Note

All modules created at lower levels of abstraction than the current level are immutable.

A correct implementation should pass self.counter as an argument to the function fn.

[13]:
class GoodModule(pax.Module):
    def __init__(self):
        self.counter = Counter()

    def __call__(self, x):
        @jax.jit
        def fn(c: Counter, x):
            c, y = pax.purecall(c, x)
            return c, y

        self.counter, y = fn(self.counter, x)
        return y
[14]:
mod = GoodModule()
x = jnp.array(0.0)
print(f"Count = {mod.counter.count}")
mod, y = pax.purecall(mod, x)
print(f"Count = {mod.counter.count}")
Count = 0
Count = 1