Base Model Prototype

Base state-space model class.

In SSM-JAX, a state-space model is represented as a class with methods defining the initial_distribution, dynamics_distribution, and emissions_distribution respectively.

The base SSM object provides template functionality for a state space model.

class ssm.base.SSM

A generic state-space model base class.

initial_distribution(covariates=None, metadata=None)

The distribution over the initial state of the SSM.

p(x_1)

Parameters
  • covariates (PyTree, optional) – optional covariates with leaf shape [B, T, …]. Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape [B, …]. Defaults to None.

Returns

initial_distribution (tfp.distributions.Distribution) – A distribution over initial states in the SSM.

Return type

Distribution

dynamics_distribution(state, covariates=None, metadata=None)

The dynamics (or state-transition) distribution conditioned on the current state.

p(x_{t+1} | x_t, u_{t+1})

Parameters

state (float) – The current state on which to condition the dynamics.

Returns

dynamics_distribution (tfp.distributions.Distribution) – The distribution over states conditioned on the current state.

Return type

Distribution

emissions_distribution(state, covariates=None, metadata=None)

The emissions (or observation) distribution conditioned on the current state.

p(y_t | x_t, u_t)

Parameters

state (float) – The current state on which to condition the emissions.

Returns

emissions_distribution (tfp.distributions.Distribution) – The emissions distribution conditioned on the provided state.

Return type

Distribution

property emissions_shape

Returns the shape of a single emission, y_t.

Returns

A tuple or tree of tuples giving the emission shape (s)

log_probability(states, data, covariates=None, metadata=None)

Computes the log joint probability of a set of states and data (observations).

\log p(x, y) = \log p(x_1) + \sum_{t=1}^{T-1} \log p(x_{t+1} | x_t) + \sum_{t=1}^{T} \log p(y_t | x_t)

Parameters
  • states – latent states x_{1:T} of shape (\text{[batch]} , \text{num\_timesteps} , \text{latent\_dim})

  • data – observed data y_{1:T} of shape (\text{[batch]} , \text{num\_timesteps} , \text{emissions\_dim})

  • covariates (PyTree, optional) – optional covariates with leaf shape [B, T, …]. Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape [B, …]. Defaults to None.

Returns

lp – log joint probability \log p(x, y) of shape (\text{batch]},)

elbo(key, data, posterior, covariates=None, metadata=None, num_samples=1)

Compute an evidence lower bound (ELBO) using the joint probability and an approximate posterior q(x) \approx p(x | y):

While in some cases the expectation can be computed in closed form, in general we will approximate it with ordinary Monte Carlo.

Parameters
  • key (jr.PRNGKey) – random seed

  • data (PyTree) – observed data with leaf shape ([B, T, D])

  • covariates (PyTree, optional) – optional covariates with leaf shape ([B, T, …]). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape ([B, …]). Defaults to None.

  • num_samples (int) – number of samples to evaluate the ELBO

Returns

elbo – the evidence lower bound of shape ([B,])

sample(key, num_steps, initial_state=None, covariates=None, metadata=None, num_samples=1)

Sample from the joint distribution defined by the state space model.

x, y \sim p(x, y)

Parameters
  • key (jr.PRNGKey) – A JAX pseudorandom number generator key.

  • num_steps (int) – Number of steps for which to sample.

  • initial_state – Optional state on which to condition the sampled trajectory. Default is None which samples the intial state from the initial distribution.

  • covariates (PyTree, optional) – optional covariates with leaf shape ([B, T, …]). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape ([B, …]). Defaults to None.

  • num_samples (int) – Number of indepedent samples (defines the batch dimension).

Returns
  • states – an array of latent states across time x_{1:T} of shape (\text{[batch]} , \text{num\_timesteps} , \text{latent\_dim})

  • emissions – an array of observations across time y_{1:T} of shape (\text{[batch]} , \text{num\_timesteps} , \text{emissions\_dim})