Utils and Plotting Functions

Utils

Useful utility functions.

class ssm.utils.Verbosity(value)

Convenience alias class for Verbosity values.

Currently, any value >= 1 corresponds to displaying progress bars for various function calls through JAX-SSM.

  • 0: OFF

  • 1: QUIET

  • 2: LOUD

  • 3: DEBUG

ssm.utils.tree_get(tree, idx)

Idx the leaves of the PyTree.

Parameters
  • tree ([type]) – [description]

  • idx ([type]) – [description]

Returns

[type] – [description]

ssm.utils.tree_concatenate(tree1, tree2, axis=0)

Concatenate leaves of two pytrees along specified axis.

Parameters
  • tree1 ([type]) – [description]

  • tree2 ([type]) – [description]

  • axis ([type]) – [description]

Returns

[type] – [description]

ssm.utils.tree_all_equal(tree1, tree2)

Check Pytree equality when tree leaves are arrays.

Parameters
  • tree1 ([type]) – [description]

  • tree2 ([type]) – [description]

Returns

isEqual (bool) – whether array PyTrees are equal

ssm.utils.ssm_pbar(num_iters, verbose, description, *args)

Return either progress bar or regular range for iterating depending on verbosity.

Parameters
  • num_iters (int) – The number of iterations for the iterator.

  • verbose (int) – if verbose == 2, return `trange object, else returns range

  • description (str) – description for progress bar

  • args – description format arguments

ssm.utils.compute_state_overlap(z1, z2, K1=None, K2=None)

Compute a matrix describing the state-wise overlap between two state vectors z1 and z2.

The state vectors should both of shape (T,) and be integer typed.

Parameters
  • z1 (Sequence[int]) – The first state vector.

  • z2 (Sequence[int]) – The second state vector.

  • K1 (Optional[int]) – Optional upper bound of states to consider for z1.

  • K2 (Optional[int]) – Optional upper bound of states to consider for z2.

Returns

overlap matrix – Matrix of cumulative overlap events.

ssm.utils.find_permutation(z1, z2, K1=None, K2=None)

Find the permutation between state vectors z1 and z2 that results in the most overlap.

Useful for recovering the “true” state identities for a discrete-state SSM.

Parameters
  • z1 (Sequence[int]) – The first state vector.

  • z2 (Sequence[int]) – The second state vector.

  • K1 (Optional[int]) – Optional upper bound of states to consider for z1.

  • K2 (Optional[int]) – Optional upper bound of states to consider for z2.

Returns

overlap matrix – Matrix of cumulative overlap events.

ssm.utils.random_rotation(seed, n, theta=None)

Helper function to create a rotating linear system.

Parameters
  • seed (jax.random.PRNGKey) – JAX random seed.

  • n (int) – Dimension of the rotation matrix.

  • theta (float, optional) – If specified, this is the angle of the rotation, otherwise a random angle sampled from a standard Gaussian scaled by ::math::pi / 2. Defaults to None.

Returns

[type] – [description]

ssm.utils.ensure_has_batch_dim(batched_args=('data', 'posterior', 'covariates', 'metadata'), model_arg='self')

Decorator to automatically add a batch dim to args defined by batched_args.

Note: this decorator makes some strong assumptions about what is passed into the function. Please see details below.

Checks the shape of the PyTree leaves inside the data argument and compares them to the shape of emissions as defined by the model. A batch dimension is added if the shape only has 1 additional dimension (num_timesteps).

Naively assumes that if data needs a batch dim, then so do the rest of the batched_args.

Parameters
  • batched_args (tuple, optional) – Names of the function arguments to batch. ‘data’ must be an element. Defaults to (“data”, “posterior”, “covariates”, “metadata”).

  • model_arg (str, optional) – The name of the argument of the model class. Used to extract information about the emissions shape. Defaults to “self”.

ssm.utils.auto_batch(batched_args=('data', 'posterior', 'covariates', 'metadata', 'states'), model_arg='self', map_function=<function vmap>)

Decorator to automatically “map” the wrapped function along a a batch if a batch dim is detected in the data. By default, “map” means vmap.

Note: this decorator makes some strong assumptions about what is passed into the function. Please see details below.

Checks the shape of the PyTree leaves inside the data argument and compares them to the shape of emissions as defined by the model. The data is considered batched if it has two additional dimensions compared to the emissions (batch_dim and num_timesteps).

Batch dimensions should always be the leading dimension. E.g. data should have shape (<batch>), <time>, <emissions_shape> where the batch dim is optional.

Naively assumes that if data has a batch dim, then so do the rest of the batched_args.

Parameters
  • batched_args (tuple, optional) – Names of the function arguments that may be batched. ‘data’ must be an element. Defaults to (“data”, “posterior”, “covariates”, “metadata”).

  • model_arg (str, optional) – The name of the argument of the model class. Used to extract information about the emissions shape. Defaults to “self”.

  • map_function (Callable, optional) – Type of map operation applied to func. Defaults to vmap.

ssm.utils.logspace_tensordot(tensor, matrix, axis)
Parameters
  • tensor ((..., m, ...)-array) –

  • matrix ((m, n)-array) –

  • axis (int) –

Returns

result ((…, n, …)-array)

ssm.utils.test_and_find_inequality(obj_a, obj_b, check_name='shape', mode='input', sig=None)

Iterates through zipped components of obj_a and obj_b to find inequality.

Prints a message and returns the indices of unequal components in obj_a and obj_b.

Parameters
  • obj_a ([type]) – [description]

  • obj_b ([type]) – [description]

  • check_name (str, optional) – [description]. Defaults to “shape”.

  • mode (str, optional) – [description]. Defaults to “input”.

  • sig ([type], optional) – [description]. Defaults to None.

Returns

[type] – [description]

ssm.utils.check_pytree_structure_match(obj_a, obj_b, mode='input', sig=None)

Checks whether pytrees A and B have the same structure. Used for debugging re-jit problems (see debug_rejit decorator).

Parameters
  • obj_a – pytree obj A (prev)

  • obj_b – pytree obj B (curr)

  • mode (str, optional) – “input” or “output”. Defaults to “input”.

  • sig (inspect.FullArgSpec, optional) – optional function signature. Used for better debug description. Defaults to None.

ssm.utils.check_pytree_shape_match(obj_a, obj_b, mode='input', sig=None)

Checks whether pytrees A and B have the same leaf shapes. Used for debugging re-jit problems (see debug_rejit decorator).

Parameters
  • obj_a (jaxlib.xla_extension.PyTreeDef) – pytree obj A (prev)

  • obj_b (jaxlib.xla_extension.PyTreeDef) – pytree obj B (curr)

  • mode (str, optional) – “input” or “output”. Defaults to “input”.

  • sig (inspect.FullArgSpec, optional) – doesn’t support signature yet.

ssm.utils.check_pytree_weak_type_match(obj_a, obj_b, mode='input', sig=None)

Checks whether pytrees A and B have the same weak_typing. Used for debugging re-jit problems (see debug_rejit decorator).

ssm.utils.check_pytree_dtype_match(obj_a, obj_b, mode='input', sig=None)

Checks whether pytrees A and B have the same dtype. Used for debugging re-jit problems (see debug_rejit decorator).

ssm.utils.check_pytree_match(obj_a, obj_b, mode='input', sig=None)

Checks whether pytrees A and B are the same by checking shape, structure, weak_typing, and dtype.

Used for debugging re-jit problems (see debug_rejit decorator).

Parameters
  • obj_a (jaxlib.xla_extension.PyTreeDef) – pytree structure A (prev)

  • obj_b (jaxlib.xla_extension.PyTreeDef) – pytree structure B (curr)

  • mode (str, optional) – “input” or “output”. Defaults to “input”.

  • sig (inspect.FullArgSpec, optional) – optional function signature. Used for better debug description. Defaults to None.

ssm.utils.debug_rejit(func)

Decorator to debug re-jitting errors.

You can also set the JAX flag: jax.config.update("jax_log_compiles", True).

Checks if input and output pytrees are consistent across multiple calls to func (else: func will need to be re-compiled).

Example:

@debug_rejit
@jit
def fn(inputs):
    return outputs

# ==> will print out useful description when input/output
#     pytrees mismatch (i.e. when fn will re-jit)

Plotting

Useful plotting utility functions.

ssm.plots.gradient_cmap(colors, nsteps=256, bounds=None)

Return a colormap that interpolates between a set of colors. Ported from HIPS-LIB plotting functions [https://github.com/HIPS/hips-lib]

Parameters
  • colors (list) – List of color values (RGB or RGBA tuples).

  • nsteps (int, optional) – Number of steps in the gradient. Defaults to 256.

  • bounds ([type], optional) – [description]. Defaults to None.

Returns

cmap – The gradient colormap.

ssm.plots.plot_dynamics_2d(dynamics_matrix, bias_vector, mins=(-40, -40), maxs=(40, 40), npts=20, axis=None, **kwargs)

Utility to visualize the dynamics for a 2 dimensional dynamical system.

Parameters
  • dynamics_matrix – 2x2 numpy array. “A” matrix for the system.

  • bias_vector – “b” vector for the system. Has size (2,).

  • mins – Tuple of minimums for the quiver plot.

  • maxs – Tuple of maximums for the quiver plot.

  • npts – Number of arrows to show.

  • axis – Axis to use for plotting. Defaults to None, and returns a new axis.

  • kwargs – keyword args passed to plt.quiver.

Returns

q – quiver object returned by pyplot