Laplace EM Documentation

Laplace EM (for non-conjugate LDS models such as GLM-LDS)

ssm.inference.laplace_em.laplace_approximation(model, data, initial_states, laplace_mode_fit_method='L-BFGS', num_laplace_mode_iters=10)

Laplace approximation to the posterior distribution for state space models with continuous latent states.

ssm.inference.laplace_em.laplace_em(key, model, data, num_iters=100, num_elbo_samples=1, num_approx_m_iters=100, laplace_mode_fit_method='L-BFGS', num_laplace_mode_iters=100, tol=0.0001, verbosity=Verbosity.DEBUG)

Fit state space models such as an LDS with GLM emissions using Laplace EM. The state space models must have continuous latent states. Ideally, the log joint probability should also be concave so that a unique maximum exists.

Laplace EM approximates the posterior as a Gaussian whose mean and covariance is set to match the mode and curvature (negative Hessian) of the posterior distribution.

Note that because Laplace EM does not use the true posterior in the E-step, we are not guaranteed that the marginal log probability increases (as is true with exact EM).

Parameters
  • rng (jax.random.PRNGKey) – JAX random seed.

  • ssm (SSM) – The SSM model object to be fit.

  • dataset (array, (num_timesteps, obs_dim)) – The observed data.

  • num_iters (int, optional) – Number of iteration to run the Laplace EM algorithm. Defaults to 100.

  • num_elbo_samples (int, optional) – Number of Monte Carlo samples used to compute the ELBO expectation. Defaults to 1.

  • laplace_mode_fit_method (str, optional) – Optimization method used to compute the mode for the Laplace approximation. Must be one of [“L-BFGS”, “BFGS”, “Adam”]. Defaults to “L-BFGS”.

  • num_laplace_mode_iters (int, optional) – Only relevant for when laplace_mode_fit_method is “Adam.” Specifies the number of iterations to run the Adam updates. High values of iterations makes jit compilation slow. Defaults to 100.

  • tol (float, optional) – Tolerance to determine convergence of ELBO. Defaults to 1e-4.

  • verbosity (Verbosity, optional) – See Verbosity. Defaults to Verbosity.DEBUG.

  • num_approx_m_iters (int) –

Returns
  • elbos (array, (num_iters,)) – The ELBO objective per iteration. Ideally, this should increase as the model is fit.

  • ssm (SSM) – The fitted SSM object after running Laplace EM.

  • posterior (MultivariateNormalBlockTridiag) – The corresponding posterior distribution of the fitted SSM.