Experimental

mutable(module)

A context manager that allows a copy module to be mutable inside the context.

Flattener(*args, **kwargs)

Flatten PAX modules for better performance.

LazyModule(*[, training, name])

A lazy module is a module that only creates submodules when needed.

graph.build_graph_module(func)

Build a graph module from a forward function.

default_mp_policy(module)

A default mixed precision policy.

apply_scaled_gradients(model, optimizer, …)

Update model, optimizer and loss scale.

save_weights_to_dict(module)

Save module weights to a dictionary.

load_weights_from_dict(module, state_dict)

Load module weights from a dictionary.

Mutable

pax.experimental.mutable(module)[source]

A context manager that allows a copy module to be mutable inside the context.

>>> net = pax.Linear(1, 2)
>>> with pax.experimental.mutable(net) as net:
...     net.bias = jnp.array(0.)
>>> assert net.bias.item() == 0.

Flattener

class pax.experimental.Flattener(*args, **kwargs)[source]

Flatten PAX modules for better performance.

Example:

>>> net = pax.Linear(3, 3)
>>> opt = opax.adam(1e-3)(net.parameters())
>>> flat_mods = pax.experimental.Flattener(model=net, optimizer=opt)
>>> net, opt = flat_mods.model, flat_mods.optimizer
>>> print(net.summary())
Linear(in_dim=3, out_dim=3, with_bias=True)
>>> print(opt.summary())
chain.<locals>.Chain
├── scale_by_adam.<locals>.ScaleByAdam
│   ├── Linear(in_dim=3, out_dim=3, with_bias=True)
│   └── Linear(in_dim=3, out_dim=3, with_bias=True)
└── scale.<locals>.Scale
__init__(**kwargs)[source]

Create a new flattener.

update(**kwargs)[source]

Update the flattener.

Example:

>>> net = pax.Linear(3, 3)
>>> flats = pax.experimental.Flattener(net=net)
>>> flats = flats.update(net=pax.Linear(4, 4))
>>> print(flats.net.summary())
Linear(in_dim=4, out_dim=4, with_bias=True)
Return type

~T

parameters()[source]

Raise an error.

Need to reconstruct the original module before getting parameters.

Return type

~T

Graph API

class pax.experimental.graph.Node(parents, fx, value)[source]

A node is an object that stores:

  • parents nodes,

  • a PAX module (or a function),

  • and a value.

For example:

>>> x = pax.experimental.graph.Node((), lambda x: x, jnp.array(0))
>>> x.parents, x.fx, x.value
((), Lambda..., DeviceArray(0, dtype=int32, weak_type=True))
__rshift__(fn)[source]

Create a new node by applying fn to the node’s value.

Example:

>>> import jax, pax, jax.numpy as jnp
>>> from functools import partial
>>> x = pax.experimental.graph.Node((), lambda x: x, jnp.array(1.))
>>> y = x >> partial(jax.lax.add, 1.)
>>> y.value
DeviceArray(2., dtype=float32, weak_type=True)
__and__(other)[source]

Concatenate two nodes to create a tuple.

Example:

>>> x = pax.experimental.graph.InputNode(1)
>>> y = pax.experimental.graph.InputNode(2)
>>> z = x & y
>>> z.value
(1, 2)
__or__(other)[source]

Concatenate two nodes to create a multi args node.

Example:

>>> x = pax.experimental.graph.InputNode(1)
>>> y = pax.experimental.graph.InputNode(2)
>>> z = (x | y) >> jax.lax.add
>>> z.value
DeviceArray(3, dtype=int32, weak_type=True)
binary_ops(fn, other)[source]

Create a new node using a binary operator.

Example:

>>> x = pax.experimental.graph.InputNode(1)
>>> y = pax.experimental.graph.InputNode(2)
>>> z = x.binary_ops(jax.lax.sub, y)
>>> z.value
DeviceArray(-1, dtype=int32, weak_type=True)
property shape

Return the shape of value.

Example:

>>> x = pax.experimental.graph.InputNode(jnp.empty((3, 4)))
>>> x.shape
(3, 4)
property dtype

Return dtype of value.

Example:

>>> x = pax.experimental.graph.InputNode(jnp.empty((1,), dtype=jnp.int32))
>>> x.dtype
dtype('int32')
__eq__(o)[source]

Return self==value.

Return type

bool

__hash__()[source]

Return hash(self).

Return type

int

class pax.experimental.graph.InputNode(value, fx=<function InputNode.<lambda>>)[source]

An InputNode object represents an input argument of a GraphModule.

__init__(value, fx=<function InputNode.<lambda>>)[source]

Creata an InputNode object from a value.

class pax.experimental.graph.GraphModule(inputs, output, name=None)[source]

A module that uses a directed graph to represent its computation.

__init__(inputs, output, name=None)[source]

Initialize module.

__call__(*xs)[source]

Call self as a function.

pax.experimental.graph.build_graph_module(func)[source]

Build a graph module from a forward function.

Example:

>>> def residual_forward(x):
...     y = x >> pax.Linear(x.shape[-1], x.shape[-1])
...     y >>= jax.nn.relu
...     z = (x | y) >> jax.lax.add
...     return z
...
>>> from pax.experimental.graph import build_graph_module
>>> net = build_graph_module(residual_forward)(jnp.empty((3, 8)))
>>> print(net.summary())
GraphModule
├── Linear(in_dim=8, out_dim=8, with_bias=True)
├── x => relu(x)
├── x => identity(x)
└── x => add(x)

Lazy Module

class pax.experimental.LazyModule(*, training=True, name=None)[source]

A lazy module is a module that only creates submodules when needed.

Example:

>>> from dataclasses import dataclass
>>> @dataclass
... class MLP(pax.experimental.LazyModule):
...     features: list
...
...     def __call__(self, x):
...         sizes = zip(self.features[:-1], self.features[1:])
...         for i, (in_dim, out_dim) in enumerate(sizes):
...             fc = self.get_or_create(f"fc_{i}", lambda: pax.Linear(in_dim, out_dim))
...             x = jax.nn.relu(fc(x))
...         return x
...
...
>>> mlp, _ = MLP([1, 2, 3]) % jnp.ones((1, 1))
>>> print(mlp.summary())
MLP(features=[1, 2, 3])
├── Linear(in_dim=1, out_dim=2, with_bias=True)
└── Linear(in_dim=2, out_dim=3, with_bias=True)
get_or_create(name, create_fn)[source]

Create and register a new attribute when it is not exist.

Return the attribute.

Return type

~T

Mixed Precision

pax.experimental.default_mp_policy(module)[source]

A default mixed precision policy.

  • Linear layers are in half precision.

  • Normalization layers are in full precision.

Example:

>>> net = pax.Sequential(pax.Linear(3, 3), pax.BatchNorm1D(3))
>>> net = net.apply(pax.experimental.default_mp_policy)
>>> print(net.summary())
Sequential
├── Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
└── BatchNorm1D(num_channels=3, ..., mp_policy=FFF)
Return type

~T

pax.experimental.apply_scaled_gradients(model, optimizer, loss_scale, grads)[source]

Update model, optimizer and loss scale.

Example:

>>> import jmp
>>> from pax.experimental import apply_scaled_gradients
>>> net = pax.Linear(2, 2)
>>> opt = opax.adam(1e-4)(net.parameters())
>>> loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2**15))
>>> grads = net.parameters()
>>> net, opt, loss_scale = apply_scaled_gradients(net, opt, loss_scale, grads)
>>> print(loss_scale.loss_scale)
32770.0

Save and load weights

pax.experimental.save_weights_to_dict(module)[source]

Save module weights to a dictionary.

>>> net = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3))
>>> weights = pax.experimental.save_weights_to_dict(net)
>>> weights
{'modules': ({'weight': ..., 'bias': ...}, {}, {'weight':..., 'bias': ...})}
Return type

Dict[str, Any]

pax.experimental.load_weights_from_dict(module, state_dict)[source]

Load module weights from a dictionary.

>>> a = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3))
>>> weights = pax.experimental.save_weights_to_dict(a)
>>> b = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3))
>>> b = pax.experimental.load_weights_from_dict(b, weights)
>>> assert a == b
Return type

~T