API Reference

diffusionlab

diffusionlab.distributions

diffusionlab.distributions.base

class diffusionlab.distributions.base.Distribution(dist_params: Dict[str, Array], dist_hparams: Dict[str, Any])[source]

Bases: object

Base class for all distributions.

This class should be subclassed by other distributions when you want to use ground truth scores, denoisers, noise predictors, or velocity estimators.

Each distribution implementation provides functions to sample from it and compute various vector fields related to a diffusion process, such as denoising (x0), noise prediction (eps), velocity estimation (v), and score estimation (score).

dist_params

Dictionary containing distribution parameters as JAX arrays. Shapes depend on the specific distribution.

Type:

Dict[str, Array]

dist_hparams

Dictionary containing distribution hyperparameters (non-array values).

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Predict the noise component ε corresponding to the noisy state x_t at time t, given the diffusion_process.

Parameters:
  • x_t (Array[*data_dims]) – The noisy state tensor at time t.

  • t (Array[]) – The time step.

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The predicted noise ε.

Return type:

Array[*data_dims]

get_vector_field(vector_field_type: VectorFieldType) Callable[[Array, Array, DiffusionProcess], Array][source]

Get the vector field function of a given type associated with this distribution.

Parameters:

vector_field_type (VectorFieldType) – The type of vector field to retrieve (e.g., VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V).

Returns:

The requested vector field function. It takes the current state x_t (Array[*data_dims]), time t (Array[]), and the diffusion_process as input and returns the corresponding vector field value (Array[*data_dims]).

Return type:

Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]

sample(key: Array, num_samples: int) Tuple[Array, Any][source]

Sample from the distribution.

Parameters:
  • key (Array) – The JAX PRNG key to use for sampling.

  • num_samples (int) – The number of samples to draw.

Returns:

A tuple containing the samples and any additional information.

Return type:

Tuple[Array[num_samples, *data_dims], Any]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Compute the score function (∇_x log p_t(x)) of the distribution at time t, given the noisy state x_t and the diffusion_process.

Parameters:
  • x_t (Array[*data_dims]) – The noisy state tensor at time t.

  • t (Array[]) – The time step.

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The score of the distribution at (x_t, t).

Return type:

Array[*data_dims]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Compute the velocity field v(x_t, t) corresponding to the noisy state x_t at time t, given the diffusion_process.

Parameters:
  • x_t (Array[*data_dims]) – The noisy state tensor at time t.

  • t (Array[]) – The time step.

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The computed velocity field v.

Return type:

Array[*data_dims]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Predict the initial state x0 (denoised sample) from the noisy state x_t at time t, given the diffusion_process.

Parameters:
  • x_t (Array[*data_dims]) – The noisy state tensor at time t.

  • t (Array[]) – The time step.

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The predicted initial state x0.

Return type:

Array[*data_dims]

diffusionlab.distributions.empirical

class diffusionlab.distributions.empirical.EmpiricalDistribution(labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]])[source]

Bases: Distribution

An empirical distribution, i.e., the uniform distribution over a dataset. The probability measure is defined as:

μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)

where x_i is the ith data point in the dataset, and N is the number of data points.

This class provides methods for sampling from the empirical distribution and computing various vector fields (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) related to the distribution under a given diffusion process.

dist_params

Dictionary containing distribution parameters (currently unused).

Type:

Dict[str, Array]

dist_hparams

Dictionary for storing hyperparameters. It may contain the following keys:

  • labeled_data (Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can be None if the data is unlabelled.

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the noise field eps(x_t, t) for an empirical distribution w.r.t. a given diffusion process.

Parameters:
  • x_t (Array[*data_dims]) – The input tensor.

  • t (Array[]) – The time tensor.

  • diffusion_process (DiffusionProcess) – The diffusion process.

Returns:

The noise field at (x_t, t).

Return type:

Array[*data_dims]

sample(key: Array, num_samples: int) Tuple[Array, Array] | Tuple[Array, None][source]

Sample from the empirical distribution using reservoir sampling. Assumes all batches in labeled_data are consistent: either all have labels (Array) or none have labels (None).

Parameters:
  • key (Array) – The JAX PRNG key to use for sampling.

  • num_samples (int) – The number of samples to draw.

Returns:

A tuple (samples, labels) containing the samples and corresponding labels (stacked into an Array), or (samples, None) if the data is unlabelled.

Return type:

Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the score function (∇_x log p_t(x)) of the empirical distribution at time t, given the noisy state x_t and the diffusion process.

Parameters:
  • x_t (Array[*data_dims]) – The noisy state tensor at time t.

  • t (Array[]) – The time tensor.

  • diffusion_process (DiffusionProcess) – The diffusion process.

Returns:

The score of the empirical distribution at (x_t, t).

Return type:

Array[*data_dims]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the velocity field v(x_t, t) for an empirical distribution w.r.t. a given diffusion process.

Parameters:
  • x_t (Array[*data_dims]) – The input tensor.

  • t (Array[]) – The time tensor.

  • diffusion_process (DiffusionProcess) – The diffusion process.

Returns:

The velocity field at (x_t, t).

Return type:

Array[*data_dims]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the denoiser E[x_0 | x_t] for an empirical distribution w.r.t. a given diffusion process.

This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of x_t given each sample.

Parameters:
  • x_t (Array[*data_dims]) – The input tensor.

  • t (Array[]) – The time tensor.

  • diffusion_process (DiffusionProcess) – The diffusion process.

Returns:

The prediction of x_0.

Return type:

Array[*data_dims]

diffusionlab.distributions.gmm

diffusionlab.distributions.gmm.gmm

class diffusionlab.distributions.gmm.gmm.GMM(means: Array, covs: Array, priors: Array)[source]

Bases: Distribution

Implements a Gaussian Mixture Model (GMM) distribution.

The probability measure is given by:

μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])

This class provides methods for sampling from the GMM and computing various vector fields (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) related to the distribution under a given diffusion process.

dist_params

Dictionary containing the core GMM parameters.

  • means (Array[num_components, data_dim]): The means of the GMM components.

  • covs (Array[num_components, data_dim, data_dim]): The covariance matrices of the GMM components.

  • priors (Array[num_components]): The prior probabilities (mixture weights) of the GMM components.

Type:

Dict[str, Array]

dist_hparams

Dictionary for storing hyperparameters (currently unused).

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the noise prediction ε for the GMM distribution.

This predicts the noise that was added to the original sample x_0 to obtain x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The noise prediction vector field ε evaluated at x_t and t.

Return type:

Array[data_dim]

sample(key: Array, num_samples: int) Tuple[Array, Array][source]

Draws samples from the GMM distribution.

Parameters:
  • key (Array) – JAX PRNG key for random sampling.

  • num_samples (int) – The total number of samples to generate.

Returns:

A tuple (samples, component_indices) containing the drawn samples and the index of the GMM component from which each sample was drawn.

Return type:

Tuple[Array[num_samples, data_dim], Array[num_samples]]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the score vector field (∇_x log p_t(x_t)) for the GMM distribution.

This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The score vector field evaluated at x_t and t.

Return type:

Array[data_dim]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the velocity vector field v for the GMM distribution.

This relates to the conditional velocity E[dx_t/dt | x_t] under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The velocity vector field v evaluated at x_t and t.

Return type:

Array[data_dim]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for the GMM distribution.

This represents the expected original sample x_0 given the noisy observation x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The denoised prediction vector field x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.gmm.gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, covs: Array, priors: Array) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for a GMM.

This implements the closed-form solution for the conditional expectation E[x_0 | x_t] where x_t ~ N(α_t x_0, σ_t^2 I) and x_0 follows the GMM distribution defined by means, covs, and priors.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – Provides α(t) and σ(t).

  • means (Array[num_components, data_dim]) – Means of the GMM components.

  • covs (Array[num_components, data_dim, data_dim]) – Covariances of the GMM components.

  • priors (Array[num_components]) – Mixture weights of the GMM components.

Returns:

The denoised prediction x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.iso_gmm

class diffusionlab.distributions.gmm.iso_gmm.IsoGMM(means: Array, variances: Array, priors: Array)[source]

Bases: Distribution

Implements an isotropic Gaussian Mixture Model (GMM) distribution.

The probability measure is given by:

μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)

This class provides methods for sampling from the GMM and computing various vector fields (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) related to the distribution under a given diffusion process.

dist_params

Dictionary containing the core GMM parameters.

  • means (Array[num_components, data_dim]): The means of the GMM components.

  • variances (Array[num_components]): The variances of the GMM components.

  • priors (Array[num_components]): The prior probabilities (mixture weights) of the GMM components.

Type:

Dict[str, Array]

dist_hparams

Dictionary for storing hyperparameters (currently unused).

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the noise prediction ε for the isotropic GMM distribution.

This predicts the noise that was added to the original sample x_0 to obtain x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The noise prediction vector field ε evaluated at x_t and t.

Return type:

Array[data_dim]

sample(key: Array, num_samples: int) Tuple[Array, Array][source]

Draws samples from the isotropic GMM distribution.

Parameters:
  • key (Array) – JAX PRNG key for random sampling.

  • num_samples (int) – The total number of samples to generate.

Returns:

A tuple (samples, component_indices) containing the drawn samples and the index of the GMM component from which each sample was drawn.

Return type:

Tuple[Array[num_samples, data_dim], Array[num_samples]]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the score vector field ∇_x log p_t(x_t) for the isotropic GMM distribution.

This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The score vector field evaluated at x_t and t.

Return type:

Array[data_dim]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the velocity vector field v for the isotropic GMM distribution.

This is conditional velocity E[dx_t/dt | x_t] under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The velocity vector field v evaluated at x_t and t.

Return type:

Array[data_dim]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for the isotropic GMM distribution.

This represents the expected original sample x_0 given the noisy observation x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The denoised prediction vector field x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.iso_gmm.iso_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variances: Array, priors: Array) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for a GMM.

This implements the closed-form solution for the conditional expectation E[x_0 | x_t] where x_t ~ N(α_t x_0, σ_t^2 I) and x_0 follows the GMM distribution defined by means, covs, and priors.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – Provides α(t) and σ(t).

  • means (Array[num_components, data_dim]) – Means of the GMM components.

  • variances (Array[num_components]) – Covariances of the GMM components.

  • priors (Array[num_components]) – Mixture weights of the GMM components.

Returns:

The denoised prediction x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.iso_hom_gmm

class diffusionlab.distributions.gmm.iso_hom_gmm.IsoHomGMM(means: Array, variance: Array, priors: Array)[source]

Bases: Distribution

Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution.

The probability measure is given by:

μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)

This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) related to the distribution under a given diffusion process.

dist_params

Dictionary containing the core GMM parameters.

  • means (Array[num_components, data_dim]): The means of the GMM components.

  • variance (Array[]): The variance of the GMM components.

  • priors (Array[num_components]): The prior probabilities (mixture weights) of the GMM components.

Type:

Dict[str, Array]

dist_hparams

Dictionary for storing hyperparameters (currently unused).

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the noise prediction ε for the isotropic homoscedastic GMM distribution.

This predicts the noise that was added to the original sample x_0 to obtain x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The noise prediction vector field ε evaluated at x_t and t.

Return type:

Array[data_dim]

sample(key: Array, num_samples: int) Tuple[Array, Array][source]

Draws samples from the isotropic homoscedastic GMM distribution.

Parameters:
  • key (Array) – JAX PRNG key for random sampling.

  • num_samples (int) – The total number of samples to generate.

Returns:

A tuple (samples, component_indices) containing the drawn samples and the index of the GMM component from which each sample was drawn.

Return type:

Tuple[Array[num_samples, data_dim], Array[num_samples]]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the score vector field ∇_x log p_t(x_t) for the isotropic homoscedastic GMM distribution.

This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The score vector field evaluated at x_t and t.

Return type:

Array[data_dim]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the velocity vector field v for the isotropic homoscedastic GMM distribution.

This is conditional velocity E[dx_t/dt | x_t] under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The velocity vector field v evaluated at x_t and t.

Return type:

Array[data_dim]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for the isotropic homoscedastic GMM distribution.

This represents the expected original sample x_0 given the noisy observation x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The denoised prediction vector field x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.iso_hom_gmm.iso_hom_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variance: Array, priors: Array) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for an isotropic homoscedastic GMM.

This implements the closed-form solution for the conditional expectation E[x_0 | x_t] where x_t ~ N(α_t x_0, σ_t^2 I) and x_0 follows the GMM distribution defined by means, variance, and priors.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – Provides α(t) and σ(t).

  • means (Array[num_components, data_dim]) – Means of the GMM components.

  • variance (Array[]) – Covariance of the GMM components.

  • priors (Array[num_components]) – Mixture weights of the GMM components.

Returns:

The denoised prediction x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.low_rank_gmm

class diffusionlab.distributions.gmm.low_rank_gmm.LowRankGMM(means: Array, cov_factors: Array, priors: Array)[source]

Bases: Distribution

Implements a low-rank Gaussian Mixture Model (GMM) distribution.

The probability measure is given by:

μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)

This class provides methods for sampling from the GMM and computing various vector fields (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) related to the distribution under a given diffusion process.

dist_params

Dictionary containing the core low-rank GMM parameters.

  • means (Array[num_components, data_dim]): The means of the GMM components.

  • cov_factors (Array[num_components, data_dim, rank]): The low-rank covariance matrix factors of the GMM components.

  • priors (Array[num_components]): The prior probabilities (mixture weights) of the GMM components.

Type:

Dict[str, Array]

dist_hparams

Dictionary for storing hyperparameters (currently unused).

Type:

Dict[str, Any]

eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the noise prediction ε for the low-rank GMM distribution.

This predicts the noise that was added to the original sample x_0 to obtain x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The noise prediction vector field ε evaluated at x_t and t.

Return type:

Array[data_dim]

sample(key: Array, num_samples: int) Tuple[Array, Array][source]

Draws samples from the low-rank GMM distribution.

Parameters:
  • key (Array) – JAX PRNG key for random sampling.

  • num_samples (int) – The total number of samples to generate.

Returns:

A tuple (samples, component_indices) containing the drawn samples and the index of the GMM component from which each sample was drawn.

Return type:

Tuple[Array[num_samples, data_dim], Array[num_samples]]

score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the score vector field ∇_x log p_t(x_t) for the low-rank GMM distribution.

This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The score vector field evaluated at x_t and t.

Return type:

Array[data_dim]

v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the velocity vector field v for the low-rank GMM distribution.

This is the conditional velocity E[dx_t/dt | x_t] under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The velocity vector field v evaluated at x_t and t.

Return type:

Array[data_dim]

x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution.

This represents the expected original sample x_0 given the noisy observation x_t at time t under the diffusion_process.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – The diffusion process definition.

Returns:

The denoised prediction vector field x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.low_rank_gmm.low_rank_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, cov_factors: Array, priors: Array) Array[source]

Computes the denoised prediction x0 = E[x_0 | x_t] for a low-rank GMM.

This implements the closed-form solution for the conditional expectation E[x_0 | x_t] where x_t ~ N(α_t x_0, σ_t^2 I) and x_0 follows the low-rank GMM distribution defined by means, cov_factors, and priors.

Parameters:
  • x_t (Array[data_dim]) – The noisy state tensor at time t.

  • t (Array[]) – The time step (scalar).

  • diffusion_process (DiffusionProcess) – Provides α(t) and σ(t).

  • means (Array[num_components, data_dim]) – Means of the GMM components.

  • cov_factors (Array[num_components, data_dim, rank]) – Low-rank covariance matrices of the GMM components.

  • priors (Array[num_components]) – Mixture weights of the GMM components.

Returns:

The denoised prediction x0 evaluated at x_t and t.

Return type:

Array[data_dim]

diffusionlab.distributions.gmm.utils

diffusionlab.distributions.gmm.utils.create_gmm_vector_field_fns(x0_fn: Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]) Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]][source]

Factory to create eps, score, and v functions from a given x0 function.

Parameters:

x0_fn – The specific x0 calculation function (e.g., gmm_x0, iso_gmm_x0). It must accept (x_t, t, diffusion_process, means, specific_cov, priors).

Returns:

A tuple containing the generated (eps_fn, score_fn, v_fn). These functions will have the same signature as x0_fn, accepting (x_t, t, diffusion_process, means, specific_cov, priors).

Return type:

Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]

diffusionlab.dynamics

class diffusionlab.dynamics.DiffusionProcess(alpha: Callable[[Array], Array], sigma: Callable[[Array], Array])[source]

Bases: object

Base class for implementing various diffusion processes.

A diffusion process defines how data evolves over time when noise is added according to specific dynamics operating on scalar time inputs. This class provides a framework to implement diffusion processes based on a schedule defined by α(t) and σ(t).

The diffusion is parameterized by two scalar functions of scalar time t:

  • α(t): Controls how much of the original signal is preserved at time t.

  • σ(t): Controls how much noise is added at time t.

The forward process for a single data point x_0 is defined as:

x_t = α(t) * x_0 + σ(t) * ε

where:

  • x_0 is the original data (Array[*data_dims])

  • x_t is the noised data at time t (Array[*data_dims])

  • ε is random noise sampled from a standard Gaussian distribution (Array[*data_dims])

  • t is the scalar diffusion time parameter (Array[])

alpha

Function mapping scalar time t -> scalar signal coefficient α(t).

Type:

Callable[[Array[]], Array[]]

sigma

Function mapping scalar time t -> scalar noise coefficient σ(t).

Type:

Callable[[Array[]], Array[]]

alpha_prime

Derivative of α w.r.t. scalar time t.

Type:

Callable[[Array[]], Array[]]

sigma_prime

Derivative of σ w.r.t. scalar time t.

Type:

Callable[[Array[]], Array[]]

forward(x: Array, t: Array, eps: Array) Array[source]

Applies the forward diffusion process to a data tensor x at time t using noise ε.

Computes x_t = α(t) * x + σ(t) * ε.

Parameters:
  • x (Array[*data_dims]) – The input data tensor x_0.

  • t (Array[]) – The scalar time parameter t.

  • eps (Array[*data_dims]) – The Gaussian noise tensor ε, matching the shape of x.

Returns:

The noised data tensor x_t at time t.

Return type:

Array[*data_dims]

class diffusionlab.dynamics.FlowMatchingProcess[source]

Bases: DiffusionProcess

Implements a diffusion process based on Flow Matching principles.

This process defines dynamics that linearly interpolate between the data distribution at t=0 and a noise distribution (standard Gaussian) at t=1.

Uses the following scalar dynamics:

  • α(t) = 1 - t

  • σ(t) = t

Forward process:

x_t = (1 - t) * x_0 + t * ε.

alpha

Function mapping scalar time t -> scalar signal coefficient α(t). Set to 1 - t.

Type:

Callable[[Array[]], Array[]]

sigma

Function mapping scalar time t -> scalar noise coefficient σ(t). Set to t.

Type:

Callable[[Array[]], Array[]]

alpha_prime

Derivative of α w.r.t. scalar time t. Set to -1.

Type:

Callable[[Array[]], Array[]]

sigma_prime

Derivative of σ w.r.t. scalar time t. Set to 1.

Type:

Callable[[Array[]], Array[]]

class diffusionlab.dynamics.VarianceExplodingProcess(sigma: Callable[[Array], Array])[source]

Bases: DiffusionProcess

Implements a Variance Exploding (VE) diffusion process.

In this process, the signal component is constant (α(t) = 1), while the noise component increases over time according to the provided σ(t) function. The variance of the noised data x_t explodes as t increases.

Forward process:

x_t = x_0 + σ(t) * ε.

This process uses:

  • α(t) = 1

  • σ(t) = Provided by the user

alpha

Function mapping scalar time t -> scalar signal coefficient α(t). Set to 1.

Type:

Callable[[Array[]], Array[]]

sigma

Function mapping scalar time t -> scalar noise coefficient σ(t). Provided by the user.

Type:

Callable[[Array[]], Array[]]

alpha_prime

Derivative of α w.r.t. scalar time t. Set to 0.

Type:

Callable[[Array[]], Array[]]

sigma_prime

Derivative of σ w.r.t. scalar time t.

Type:

Callable[[Array[]], Array[]]

class diffusionlab.dynamics.VariancePreservingProcess[source]

Bases: DiffusionProcess

Implements a Variance Preserving (VP) diffusion process, often used in DDPMs.

This process maintains the variance of the noised data x_t close to 1 (assuming x_0 and ε have unit variance) throughout the diffusion by scaling the signal and noise components appropriately.

Uses the following scalar dynamics:

  • α(t) = sqrt(1 - t²)

  • σ(t) = t

Forward process:

x_t = sqrt(1 - t²) * x_0 + t * ε.

alpha

Function mapping scalar time t -> scalar signal coefficient α(t). Set to sqrt(1 - t²).

Type:

Callable[[Array[]], Array[]]

sigma

Function mapping scalar time t -> scalar noise coefficient σ(t). Set to t.

Type:

Callable[[Array[]], Array[]]

alpha_prime

Derivative of α w.r.t. scalar time t. Set to -t / sqrt(1 - t²).

Type:

Callable[[Array[]], Array[]]

sigma_prime

Derivative of σ w.r.t. scalar time t. Set to 1.

Type:

Callable[[Array[]], Array[]]

diffusionlab.losses

class diffusionlab.losses.DiffusionLoss(diffusion_process: DiffusionProcess, vector_field_type: VectorFieldType, num_noise_draws_per_sample: int)[source]

Bases: object

Loss function for training diffusion models.

This dataclass implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model’s prediction and the target, which depends on the chosen vector field type.

The loss supports different target types:

  • VectorFieldType.X0: Learn to predict the original clean data x_0

  • VectorFieldType.EPS: Learn to predict the noise component eps

  • VectorFieldType.V: Learn to predict the velocity field v

  • VectorFieldType.SCORE: Not directly supported (raises ValueError)

diffusion_process

The diffusion process defining the forward dynamics

Type:

DiffusionProcess

vector_field_type

The type of target to learn to estimate via minimizing the loss function.

Type:

VectorFieldType

num_noise_draws_per_sample

The number of noise draws per sample to use for the batchwise loss.

Type:

int

target

Function that computes the target based on the specified target_type.

Signature: (x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]

Type:

Callable[[Array, Array, Array, Array, Array], Array]

loss(key: Array, vector_field: Callable[[Array, Array], Array], x_0: Array, t: Array) Array[source]

Compute the average loss over multiple noise draws for a single data point and time.

This method estimates the expected loss at a given time t for a clean data sample x_0. It does this by drawing num_noise_draws_per_sample noise vectors (eps), generating the corresponding noisy samples x_t using the diffusion_process, predicting the target quantity f_x_t using the provided vector_field (vmapped internally), and then calculating the prediction_loss for each noise sample. The final loss is the average over these samples.

Parameters:
  • key (Array) – The PRNG key for noise generation.

  • vector_field (Callable[[Array, Array], Array]) –

    The vector field function that takes a single noisy data sample x_t and its corresponding time t, and returns the model’s prediction f_x_t. This function will be vmapped internally over the batch dimension created by num_noise_draws_per_sample.

    Signature: (x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims].

  • x_0 (Array[*data_dims]) – The original clean data sample.

  • t (Array[]) – The scalar time parameter.

Returns:

The scalar loss value, averaged over num_noise_draws_per_sample noise instances.

Return type:

Array[]

prediction_loss(x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array) Array[source]

Compute the loss given a prediction and inputs/targets.

This method calculates the mean squared error between the model’s prediction (f_x_t) and the target value determined by the target_type (self.target).

Parameters:
  • x_t (Array[*data_dims]) – The noised data at time t.

  • f_x_t (Array[*data_dims]) – The model’s prediction at time t.

  • x_0 (Array[*data_dims]) – The original clean data.

  • eps (Array[*data_dims]) – The noise used to generate x_t.

  • t (Array[]) – The scalar time parameter.

Returns:

The scalar loss value for the given sample.

Return type:

Array[]

diffusionlab.samplers

class diffusionlab.samplers.DDMSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]

Bases: Sampler

Class for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM) or Denoising Diffusion Implicit Models (DDIM) sampling strategy.

This sampler first converts any given vector field type (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) provided by vector_field into an equivalent x0 prediction using the convert_vector_field_type utility. Then, it applies the DDPM (if use_stochastic_sampler is True) or DDIM (if use_stochastic_sampler is False) update rule based on this x0 prediction.

diffusion_process

The diffusion process defining the forward dynamics.

Type:

DiffusionProcess

vector_field

The function predicting the vector field.

Type:

Callable[[Array[*data_dims], Array[]], Array[*data_dims]]

vector_field_type

The type of the vector field predicted by vector_field.

Type:

VectorFieldType

use_stochastic_sampler

If True, uses DDPM (stochastic); otherwise, uses DDIM (deterministic).

Type:

bool

sample_step

The DDPM or DDIM step function.

Type:

Callable[[int, Array, Array, Array], Array]

get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]

Get the appropriate DDPM/DDIM sampling step function based on stochasticity.

Returns:

The DDPM (stochastic) or DDIM (deterministic) step function, which has signature:

(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]

Return type:

Callable[[int, Array, Array, Array], Array]

class diffusionlab.samplers.EulerMaruyamaSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]

Bases: Sampler

Class for sampling from diffusion models using the first-order Euler-Maruyama integrator for the reverse process SDE/ODE.

This sampler implements the step function based on the Euler-Maruyama discretization of the reverse SDE (if use_stochastic_sampler is True) or the corresponding probability flow ODE (if use_stochastic_sampler is False). It supports all vector field types (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V).

diffusion_process

The diffusion process defining the forward dynamics.

Type:

DiffusionProcess

vector_field

The function predicting the vector field. Takes the current state x_t and time t as input.

Type:

Callable[[Array[*data_dims], Array[]], Array[*data_dims]]

vector_field_type

The type of the vector field predicted by vector_field.

Type:

VectorFieldType

use_stochastic_sampler

Whether to use a stochastic or deterministic reverse process.

Type:

bool

sample_step

The specific function used to perform one sampling step. Takes step index idx, current state x_t, noise array zs, and time schedule ts as input. Set during initialization based on the sampler type and use_stochastic_sampler.

Type:

Callable[[int, Array, Array, Array], Array]

get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]

Get the appropriate Euler-Maruyama sampling step function based on the vector field type and stochasticity.

Returns:

The specific Euler-Maruyama step function to use.

Signature: (idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]

Return type:

Callable[[int, Array, Array, Array], Array]

class diffusionlab.samplers.Sampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]

Bases: object

Base class for sampling from diffusion models using various vector field types.

A Sampler combines a diffusion process, a vector field prediction function, and a scheduler to generate samples from a trained diffusion model using the reverse process (denoising/sampling).

The sampler supports different vector field types (VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V) and can perform both stochastic and deterministic sampling based on the subclass implementation and the use_stochastic_sampler` flag.

diffusion_process

The diffusion process defining the forward dynamics.

Type:

DiffusionProcess

vector_field

The function predicting the vector field. Takes the current state x_t and time t as input.

Type:

Callable[[Array[*data_dims], Array[]], Array[*data_dims]]

vector_field_type

The type of the vector field predicted by vector_field.

Type:

VectorFieldType

use_stochastic_sampler

Whether to use a stochastic or deterministic reverse process.

Type:

bool

sample_step

The specific function used to perform one sampling step. Takes step index idx, current state x_t, noise array zs, and time schedule ts as input. Set during initialization based on the sampler type and use_stochastic_sampler.

Type:

Callable[[int, Array, Array, Array], Array]

get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]

Abstract method to get the appropriate sampling step function.

Subclasses must implement this method to return the specific function used for performing one step of the reverse process, based on the sampler’s implementation details (e.g., integrator type) and the use_stochastic_sampler flag.

Returns:

The sampling step function, which has signature:

(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]

Return type:

Callable[[int, Array, Array, Array], Array]

sample(x_init: Array, zs: Array, ts: Array) Array[source]

Sample from the model using the reverse diffusion process.

This method generates a final sample by iteratively applying the sample_step function, starting from an initial state x_init and using the provided noise zs and time schedule ts.

Parameters:
  • x_init (Array[*data_dims]) – The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution at ts[0]).

  • zs (Array[num_steps, *data_dims]) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.

  • ts (Array[num_steps+1]) – The time schedule for sampling. A sorted decreasing array of times from t_max to t_min.

Returns:

The generated sample at the final time ts[-1].

Return type:

Array[*data_dims]

sample_trajectory(x_init: Array, zs: Array, ts: Array) Array[source]

Sample a trajectory from the model using the reverse diffusion process.

This method generates the entire trajectory of intermediate samples by iteratively applying the sample_step function.

Parameters:
  • x_init (Array[*data_dims]) – The initial noisy tensor from which to start sampling (at time ts[0]).

  • zs (Array[num_steps, *data_dims]) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.

  • ts (Array[num_steps+1]) – The time schedule for sampling. A sorted decreasing array of times from t_max to t_min.

Returns:

The complete generated trajectory including the initial state x_init.

Return type:

Array[num_steps+1, *data_dims]

diffusionlab.schedulers

class diffusionlab.schedulers.Scheduler[source]

Bases: object

Base class for time step schedulers used in diffusion, denoising, and sampling.

Allows for extensible scheduler implementations where subclasses can define their own initialization and time step generation parameters via **kwargs.

get_ts(**ts_hparams: Any) Array[source]

Generate the sequence of time steps.

This is an abstract method that must be implemented by subclasses. Subclasses should define the specific keyword arguments they expect within **ts_hparams.

Parameters:

**ts_hparams (Dict[str, Any]) – Keyword arguments containing parameters for generating time steps.

Returns:

A tensor containing the sequence of time steps in descending order.

Return type:

Array

Raises:
  • NotImplementedError – If the subclass does not implement this method.

  • KeyError – If a required parameter is missing in **ts_hparams (in subclass).

class diffusionlab.schedulers.UniformScheduler[source]

Bases: Scheduler

A scheduler that generates uniformly spaced time steps.

Requires t_min, t_max, and num_steps to be passed to the get_ts method via keyword arguments. The number of points generated will be num_steps + 1.

get_ts(**ts_hparams: Any) Array[source]

Generate uniformly spaced time steps.

Parameters:

**ts_hparams (Dict[str, Any]) –

Keyword arguments must contain

  • t_min (float): The minimum time value, typically close to 0.

  • t_max (float): The maximum time value, typically close to 1.

  • num_steps (int): The number of diffusion steps. The function will generate num_steps + 1 time points.

Returns:

A JAX array containing uniformly spaced time steps

in descending order (from t_max to t_min).

Return type:

Array[num_steps+1]

Raises:
  • KeyError – If t_min, t_max, or num_steps is not found in ts_hparams.

  • AssertionError – If t_min/t_max constraints are violated or num_steps < 1.

diffusionlab.vector_fields

class diffusionlab.vector_fields.VectorFieldType(*values)[source]

Bases: Enum

Enum representing the type of a vector field. A vector field is a function that takes in x_t (Array[*data_dims]) and t (Array[]) and returns a vector of the same shape as x_t (Array[*data_dims]).

DiffusionLab supports the following vector field types:

  • VectorFieldType.SCORE: The score of the distribution.

  • VectorFieldType.X0: The denoised state.

  • VectorFieldType.EPS: The noise.

  • VectorFieldType.V: The velocity field.

diffusionlab.vector_fields.convert_vector_field_type(x: Array, f_x: Array, alpha: Array, sigma: Array, alpha_prime: Array, sigma_prime: Array, in_type: VectorFieldType, out_type: VectorFieldType) Array[source]

Converts the output of a vector field from one type to another.

Parameters:
  • x (Array[*data_dims]) – The input tensor.

  • f_x (Array[*data_dims]) – The output of the vector field f evaluated at x.

  • alpha (Array[]) – A scalar representing the scale parameter.

  • sigma (Array[]) – A scalar representing the noise level parameter.

  • alpha_prime (Array[]) – A scalar representing the scale derivative parameter.

  • sigma_prime (Array[]) – A scalar representing the noise level derivative parameter.

  • in_type (VectorFieldType) – The type of the input vector field (e.g. VectorFieldType.SCORE, VectorFieldType.X0, VectorFieldType.EPS, VectorFieldType.V).

  • out_type (VectorFieldType) – The type of the output vector field.

Returns:

The converted output of the vector field

Return type:

Array[*data_dims]