Utilities

parameters_method(*trainable_attributes[, …])

Return a parameters method.

grad(fun, *args, **kwargs)

Compute gradient with respect to trainable parameters of the first argument.

value_and_grad(func[, has_aux])

A PAX-compatible version of jax.value_and_grad.

scan(func, init, xs[, length, unroll, …])

jax.lax.scan with an additional time_major=False mode.

build_update_fn(loss_fn, *[, scan_mode])

Build a simple update function.

parameters_method

pax.parameters_method(*trainable_attributes, submodules=True)[source]

Return a parameters method.

>>> class Linear(pax.Module):
...     parameters = pax.parameters_method("weight")
...     def __init__(self):
...         self.weight = jnp.array(1.0)
>>> fc = Linear()
>>> fc == fc.parameters()
True
Parameters
  • trainable_atributes – a list of trainable attribute names.

  • submodules – include submodules if true.

Returns

A method that returns a module with only trainable weights.

grad

pax.grad(fun, *args, **kwargs)[source]

Compute gradient with respect to trainable parameters of the first argument.

Example:

>>> @pax.pure
... def loss_fn(model: pax.Linear, x, y):
...     y_hat = model(x)
...     loss = jnp.mean(jnp.square(y - y_hat))
...     return loss, (loss, model)
...
>>> grad_fn = pax.grad(loss_fn, has_aux=True)
>>> net = pax.Linear(1, 1)
>>> x = jnp.ones((3, 1))
>>> grads, (loss, net) = grad_fn(net, x, x)
Return type

Callable[…, Tuple[~T, ~C]]

value_and_grad

pax.value_and_grad(func, has_aux=False)[source]

A PAX-compatible version of jax.value_and_grad.

This version computes gradients w.r.t. trainable parameters of a PAX module.

scan

pax.scan(func, init, xs, length=None, unroll=1, time_major=True)[source]

jax.lax.scan with an additional time_major=False mode.

The semantics of scan are given roughly by this Python implementation:

>>> def scan(f, init, xs, length=None):
...     if xs is None:
...         xs = [None] * length
...     carry = init
...     ys = []
...     for x in xs:
...         carry, y = f(carry, x)
...         ys.append(y)
...     return carry, np.stack(ys)

build_update_fn

pax.build_update_fn(loss_fn, *, scan_mode=False)[source]

Build a simple update function.

Note: The output of loss_fn must be (loss, (aux, model)).

Parameters
  • loss_fn – The loss function.

  • scan_mode (bool) – If true, use (model, optimizer) as a single argument.

Example:

>>> def mse_loss(model, x, y):
...     y_hat = model(x)
...     loss = jnp.mean(jnp.square(y - y_hat))
...     return loss, (loss, model)
...
>>> update_fn = pax.utils.build_update_fn(mse_loss)
>>> net = pax.Linear(2, 2)
>>> optimizer = opax.adam(1e-4)(net.parameters())
>>> x = jnp.ones((32, 2))
>>> y = jnp.zeros((32, 2))
>>> net, optimizer, loss = update_fn(net, optimizer, x, y)