Module Transformations

A module transformation is a pure function that inputs PAX’s modules and outputs PAX’s modules.

update_parameters(mod, *, params)

Return a module that uses trainable parameters in params.

enable_train_mode(mod)

Return a module in training mode.

enable_eval_mode(mod)

Return a module in evaluation mode.

select_parameters(mod)

Select PARAMETER leaves only.

freeze_parameters(mod)

Return a copy module with all trainable parameters are converted to non-trainable states.

unfreeze_parameters(mod, *, origin)

Return a copy module with all trainable parameters are converted to non-trainable states.

apply_mp_policy(module, mp_policy)

Create a mixed-precision module.

unwrap_mp_policy(module)

Unwrap a mixed-precision module to recreate the original module.

update_parameters

pax.update_parameters(mod, *, params)[source]

Return a module that uses trainable parameters in params.

Return type

~T

enable_train_mode

pax.enable_train_mode(mod)[source]

Return a module in training mode.

Return type

~T

enable_eval_mode

pax.enable_eval_mode(mod)[source]

Return a module in evaluation mode.

Return type

~T

select_parameters

pax.select_parameters(mod)[source]

Select PARAMETER leaves only.

Return type

~T

freeze_parameters

pax.freeze_parameters(mod)[source]

Return a copy module with all trainable parameters are converted to non-trainable states.

Return type

~T

unfreeze_parameters

pax.unfreeze_parameters(mod, *, origin)[source]

Return a copy module with all trainable parameters are converted to non-trainable states.

Return type

~T

apply_mp_policy

pax.apply_mp_policy(module, mp_policy)[source]

Create a mixed-precision module.

Create a subclass on the fly to enforce the mixed-precision policy.

>>> import jmp
>>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
>>> net = pax.Linear(3, 3)
>>> net = pax.apply_mp_policy(net, mp_policy)
>>> print(net.summary())
Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
Return type

~T

unwrap_mp_policy

pax.unwrap_mp_policy(module)[source]

Unwrap a mixed-precision module to recreate the original module.

>>> import jmp
>>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32")
>>> net = pax.Linear(3, 3)
>>> net = pax.apply_mp_policy(net, mp_policy)
>>> print(net.summary())
Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
>>> net = pax.unwrap_mp_policy(net)
>>> print(net.summary())
Linear(in_dim=3, out_dim=3, with_bias=True)
Return type

~T