Linear Dynamical System (LDS)

LDS Models

class ssm.lds.base.LDS(initial_condition, dynamics, emissions)

The LDS base class.

Parameters
property emissions_shape

Returns the shape of a single emission, y_t.

Returns

A tuple or tree of tuples giving the emission shape (s)

initial_distribution(covariates=None, metadata=None)

The distribution over the initial state of the SSM.

p(x_1)

Parameters
  • covariates (PyTree, optional) – optional covariates with leaf shape [B, T, …]. Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape [B, …]. Defaults to None.

Returns

initial_distribution (tfp.distributions.Distribution) – A distribution over initial states in the SSM.

dynamics_distribution(state, covariates=None, metadata=None)

The dynamics (or state-transition) distribution conditioned on the current state.

p(x_{t+1} | x_t, u_{t+1})

Parameters

state – The current state on which to condition the dynamics.

Returns

dynamics_distribution (tfp.distributions.Distribution) – The distribution over states conditioned on the current state.

emissions_distribution(state, covariates=None, metadata=None)

The emissions (or observation) distribution conditioned on the current state.

p(y_t | x_t, u_t)

Parameters

state – The current state on which to condition the emissions.

Returns

emissions_distribution (tfp.distributions.Distribution) – The emissions distribution conditioned on the provided state.

initialize(dataset, covariates=None, metadata=None, key=None, method=None)

Initialize the LDS parameters. NOTE: Not yet implemented.

m_step(data, posterior, covariates=None, metadata=None, key=None)

Update the model in a (potentially approximate) M step.

Parameters
  • data (np.ndarray) – observed data with shape (B, T, D)

  • posterior (LDSPosterior) – LDS posterior object with leaf shapes (B, …).

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • key (jr.PRNGKey, optional) – random seed. Defaults to None.

Returns

lds (LDS) – updated lds object

Return type

LDS

fit(data, covariates=None, metadata=None, method='laplace_em', rng=None, num_iters=100, tol=0.0001, verbosity=Verbosity.DEBUG)

Fit the LDS to a dataset using the specified method.

Generally speaking, we cannot perform exact EM for an LDS with arbitrary emissions. However, for an LDS with generalized linear model (GLM) emissions, we can perform Laplace EM.

Parameters
  • data (np.ndarray) – observed data of shape (\text{[batch]} , \text{num\_timesteps} , \text{emissions\_dim})

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • method (str, optional) – model fit method. Must be one of [“laplace_em”]. Defaults to “laplace_em”.

  • rng (jr.PRNGKey, optional) – Random seed. Defaults to None.

  • num_iters (int, optional) – number of fit iterations. Defaults to 100.

  • tol (float, optional) – tolerance in log probability to determine convergence. Defaults to 1e-4.

  • verbosity (Verbosity, optional) – print verbosity. Defaults to Verbosity.DEBUG.

Raises

ValueError – if fit method is not reocgnized

Returns
  • elbos (np.ndarray) – elbos at each fit iteration

  • model (LDS) – the fitted model

  • posteriors (LDSPosterior) – the fitted posteriors

class ssm.lds.models.GaussianLDS(num_latent_dims, num_emission_dims, initial_state_mean=None, initial_state_scale_tril=None, dynamics_weights=None, dynamics_bias=None, dynamics_scale_tril=None, emission_weights=None, emission_bias=None, emission_scale_tril=None, seed=None)

LDS with Gaussian emissions.

p(y_t | x_t) \sim \mathcal{N}(\mu_{x_t}, \Sigma_{x_t})

The GaussianLDS can be initialized by specifying each parameter explicitly, or you can simply specify the num_latent_dims, num_emission_dims, and seed to create a GaussianLDS with generic, randomly initialized parameters.

Parameters
  • num_latent_dims (int) – number of latent dims.

  • num_emission_dims (int) – number of emission dims.

  • initial_state_mean (np.ndarray, optional) – initial state mean. Defaults to zero vector.

  • initial_state_scale_tril (np.ndarray, optional) – initial state lower-triangular factor of covariance. Defaults to identity matrix.

  • dynamics_weights (np.ndarray, optional) – weights in dynamics GLM. Defaults to a random rotation.

  • dynamics_bias (np.ndarray, optional) – bias in dynamics GLM. Defaults to zero vector.

  • dynamics_scale_tril (np.ndarray, optional) – dynamics GLM lower triangular initial state lower-triangular factor of covariance. Defaults to 0.1**2 * identity matrix.

  • emission_weights (np.ndarray, optional) – weights in emissions GLM. Defaults to a random rotation.

  • emission_bias (np.ndarray, optional) – bias in emissions GLM. Defaults to zero vector.

  • emission_scale_tril (np.ndarray, optional) – emissions GLM slower-triangular factor of covariance. Defaults to the identity matrix.

  • seed (jr.PRNGKey, optional) – random seed. Defaults to None.

e_step(data, covariates=None, metadata=None)

Compute the exact posterior by extracting the natural parameters of the LDS, namely the block tridiagonal precision matrix (J) and the linear coefficient (h).

Parameters
  • data (np.ndarray) – the observed data of shape (B, T, D)

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

posterior (LDSPosterior) – the exact posterior over the latent states.

Return type

MultivariateNormalBlockTridiag

marginal_likelihood(data, posterior=None, covariates=None, metadata=None)

The exact marginal likelihood of the observed data.

For a Gaussian LDS, we can compute the exact marginal likelihood of the data (y) given the posterior p(x | y) via Bayes’ rule:

\log p(y) = \log p(y, x) - \log p(x | y)

This equality holds for _any_ choice of x. We’ll use the posterior mean.

Parameters
  • data (np.ndarray) – the observed data.

  • posterior (LDSPosterior, optional) – the posterior distribution on latent states. If None, the posterior is computed via message passing. Defaults to None.

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

- lp (float) – The marginal log likelihood of the data.

fit(data, covariates=None, metadata=None, method='em', key=None, num_iters=100, tol=0.0001, verbosity=Verbosity.DEBUG)

Fit the GaussianLDS to a dataset using the specified method.

Note: because the observations are Gaussian, we can perform exact EM for a GaussianEM (i.e. the model is conjugate).

Parameters
  • data (np.ndarray) – observed data of shape (\text{[batch]} , \text{num\_timesteps} , \text{emissions\_dim})

  • covariates (PyTreeDef, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTreeDef, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • method (str, optional) – model fit method. Must be one of [“em”, “laplace_em”]. Defaults to “em”.

  • key (jr.PRNGKey, optional) – Random seed. Defaults to None.

  • num_iters (int, optional) – number of fit iterations. Defaults to 100.

  • tol (float, optional) – tolerance in log probability to determine convergence. Defaults to 1e-4.

  • verbosity (Verbosity, optional) – print verbosity. Defaults to Verbosity.DEBUG.

Raises

ValueError – if fit method is not reocgnized

Returns
  • elbos (np.ndarray) – elbos at each fit iteration

  • model (LDS) – the fitted model

  • posteriors (LDSPosterior) – the fitted posteriors

class ssm.lds.models.PoissonLDS(num_latent_dims, num_emission_dims, initial_state_mean=None, initial_state_scale_tril=None, dynamics_weights=None, dynamics_bias=None, dynamics_scale_tril=None, emission_weights=None, emission_bias=None, emission_scale_tril=None, seed=None)

LDS with Poisson emissions.

p(y_t | x_t) \sim   ext{Po}(\lambda = \lambda_{x_t})

The PoissonLDS can be initialized by specifying each parameter explicitly, or you can simply specify the num_latent_dims, num_emission_dims, and seed to create a GaussianLDS with generic, randomly initialized parameters.

Parameters
  • num_latent_dims (int) – number of latent dims.

  • num_emission_dims (int) – number of emission dims.

  • initial_state_mean (np.ndarray, optional) – initial state mean. Defaults to zero vector.

  • initial_state_scale_tril (np.ndarray, optional) – initial state lower-triangular factor of covariance. Defaults to identity matrix.

  • dynamics_weights (np.ndarray, optional) – weights in dynamics GLM. Defaults to a random rotation.

  • dynamics_bias (np.ndarray, optional) – bias in dynamics GLM. Defaults to zero vector.

  • dynamics_scale_tril (np.ndarray, optional) – dynamics GLM lower triangular initial state lower-triangular factor of covariance. Defaults to 0.1**2 * identity matrix.

  • emission_weights (np.ndarray, optional) – weights in emissions GLM. Defaults to a random matrix.

  • emission_bias (np.ndarray, optional) – bias in emissions GLM. Defaults to zero vector.

  • seed (jr.PRNGKey, optional) – random seed. Defaults to None.

  • emission_scale_tril (Array) –

LDS Components

LDS Initials

class ssm.lds.initial.InitialCondition

Base class for initial state distributions of an LDS.

p(x_1 \mid u_t)

where u_t are optional covariates at time t.

distribution(covariates=None, metadata=None)

Return the distribution of x_1 (potentially given covariates u_t)

Parameters
  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

distribution (tfd.Distribution) – distribution of z_1

class ssm.lds.initial.StandardInitialCondition(initial_mean=None, initial_scale_tril=None, initial_distribution=None, initial_distribution_prior=None)

The standard model is a multivariate Normal distribution. (With covariance parameterized by the lower triagular scale cov = scale_tril @ scale_tril.T)

Parameters
  • initial_distribution (ssmd.MultivariateNormalTriL) –

  • initial_distribution_prior (ssmd.NormalInverseWishart) –

distribution(covariates=None, metadata=None)

Return the distribution of x_1.

Parameters
  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

distribution (tfd.Distribution) – distribution of z_1

m_step(dataset, posteriors, covariates=None, metadata=None)

Update the initial distribution in an M step given posteriors over the latent states.

Update is performed in place.

Parameters
  • dataset (np.ndarray) – the observed dataset with shape (B, T, D)

  • posteriors (HMMPosterior) – posteriors over the latent states with leaf shape (B, …)

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

initial_condition (StandardInitialCondition) – updated initial condition object

Return type

StandardInitialCondition

LDS Dynamics

class ssm.lds.dynamics.Dynamics

Base class for HMM transitions models,

p_t(z_t \mid z_{t-1}, u_t)

where u_t are optional covariates at time t.

distribution(state, covariates=None, metadata=None)

Return the conditional distribution of x_t given state x_{t-1}

Parameters
  • state (float) – state x_{t-1}

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

distribution (tfd.Distribution) – conditional distribution of x_t given state x_{t-1}.

m_step(dataset, posteriors)

Update the transition parameters in an M step given posteriors over the latent states.

Parameters
  • dataset (np.ndarray) – the observed dataset with shape (B, T, D)

  • posteriors (HMMPosterior) – posteriors over the latent states with leaf shape (B, …)

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

dynamics (Dynamics) – updated dynamics object

Return type

Dynamics

class ssm.lds.dynamics.StationaryDynamics(weights=None, bias=None, scale_tril=None, dynamics_distribution=None, dynamics_distribution_prior=None)

Basic dynamics model for LDS.

Parameters
  • dynamics_distribution (ssmd.GaussianLinearRegression) –

  • dynamics_distribution_prior (ssmd.GaussianLinearRegressionPrior) –

distribution(state, covariates=None, metadata=None)

Return the conditional distribution of x_t given state x_{t-1}

Parameters
  • state (float) – state x_{t-1}

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

distribution (tfd.Distribution) – conditional distribution of x_t given state x_{t-1}.

m_step(batched_data, batched_posteriors, batched_covariates=None, batched_metadata=None)

Update the transition parameters in an M step given posteriors over the latent states.

Parameters
  • dataset (np.ndarray) – the observed dataset with shape (B, T, D)

  • posteriors (HMMPosterior) – posteriors over the latent states with leaf shape (B, …)

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

dynamics (Dynamics) – updated dynamics object

Return type

StationaryDynamics

LDS Emissions

class ssm.lds.emissions.Emissions

Base class of emission distribution of an LDS

p_t(y_t \mid x_t, u_t)

where u_t are optional covariates.

distribution(state, covariates=None, metadata=None)

Return the conditional distribution of emission y_t given state x_t and (optionally) covariates u_t.

Parameters
  • state (float) – continuous state

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

emissions distribution (tfd.MultivariateNormalLinearOperator) – emissions distribution at given state

m_step(data, posterior, covariates=None, metadata=None, num_samples=1, key=None)

Update the emissions distribution using an M-step.

Operates over a batch of data (posterior must have the same batch dim).

Parameters
  • dataset (np.ndarray) – the observed dataset

  • posteriors (LDSPosterior) – the HMM posteriors

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • num_samples (int) – number of samples from posterior to use in a generic update

  • key (jr.PRNGKey) – random seed

Returns

emissions (Emissions) – updated emissions object

Return type

Emissions

class ssm.lds.emissions.GaussianEmissions(weights=None, bias=None, scale_tril=None, emissions_distribution=None, emissions_distribution_prior=None)
Parameters
  • emissions_distribution (tfd.Distribution) –

  • emissions_distribution_prior (tfd.Distribution) –

distribution(state, covariates=None, metadata=None)

Return the conditional distribution of emission y_t given state x_t and (optionally) covariates u_t.

Parameters
  • state (float) – continuous state

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

emissions distribution (tfd.MultivariateNormalLinearOperator) – emissions distribution at given state

m_step(data, posterior, covariates=None, metadata=None, key=None)

Update the emissions distribution using an exact M-step.

Operates over a batch of data (posterior must have the same batch dim).

Parameters
  • dataset (np.ndarray) – the observed dataset

  • posteriors (LDSPosterior) – the HMM posteriors

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • num_samples (int) – number of samples from posterior to use in a generic update

  • key (jr.PRNGKey) – random seed

Returns

emissions (GaussianEmissions) – updated emissions object

Return type

GaussianEmissions

class ssm.lds.emissions.PoissonEmissions(weights=None, bias=None, emissions_distribution=None, emissions_distribution_prior=None)
Parameters
  • emissions_distribution (glm.PoissonGLM) –

  • emissions_distribution_prior (tfd.Distribution) –

distribution(state, covariates=None, metadata=None)

Return the conditional distribution of emission y_t given state x_t and (optionally) covariates u_t.

Parameters
  • state (float) – continuous state

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

Returns

emissions distribution (tfd.MultivariateNormalLinearOperator) – emissions distribution at given state

m_step(data, posterior, covariates=None, metadata=None, num_samples=1, key=None)

Update the emissions distribution using an M-step.

Operates over a batch of data (posterior must have the same batch dim).

Parameters
  • dataset (np.ndarray) – the observed dataset

  • posteriors (LDSPosterior) – the HMM posteriors

  • covariates (PyTree, optional) – optional covariates with leaf shape (B, T, …). Defaults to None.

  • metadata (PyTree, optional) – optional metadata with leaf shape (B, …). Defaults to None.

  • num_samples (int) – number of samples from posterior to use in a generic update

  • key (jr.PRNGKey) – random seed

Returns

emissions (Emissions) – updated emissions object

Return type

PoissonEmissions