Experimental¶
|
A context manager that allows a copy module to be mutable inside the context. |
|
Flatten PAX modules for better performance. |
|
A lazy module is a module that only creates submodules when needed. |
|
Build a graph module from a forward function. |
|
A default mixed precision policy. |
|
Update model, optimizer and loss scale. |
|
Save module weights to a dictionary. |
|
Load module weights from a dictionary. |
Mutable¶
Flattener¶
-
class
pax.experimental.Flattener(*args, **kwargs)[source]¶ Flatten PAX modules for better performance.
Example:
>>> net = pax.Linear(3, 3) >>> opt = opax.adam(1e-3)(net.parameters()) >>> flat_mods = pax.experimental.Flattener(model=net, optimizer=opt) >>> net, opt = flat_mods.model, flat_mods.optimizer >>> print(net.summary()) Linear(in_dim=3, out_dim=3, with_bias=True) >>> print(opt.summary()) chain.<locals>.Chain ├── scale_by_adam.<locals>.ScaleByAdam │ ├── Linear(in_dim=3, out_dim=3, with_bias=True) │ └── Linear(in_dim=3, out_dim=3, with_bias=True) └── scale.<locals>.Scale
Graph API¶
-
class
pax.experimental.graph.Node(parents, fx, value)[source]¶ A node is an object that stores:
parents nodes,
a PAX module (or a function),
and a value.
For example:
>>> x = pax.experimental.graph.Node((), lambda x: x, jnp.array(0)) >>> x.parents, x.fx, x.value ((), Lambda..., DeviceArray(0, dtype=int32, weak_type=True))
-
__rshift__(fn)[source]¶ Create a new node by applying fn to the node’s value.
Example:
>>> import jax, pax, jax.numpy as jnp >>> from functools import partial >>> x = pax.experimental.graph.Node((), lambda x: x, jnp.array(1.)) >>> y = x >> partial(jax.lax.add, 1.) >>> y.value DeviceArray(2., dtype=float32, weak_type=True)
-
__and__(other)[source]¶ Concatenate two nodes to create a tuple.
Example:
>>> x = pax.experimental.graph.InputNode(1) >>> y = pax.experimental.graph.InputNode(2) >>> z = x & y >>> z.value (1, 2)
-
__or__(other)[source]¶ Concatenate two nodes to create a multi args node.
Example:
>>> x = pax.experimental.graph.InputNode(1) >>> y = pax.experimental.graph.InputNode(2) >>> z = (x | y) >> jax.lax.add >>> z.value DeviceArray(3, dtype=int32, weak_type=True)
-
binary_ops(fn, other)[source]¶ Create a new node using a binary operator.
Example:
>>> x = pax.experimental.graph.InputNode(1) >>> y = pax.experimental.graph.InputNode(2) >>> z = x.binary_ops(jax.lax.sub, y) >>> z.value DeviceArray(-1, dtype=int32, weak_type=True)
-
property
shape¶ Return the shape of value.
Example:
>>> x = pax.experimental.graph.InputNode(jnp.empty((3, 4))) >>> x.shape (3, 4)
-
property
dtype¶ Return dtype of value.
Example:
>>> x = pax.experimental.graph.InputNode(jnp.empty((1,), dtype=jnp.int32)) >>> x.dtype dtype('int32')
-
class
pax.experimental.graph.InputNode(value, fx=<function InputNode.<lambda>>)[source]¶ An InputNode object represents an input argument of a GraphModule.
-
class
pax.experimental.graph.GraphModule(inputs, output, name=None)[source]¶ A module that uses a directed graph to represent its computation.
-
pax.experimental.graph.build_graph_module(func)[source]¶ Build a graph module from a forward function.
Example:
>>> def residual_forward(x): ... y = x >> pax.Linear(x.shape[-1], x.shape[-1]) ... y >>= jax.nn.relu ... z = (x | y) >> jax.lax.add ... return z ... >>> from pax.experimental.graph import build_graph_module >>> net = build_graph_module(residual_forward)(jnp.empty((3, 8))) >>> print(net.summary()) GraphModule ├── Linear(in_dim=8, out_dim=8, with_bias=True) ├── x => relu(x) ├── x => identity(x) └── x => add(x)
Lazy Module¶
-
class
pax.experimental.LazyModule(*, training=True, name=None)[source]¶ A lazy module is a module that only creates submodules when needed.
Example:
>>> from dataclasses import dataclass >>> @dataclass ... class MLP(pax.experimental.LazyModule): ... features: list ... ... def __call__(self, x): ... sizes = zip(self.features[:-1], self.features[1:]) ... for i, (in_dim, out_dim) in enumerate(sizes): ... fc = self.get_or_create(f"fc_{i}", lambda: pax.Linear(in_dim, out_dim)) ... x = jax.nn.relu(fc(x)) ... return x ... ... >>> mlp, _ = MLP([1, 2, 3]) % jnp.ones((1, 1)) >>> print(mlp.summary()) MLP(features=[1, 2, 3]) ├── Linear(in_dim=1, out_dim=2, with_bias=True) └── Linear(in_dim=2, out_dim=3, with_bias=True)
Mixed Precision¶
-
pax.experimental.default_mp_policy(module)[source]¶ A default mixed precision policy.
Linear layers are in half precision.
Normalization layers are in full precision.
Example:
>>> net = pax.Sequential(pax.Linear(3, 3), pax.BatchNorm1D(3)) >>> net = net.apply(pax.experimental.default_mp_policy) >>> print(net.summary()) Sequential ├── Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF) └── BatchNorm1D(num_channels=3, ..., mp_policy=FFF)
- Return type
~T
-
pax.experimental.apply_scaled_gradients(model, optimizer, loss_scale, grads)[source]¶ Update model, optimizer and loss scale.
Example:
>>> import jmp >>> from pax.experimental import apply_scaled_gradients >>> net = pax.Linear(2, 2) >>> opt = opax.adam(1e-4)(net.parameters()) >>> loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2**15)) >>> grads = net.parameters() >>> net, opt, loss_scale = apply_scaled_gradients(net, opt, loss_scale, grads) >>> print(loss_scale.loss_scale) 32770.0
Save and load weights¶
-
pax.experimental.save_weights_to_dict(module)[source]¶ Save module weights to a dictionary.
>>> net = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3)) >>> weights = pax.experimental.save_weights_to_dict(net) >>> weights {'modules': ({'weight': ..., 'bias': ...}, {}, {'weight':..., 'bias': ...})}
- Return type
Dict[str,Any]
-
pax.experimental.load_weights_from_dict(module, state_dict)[source]¶ Load module weights from a dictionary.
>>> a = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3)) >>> weights = pax.experimental.save_weights_to_dict(a) >>> b = pax.Sequential(pax.Linear(1, 2), jax.nn.relu, pax.Linear(2, 3)) >>> b = pax.experimental.load_weights_from_dict(b, weights) >>> assert a == b
- Return type
~T