Source code for diffusionlab.distributions.gmm.low_rank_gmm

from typing import Tuple, cast
from dataclasses import dataclass
from jax import numpy as jnp, Array
import jax
from diffusionlab.distributions.base import Distribution
from diffusionlab.distributions.gmm.utils import (
    _logdet_psd,
    _lstsq,
    create_gmm_vector_field_fns,
)
from diffusionlab.dynamics import DiffusionProcess


[docs] @dataclass(frozen=True) class LowRankGMM(Distribution): """ Implements a low-rank Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)`` This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core low-rank GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``cov_factors`` (``Array[num_components, data_dim, rank]``): The low-rank covariance matrix factors of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). """ def __init__(self, means: Array, cov_factors: Array, priors: Array): """ Initializes the low-rank GMM distribution. Args: means (``Array[num_components, data_dim]``): Means for each Gaussian component. cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices for each Gaussian component. priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. """ eps = cast(float, jnp.finfo(cov_factors.dtype).eps) assert means.ndim == 2 num_components, data_dim, rank = cov_factors.shape assert means.shape == (num_components, data_dim) assert priors.shape == (num_components,) assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) super().__init__( dist_params={ "means": means, "cov_factors": cov_factors, "priors": priors, }, dist_hparams={}, )
[docs] def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: """ Draws samples from the low-rank GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. """ num_components, data_dim = self.dist_params["means"].shape key, key_cat, key_norm = jax.random.split(key, 3) component_indices = jax.random.categorical( key_cat, jnp.log(self.dist_params["priors"]), shape=(num_samples,) ) # (num_samples,) chosen_means = self.dist_params["means"][ component_indices ] # (num_samples, data_dim) chosen_cov_factors = self.dist_params["cov_factors"][ component_indices ] # (num_samples, data_dim, rank) sample_keys = jax.random.split(key_norm, num_samples) # (num_samples, ) def sample_one(mean: Array, cov_factor: Array, single_key: Array) -> Array: data_dim, rank = cov_factor.shape noise = jax.random.multivariate_normal( single_key, jnp.zeros((rank,)), jnp.eye(rank), shape=() ) return mean + cov_factor @ noise samples = jax.vmap(sample_one)( chosen_means, chosen_cov_factors, sample_keys ) # (num_samples, data_dim) return samples, component_indices
[docs] def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the score vector field ``∇_x log p_t(x_t)`` for the low-rank GMM distribution. This is calculated with respect to the perturbed distribution ``p_t`` induced by the ``diffusion_process`` at time ``t``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. """ return low_rank_gmm_score( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["cov_factors"], self.dist_params["priors"], )
[docs] def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. """ return low_rank_gmm_x0( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["cov_factors"], self.dist_params["priors"], )
[docs] def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the noise prediction ε for the low-rank GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. """ return low_rank_gmm_eps( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["cov_factors"], self.dist_params["priors"], )
[docs] def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the velocity vector field ``v`` for the low-rank GMM distribution. This is the conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. """ return low_rank_gmm_v( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["cov_factors"], self.dist_params["priors"], )
[docs] def low_rank_gmm_x0( x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, cov_factors: Array, priors: Array, ) -> Array: """ Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a low-rank GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the low-rank GMM distribution defined by ``means``, ``cov_factors``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. """ num_components, data_dim, rank = cov_factors.shape alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) inner_covs = jax.vmap(lambda cov_factor: cov_factor.T @ cov_factor)( cov_factors ) # (num_components, rank, rank) xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( means_t ) # (num_components, data_dim) covs_t_inverse_xbars_t = jax.vmap( lambda cov_factor, inner_cov, xbar_t: (1 / sigma_t**2) * ( xbar_t - cov_factor @ _lstsq( inner_cov + (sigma_t / alpha_t) ** 2 * jnp.eye(rank), cov_factor.T @ xbar_t, ) ) )(cov_factors, inner_covs, xbars_t) # (num_components, data_dim) logdets_covs_t = jax.vmap( lambda inner_cov: _logdet_psd( (alpha_t / sigma_t) ** 2 * inner_cov + jnp.eye(rank) ) )(inner_covs) + 2 * data_dim * jnp.log(sigma_t) # (num_components,) log_likelihoods_unnormalized = jax.vmap( lambda xbar_t, covs_t_inverse_xbar_t, logdet_covs_t: -(1 / 2) * (jnp.sum(xbar_t * covs_t_inverse_xbar_t) + logdet_covs_t) )(xbars_t, covs_t_inverse_xbars_t, logdets_covs_t) # (num_components,) log_posterior_unnormalized = ( jnp.log(priors) + log_likelihoods_unnormalized ) # (num_components,) posterior_probs = jax.nn.softmax( log_posterior_unnormalized, axis=0 ) # (num_components,) posterior_means = jax.vmap( lambda mean, cov_factor, covs_t_inverse_xbar_t: mean + alpha_t * cov_factor @ (cov_factor.T @ covs_t_inverse_xbar_t) )(means, cov_factors, covs_t_inverse_xbars_t) # (num_components, data_dim) x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) return x0_pred
# Generate eps, score, v functions from low_rank_gmm_x0 low_rank_gmm_eps, low_rank_gmm_score, low_rank_gmm_v = create_gmm_vector_field_fns( low_rank_gmm_x0 )