PAX Basics

Module(*[, training, name])

The base class for all PAX modules.

EmptyNode()

Mark an uninitialized or deleted pytree node.

pure(func)

Make a function pure by copying the inputs.

purecall(module, *args, **kwargs)

Call a module and return the updated module.

seed_rng_key(seed)

Set self._rng = random.Random(seed).

next_rng_key()

Return a random rng key.

PAX’s Module

class pax.Module(*, training=True, name=None)[source]

The base class for all PAX modules.

Example:

>>> class Counter(pax.Module):
...     def __init__(self):
...         super().__init__()
...         self.count = jnp.array(0)
...
...     def step(self, x):
...         self.count += 1
__init__(*, training=True, name=None)[source]

Initialize module.

property training

Return True if a module is in training mode.

>>> net = pax.Linear(1, 1)
>>> net.training
True
>>> net = net.eval()
>>> net.training
False
Return type

bool

parameters()[source]

Return a new module with trainable weights only.

train()[source]

Return a module in training mode.

Return type

~T

eval()[source]

Return a module in evaluation mode.

Return type

~T

update_parameters(params)[source]

Return a new module with updated parameters.

Return type

~T

replace(**kwargs)[source]

Return a new module with some attributes replaced.

>>> net = pax.Linear(2, 2)
>>> net = net.replace(bias=jnp.zeros((2,)))
Return type

~T

replace_node(node, value)[source]

Replace a node of the pytree by a new value.

Example:

>>> mod = pax.Sequential(
...     pax.Linear(2,2),
...     jax.nn.relu
... )
>>> mod = mod.replace_node(mod[0].weight, jnp.zeros((2, 3)))
>>> print(mod[0].weight.shape)
(2, 3)
Return type

~T

summary(return_list=False)[source]

Summarize a module as a tree of its submodules.

Parameters

return_list (bool) – return a list of lines instead of a joined string.

>>> net = pax.Sequential(pax.Linear(2, 3), jax.nn.relu, pax.Linear(3, 4))
>>> print(net.summary())
Sequential
├── Linear(in_dim=2, out_dim=3, with_bias=True)
├── x => relu(x)
└── Linear(in_dim=3, out_dim=4, with_bias=True)
Return type

Union[str, List[str]]

apply(apply_fn)[source]

Apply a function to all submodules.

>>> def print_param_count(mod):
...     count = sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.size, mod)))
...     print(f"{count} {mod}")
...     return mod
...
>>> net = pax.Sequential(pax.Linear(1, 1), jax.nn.relu)
>>> net = net.apply(print_param_count)
2 Linear(in_dim=1, out_dim=1, with_bias=True)
0 Lambda(relu)
2 Sequential
Parameters

apply_fn – a function which inputs a module and outputs a transformed module.

Return type

~T

state_dict()[source]

Return module’s state dictionary.

Return type

Dict[str, Any]

load_state_dict(state_dict)[source]

Return a new module from the state dictionary.

Return type

~T

__mod__(args)[source]

An alternative to pax.module_and_value.

>>> bn = pax.BatchNorm1D(3)
>>> x = jnp.ones((5, 8, 3))
>>> bn, y = bn % x
>>> bn
BatchNorm1D(num_channels=3, ...)
Return type

Tuple[~T, Any]

__or__(other)[source]

Merge two modules.

>>> a = pax.Linear(2, 2)
>>> b = pax.Linear(2, 2)
>>> c = a | b
>>> d = b | a
>>> c == d
False
Return type

~T

class pax.ParameterModule(*, training=True, name=None)[source]

A PAX module that registers attributes as parameters by default.

parameters()[source]

Return a new module with trainable weights only.

class pax.StateModule(*, training=True, name=None)[source]

A PAX module that registers attributes as states by default.

class pax.EmptyNode[source]

Mark an uninitialized or deleted pytree node.

tree_flatten()[source]

Flatten empty node.

classmethod tree_unflatten(aux, children)[source]

Unflatten empty node.

__eq__(o)[source]

Return self==value.

Return type

bool

Purify functions and methods

pax.pure(func)[source]

Make a function pure by copying the inputs.

Any modification on the copy will not affect the original inputs.

Note: only functions that are wrapped by pax.pure are allowed to modify PAX’s Modules.

Example:

>>> f = pax.Linear(3,3)
>>> f.a_list = []
Traceback (most recent call last):
  ...
ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.
>>>
>>> @pax.pure
... def add_list(m):
...     m.a_list = []
...     return m
...
>>> f = add_list(f)
>>> print(f.a_list)
[]
Parameters

func (Callable) – A function.

Returns

A pure function.

pax.purecall(module, *args, **kwargs)[source]

Call a module and return the updated module.

A shortcut for pax.pure(lambda f, x: [f, f(x)]).

Return type

Tuple[Any, ~O]

Random Number Generator

seed_rng_key(seed)

Set self._rng = random.Random(seed).

next_rng_key()

Return a random rng key.

seed_rng_key

pax.seed_rng_key(seed)[source]

Set self._rng = random.Random(seed).

Parameters

seed (int) – an integer seed.

Return type

None

next_rng_key

pax.next_rng_key()[source]

Return a random rng key. Renew the global random state.

Return type

Union[Any, ndarray]