Module Transformations¶
A module transformation is a pure function that inputs PAX’s modules and outputs PAX’s modules.
|
Return a module that uses trainable parameters in params. |
|
Return a module in training mode. |
|
Return a module in evaluation mode. |
|
Select PARAMETER leaves only. |
|
Return a copy module with all trainable parameters are converted to non-trainable states. |
|
Return a copy module with all trainable parameters are converted to non-trainable states. |
|
Create a mixed-precision module. |
|
Unwrap a mixed-precision module to recreate the original module. |
update_parameters¶
enable_train_mode¶
enable_eval_mode¶
freeze_parameters¶
unfreeze_parameters¶
apply_mp_policy¶
-
pax.apply_mp_policy(module, mp_policy)[source]¶ Create a mixed-precision module.
Create a subclass on the fly to enforce the mixed-precision policy.
>>> import jmp >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32") >>> net = pax.Linear(3, 3) >>> net = pax.apply_mp_policy(net, mp_policy) >>> print(net.summary()) Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF)
- Return type
~T
unwrap_mp_policy¶
-
pax.unwrap_mp_policy(module)[source]¶ Unwrap a mixed-precision module to recreate the original module.
>>> import jmp >>> mp_policy = jmp.get_policy("params=float32,compute=float16,output=float32") >>> net = pax.Linear(3, 3) >>> net = pax.apply_mp_policy(net, mp_policy) >>> print(net.summary()) Linear(in_dim=3, out_dim=3, with_bias=True, mp_policy=FHF) >>> net = pax.unwrap_mp_policy(net) >>> print(net.summary()) Linear(in_dim=3, out_dim=3, with_bias=True)
- Return type
~T