Common Modules

Linear(in_dim, out_dim[, with_bias, w_init, …])

A linear transformation is applied over the last dimension of the input.

Conv1D(in_features, out_features, kernel_shape)

1D Convolution Module.

Conv2D(in_features, out_features, kernel_shape)

2D Convolution Module.

Conv1DTranspose(in_features, out_features, …)

1D Convolution Transpose Module.

Conv2DTranspose(in_features, out_features, …)

2D Convolution Transpose Module.

BatchNorm1D(num_channels[, create_scale, …])

The 1D version of BatchNorm.

BatchNorm2D(num_channels[, create_scale, …])

The 2D version of BatchNorm.

LayerNorm(num_channels, axis, create_scale, …)

LayerNorm module.

GroupNorm(groups, num_channels[, axis, …])

Group normalization module.

Sequential(*layers[, name])

Execute layers in order.

VanillaRNN(input_dim, hidden_dim, *[, …])

Basic recurrent neural network.

LSTM(input_dim, hidden_dim[, w_init, …])

Long Short Term Memory (LSTM) RNN module.

GRU(input_dim, hidden_dim, *[, rng_key, name])

This class implements the “fully gated unit” GRU.

MultiHeadAttention(num_heads, key_size, …)

Multi-headed attention mechanism.

Identity(*[, training, name])

Identity function as a module.

avg_pool(value, window_shape, strides, padding)

Average pool.

max_pool(value, window_shape, strides, padding)

Max pool.

Linear

class pax.Linear(in_dim, out_dim, with_bias=True, w_init=None, b_init=None, *, rng_key=None, name=None)[source]

A linear transformation is applied over the last dimension of the input.

__init__(in_dim, out_dim, with_bias=True, w_init=None, b_init=None, *, rng_key=None, name=None)[source]
Parameters
  • in_dim (int) – the number of input features.

  • out_dim (int) – the number of output features.

  • with_bias (bool) – whether to add a bias to the output (default: True).

  • w_init – initializer function for the weight matrix.

  • b_init – initializer function for the bias.

  • rng_key (Union[Any, ndarray, None]) – the key to generate initial parameters.

  • name (Optional[str]) – module name.

__call__(x)[source]

Applies a linear transformation to the inputs along the last dimension.

Parameters

x (ndarray) – The nd-array to be transformed.

Return type

ndarray

Dropout

class pax.Dropout(dropout_rate, *, name=None)[source]

A Dropout Module.

Dropout module stores an internal state rng_key. It refreshes rng_key whenever a forward pass is executed.

__init__(dropout_rate, *, name=None)[source]

Create a dropout module.

Parameters
  • dropout_rate (float) – the probability of dropping an element.

  • name (Optional[str]) – the module name.

__call__(x)[source]

Dropout x randomly.

Return the input x if in eval mode or dropout_rate=0.

Embed

class pax.Embed(vocab_size, embed_dim, w_init=None, *, rng_key=None, name=None)[source]

Embed module maps integer values to real vectors. The embedded vectors are trainable.

__init__(vocab_size, embed_dim, w_init=None, *, rng_key=None, name=None)[source]

An embed module.

Parameters
  • vocab_size (int) – the number of embedded vectors.

  • embed_dim (int) – the size of embedded vectors.

  • w_init (Optional[Callable]) – weight initializer. Default: truncated_normal.

  • name (Optional[str]) – module name.

__call__(x)[source]

Return embedded vectors indexed by x.

Convolution

Conv1D

class pax.Conv1D(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', feature_group_count=1, *, name=None, rng_key=None)[source]

1D Convolution Module.

__init__(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', feature_group_count=1, *, name=None, rng_key=None)[source]

Initializes the module.

(Haiku documentation)

Parameters
  • in_features (int) – Number of input channels.

  • out_features (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 1. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding. Either VALID or SAME or sequence of Tuple[int, int] representing the padding before and after for each spatial dimension. Defaults to SAME.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[Callable]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[Callable]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together.

  • name (Optional[str]) – The name of the module.

  • rng_key (Union[Any, ndarray, None]) – The random key.

Conv2D

class pax.Conv2D(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', feature_group_count=1, *, name=None, rng_key=None)[source]

2D Convolution Module.

__init__(in_features, out_features, kernel_shape, stride=1, rate=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', feature_group_count=1, *, name=None, rng_key=None)[source]

Initializes the module.

(Haiku documentation)

Parameters
  • in_features (int) – Number of output channels.

  • out_features (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 2.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 2. Defaults to 1.

  • rate (Union[int, Sequence[int]]) – Optional kernel dilation rate. Either an integer or a sequence of length 2. 1 corresponds to standard ND convolution, rate > 1 corresponds to dilated convolution. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding. Either VALID or SAME or sequence of Tuple[int, int] representing the padding before and after for each spatial dimension. Defaults to SAME.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[Callable]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[Callable]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NCHW. By default, NHWC.

  • feature_group_count (int) – Optional number of groups in group convolution. Default value of 1 corresponds to normal dense convolution. If a higher value is used, convolutions are applied separately to that many groups, then stacked together.

  • name (Optional[str]) – The name of the module.

  • rng_key (Union[Any, ndarray, None]) – The random key.

Conv1DTranspose

class pax.Conv1DTranspose(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', *, name=None, rng_key=None)[source]

1D Convolution Transpose Module.

__init__(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NWC', *, name=None, rng_key=None)[source]

Initializes the module.

(Haiku documentation)

Parameters
  • in_features (int) – Number of input channels.

  • out_features (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[Callable]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[Callable]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NWC or NCW. By default, NWC.

  • name (Optional[str]) – The name of the module.

  • rng_key (Union[Any, ndarray, None]) – The random key.

Conv2DTranspose

class pax.Conv2DTranspose(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', *, name=None, rng_key=None)[source]

2D Convolution Transpose Module.

__init__(in_features, out_features, kernel_shape, stride=1, padding='SAME', with_bias=True, w_init=None, b_init=None, data_format='NHWC', *, name=None, rng_key=None)[source]

Initializes the module.

(Haiku documentation)

Parameters
  • in_features (int) – Number of input channels.

  • out_features (int) – Number of output channels.

  • kernel_shape (Union[int, Sequence[int]]) – The shape of the kernel. Either an integer or a sequence of length 1.

  • stride (Union[int, Sequence[int]]) – Optional stride for the kernel. Either an integer or a sequence of length 1. Defaults to 1.

  • padding (Union[str, Sequence[Tuple[int, int]]]) – Optional padding algorithm. Either VALID or SAME. Defaults to SAME.

  • with_bias (bool) – Whether to add a bias. By default, true.

  • w_init (Optional[Callable]) – Optional weight initialization. By default, truncated normal.

  • b_init (Optional[Callable]) – Optional bias initialization. By default, zeros.

  • data_format (str) – The data format of the input. Either NHWC or NHCW. By default, NHWC.

  • name (Optional[str]) – The name of the module.

  • rng_key (Union[Any, ndarray, None]) – The random key.

Normalization

BatchNorm1D

class pax.BatchNorm1D(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NWC', *, name=None)[source]

The 1D version of BatchNorm.

__init__(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NWC', *, name=None)[source]

Create a new BatchNorm module.

Parameters
  • num_channels (int) – the number of filters.

  • create_scale (bool) – create a trainable scale parameter.

  • create_offset (bool) – create a trainable offset parameter.

  • decay_rate (float) – the decay rate for tracking the averaged mean and the averaged variance.

  • eps (float) – a small positive number to avoid divided by zero.

  • data_format (str) – the data format [“NHWC”, NCHW”, “NWC”, “NCW”].

  • reduced_axes – list of axes that will be reduced in the jnp.mean computation.

  • param_shape – the shape of parameters.

BatchNorm2D

class pax.BatchNorm2D(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NHWC', *, name=None)[source]

The 2D version of BatchNorm.

__init__(num_channels, create_scale=True, create_offset=True, decay_rate=0.9, eps=1e-05, data_format='NHWC', *, name=None)[source]

Create a new BatchNorm module.

Parameters
  • num_channels (int) – the number of filters.

  • create_scale (bool) – create a trainable scale parameter.

  • create_offset (bool) – create a trainable offset parameter.

  • decay_rate (float) – the decay rate for tracking the averaged mean and the averaged variance.

  • eps (float) – a small positive number to avoid divided by zero.

  • data_format (str) – the data format [“NHWC”, NCHW”, “NWC”, “NCW”].

  • reduced_axes – list of axes that will be reduced in the jnp.mean computation.

  • param_shape – the shape of parameters.

LayerNorm

class pax.LayerNorm(num_channels, axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, *, rng_key=None, name=None)[source]

LayerNorm module. See: https://arxiv.org/abs/1607.06450.

__init__(num_channels, axis, create_scale, create_offset, eps=1e-05, scale_init=None, offset_init=None, *, rng_key=None, name=None)[source]

Initialize module.

__call__(inputs, scale=None, offset=None)[source]

Returns normalized inputs.

Parameters
  • inputs (ndarray) – An array, where the data format is [N, ..., C].

  • scale (Optional[ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the scale applied to the normalized inputs. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[ndarray]) – An array up to n-D. The shape of this tensor must be broadcastable to the shape of inputs. This is the offset applied to the normalized inputs. This cannot be passed in if the module was constructed with create_offset=True.

Return type

ndarray

Returns

The array, normalized.

GroupNorm

class pax.GroupNorm(groups, num_channels, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', *, rng_key=None, name=None)[source]

Group normalization module.

This applies group normalization to the x. This involves splitting the channels into groups before calculating the mean and variance. The default behavior is to compute the mean and variance over the spatial dimensions and the grouped channels. The mean and variance will never be computed over the created groups axis.

It transforms the input x into:

\[\d{outputs} = \d{scale} \dfrac{x - \mu}{\sigma + \epsilon} + \d{offset}\]

Where \(\mu\) and \(\sigma\) are respectively the mean and standard deviation of x.

There are many different variations for how users want to manage scale and offset if they require them at all. These are:

  • No scale/offset in which case create_* should be set to False and scale/offset aren’t passed when the module is called.

  • Trainable scale/offset in which case create_* should be set to True and again scale/offset aren’t passed when the module is called. In this case this module creates and owns the scale/offset parameters.

  • Externally generated scale/offset, such as for conditional normalization, in which case create_* should be set to False and then the values fed in at call time.

__init__(groups, num_channels, axis=slice(1, None, None), create_scale=True, create_offset=True, eps=1e-05, scale_init=None, offset_init=None, data_format='channels_last', *, rng_key=None, name=None)[source]

Constructs a GroupNorm module.

Parameters
  • groups (int) – number of groups to divide the channels by. The number of channels must be divisible by this.

  • num_channels (int) – number of channels.

  • axis (Union[int, slice, Sequence[int]]) – int, slice or sequence of ints representing the axes which should be normalized across. By default this is all but the first dimension. For time series data use slice(2, None) to average over the none Batch and Time data.

  • create_scale (bool) – whether to create a trainable scale per channel applied after the normalization.

  • create_offset (bool) – whether to create a trainable offset per channel applied after normalization and scaling.

  • eps (float) – Small epsilon to add to the variance to avoid division by zero. Defaults to 1e-5.

  • scale_init (Optional[Callable]) – Optional initializer for the scale parameter. Can only be set if create_scale=True. By default scale is initialized to 1.

  • offset_init (Optional[Callable]) – Optional initializer for the offset parameter. Can only be set if create_offset=True. By default offset is initialized to 0.

  • data_format (str) – The data format of the input. Can be either channels_first, channels_last, N...C or NC.... By default it is channels_last.

  • name (Optional[str]) – Name of the module.

__call__(x, scale=None, offset=None)[source]

Returns normalized inputs.

Parameters
  • x (ndarray) – An n-D tensor of the data_format specified in the constructor on which the transformation is performed.

  • scale (Optional[ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of x. This is the scale applied to the normalized x. This cannot be passed in if the module was constructed with create_scale=True.

  • offset (Optional[ndarray]) – A tensor up to n-D. The shape of this tensor must be broadcastable to the shape of x. This is the offset applied to the normalized x. This cannot be passed in if the module was constructed with create_offset=True.

Return type

ndarray

Returns

An n-d tensor of the same shape as x that has been normalized.

Recurrent

VanillaRNN

class pax.VanillaRNN(input_dim, hidden_dim, *, rng_key=None, name=None)[source]

Basic recurrent neural network.

__init__(input_dim, hidden_dim, *, rng_key=None, name=None)[source]

Create a vanilla RNN module.

Parameters
  • input_dim (int) – input dimension.

  • hidden_dim (int) – hidden dimension.

  • rng_key (Union[Any, ndarray, None]) – random key.

  • name (Optional[str]) – module name.

__call__(state, x)[source]

A single rnn step.

Return type

Tuple[VanillaRNNState, ndarray]

LSTM

class pax.LSTM(input_dim, hidden_dim, w_init=None, forget_gate_bias=0.0, *, rng_key=None, name=None)[source]

Long Short Term Memory (LSTM) RNN module.

__init__(input_dim, hidden_dim, w_init=None, forget_gate_bias=0.0, *, rng_key=None, name=None)[source]

Create a LSTM module.

Parameters
  • input_dim (int) – The input dimension.

  • hidden_dim (int) – The number of LSTM cells.

  • w_init (Optional[Callable]) – weight initializer.

  • forget_gate_bias (float) – Prefer forget. Default 0.

  • rng_key (Union[Any, ndarray, None]) – random key.

  • name (Optional[str]) – module name.

__call__(state, x)[source]

Do a single lstm step.

Parameters
  • state (LSTMState) – The current LSTM state.

  • x (ndarray) – The input.

Return type

Tuple[LSTMState, ndarray]

GRU

class pax.GRU(input_dim, hidden_dim, *, rng_key=None, name=None)[source]

This class implements the “fully gated unit” GRU.

Reference: https://en.wikipedia.org/wiki/Gated_recurrent_unit

__init__(input_dim, hidden_dim, *, rng_key=None, name=None)[source]

Create a GRU module.

Parameters
  • input_dim (int) – the input size.

  • hidden_dim (int) – the number of GRU cells.

initial_state(batch_size)[source]

Create an all zeros initial state.

Return type

GRUState

__call__(state, x)[source]

Do a single gru step.

Parameters
  • state (GRUState) – The current GRU state.

  • x – The input.

Return type

Tuple[GRUState, ndarray]

Pool

avg_pool

pax.avg_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Average pool.

Parameters
  • value (ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

ndarray

Returns

Pooled result. Same rank as value.

Raises

ValueError – If the padding is not valid.

max_pool

pax.max_pool(value, window_shape, strides, padding, channel_axis=- 1)[source]

Max pool.

Parameters
  • value (ndarray) – Value to pool.

  • window_shape (Union[int, Sequence[int]]) – Shape of the pooling window, an int or same rank as value.

  • strides (Union[int, Sequence[int]]) – Strides of the pooling window, an int or same rank as value.

  • padding (str) – Padding algorithm. Either VALID or SAME.

  • channel_axis (Optional[int]) – Axis of the spatial channels for which pooling is skipped, used to infer window_shape or strides if they are an integer.

Return type

ndarray

Returns

Pooled result. Same rank as value.

MultiHeadAttention

class pax.MultiHeadAttention(num_heads, key_size, w_init_scale)[source]

Multi-headed attention mechanism. As described in the vanilla Transformer paper: “Attention is all you need” https://arxiv.org/abs/1706.03762

__init__(num_heads, key_size, w_init_scale)[source]

Initialize module.

__call__(query, key, value, mask=None)[source]

Compute (optionally masked) MHA with queries, keys & values.

Return type

ndarray

Utilities

Sequential

class pax.Sequential(*layers, name=None)[source]

Execute layers in order.

Support pax.Module (callable pytree) and any jax functions.

For example:

>>> net = pax.Sequential(
...              pax.Linear(2, 32),
...              jax.nn.relu,
...              pax.Linear(32, 3)
... )
>>> print(net.summary())
Sequential
├── Linear(in_dim=2, out_dim=32, with_bias=True)
├── x => relu(x)
└── Linear(in_dim=32, out_dim=3, with_bias=True)
>>> x = jnp.empty((3, 2))
>>> y = net(x)
>>> y.shape
(3, 3)
__init__(*layers, name=None)[source]

Create a Sequential module.

__call__(x)[source]

Call layers in order.

__getitem__(index)[source]

Get an item from the modules list.

Return type

~T

set(index, value)[source]

Set an item to the modules list.

Return type

~T

RngSeq

class pax.RngSeq(seed=None, rng_key=None)[source]

A module which generates an infinite sequence of rng keys.

__init__(seed=None, rng_key=None)[source]

Initialize a random key sequence.

Note: rng_key has a higher priority than seed.

Parameters
  • seed (Optional[int]) – an integer seed.

  • rng_key (Union[Any, ndarray, None]) – a jax random key.

next_rng_key(num_keys=1)[source]

Return the next random key of the sequence.

Note:

  • Return a key if num_keys is 1,

  • Return a list of keys if num_keys is greater than 1.

  • This is not a deterministic sequence if values of num_keys are mixed randomly.

Parameters

num_keys (int) – return more than one key.

Return type

Union[Any, ndarray, Sequence[Union[Any, ndarray]]]

Lambda

class pax.Lambda(func, name=None)[source]

Convert a function to a module.

Example:

>>> net = pax.Lambda(jax.nn.relu)
>>> print(net.summary())
x => relu(x)
>>> y = net(jnp.array(-1))
>>> y
DeviceArray(0, dtype=int32, weak_type=True)

Identity

class pax.Identity(*, training=True, name=None)[source]

Identity function as a module.

__call__(x)[source]

return x

EMA

class pax.EMA(initial_value, decay_rate, debias=False, allow_int=False)[source]

Exponential Moving Average (EMA) Module

__init__(initial_value, decay_rate, debias=False, allow_int=False)[source]

Create a new EMA module.

If allow_int=True, integer leaves are updated to the newest values instead of averaging.

Parameters
  • initial_value – the initial value.

  • decay_rate (float) – the decay rate.

  • debias (bool) – ignore the initial value to avoid biased estimates.

  • allow_int (bool) – allow integer values.

__call__(xs)[source]

Return the ema of xs. Also, update internal states.