Base Model Prototype
Base state-space model class.
In SSM-JAX, a state-space model is represented as a class with methods defining
the initial_distribution, dynamics_distribution, and emissions_distribution
respectively.
The base SSM object provides template functionality for a state space model.
- class ssm.base.SSM
A generic state-space model base class.
- initial_distribution(covariates=None, metadata=None)
The distribution over the initial state of the SSM.
- Parameters
covariates (PyTree, optional) – optional covariates with leaf shape [B, T, …]. Defaults to None.
metadata (PyTree, optional) – optional metadata with leaf shape [B, …]. Defaults to None.
- Returns
initial_distribution (tfp.distributions.Distribution) – A distribution over initial states in the SSM.
- Return type
Distribution
- dynamics_distribution(state, covariates=None, metadata=None)
The dynamics (or state-transition) distribution conditioned on the current state.
- Parameters
state (float) – The current state on which to condition the dynamics.
- Returns
dynamics_distribution (tfp.distributions.Distribution) – The distribution over states conditioned on the current state.
- Return type
Distribution
- emissions_distribution(state, covariates=None, metadata=None)
The emissions (or observation) distribution conditioned on the current state.
- Parameters
state (float) – The current state on which to condition the emissions.
- Returns
emissions_distribution (tfp.distributions.Distribution) – The emissions distribution conditioned on the provided state.
- Return type
Distribution
- property emissions_shape
Returns the shape of a single emission,
.
- Returns
A tuple or tree of tuples giving the emission shape (s)
- log_probability(states, data, covariates=None, metadata=None)
Computes the log joint probability of a set of states and data (observations).
- Parameters
states – latent states
of shape
data – observed data
of shape
covariates (PyTree, optional) – optional covariates with leaf shape [B, T, …]. Defaults to None.
metadata (PyTree, optional) – optional metadata with leaf shape [B, …]. Defaults to None.
- Returns
lp – log joint probability
of shape
- elbo(key, data, posterior, covariates=None, metadata=None, num_samples=1)
Compute an evidence lower bound (ELBO) using the joint probability and an approximate posterior
:
While in some cases the expectation can be computed in closed form, in general we will approximate it with ordinary Monte Carlo.
- Parameters
key (jr.PRNGKey) – random seed
data (PyTree) – observed data with leaf shape ([B, T, D])
covariates (PyTree, optional) – optional covariates with leaf shape ([B, T, …]). Defaults to None.
metadata (PyTree, optional) – optional metadata with leaf shape ([B, …]). Defaults to None.
num_samples (int) – number of samples to evaluate the ELBO
- Returns
elbo – the evidence lower bound of shape ([B,])
- sample(key, num_steps, initial_state=None, covariates=None, metadata=None, num_samples=1)
Sample from the joint distribution defined by the state space model.
- Parameters
key (jr.PRNGKey) – A JAX pseudorandom number generator key.
num_steps (int) – Number of steps for which to sample.
initial_state – Optional state on which to condition the sampled trajectory. Default is None which samples the intial state from the initial distribution.
covariates (PyTree, optional) – optional covariates with leaf shape ([B, T, …]). Defaults to None.
metadata (PyTree, optional) – optional metadata with leaf shape ([B, …]). Defaults to None.
num_samples (int) – Number of indepedent samples (defines the batch dimension).
- Returns
states – an array of latent states across time
of shape
emissions – an array of observations across time
of shape