Interactive online version: Open In Colab

Understanding PAX’s module

This tutorial shows how to build a PAX-like module for neural network training from scratch.

Pytree

[1]:
from copy import copy
import jax
import numpy as np
import jax.numpy as jnp
from absl import logging
[2]:
logging.set_verbosity(logging.FATAL)

First, let’s talk about pytree.

Pytrees are tree-like structures that are constructed from Python object containers. Here are a few examples of pytree:

[3]:
a = 123
b = [1, 2, 3]
d = (1, 2, 3)
c = {"1": 1, "2": 2, "3": 3}
e = [(1, 2), "123", {"1": 1, "2": [4, 5]}]

JAX provides the jax.tree_util.tree_flatten function that transforms an object into its tree representation that includes:

  • leaves: a list of tree leaves.

  • treedef: information about the structure of the tree.

[4]:
leaves, treedef = jax.tree_util.tree_flatten(e)
print("Leaves:", leaves)
print("TreeDef:", treedef)
Leaves: [1, 2, '123', 1, 4, 5]
TreeDef: PyTreeDef([(*, *), *, {'1': *, '2': [*, *]}])

Note

Even though a pytree can have any object at its leaves, many jax functions such as jax.jit, jax.lax.scan, jax.grad, etc. only support pytrees with ndarray leaves.

We can reverse jax.tree_util.tree_flatten transformation with jax.tree_util.tree_unflatten:

[5]:
jax.tree_util.tree_unflatten(treedef=treedef, leaves=leaves)
[5]:
[(1, 2), '123', {'1': 1, '2': [4, 5]}]

A simple PAX module

Now let’s try to build a simple PAX module. The core idea here is that:

A module is also a pytree.

To let JAX knows how to flatten and unflatten a pytree module:

  1. It needs to implement two methods: tree_flatten and tree_unflatten.

  2. It is registered as a pytree node.

[6]:
@jax.tree_util.register_pytree_node_class
class ModuleV0:
    def __init__(self, mylist):
        self.mylist = mylist
        self.is_training = True

    def tree_flatten(self):
        chilren = [self.mylist]
        aux_info = {"is_training": self.is_training}
        return chilren, aux_info

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        new_object = cls.__new__(cls)
        new_object.mylist = children[0]
        new_object.is_training = aux_info["is_training"]
        return new_object

    def __repr__(self):
        name = self.__class__.__name__
        info = f"mylist={self.mylist}, is_training={self.is_training}"
        return f"{name}({info})"

The function jax.tree_util.register_pytree_node_class registers Module as a class of pytree nodes.

Let’s try to flatten and unflatten a module.

[7]:
mod = ModuleV0([1, 2, 3])
print(mod)
leaves, tree_def = jax.tree_util.tree_flatten(mod)
print(leaves, tree_def)
new_mod = jax.tree_util.tree_unflatten(tree_def, leaves)
new_mod
ModuleV0(mylist=[1, 2, 3], is_training=True)
[1, 2, 3] PyTreeDef(CustomNode(<class '__main__.ModuleV0'>[{'is_training': True}], [[*, *, *]]))
[7]:
ModuleV0(mylist=[1, 2, 3], is_training=True)

Note: is_training is considered as part of the PyTreeDef.

Introducing register_subtree method

OK, but our pytree module only supports mylist and is_training attributes. A real module for neural network training can have an arbitrary number of attributes.

Moreover, how can our module know that mylist is part of the subtree while is_training belongs to the tree definition?

One solution is:

  1. to keep a set (namely, tree_part_names) that tells if an attribute is part of the tree or not.

  2. users need to register if an attribute is part of the tree.

  3. any attribute that is not registered belongs to the tree definition.

[8]:
@jax.tree_util.register_pytree_node_class
class ModuleV1(ModuleV0):
    def __init__(self):
        self.tree_part_names = frozenset()
        self.is_training = True

    def tree_flatten(self):
        children = []
        others = []
        children_names = []

        for name, value in vars(self).items():
            if name in self.tree_part_names:
                children.append(value)
                children_names.append(name)
            else:
                others.append((name, value))
        return children, (children_names, others)

    @classmethod
    def tree_unflatten(cls, aux_info, children):
        children_names, others = aux_info
        new_object = cls.__new__(cls)
        new_object.__dict__.update(others)
        new_object.__dict__.update(zip(children_names, children))
        return new_object

    def register_subtree(self, name, value):
        self.__dict__[name] = value
        self.tree_part_names = self.tree_part_names.union([name])

    def __init_subclass__(cls):
        jax.tree_util.register_pytree_node_class(cls)

Our new module has register_subtree method that adds attribute’s name to the tree_part_names set.

The tree_flatten method lists all attributes of the object and checks if its name is in tree_part_names or not. If it is, its value will be added to the children list, otherwise, (name, value) will be added to the others list.

The tree_unflatten method combines information from others, children_names, and children to reconstruct the module.

Note:

  1. We purposely use frozenset to guarantee that any modification of tree_part_names in one module does not affect other modules. (However, this is not guaranteed for other attributes of the module.)

  2. __init_subclass__ ensures any subclass of Module is registered as pytree node.

Let’s try our module with a simple counter:

[9]:
class Counter(ModuleV1):
    def __init__(self):
        super().__init__()

        self.register_subtree("count", 0)

    def step(self):
        self.count = self.count + 1

    def __repr__(self):
        return f"{self.__class__.__name__}(count={self.count})"
[10]:
counter = Counter()
print(counter)
counter.step()
print(counter)
leaves, treedef = jax.tree_util.tree_flatten(counter)
print((leaves, treedef))
new_counter = jax.tree_util.tree_unflatten(treedef, leaves)
print(new_counter)
Counter(count=0)
Counter(count=1)
([1], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count'], [('tree_part_names', frozenset({'count'})), ('is_training', True)])], [*])))
Counter(count=1)

A custom parameters method

Our module does not have a way to select trainable parameters. We need this feature for gradient computation.

PAX’s solution is to let the user implement a parameters() method themself. For example:

[11]:
class ModuleV2(ModuleV1):
    def replace(self, **kwargs):
        mod = copy(self)
        for name, value in kwargs.items():
            setattr(mod, name, value)
        return mod

    def parameters(self):
        weights = {}
        for name in self.tree_part_names:
            value = getattr(self, name)
            value = value.parameters() if isinstance(value, ModuleV2) else None
            weights[name] = value
        return self.replace(**weights)
[12]:
class Linear(ModuleV2):
    def __init__(self):
        super().__init__()
        self.register_subtree("weight", jnp.array(1.0))
        self.register_subtree("bias", jnp.array(0.0))
        self.register_subtree("count", jnp.array(0))

    def parameters(self):
        return super().parameters().replace(weight=self.weight, bias=self.bias)

    def __call__(self, x):
        self.count += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[13]:
fc = Linear()
x = 2.0
y = fc(x)
print(fc)
print(fc.parameters())
Linear(weight=1.0, bias=0.0, count=1)
Linear(weight=1.0, bias=0.0, count=None)

However, it is a bit inconvenient that we have to implement a parameters method ourselves. Below is a utility function that does the job for us.

[14]:
def parameters_method(*trainable_weights):
    def _parameters(self):
        values = {name: getattr(self, name) for name in trainable_weights}
        return super(self.__class__, self).parameters().replace(**values)

    return _parameters
[15]:
class Linear(ModuleV2):
    def __init__(self):
        super().__init__()
        self.register_subtree("weight", jnp.array(1.0))
        self.register_subtree("bias", jnp.array(0.0))
        self.register_subtree("count", jnp.array(0))

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"

Find and register subtrees

It is inconvenient that we have to register subtrees manually, we can have a method that detects subtree attributes for us.

[16]:
class ModuleV3(ModuleV2):
    def find_and_register_subtree(self):
        for name, value in self.__dict__.items():
            is_pytree = lambda x: isinstance(x, (np.ndarray, jnp.ndarray, ModuleV3))
            leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_pytree)
            if any(map(is_pytree, leaves)):
                self.register_subtree(name, value)
[17]:
class Linear(ModuleV3):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)
        self.count = jnp.array(0)
        self.find_and_register_subtree()

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[18]:
fc = Linear()
fc.tree_part_names
[18]:
frozenset({'bias', 'count', 'weight'})

Metaclass

We can get rid of calling self.find_and_register_subtree() explicitly by using metaclass.

[19]:
class ModuleMetaclass(type):
    def __call__(cls, *args, **kwargs):
        module = cls.__new__(cls, *args, **kwargs)
        cls.__init__(module, *args, **kwargs)
        module.find_and_register_subtree()
        return module
[20]:
class ModuleV4(ModuleV3, metaclass=ModuleMetaclass):
    pass
[21]:
class Linear(ModuleV4):
    def __init__(self):
        super().__init__()
        self.weight = jnp.array(1.0)
        self.bias = jnp.array(0.0)
        self.count = jnp.array(0)

    parameters = parameters_method("weight", "bias")

    def __call__(self, x):
        self.counter += 1
        return x * self.weight + self.bias

    def __repr__(self):
        return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[22]:
fc = Linear()
fc.tree_part_names
[22]:
frozenset({'bias', 'count', 'weight'})