Source code for diffusionlab.distributions.gmm.iso_hom_gmm

from dataclasses import dataclass
from typing import Tuple, cast
from jax import numpy as jnp, Array
import jax
from diffusionlab.distributions.base import Distribution
from diffusionlab.distributions.gmm.utils import create_gmm_vector_field_fns
from diffusionlab.dynamics import DiffusionProcess


[docs] @dataclass(frozen=True) class IsoHomGMM(Distribution): """ Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution. The probability measure is given by: ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)`` This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. - ``variance`` (``Array[]``): The variance of the GMM components. - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). """ def __init__(self, means: Array, variance: Array, priors: Array): """ Initializes the isotropic homoscedastic GMM distribution. Args: means (``Array[num_components, data_dim]``): Means for each Gaussian component. variance (``Array[]``): Variance for each Gaussian component. priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. """ eps = cast(float, jnp.finfo(variance.dtype).eps) assert means.ndim == 2 num_components, data_dim = means.shape assert variance.shape == () assert priors.shape == (num_components,) assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) assert variance >= -eps super().__init__( dist_params={ "means": means, "variance": variance, "priors": priors, }, dist_hparams={}, )
[docs] def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: """ Draws samples from the isotropic homoscedastic GMM distribution. Args: key (``Array``): JAX PRNG key for random sampling. num_samples (``int``): The total number of samples to generate. Returns: ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. """ num_components, data_dim = self.dist_params["means"].shape key, key_cat, key_norm = jax.random.split(key, 3) component_indices = jax.random.categorical( key_cat, jnp.log(self.dist_params["priors"]), shape=(num_samples,) ) # (num_samples,) chosen_means = self.dist_params["means"][ component_indices ] # (num_samples, data_dim) var = self.dist_params["variance"] # () sample_keys = jax.random.split(key_norm, num_samples) # (num_samples, ) def sample_one(mean: Array, single_key: Array) -> Array: data_dim = mean.shape[0] noise = jax.random.multivariate_normal( single_key, jnp.zeros_like(mean), jnp.eye(data_dim), shape=() ) return mean + jnp.sqrt(var) * noise samples = jax.vmap(sample_one)( chosen_means, sample_keys ) # (num_samples, data_dim) return samples, component_indices
[docs] def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic homoscedastic GMM distribution. This is calculated with respect to the perturbed distribution p_t induced by the `diffusion_process` at time `t`. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. """ return iso_hom_gmm_score( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["variance"], self.dist_params["priors"], )
[docs] def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic homoscedastic GMM distribution. This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. """ return iso_hom_gmm_x0( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["variance"], self.dist_params["priors"], )
[docs] def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the noise prediction ``ε`` for the isotropic homoscedastic GMM distribution. This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` at time ``t`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. """ return iso_hom_gmm_eps( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["variance"], self.dist_params["priors"], )
[docs] def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: """ Computes the velocity vector field ``v`` for the isotropic homoscedastic GMM distribution. This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): The diffusion process definition. Returns: ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. """ return iso_hom_gmm_v( x_t, t, diffusion_process, self.dist_params["means"], self.dist_params["variance"], self.dist_params["priors"], )
[docs] def iso_hom_gmm_x0( x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variance: Array, priors: Array, ) -> Array: """ Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for an isotropic homoscedastic GMM. This implements the closed-form solution for the conditional expectation ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution defined by ``means``, ``variance``, and ``priors``. Args: x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. t (``Array[]``): The time step (scalar). diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. means (``Array[num_components, data_dim]``): Means of the GMM components. variance (``Array[]``): Covariance of the GMM components. priors (``Array[num_components]``): Mixture weights of the GMM components. Returns: ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. """ num_components, data_dim = means.shape alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) variance_t = alpha_t**2 * variance + sigma_t**2 # (,) xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( means_t ) # (num_components, data_dim) variance_t_inv_xbars_t = jax.vmap(lambda xbar_t: xbar_t / variance_t)( xbars_t ) # (num_components, data_dim) log_likelihoods_unnormalized = jax.vmap( lambda xbar_t, variance_t_inv_xbar_t: -(1 / 2) * jnp.sum(xbar_t * variance_t_inv_xbar_t) )(xbars_t, variance_t_inv_xbars_t) # (num_components,) log_posterior_unnormalized = ( jnp.log(priors) + log_likelihoods_unnormalized ) # (num_components,) posterior_probs = jax.nn.softmax( log_posterior_unnormalized, axis=0 ) # (num_components,) sum to 1 posterior_means = jax.vmap( lambda mean, variance_t_inv_xbar_t: mean + alpha_t * variance * variance_t_inv_xbar_t )(means, variance_t_inv_xbars_t) # (num_components, data_dim) x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) return x0_pred
# Generate eps, score, v functions from iso_hom_gmm_x0 iso_hom_gmm_eps, iso_hom_gmm_score, iso_hom_gmm_v = create_gmm_vector_field_fns( iso_hom_gmm_x0 )