Utilities¶
|
Return a parameters method. |
|
Compute gradient with respect to trainable parameters of the first argument. |
|
A PAX-compatible version of jax.value_and_grad. |
|
|
|
Build a simple update function. |
parameters_method¶
-
pax.parameters_method(*trainable_attributes, submodules=True)[source]¶ Return a parameters method.
>>> class Linear(pax.Module): ... parameters = pax.parameters_method("weight") ... def __init__(self): ... self.weight = jnp.array(1.0) >>> fc = Linear() >>> fc == fc.parameters() True
- Parameters
trainable_atributes – a list of trainable attribute names.
submodules – include submodules if true.
- Returns
A method that returns a module with only trainable weights.
grad¶
-
pax.grad(fun, *args, **kwargs)[source]¶ Compute gradient with respect to trainable parameters of the first argument.
Example:
>>> @pax.pure ... def loss_fn(model: pax.Linear, x, y): ... y_hat = model(x) ... loss = jnp.mean(jnp.square(y - y_hat)) ... return loss, (loss, model) ... >>> grad_fn = pax.grad(loss_fn, has_aux=True) >>> net = pax.Linear(1, 1) >>> x = jnp.ones((3, 1)) >>> grads, (loss, net) = grad_fn(net, x, x)
- Return type
Callable[…,Tuple[~T, ~C]]
value_and_grad¶
scan¶
-
pax.scan(func, init, xs, length=None, unroll=1, time_major=True)[source]¶ jax.lax.scanwith an additionaltime_major=Falsemode.The semantics of
scanare given roughly by this Python implementation:>>> def scan(f, init, xs, length=None): ... if xs is None: ... xs = [None] * length ... carry = init ... ys = [] ... for x in xs: ... carry, y = f(carry, x) ... ys.append(y) ... return carry, np.stack(ys)
build_update_fn¶
-
pax.build_update_fn(loss_fn, *, scan_mode=False)[source]¶ Build a simple update function.
Note: The output of
loss_fnmust be(loss, (aux, model)).- Parameters
loss_fn – The loss function.
scan_mode (
bool) – If true, use (model, optimizer) as a single argument.
Example:
>>> def mse_loss(model, x, y): ... y_hat = model(x) ... loss = jnp.mean(jnp.square(y - y_hat)) ... return loss, (loss, model) ... >>> update_fn = pax.utils.build_update_fn(mse_loss) >>> net = pax.Linear(2, 2) >>> optimizer = opax.adam(1e-4)(net.parameters()) >>> x = jnp.ones((32, 2)) >>> y = jnp.zeros((32, 2)) >>> net, optimizer, loss = update_fn(net, optimizer, x, y)