Utils and Plotting Functions
Utils
Useful utility functions.
- class ssm.utils.Verbosity(value)
Convenience alias class for Verbosity values.
Currently, any value >= 1 corresponds to displaying progress bars for various function calls through JAX-SSM.
0:
OFF1:
QUIET2:
LOUD3:
DEBUG
- ssm.utils.tree_get(tree, idx)
Idx the leaves of the PyTree.
- Parameters
tree ([type]) – [description]
idx ([type]) – [description]
- Returns
[type] – [description]
- ssm.utils.tree_concatenate(tree1, tree2, axis=0)
Concatenate leaves of two pytrees along specified axis.
- Parameters
tree1 ([type]) – [description]
tree2 ([type]) – [description]
axis ([type]) – [description]
- Returns
[type] – [description]
- ssm.utils.tree_all_equal(tree1, tree2)
Check Pytree equality when tree leaves are arrays.
- Parameters
tree1 ([type]) – [description]
tree2 ([type]) – [description]
- Returns
isEqual (bool) – whether array PyTrees are equal
- ssm.utils.ssm_pbar(num_iters, verbose, description, *args)
Return either progress bar or regular range for iterating depending on verbosity.
- Parameters
num_iters (int) – The number of iterations for the iterator.
verbose (int) – if
verbose == 2, return`trangeobject, else returnsrangedescription (str) – description for progress bar
args – description format arguments
- ssm.utils.compute_state_overlap(z1, z2, K1=None, K2=None)
Compute a matrix describing the state-wise overlap between two state vectors
z1andz2.The state vectors should both of shape
(T,)and be integer typed.- Parameters
z1 (Sequence[int]) – The first state vector.
z2 (Sequence[int]) – The second state vector.
K1 (Optional[int]) – Optional upper bound of states to consider for
z1.K2 (Optional[int]) – Optional upper bound of states to consider for
z2.
- Returns
overlap matrix – Matrix of cumulative overlap events.
- ssm.utils.find_permutation(z1, z2, K1=None, K2=None)
Find the permutation between state vectors
z1andz2that results in the most overlap.Useful for recovering the “true” state identities for a discrete-state SSM.
- Parameters
z1 (Sequence[int]) – The first state vector.
z2 (Sequence[int]) – The second state vector.
K1 (Optional[int]) – Optional upper bound of states to consider for
z1.K2 (Optional[int]) – Optional upper bound of states to consider for
z2.
- Returns
overlap matrix – Matrix of cumulative overlap events.
- ssm.utils.random_rotation(seed, n, theta=None)
Helper function to create a rotating linear system.
- Parameters
seed (jax.random.PRNGKey) – JAX random seed.
n (int) – Dimension of the rotation matrix.
theta (float, optional) – If specified, this is the angle of the rotation, otherwise a random angle sampled from a standard Gaussian scaled by ::math::pi / 2. Defaults to None.
- Returns
[type] – [description]
- ssm.utils.ensure_has_batch_dim(batched_args=('data', 'posterior', 'covariates', 'metadata'), model_arg='self')
Decorator to automatically add a batch dim to args defined by batched_args.
Note: this decorator makes some strong assumptions about what is passed into the function. Please see details below.
Checks the shape of the PyTree leaves inside the data argument and compares them to the shape of emissions as defined by the model. A batch dimension is added if the shape only has 1 additional dimension (num_timesteps).
Naively assumes that if data needs a batch dim, then so do the rest of the batched_args.
- Parameters
batched_args (tuple, optional) – Names of the function arguments to batch. ‘data’ must be an element. Defaults to (“data”, “posterior”, “covariates”, “metadata”).
model_arg (str, optional) – The name of the argument of the model class. Used to extract information about the emissions shape. Defaults to “self”.
- ssm.utils.auto_batch(batched_args=('data', 'posterior', 'covariates', 'metadata', 'states'), model_arg='self', map_function=<function vmap>)
Decorator to automatically “map” the wrapped function along a a batch if a batch dim is detected in the data. By default, “map” means vmap.
Note: this decorator makes some strong assumptions about what is passed into the function. Please see details below.
Checks the shape of the PyTree leaves inside the data argument and compares them to the shape of emissions as defined by the model. The data is considered batched if it has two additional dimensions compared to the emissions (batch_dim and num_timesteps).
Batch dimensions should always be the leading dimension. E.g. data should have shape
(<batch>), <time>, <emissions_shape>where the batch dim is optional.Naively assumes that if data has a batch dim, then so do the rest of the batched_args.
- Parameters
batched_args (tuple, optional) – Names of the function arguments that may be batched. ‘data’ must be an element. Defaults to (“data”, “posterior”, “covariates”, “metadata”).
model_arg (str, optional) – The name of the argument of the model class. Used to extract information about the emissions shape. Defaults to “self”.
map_function (Callable, optional) – Type of map operation applied to func. Defaults to vmap.
- ssm.utils.logspace_tensordot(tensor, matrix, axis)
- Parameters
tensor ((..., m, ...)-array) –
matrix ((m, n)-array) –
axis (int) –
- Returns
result ((…, n, …)-array)
- ssm.utils.test_and_find_inequality(obj_a, obj_b, check_name='shape', mode='input', sig=None)
Iterates through zipped components of obj_a and obj_b to find inequality.
Prints a message and returns the indices of unequal components in obj_a and obj_b.
- Parameters
obj_a ([type]) – [description]
obj_b ([type]) – [description]
check_name (str, optional) – [description]. Defaults to “shape”.
mode (str, optional) – [description]. Defaults to “input”.
sig ([type], optional) – [description]. Defaults to None.
- Returns
[type] – [description]
- ssm.utils.check_pytree_structure_match(obj_a, obj_b, mode='input', sig=None)
Checks whether pytrees A and B have the same structure. Used for debugging re-jit problems (see debug_rejit decorator).
- Parameters
obj_a – pytree obj A (prev)
obj_b – pytree obj B (curr)
mode (str, optional) – “input” or “output”. Defaults to “input”.
sig (inspect.FullArgSpec, optional) – optional function signature. Used for better debug description. Defaults to None.
- ssm.utils.check_pytree_shape_match(obj_a, obj_b, mode='input', sig=None)
Checks whether pytrees A and B have the same leaf shapes. Used for debugging re-jit problems (see debug_rejit decorator).
- Parameters
obj_a (jaxlib.xla_extension.PyTreeDef) – pytree obj A (prev)
obj_b (jaxlib.xla_extension.PyTreeDef) – pytree obj B (curr)
mode (str, optional) – “input” or “output”. Defaults to “input”.
sig (inspect.FullArgSpec, optional) – doesn’t support signature yet.
- ssm.utils.check_pytree_weak_type_match(obj_a, obj_b, mode='input', sig=None)
Checks whether pytrees A and B have the same weak_typing. Used for debugging re-jit problems (see debug_rejit decorator).
- ssm.utils.check_pytree_dtype_match(obj_a, obj_b, mode='input', sig=None)
Checks whether pytrees A and B have the same dtype. Used for debugging re-jit problems (see debug_rejit decorator).
- ssm.utils.check_pytree_match(obj_a, obj_b, mode='input', sig=None)
Checks whether pytrees A and B are the same by checking shape, structure, weak_typing, and dtype.
Used for debugging re-jit problems (see debug_rejit decorator).
- Parameters
obj_a (jaxlib.xla_extension.PyTreeDef) – pytree structure A (prev)
obj_b (jaxlib.xla_extension.PyTreeDef) – pytree structure B (curr)
mode (str, optional) – “input” or “output”. Defaults to “input”.
sig (inspect.FullArgSpec, optional) – optional function signature. Used for better debug description. Defaults to None.
- ssm.utils.debug_rejit(func)
Decorator to debug re-jitting errors.
You can also set the JAX flag:
jax.config.update("jax_log_compiles", True).Checks if input and output pytrees are consistent across multiple calls to func (else: func will need to be re-compiled).
Example:
@debug_rejit @jit def fn(inputs): return outputs # ==> will print out useful description when input/output # pytrees mismatch (i.e. when fn will re-jit)
Plotting
Useful plotting utility functions.
- ssm.plots.gradient_cmap(colors, nsteps=256, bounds=None)
Return a colormap that interpolates between a set of colors. Ported from HIPS-LIB plotting functions [https://github.com/HIPS/hips-lib]
- Parameters
colors (list) – List of color values (RGB or RGBA tuples).
nsteps (int, optional) – Number of steps in the gradient. Defaults to 256.
bounds ([type], optional) – [description]. Defaults to None.
- Returns
cmap – The gradient colormap.
- ssm.plots.plot_dynamics_2d(dynamics_matrix, bias_vector, mins=(-40, -40), maxs=(40, 40), npts=20, axis=None, **kwargs)
Utility to visualize the dynamics for a 2 dimensional dynamical system.
- Parameters
dynamics_matrix – 2x2 numpy array. “A” matrix for the system.
bias_vector – “b” vector for the system. Has size (2,).
mins – Tuple of minimums for the quiver plot.
maxs – Tuple of maximums for the quiver plot.
npts – Number of arrows to show.
axis – Axis to use for plotting. Defaults to None, and returns a new axis.
kwargs – keyword args passed to plt.quiver.
- Returns
q – quiver object returned by pyplot