PAX Basics¶
|
The base class for all PAX modules. |
Mark an uninitialized or deleted pytree node. |
|
|
Make a function pure by copying the inputs. |
|
Call a module and return the updated module. |
|
Set |
Return a random rng key. |
PAX’s Module¶
-
class
pax.Module(*, training=True, name=None)[source]¶ The base class for all PAX modules.
Example:
>>> class Counter(pax.Module): ... def __init__(self): ... super().__init__() ... self.count = jnp.array(0) ... ... def step(self, x): ... self.count += 1
-
property
training¶ Return True if a module is in training mode.
>>> net = pax.Linear(1, 1) >>> net.training True >>> net = net.eval() >>> net.training False
- Return type
bool
-
replace(**kwargs)[source]¶ Return a new module with some attributes replaced.
>>> net = pax.Linear(2, 2) >>> net = net.replace(bias=jnp.zeros((2,)))
- Return type
~T
-
replace_node(node, value)[source]¶ Replace a node of the pytree by a new value.
Example:
>>> mod = pax.Sequential( ... pax.Linear(2,2), ... jax.nn.relu ... ) >>> mod = mod.replace_node(mod[0].weight, jnp.zeros((2, 3))) >>> print(mod[0].weight.shape) (2, 3)
- Return type
~T
-
summary(return_list=False)[source]¶ Summarize a module as a tree of its submodules.
- Parameters
return_list (
bool) – return a list of lines instead of a joined string.
>>> net = pax.Sequential(pax.Linear(2, 3), jax.nn.relu, pax.Linear(3, 4)) >>> print(net.summary()) Sequential ├── Linear(in_dim=2, out_dim=3, with_bias=True) ├── x => relu(x) └── Linear(in_dim=3, out_dim=4, with_bias=True)
- Return type
Union[str,List[str]]
-
apply(apply_fn)[source]¶ Apply a function to all submodules.
>>> def print_param_count(mod): ... count = sum(jax.tree_util.tree_leaves(jax.tree_util.tree_map(jnp.size, mod))) ... print(f"{count} {mod}") ... return mod ... >>> net = pax.Sequential(pax.Linear(1, 1), jax.nn.relu) >>> net = net.apply(print_param_count) 2 Linear(in_dim=1, out_dim=1, with_bias=True) 0 Lambda(relu) 2 Sequential
- Parameters
apply_fn – a function which inputs a module and outputs a transformed module.
- Return type
~T
-
property
-
class
pax.ParameterModule(*, training=True, name=None)[source]¶ A PAX module that registers attributes as parameters by default.
Purify functions and methods¶
-
pax.pure(func)[source]¶ Make a function pure by copying the inputs.
Any modification on the copy will not affect the original inputs.
Note: only functions that are wrapped by pax.pure are allowed to modify PAX’s Modules.
Example:
>>> f = pax.Linear(3,3) >>> f.a_list = [] Traceback (most recent call last): ... ValueError: Cannot modify a module in immutable mode. Please do this computation inside a function decorated by `pax.pure`. >>> >>> @pax.pure ... def add_list(m): ... m.a_list = [] ... return m ... >>> f = add_list(f) >>> print(f.a_list) []
- Parameters
func (
Callable) – A function.- Returns
A pure function.
Random Number Generator¶
|
Set |
Return a random rng key. |