Understanding PAX’s module¶
This tutorial shows how to build a PAX-like module for neural network training from scratch.
Pytree¶
[1]:
from copy import copy
import jax
import numpy as np
import jax.numpy as jnp
from absl import logging
[2]:
logging.set_verbosity(logging.FATAL)
First, let’s talk about pytree.
Pytrees are tree-like structures that are constructed from Python object containers. Here are a few examples of pytree:
[3]:
a = 123
b = [1, 2, 3]
d = (1, 2, 3)
c = {"1": 1, "2": 2, "3": 3}
e = [(1, 2), "123", {"1": 1, "2": [4, 5]}]
JAX provides the jax.tree_util.tree_flatten function that transforms an object into its tree representation that includes:
leaves: a list of tree leaves.treedef: information about the structure of the tree.
[4]:
leaves, treedef = jax.tree_util.tree_flatten(e)
print("Leaves:", leaves)
print("TreeDef:", treedef)
Leaves: [1, 2, '123', 1, 4, 5]
TreeDef: PyTreeDef([(*, *), *, {'1': *, '2': [*, *]}])
Note
Even though a pytree can have any object at its leaves, many jax functions such as jax.jit, jax.lax.scan, jax.grad, etc. only support pytrees with ndarray leaves.
We can reverse jax.tree_util.tree_flatten transformation with jax.tree_util.tree_unflatten:
[5]:
jax.tree_util.tree_unflatten(treedef=treedef, leaves=leaves)
[5]:
[(1, 2), '123', {'1': 1, '2': [4, 5]}]
A simple PAX module¶
Now let’s try to build a simple PAX module. The core idea here is that:
A module is also a pytree.
To let JAX knows how to flatten and unflatten a pytree module:
It needs to implement two methods:
tree_flattenandtree_unflatten.It is registered as a pytree node.
[6]:
@jax.tree_util.register_pytree_node_class
class ModuleV0:
def __init__(self, mylist):
self.mylist = mylist
self.is_training = True
def tree_flatten(self):
chilren = [self.mylist]
aux_info = {"is_training": self.is_training}
return chilren, aux_info
@classmethod
def tree_unflatten(cls, aux_info, children):
new_object = cls.__new__(cls)
new_object.mylist = children[0]
new_object.is_training = aux_info["is_training"]
return new_object
def __repr__(self):
name = self.__class__.__name__
info = f"mylist={self.mylist}, is_training={self.is_training}"
return f"{name}({info})"
The function jax.tree_util.register_pytree_node_class registers Module as a class of pytree nodes.
Let’s try to flatten and unflatten a module.
[7]:
mod = ModuleV0([1, 2, 3])
print(mod)
leaves, tree_def = jax.tree_util.tree_flatten(mod)
print(leaves, tree_def)
new_mod = jax.tree_util.tree_unflatten(tree_def, leaves)
new_mod
ModuleV0(mylist=[1, 2, 3], is_training=True)
[1, 2, 3] PyTreeDef(CustomNode(<class '__main__.ModuleV0'>[{'is_training': True}], [[*, *, *]]))
[7]:
ModuleV0(mylist=[1, 2, 3], is_training=True)
Note: is_training is considered as part of the PyTreeDef.
Introducing register_subtree method¶
OK, but our pytree module only supports mylist and is_training attributes. A real module for neural network training can have an arbitrary number of attributes.
Moreover, how can our module know that mylist is part of the subtree while is_training belongs to the tree definition?
One solution is:
to keep a set (namely,
tree_part_names) that tells if an attribute is part of the tree or not.users need to register if an attribute is part of the tree.
any attribute that is not registered belongs to the tree definition.
[8]:
@jax.tree_util.register_pytree_node_class
class ModuleV1(ModuleV0):
def __init__(self):
self.tree_part_names = frozenset()
self.is_training = True
def tree_flatten(self):
children = []
others = []
children_names = []
for name, value in vars(self).items():
if name in self.tree_part_names:
children.append(value)
children_names.append(name)
else:
others.append((name, value))
return children, (children_names, others)
@classmethod
def tree_unflatten(cls, aux_info, children):
children_names, others = aux_info
new_object = cls.__new__(cls)
new_object.__dict__.update(others)
new_object.__dict__.update(zip(children_names, children))
return new_object
def register_subtree(self, name, value):
self.__dict__[name] = value
self.tree_part_names = self.tree_part_names.union([name])
def __init_subclass__(cls):
jax.tree_util.register_pytree_node_class(cls)
Our new module has register_subtree method that adds attribute’s name to the tree_part_names set.
The tree_flatten method lists all attributes of the object and checks if its name is in tree_part_names or not. If it is, its value will be added to the children list, otherwise, (name, value) will be added to the others list.
The tree_unflatten method combines information from others, children_names, and children to reconstruct the module.
Note:
We purposely use
frozensetto guarantee that any modification oftree_part_namesin one module does not affect other modules. (However, this is not guaranteed for other attributes of the module.)__init_subclass__ensures any subclass ofModuleis registered as pytree node.
Let’s try our module with a simple counter:
[9]:
class Counter(ModuleV1):
def __init__(self):
super().__init__()
self.register_subtree("count", 0)
def step(self):
self.count = self.count + 1
def __repr__(self):
return f"{self.__class__.__name__}(count={self.count})"
[10]:
counter = Counter()
print(counter)
counter.step()
print(counter)
leaves, treedef = jax.tree_util.tree_flatten(counter)
print((leaves, treedef))
new_counter = jax.tree_util.tree_unflatten(treedef, leaves)
print(new_counter)
Counter(count=0)
Counter(count=1)
([1], PyTreeDef(CustomNode(<class '__main__.Counter'>[(['count'], [('tree_part_names', frozenset({'count'})), ('is_training', True)])], [*])))
Counter(count=1)
A custom parameters method¶
Our module does not have a way to select trainable parameters. We need this feature for gradient computation.
PAX’s solution is to let the user implement a parameters() method themself. For example:
[11]:
class ModuleV2(ModuleV1):
def replace(self, **kwargs):
mod = copy(self)
for name, value in kwargs.items():
setattr(mod, name, value)
return mod
def parameters(self):
weights = {}
for name in self.tree_part_names:
value = getattr(self, name)
value = value.parameters() if isinstance(value, ModuleV2) else None
weights[name] = value
return self.replace(**weights)
[12]:
class Linear(ModuleV2):
def __init__(self):
super().__init__()
self.register_subtree("weight", jnp.array(1.0))
self.register_subtree("bias", jnp.array(0.0))
self.register_subtree("count", jnp.array(0))
def parameters(self):
return super().parameters().replace(weight=self.weight, bias=self.bias)
def __call__(self, x):
self.count += 1
return x * self.weight + self.bias
def __repr__(self):
return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[13]:
fc = Linear()
x = 2.0
y = fc(x)
print(fc)
print(fc.parameters())
Linear(weight=1.0, bias=0.0, count=1)
Linear(weight=1.0, bias=0.0, count=None)
However, it is a bit inconvenient that we have to implement a parameters method ourselves. Below is a utility function that does the job for us.
[14]:
def parameters_method(*trainable_weights):
def _parameters(self):
values = {name: getattr(self, name) for name in trainable_weights}
return super(self.__class__, self).parameters().replace(**values)
return _parameters
[15]:
class Linear(ModuleV2):
def __init__(self):
super().__init__()
self.register_subtree("weight", jnp.array(1.0))
self.register_subtree("bias", jnp.array(0.0))
self.register_subtree("count", jnp.array(0))
parameters = parameters_method("weight", "bias")
def __call__(self, x):
self.counter += 1
return x * self.weight + self.bias
def __repr__(self):
return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
Find and register subtrees¶
It is inconvenient that we have to register subtrees manually, we can have a method that detects subtree attributes for us.
[16]:
class ModuleV3(ModuleV2):
def find_and_register_subtree(self):
for name, value in self.__dict__.items():
is_pytree = lambda x: isinstance(x, (np.ndarray, jnp.ndarray, ModuleV3))
leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_pytree)
if any(map(is_pytree, leaves)):
self.register_subtree(name, value)
[17]:
class Linear(ModuleV3):
def __init__(self):
super().__init__()
self.weight = jnp.array(1.0)
self.bias = jnp.array(0.0)
self.count = jnp.array(0)
self.find_and_register_subtree()
parameters = parameters_method("weight", "bias")
def __call__(self, x):
self.counter += 1
return x * self.weight + self.bias
def __repr__(self):
return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[18]:
fc = Linear()
fc.tree_part_names
[18]:
frozenset({'bias', 'count', 'weight'})
Metaclass¶
We can get rid of calling self.find_and_register_subtree() explicitly by using metaclass.
[19]:
class ModuleMetaclass(type):
def __call__(cls, *args, **kwargs):
module = cls.__new__(cls, *args, **kwargs)
cls.__init__(module, *args, **kwargs)
module.find_and_register_subtree()
return module
[20]:
class ModuleV4(ModuleV3, metaclass=ModuleMetaclass):
pass
[21]:
class Linear(ModuleV4):
def __init__(self):
super().__init__()
self.weight = jnp.array(1.0)
self.bias = jnp.array(0.0)
self.count = jnp.array(0)
parameters = parameters_method("weight", "bias")
def __call__(self, x):
self.counter += 1
return x * self.weight + self.bias
def __repr__(self):
return f"{self.__class__.__name__}(weight={self.weight}, bias={self.bias}, count={self.count})"
[22]:
fc = Linear()
fc.tree_part_names
[22]:
frozenset({'bias', 'count', 'weight'})