PAX Basics

Module([name])

The base class for all PAX modules.

pure(func)

Make a function pure by copying the inputs.

module_and_value(module_or_method)

Return a pure function that executes a module’s method.

seed_rng_key(seed)

Set self._rng = random.Random(seed).

next_rng_key()

Return a random rng key.

PAX’s Module

class pax.Module(name=None)[source]

The base class for all PAX modules.

Example:

>>> class Counter(pax.Module):
...     def __init__(self):
...         super().__init__()
...         self.count = jnp.array(0)
...
...     def step(self, x):
...         self.count += 1
__init__(name=None)[source]

Initialize module.

>>> linear = pax.Linear(3, 3, name="input_layer")
>>> print(linear)
(input_layer) Linear(in_dim=3, out_dim=3, with_bias=True)
property training

Return True if a module is in training mode.

>>> net = pax.Linear(1, 1)
>>> net.training
True
>>> net = net.eval()
>>> net.training
False
Return type

bool

parameters()[source]

Return a new module with trainable weights only.

train()[source]

Return a module in training mode.

Return type

~T

eval()[source]

Return a module in evaluation mode.

Return type

~T

update_parameters(params)[source]

Return a new module with updated parameters.

Return type

~T

replace(**kwargs)[source]

Return a new module with some attributes replaced.

>>> net = pax.Linear(2, 2)
>>> net = net.replace(bias=jnp.zeros((2,)))
Return type

~T

replace_node(node, value)[source]

Replace a node of the pytree by a new value.

Example:

>>> mod = pax.Sequential(
...     pax.Linear(2,2),
...     jax.nn.relu
... )
>>> mod = mod.replace_node(mod[0].weight, jnp.zeros((2, 3)))
>>> print(mod[0].weight.shape)
(2, 3)
Return type

~T

summary(return_list=False)[source]

Summarize a module as a tree of its submodules.

Parameters

return_list (bool) – return a list of lines instead of a joined string.

>>> net = pax.Sequential(pax.Linear(2, 3), jax.nn.relu, pax.Linear(3, 4))
>>> print(net.summary())
Sequential
├── Linear(in_dim=2, out_dim=3, with_bias=True)
├── x => relu(x)
└── Linear(in_dim=3, out_dim=4, with_bias=True)
Return type

Union[str, List[str]]

apply(apply_fn)[source]

Apply a function to all submodules.

>>> def print_param_count(mod):
...     count = sum(jax.tree_leaves(jax.tree_map(jnp.size, mod)))
...     print(f"{count} {mod}")
...     return mod
...
>>> net = pax.Sequential(pax.Linear(1, 1), jax.nn.relu)
>>> net = net.apply(print_param_count)
2 Linear(in_dim=1, out_dim=1, with_bias=True)
0 Lambda(relu)
2 Sequential
Parameters
  • apply_fn – a function which inputs a module and outputs a transformed module.

  • check_treedef – check treedef before applying the function.

Return type

~T

__mod__(args)[source]

An alternative to pax.module_and_value.

>>> bn = pax.BatchNorm1D(3)
>>> x = jnp.ones((5, 8, 3))
>>> bn, y = bn % x
>>> bn
BatchNorm1D(num_channels=3, ...)
Return type

Tuple[~T, Any]

__or__(other)[source]

Merge two modules.

>>> a = pax.Linear(2, 2)
>>> b = pax.Linear(2, 2)
>>> c = a | b
>>> d = b | a
>>> c == d
False
Return type

~T

class pax.ParameterModule(name=None)[source]

A PAX module that registers attributes as parameters by default.

parameters()[source]

Return a new module with trainable weights only.

class pax.StateModule(name=None)[source]

A PAX module that registers attributes as states by default.

Purify functions and methods

pax.pure(func)[source]

Make a function pure by copying the inputs.

Any modification on the copy will not affect the original inputs.

Note: only functions that are wrapped by pax.pure are allowed to modify PAX’s Modules.

Example:

>>> f = pax.Linear(3,3)
>>> f.a_list = []
Traceback (most recent call last):
  ...
ValueError: Cannot modify a module in immutable mode.
Please do this computation inside a function decorated by `pax.pure`.
>>>
>>> @pax.pure
... def add_list(m):
...     m.a_list = []
...     return m
...
>>> f = add_list(f)
>>> print(f.a_list)
[]
Parameters

func (Callable) – A function.

Returns

A pure function.

pax.module_and_value(module_or_method)[source]

Return a pure function that executes a module’s method.

This pure function also returns the updated input module in the output.

Example:

>>> net = pax.Linear(1, 1)
>>> x = jnp.ones((32, 1))
>>> net, y = pax.module_and_value(net)(x)  # note: `net` is also returned.
Parameters

module_or_method (Callable[…, ~O]) – Either a PAX module or a method of a PAX module.

Return type

Callable[…, Tuple[~T, ~O]]

Returns

A pure function.

Random Number Generator

seed_rng_key(seed)

Set self._rng = random.Random(seed).

next_rng_key()

Return a random rng key.

seed_rng_key

pax.seed_rng_key(seed)[source]

Set self._rng = random.Random(seed).

Parameters

seed (int) – an integer seed.

Return type

None

next_rng_key

pax.next_rng_key()[source]

Return a random rng key. Renew the global random state.

Return type

Union[Any, ndarray]