Common Modules¶
|
A linear transformation is applied over the last dimension of the input. |
|
1D Convolution Module. |
|
2D Convolution Module. |
|
1D Convolution Transpose Module. |
|
2D Convolution Transpose Module. |
|
The 1D version of BatchNorm. |
|
The 2D version of BatchNorm. |
|
LayerNorm module. |
|
Group normalization module. |
|
Execute layers in order. |
|
Basic recurrent neural network. |
|
Long Short Term Memory (LSTM) RNN module. |
|
This class implements the “fully gated unit” GRU. |
|
Multi-headed attention mechanism. |
|
Identity function as a module. |
|
Average pool. |
|
Max pool. |
Linear¶
-
class
pax.Linear(in_dim, out_dim, with_bias=True, w_init=None, b_init=None, *, rng_key=None, name=None)[source]¶ A linear transformation is applied over the last dimension of the input.
-
__init__(in_dim, out_dim, with_bias=True, w_init=None, b_init=None, *, rng_key=None, name=None)[source]¶ - Parameters
in_dim (
int) – the number of input features.out_dim (
int) – the number of output features.with_bias (
bool) – whether to add a bias to the output (default: True).w_init – initializer function for the weight matrix.
b_init – initializer function for the bias.
rng_key (
Union[Any,ndarray,None]) – the key to generate initial parameters.name (
Optional[str]) – module name.
-
Dropout¶
-
class
pax.Dropout(dropout_rate, *, name=None)[source]¶ A Dropout Module.
Dropout module stores an internal state
rng_key. It refreshesrng_keywhenever a forward pass is executed.
Embed¶
-
class
pax.Embed(vocab_size, embed_dim, w_init=None, *, rng_key=None, name=None)[source]¶ Embed module maps integer values to real vectors. The embedded vectors are trainable.
-
__init__(vocab_size, embed_dim, w_init=None, *, rng_key=None, name=None)[source]¶ An embed module.
- Parameters
vocab_size (
int) – the number of embedded vectors.embed_dim (
int) – the size of embedded vectors.w_init (
Optional[Callable]) – weight initializer. Default: truncated_normal.name (
Optional[str]) – module name.
-
Convolution¶
Conv1D¶
-
class
pax.Conv1D(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', feature_group_count=1, *, name=None, rng_key=None)[source]¶ 1D Convolution Module.
-
__init__(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', feature_group_count=1, *, name=None, rng_key=None)[source]¶ Initializes the module.
(Haiku documentation)
- Parameters
in_features (
int) – Number of input channels.out_features (
int) – Number of output channels.kernel_shape (
Union[int,Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.stride (
Union[int,Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.rate (
Union[int,Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution,rate > 1corresponds to dilated convolution. Defaults to 1.padding (
Union[str,Sequence[Tuple[int,int]]]) – Optional padding. EitherVALIDorSAMEor sequence of Tuple[int, int] representing the padding before and after for each spatial dimension. Defaults toSAME.with_bias (
bool) – Whether to add a bias. By default, true.w_init (
Optional[Callable]) – Optional weight initialization. By default, truncated normal.b_init (
Optional[Callable]) – Optional bias initialization. By default, zeros.data_format (
str) – The data format of the input. EitherNWCorNCW. By default,NWC.feature_group_count (
int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together.name (
Optional[str]) – The name of the module.rng_key (
Union[Any,ndarray,None]) – The random key.
-
Conv2D¶
-
class
pax.Conv2D(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', feature_group_count=1, *, name=None, rng_key=None)[source]¶ 2D Convolution Module.
-
__init__(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', feature_group_count=1, *, name=None, rng_key=None)[source]¶ Initializes the module.
(Haiku documentation)
- Parameters
in_features (
int) – Number of output channels.out_features (
int) – Number of output channels.kernel_shape (
Union[int,Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.stride (
Union[int,Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.rate (
Union[int,Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution,rate > 1corresponds to dilated convolution. Defaults to 1.padding (
Union[str,Sequence[Tuple[int,int]]]) – Optional padding. EitherVALIDorSAMEor sequence of Tuple[int, int] representing the padding before and after for each spatial dimension. Defaults toSAME.with_bias (
bool) – Whether to add a bias. By default, true.w_init (
Optional[Callable]) – Optional weight initialization. By default, truncated normal.b_init (
Optional[Callable]) – Optional bias initialization. By default, zeros.data_format (
str) – The data format of the input. EitherNHWCorNCHW. By default,NHWC.feature_group_count (
int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together.name (
Optional[str]) – The name of the module.rng_key (
Union[Any,ndarray,None]) – The random key.
-
Conv1DTranspose¶
-
class
pax.Conv1DTranspose(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', *, name=None, rng_key=None)[source]¶ 1D Convolution Transpose Module.
-
__init__(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', *, name=None, rng_key=None)[source]¶ Initializes the module.
(Haiku documentation)
- Parameters
in_features (
int) – Number of input channels.out_features (
int) – Number of output channels.kernel_shape (
Union[int,Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.stride (
Union[int,Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.padding (
Union[str,Sequence[Tuple[int,int]]]) – Optional padding algorithm. EitherVALIDorSAME. Defaults toSAME.with_bias (
bool) – Whether to add a bias. By default, true.w_init (
Optional[Callable]) – Optional weight initialization. By default, truncated normal.b_init (
Optional[Callable]) – Optional bias initialization. By default, zeros.data_format (
str) – The data format of the input. EitherNWCorNCW. By default,NWC.name (
Optional[str]) – The name of the module.rng_key (
Union[Any,ndarray,None]) – The random key.
-
Conv2DTranspose¶
-
class
pax.Conv2DTranspose(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', *, name=None, rng_key=None)[source]¶ 2D Convolution Transpose Module.
-
__init__(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', *, name=None, rng_key=None)[source]¶ Initializes the module.
(Haiku documentation)
- Parameters
in_features (
int) – Number of input channels.out_features (
int) – Number of output channels.kernel_shape (
Union[int,Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.stride (
Union[int,Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.padding (
Union[str,Sequence[Tuple[int,int]]]) – Optional padding algorithm. EitherVALIDorSAME. Defaults toSAME.with_bias (
bool) – Whether to add a bias. By default, true.w_init (
Optional[Callable]) – Optional weight initialization. By default, truncated normal.b_init (
Optional[Callable]) – Optional bias initialization. By default, zeros.data_format (
str) – The data format of the input. EitherNHWCorNHCW. By default,NHWC.name (
Optional[str]) – The name of the module.rng_key (
Union[Any,ndarray,None]) – The random key.
-
Normalization¶
BatchNorm1D¶
-
class
pax.BatchNorm1D(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NWC', *, name=None)[source]¶ The 1D version of BatchNorm.
-
__init__(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NWC', *, name=None)[source]¶ Create a new BatchNorm module.
- Parameters
num_channels (
int) – the number of filters.create_scale (
bool) – create a trainable scale parameter.create_offset (
bool) – create a trainable offset parameter.decay_rate (
float) – the decay rate for tracking the averaged mean and the averaged variance.eps (
float) – a small positive number to avoid divided by zero.data_format (
str) – the data format [“NHWC”, NCHW”, “NWC”, “NCW”].reduced_axes – list of axes that will be reduced in the jnp.mean computation.
param_shape – the shape of parameters.
-
BatchNorm2D¶
-
class
pax.BatchNorm2D(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NHWC', *, name=None)[source]¶ The 2D version of BatchNorm.
-
__init__(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NHWC', *, name=None)[source]¶ Create a new BatchNorm module.
- Parameters
num_channels (
int) – the number of filters.create_scale (
bool) – create a trainable scale parameter.create_offset (
bool) – create a trainable offset parameter.decay_rate (
float) – the decay rate for tracking the averaged mean and the averaged variance.eps (
float) – a small positive number to avoid divided by zero.data_format (
str) – the data format [“NHWC”, NCHW”, “NWC”, “NCW”].reduced_axes – list of axes that will be reduced in the jnp.mean computation.
param_shape – the shape of parameters.
-
LayerNorm¶
-
class
pax.LayerNorm(num_channels, axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, *, rng_key=None, name=None)[source]¶ LayerNorm module. See: https://arxiv.org/abs/1607.06450.
-
__init__(num_channels, axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, *, rng_key=None, name=None)[source]¶ Initialize module.
-
__call__(inputs, scale=None, offset=None)[source]¶ Returns normalized inputs.
- Parameters
inputs (
ndarray) – An array, where the data format is[N, ..., C].scale (
Optional[ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape ofinputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_scale=True.offset (
Optional[ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape ofinputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed withcreate_offset=True.
- Return type
ndarray- Returns
The array, normalized.
-
GroupNorm¶
-
class
pax.GroupNorm(groups, num_channels, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', *, rng_key=None, name=None)[source]¶ Group normalization module.
This applies group normalization to the x. This involves splitting the channels into groups before calculating the mean and variance. The default behavior is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.
It transforms the input
xinto:\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of
x.There are many different variations for how users want to manage scale and offset if they require them at all. These are:
No
scale/offsetin which casecreate_*should be set toFalseandscale/offsetaren’t passed when the module is called.Trainable
scale/offsetin which case create_* should be set toTrueand againscale/offsetaren’t passed when the module is called. In this case this module creates and owns the scale/offset parameters.Externally generated
scale/offset, such as for conditional normalization, in which casecreate_*should be set toFalseand then the values fed in at call time.
-
__init__(groups, num_channels, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', *, rng_key=None, name=None)[source]¶ Constructs a
GroupNormmodule.- Parameters
groups (
int) – number of groups to divide the channels by. The number of channels must be divisible by this.num_channels (
int) – number of channels.axis (
Union[int,slice,Sequence[int]]) –int,sliceor sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use slice(2, None) to average over the none Batch and Time data.create_scale (
bool) – whether to create a trainable scale per channel applied after the normalization.create_offset (
bool) – whether to create a trainable offset per channel applied after normalization and scaling.eps (
float) – Small epsilon to add to the variance to avoid division by zero. Defaults to1e-5.scale_init (
Optional[Callable]) – Optional initializer for the scale parameter. Can only be set ifcreate_scale=True. By default scale is initialized to1.offset_init (
Optional[Callable]) – Optional initializer for the offset parameter. Can only be set ifcreate_offset=True. By default offset is initialized to0.data_format (
str) – The data format of the input. Can be eitherchannels_first,channels_last,N...CorNC.... By default it ischannels_last.name (
Optional[str]) – Name of the module.
-
__call__(x, scale=None, offset=None)[source]¶ Returns normalized inputs.
- Parameters
x (
ndarray) – An n-D tensor of thedata_formatspecified in the constructor on which the transformation is performed.scale (
Optional[ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape ofx. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed withcreate_scale=True.offset (
Optional[ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape ofx. This is the offset applied to the normalizedx. This cannot be passed in if the module was constructed withcreate_offset=True.
- Return type
ndarray- Returns
An n-d tensor of the same shape as x that has been normalized.
Recurrent¶
VanillaRNN¶
-
class
pax.VanillaRNN(input_dim, hidden_dim, *, rng_key=None, name=None)[source]¶ Basic recurrent neural network.
LSTM¶
-
class
pax.LSTM(input_dim, hidden_dim, w_init=None, forget_gate_bias=0.0, *, rng_key=None, name=None)[source]¶ Long Short Term Memory (LSTM) RNN module.
-
__init__(input_dim, hidden_dim, w_init=None, forget_gate_bias=0.0, *, rng_key=None, name=None)[source]¶ Create a LSTM module.
- Parameters
input_dim (
int) – The input dimension.hidden_dim (
int) – The number of LSTM cells.w_init (
Optional[Callable]) – weight initializer.forget_gate_bias (
float) – Prefer forget. Default 0.rng_key (
Union[Any,ndarray,None]) – random key.name (
Optional[str]) – module name.
-
GRU¶
-
class
pax.GRU(input_dim, hidden_dim, *, rng_key=None, name=None)[source]¶ This class implements the “fully gated unit” GRU.
Reference: https://en.wikipedia.org/wiki/Gated_recurrent_unit
Pool¶
avg_pool¶
-
pax.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]¶ Average pool.
- Parameters
value (
ndarray) – Value to pool.window_shape (
Union[int,Sequence[int]]) – Shape of the pooling window, an int or same rank as value.strides (
Union[int,Sequence[int]]) – Strides of the pooling window, an int or same rank as value.padding (
str) – Padding algorithm. EitherVALIDorSAME.channel_axis (
Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to inferwindow_shapeorstridesif they are an integer.
- Return type
ndarray- Returns
Pooled result. Same rank as value.
- Raises
ValueError – If the padding is not valid.
max_pool¶
-
pax.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]¶ Max pool.
- Parameters
value (
ndarray) – Value to pool.window_shape (
Union[int,Sequence[int]]) – Shape of the pooling window, an int or same rank as value.strides (
Union[int,Sequence[int]]) – Strides of the pooling window, an int or same rank as value.padding (
str) – Padding algorithm. EitherVALIDorSAME.channel_axis (
Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to inferwindow_shapeorstridesif they are an integer.
- Return type
ndarray- Returns
Pooled result. Same rank as value.
MultiHeadAttention¶
-
class
pax.MultiHeadAttention(num_heads, key_size, w_init_scale)[source]¶ Multi-headed attention mechanism. As described in the vanilla Transformer paper: “Attention is all you need” https://arxiv.org/abs/1706.03762
Utilities¶
Sequential¶
-
class
pax.Sequential(*layers, name=None)[source]¶ Execute layers in order.
Support pax.Module (callable pytree) and any jax functions.
For example:
>>> net = pax.Sequential( ... pax.Linear(2, 32), ... jax.nn.relu, ... pax.Linear(32, 3) ... ) >>> print(net.summary()) Sequential ├── Linear(in_dim=2, out_dim=32, with_bias=True) ├── x => relu(x) └── Linear(in_dim=32, out_dim=3, with_bias=True) >>> x = jnp.empty((3, 2)) >>> y = net(x) >>> y.shape (3, 3)
RngSeq¶
-
class
pax.RngSeq(seed=None, rng_key=None)[source]¶ A module which generates an infinite sequence of rng keys.
-
__init__(seed=None, rng_key=None)[source]¶ Initialize a random key sequence.
Note:
rng_keyhas a higher priority thanseed.- Parameters
seed (
Optional[int]) – an integer seed.rng_key (
Union[Any,ndarray,None]) – a jax random key.
-
next_rng_key(num_keys=1)[source]¶ Return the next random key of the sequence.
Note:
Return a key if
num_keysis1,Return a list of keys if
num_keysis greater than1.This is not a deterministic sequence if values of
num_keysare mixed randomly.
- Parameters
num_keys (
int) – return more than one key.- Return type
Union[Any,ndarray,Sequence[Union[Any,ndarray]]]
-
Lambda¶
Identity¶
EMA¶
-
class
pax.EMA(initial_value, decay_rate, debias=False, allow_int=False)[source]¶ Exponential Moving Average (EMA) Module
-
__init__(initial_value, decay_rate, debias=False, allow_int=False)[source]¶ Create a new EMA module.
If allow_int=True, integer leaves are updated to the newest values instead of averaging.
- Parameters
initial_value – the initial value.
decay_rate (
float) – the decay rate.debias (
bool) – ignore the initial value to avoid biased estimates.allow_int (
bool) – allow integer values.
-