API Reference¶
diffusionlab¶
diffusionlab.distributions¶
diffusionlab.distributions.base¶
- class diffusionlab.distributions.base.Distribution(dist_params: Dict[str, Array], dist_hparams: Dict[str, Any])[source]¶
Bases:
objectBase class for all distributions.
This class should be subclassed by other distributions when you want to use ground truth scores, denoisers, noise predictors, or velocity estimators.
Each distribution implementation provides functions to sample from it and compute various vector fields related to a diffusion process, such as denoising (
x0), noise prediction (eps), velocity estimation (v), and score estimation (score).- dist_params¶
Dictionary containing distribution parameters as JAX arrays. Shapes depend on the specific distribution.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary containing distribution hyperparameters (non-array values).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Predict the noise component
εcorresponding to the noisy statex_tat timet, given thediffusion_process.- Parameters:
x_t (
Array[*data_dims]) – The noisy state tensor at timet.t (
Array[]) – The time step.diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The predicted noise
ε.- Return type:
Array[*data_dims]
- get_vector_field(vector_field_type: VectorFieldType) Callable[[Array, Array, DiffusionProcess], Array][source]¶
Get the vector field function of a given type associated with this distribution.
- Parameters:
vector_field_type (
VectorFieldType) – The type of vector field to retrieve (e.g.,VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V).- Returns:
The requested vector field function. It takes the current state
x_t(Array[*data_dims]), timet(Array[]), and thediffusion_processas input and returns the corresponding vector field value (Array[*data_dims]).- Return type:
Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]
- sample(key: Array, num_samples: int) Tuple[Array, Any][source]¶
Sample from the distribution.
- Parameters:
key (
Array) – The JAX PRNG key to use for sampling.num_samples (
int) – The number of samples to draw.
- Returns:
A tuple containing the samples and any additional information.
- Return type:
Tuple[Array[num_samples, *data_dims], Any]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Compute the score function (
∇_x log p_t(x)) of the distribution at timet, given the noisy statex_tand thediffusion_process.- Parameters:
x_t (
Array[*data_dims]) – The noisy state tensor at timet.t (
Array[]) – The time step.diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The score of the distribution at
(x_t, t).- Return type:
Array[*data_dims]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Compute the velocity field
v(x_t, t)corresponding to the noisy statex_tat timet, given thediffusion_process.- Parameters:
x_t (
Array[*data_dims]) – The noisy state tensor at timet.t (
Array[]) – The time step.diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The computed velocity field
v.- Return type:
Array[*data_dims]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Predict the initial state
x0(denoised sample) from the noisy statex_tat timet, given thediffusion_process.- Parameters:
x_t (
Array[*data_dims]) – The noisy state tensor at timet.t (
Array[]) – The time step.diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The predicted initial state
x0.- Return type:
Array[*data_dims]
diffusionlab.distributions.empirical¶
- class diffusionlab.distributions.empirical.EmpiricalDistribution(labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]])[source]¶
Bases:
DistributionAn empirical distribution, i.e., the uniform distribution over a dataset. The probability measure is defined as:
μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)where
x_iis the ith data point in the dataset, andNis the number of data points.This class provides methods for sampling from the empirical distribution and computing various vector fields (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing distribution parameters (currently unused).
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters. It may contain the following keys:
labeled_data(Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can beNoneif the data is unlabelled.
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the noise field
eps(x_t, t)for an empirical distribution w.r.t. a given diffusion process.- Parameters:
x_t (
Array[*data_dims]) – The input tensor.t (
Array[]) – The time tensor.diffusion_process (
DiffusionProcess) – The diffusion process.
- Returns:
The noise field at
(x_t, t).- Return type:
Array[*data_dims]
- sample(key: Array, num_samples: int) Tuple[Array, Array] | Tuple[Array, None][source]¶
Sample from the empirical distribution using reservoir sampling. Assumes all batches in
labeled_dataare consistent: either all have labels (Array) or none have labels (None).- Parameters:
key (
Array) – The JAX PRNG key to use for sampling.num_samples (
int) – The number of samples to draw.
- Returns:
A tuple
(samples, labels)containing the samples and corresponding labels (stacked into anArray), or(samples, None)if the data is unlabelled.- Return type:
Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the score function (
∇_x log p_t(x)) of the empirical distribution at timet, given the noisy statex_tand the diffusion process.- Parameters:
x_t (
Array[*data_dims]) – The noisy state tensor at timet.t (
Array[]) – The time tensor.diffusion_process (
DiffusionProcess) – The diffusion process.
- Returns:
The score of the empirical distribution at
(x_t, t).- Return type:
Array[*data_dims]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the velocity field
v(x_t, t)for an empirical distribution w.r.t. a given diffusion process.- Parameters:
x_t (
Array[*data_dims]) – The input tensor.t (
Array[]) – The time tensor.diffusion_process (
DiffusionProcess) – The diffusion process.
- Returns:
The velocity field at
(x_t, t).- Return type:
Array[*data_dims]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the denoiser
E[x_0 | x_t]for an empirical distribution w.r.t. a given diffusion process.This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of
x_tgiven each sample.- Parameters:
x_t (
Array[*data_dims]) – The input tensor.t (
Array[]) – The time tensor.diffusion_process (
DiffusionProcess) – The diffusion process.
- Returns:
The prediction of
x_0.- Return type:
Array[*data_dims]
diffusionlab.distributions.gmm¶
diffusionlab.distributions.gmm.gmm¶
- class diffusionlab.distributions.gmm.gmm.GMM(means: Array, covs: Array, priors: Array)[source]¶
Bases:
DistributionImplements a Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means(Array[num_components, data_dim]): The means of the GMM components.covs(Array[num_components, data_dim, data_dim]): The covariance matrices of the GMM components.priors(Array[num_components]): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the noise prediction
εfor the GMM distribution.This predicts the noise that was added to the original sample
x_0to obtainx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The noise prediction vector field
εevaluated atx_tandt.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array][source]¶
Draws samples from the GMM distribution.
- Parameters:
key (
Array) – JAX PRNG key for random sampling.num_samples (
int) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the score vector field
(∇_x log p_t(x_t))for the GMM distribution.This is calculated with respect to the perturbed distribution
p_tinduced by thediffusion_processat timet.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_tandt.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the velocity vector field
vfor the GMM distribution.This relates to the conditional velocity
E[dx_t/dt | x_t]under thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The velocity vector field
vevaluated atx_tandt.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for the GMM distribution.This represents the expected original sample
x_0given the noisy observationx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0evaluated atx_tandt.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.gmm.gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, covs: Array, priors: Array) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for a GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]wherex_t ~ N(α_t x_0, σ_t^2 I)andx_0follows the GMM distribution defined bymeans,covs, andpriors.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – Providesα(t)andσ(t).means (
Array[num_components, data_dim]) – Means of the GMM components.covs (
Array[num_components, data_dim, data_dim]) – Covariances of the GMM components.priors (
Array[num_components]) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0evaluated atx_tandt.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.iso_gmm¶
- class diffusionlab.distributions.gmm.iso_gmm.IsoGMM(means: Array, variances: Array, priors: Array)[source]¶
Bases:
DistributionImplements an isotropic Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means(Array[num_components, data_dim]): The means of the GMM components.variances(Array[num_components]): The variances of the GMM components.priors(Array[num_components]): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the noise prediction
εfor the isotropic GMM distribution.This predicts the noise that was added to the original sample
x_0to obtainx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The noise prediction vector field
εevaluated atx_tandt.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array][source]¶
Draws samples from the isotropic GMM distribution.
- Parameters:
key (
Array) – JAX PRNG key for random sampling.num_samples (
int) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the score vector field
∇_x log p_t(x_t)for the isotropic GMM distribution.This is calculated with respect to the perturbed distribution
p_tinduced by thediffusion_processat timet.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_tandt.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the velocity vector field
vfor the isotropic GMM distribution.This is conditional velocity
E[dx_t/dt | x_t]under thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The velocity vector field
vevaluated atx_tandt.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for the isotropic GMM distribution.This represents the expected original sample
x_0given the noisy observationx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0evaluated atx_tandt.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.iso_gmm.iso_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variances: Array, priors: Array) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for a GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]wherex_t ~ N(α_t x_0, σ_t^2 I)andx_0follows the GMM distribution defined bymeans,covs, andpriors.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – Providesα(t)andσ(t).means (
Array[num_components, data_dim]) – Means of the GMM components.variances (
Array[num_components]) – Covariances of the GMM components.priors (
Array[num_components]) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0evaluated atx_tandt.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.iso_hom_gmm¶
- class diffusionlab.distributions.gmm.iso_hom_gmm.IsoHomGMM(means: Array, variance: Array, priors: Array)[source]¶
Bases:
DistributionImplements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means(Array[num_components, data_dim]): The means of the GMM components.variance(Array[]): The variance of the GMM components.priors(Array[num_components]): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the noise prediction
εfor the isotropic homoscedastic GMM distribution.This predicts the noise that was added to the original sample
x_0to obtainx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The noise prediction vector field
εevaluated atx_tandt.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array][source]¶
Draws samples from the isotropic homoscedastic GMM distribution.
- Parameters:
key (
Array) – JAX PRNG key for random sampling.num_samples (
int) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the score vector field
∇_x log p_t(x_t)for the isotropic homoscedastic GMM distribution.This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.
- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_tandt.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the velocity vector field
vfor the isotropic homoscedastic GMM distribution.This is conditional velocity
E[dx_t/dt | x_t]under thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The velocity vector field
vevaluated atx_tandt.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for the isotropic homoscedastic GMM distribution.This represents the expected original sample
x_0given the noisy observationx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0evaluated atx_tandt.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.iso_hom_gmm.iso_hom_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variance: Array, priors: Array) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for an isotropic homoscedastic GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]wherex_t ~ N(α_t x_0, σ_t^2 I)andx_0follows the GMM distribution defined bymeans,variance, andpriors.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – Providesα(t)andσ(t).means (
Array[num_components, data_dim]) – Means of the GMM components.variance (
Array[]) – Covariance of the GMM components.priors (
Array[num_components]) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0evaluated atx_tandt.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.low_rank_gmm¶
- class diffusionlab.distributions.gmm.low_rank_gmm.LowRankGMM(means: Array, cov_factors: Array, priors: Array)[source]¶
Bases:
DistributionImplements a low-rank Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core low-rank GMM parameters.
means(Array[num_components, data_dim]): The means of the GMM components.cov_factors(Array[num_components, data_dim, rank]): The low-rank covariance matrix factors of the GMM components.priors(Array[num_components]): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the noise prediction ε for the low-rank GMM distribution.
This predicts the noise that was added to the original sample
x_0to obtainx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The noise prediction vector field
εevaluated atx_tandt.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array][source]¶
Draws samples from the low-rank GMM distribution.
- Parameters:
key (
Array) – JAX PRNG key for random sampling.num_samples (
int) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the score vector field
∇_x log p_t(x_t)for the low-rank GMM distribution.This is calculated with respect to the perturbed distribution
p_tinduced by thediffusion_processat timet.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_tandt.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the velocity vector field
vfor the low-rank GMM distribution.This is the conditional velocity
E[dx_t/dt | x_t]under thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The velocity vector field
vevaluated atx_tandt.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array[source]¶
Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution.
This represents the expected original sample
x_0given the noisy observationx_tat timetunder thediffusion_process.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0evaluated atx_tandt.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.low_rank_gmm.low_rank_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, cov_factors: Array, priors: Array) Array[source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]for a low-rank GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]wherex_t ~ N(α_t x_0, σ_t^2 I)andx_0follows the low-rank GMM distribution defined bymeans,cov_factors, andpriors.- Parameters:
x_t (
Array[data_dim]) – The noisy state tensor at timet.t (
Array[]) – The time step (scalar).diffusion_process (
DiffusionProcess) – Providesα(t)andσ(t).means (
Array[num_components, data_dim]) – Means of the GMM components.cov_factors (
Array[num_components, data_dim, rank]) – Low-rank covariance matrices of the GMM components.priors (
Array[num_components]) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0evaluated atx_tandt.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.utils¶
- diffusionlab.distributions.gmm.utils.create_gmm_vector_field_fns(x0_fn: Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]) Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]][source]¶
Factory to create eps, score, and v functions from a given x0 function.
- Parameters:
x0_fn – The specific x0 calculation function (e.g.,
gmm_x0,iso_gmm_x0). It must accept(x_t, t, diffusion_process, means, specific_cov, priors).- Returns:
A tuple containing the generated
(eps_fn, score_fn, v_fn). These functions will have the same signature asx0_fn, accepting(x_t, t, diffusion_process, means, specific_cov, priors).- Return type:
Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]
diffusionlab.dynamics¶
- class diffusionlab.dynamics.DiffusionProcess(alpha: Callable[[Array], Array], sigma: Callable[[Array], Array])[source]¶
Bases:
objectBase class for implementing various diffusion processes.
A diffusion process defines how data evolves over time when noise is added according to specific dynamics operating on scalar time inputs. This class provides a framework to implement diffusion processes based on a schedule defined by
α(t)andσ(t).The diffusion is parameterized by two scalar functions of scalar time
t:α(t): Controls how much of the original signal is preserved at timet.σ(t): Controls how much noise is added at timet.
The forward process for a single data point
x_0is defined as:x_t = α(t) * x_0 + σ(t) * εwhere:
x_0is the original data (Array[*data_dims])x_tis the noised data at timet(Array[*data_dims])εis random noise sampled from a standard Gaussian distribution (Array[*data_dims])tis the scalar diffusion time parameter (Array[])
- alpha¶
Function mapping scalar time
t-> scalar signal coefficientα(t).- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t-> scalar noise coefficientσ(t).- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
αw.r.t. scalar timet.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σw.r.t. scalar timet.- Type:
Callable[[Array[]], Array[]]
- forward(x: Array, t: Array, eps: Array) Array[source]¶
Applies the forward diffusion process to a data tensor
xat timetusing noiseε.Computes
x_t = α(t) * x + σ(t) * ε.- Parameters:
x (
Array[*data_dims]) – The input data tensorx_0.t (
Array[]) – The scalar time parametert.eps (
Array[*data_dims]) – The Gaussian noise tensorε, matching the shape ofx.
- Returns:
The noised data tensor
x_tat timet.- Return type:
Array[*data_dims]
- class diffusionlab.dynamics.FlowMatchingProcess[source]¶
Bases:
DiffusionProcessImplements a diffusion process based on Flow Matching principles.
This process defines dynamics that linearly interpolate between the data distribution at
t=0and a noise distribution (standard Gaussian) att=1.Uses the following scalar dynamics:
α(t) = 1 - tσ(t) = t
Forward process:
x_t = (1 - t) * x_0 + t * ε.- alpha¶
Function mapping scalar time
t-> scalar signal coefficientα(t). Set to1 - t.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t-> scalar noise coefficientσ(t). Set tot.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
αw.r.t. scalar timet. Set to-1.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σw.r.t. scalar timet. Set to1.- Type:
Callable[[Array[]], Array[]]
- class diffusionlab.dynamics.VarianceExplodingProcess(sigma: Callable[[Array], Array])[source]¶
Bases:
DiffusionProcessImplements a Variance Exploding (VE) diffusion process.
In this process, the signal component is constant (
α(t) = 1), while the noise component increases over time according to the providedσ(t)function. The variance of the noised datax_texplodes astincreases.Forward process:
x_t = x_0 + σ(t) * ε.This process uses:
α(t) = 1σ(t) =Provided by the user
- alpha¶
Function mapping scalar time
t-> scalar signal coefficientα(t). Set to 1.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t-> scalar noise coefficientσ(t). Provided by the user.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
αw.r.t. scalar timet. Set to 0.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σw.r.t. scalar timet.- Type:
Callable[[Array[]], Array[]]
- class diffusionlab.dynamics.VariancePreservingProcess[source]¶
Bases:
DiffusionProcessImplements a Variance Preserving (VP) diffusion process, often used in DDPMs.
This process maintains the variance of the noised data
x_tclose to 1 (assumingx_0andεhave unit variance) throughout the diffusion by scaling the signal and noise components appropriately.Uses the following scalar dynamics:
α(t) = sqrt(1 - t²)σ(t) = t
Forward process:
x_t = sqrt(1 - t²) * x_0 + t * ε.- alpha¶
Function mapping scalar time
t-> scalar signal coefficientα(t). Set tosqrt(1 - t²).- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t-> scalar noise coefficientσ(t). Set tot.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
αw.r.t. scalar timet. Set to-t / sqrt(1 - t²).- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σw.r.t. scalar timet. Set to1.- Type:
Callable[[Array[]], Array[]]
diffusionlab.losses¶
- class diffusionlab.losses.DiffusionLoss(diffusion_process: DiffusionProcess, vector_field_type: VectorFieldType, num_noise_draws_per_sample: int)[source]¶
Bases:
objectLoss function for training diffusion models.
This dataclass implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model’s prediction and the target, which depends on the chosen vector field type.
The loss supports different target types:
VectorFieldType.X0: Learn to predict the original clean data x_0VectorFieldType.EPS: Learn to predict the noise component epsVectorFieldType.V: Learn to predict the velocity field vVectorFieldType.SCORE: Not directly supported (raises ValueError)
- diffusion_process¶
The diffusion process defining the forward dynamics
- Type:
DiffusionProcess
- vector_field_type¶
The type of target to learn to estimate via minimizing the loss function.
- Type:
VectorFieldType
- num_noise_draws_per_sample¶
The number of noise draws per sample to use for the batchwise loss.
- Type:
int
- target¶
Function that computes the target based on the specified target_type.
Signature:
(x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]- Type:
Callable[[Array, Array, Array, Array, Array], Array]
- loss(key: Array, vector_field: Callable[[Array, Array], Array], x_0: Array, t: Array) Array[source]¶
Compute the average loss over multiple noise draws for a single data point and time.
This method estimates the expected loss at a given time
tfor a clean data samplex_0. It does this by drawingnum_noise_draws_per_samplenoise vectors (eps), generating the corresponding noisy samplesx_tusing thediffusion_process, predicting the target quantityf_x_tusing the providedvector_field(vmapped internally), and then calculating theprediction_lossfor each noise sample. The final loss is the average over these samples.- Parameters:
key (
Array) – The PRNG key for noise generation.vector_field (
Callable[[Array, Array], Array]) –The vector field function that takes a single noisy data sample
x_tand its corresponding timet, and returns the model’s predictionf_x_t. This function will be vmapped internally over the batch dimension created bynum_noise_draws_per_sample.Signature:
(x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims].x_0 (
Array[*data_dims]) – The original clean data sample.t (
Array[]) – The scalar time parameter.
- Returns:
The scalar loss value, averaged over
num_noise_draws_per_samplenoise instances.- Return type:
Array[]
- prediction_loss(x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array) Array[source]¶
Compute the loss given a prediction and inputs/targets.
This method calculates the mean squared error between the model’s prediction (
f_x_t) and the target value determined by the target_type (self.target).- Parameters:
x_t (
Array[*data_dims]) – The noised data at timet.f_x_t (
Array[*data_dims]) – The model’s prediction at timet.x_0 (
Array[*data_dims]) – The original clean data.eps (
Array[*data_dims]) – The noise used to generatex_t.t (
Array[]) – The scalar time parameter.
- Returns:
The scalar loss value for the given sample.
- Return type:
Array[]
diffusionlab.samplers¶
- class diffusionlab.samplers.DDMSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
SamplerClass for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM) or Denoising Diffusion Implicit Models (DDIM) sampling strategy.
This sampler first converts any given vector field type (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) provided byvector_fieldinto an equivalent x0 prediction using theconvert_vector_field_typeutility. Then, it applies the DDPM (ifuse_stochastic_samplerisTrue) or DDIM (ifuse_stochastic_samplerisFalse) update rule based on this x0 prediction.- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field.
- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field.- Type:
VectorFieldType
- use_stochastic_sampler¶
If
True, uses DDPM (stochastic); otherwise, uses DDIM (deterministic).- Type:
bool
- sample_step¶
The DDPM or DDIM step function.
- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]¶
Get the appropriate DDPM/DDIM sampling step function based on stochasticity.
- Returns:
The DDPM (stochastic) or DDIM (deterministic) step function, which has signature:
(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]- Return type:
Callable[[int, Array, Array, Array], Array]
- class diffusionlab.samplers.EulerMaruyamaSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
SamplerClass for sampling from diffusion models using the first-order Euler-Maruyama integrator for the reverse process SDE/ODE.
This sampler implements the step function based on the Euler-Maruyama discretization of the reverse SDE (if
use_stochastic_sampleris True) or the corresponding probability flow ODE (ifuse_stochastic_sampleris False). It supports all vector field types (VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V).- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field. Takes the current state
x_tand timetas input.- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field.- Type:
VectorFieldType
- use_stochastic_sampler¶
Whether to use a stochastic or deterministic reverse process.
- Type:
bool
- sample_step¶
The specific function used to perform one sampling step. Takes step index
idx, current statex_t, noise arrayzs, and time scheduletsas input. Set during initialization based on the sampler type anduse_stochastic_sampler.- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]¶
Get the appropriate Euler-Maruyama sampling step function based on the vector field type and stochasticity.
- Returns:
The specific Euler-Maruyama step function to use.
Signature:
(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]- Return type:
Callable[[int, Array, Array, Array], Array]
- class diffusionlab.samplers.Sampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
objectBase class for sampling from diffusion models using various vector field types.
A Sampler combines a diffusion process, a vector field prediction function, and a scheduler to generate samples from a trained diffusion model using the reverse process (denoising/sampling).
The sampler supports different vector field types (
VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V) and can perform both stochastic and deterministic sampling based on the subclass implementation and the use_stochastic_sampler` flag.- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field. Takes the current state
x_tand timetas input.- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field.- Type:
VectorFieldType
- use_stochastic_sampler¶
Whether to use a stochastic or deterministic reverse process.
- Type:
bool
- sample_step¶
The specific function used to perform one sampling step. Takes step index
idx, current statex_t, noise arrayzs, and time scheduletsas input. Set during initialization based on the sampler type anduse_stochastic_sampler.- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array][source]¶
Abstract method to get the appropriate sampling step function.
Subclasses must implement this method to return the specific function used for performing one step of the reverse process, based on the sampler’s implementation details (e.g., integrator type) and the
use_stochastic_samplerflag.- Returns:
The sampling step function, which has signature:
(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]- Return type:
Callable[[int, Array, Array, Array], Array]
- sample(x_init: Array, zs: Array, ts: Array) Array[source]¶
Sample from the model using the reverse diffusion process.
This method generates a final sample by iteratively applying the
sample_stepfunction, starting from an initial statex_initand using the provided noisezsand time schedulets.- Parameters:
x_init (
Array[*data_dims]) – The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution atts[0]).zs (
Array[num_steps, *data_dims]) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.ts (
Array[num_steps+1]) – The time schedule for sampling. A sorted decreasing array of times fromt_maxtot_min.
- Returns:
The generated sample at the final time
ts[-1].- Return type:
Array[*data_dims]
- sample_trajectory(x_init: Array, zs: Array, ts: Array) Array[source]¶
Sample a trajectory from the model using the reverse diffusion process.
This method generates the entire trajectory of intermediate samples by iteratively applying the
sample_stepfunction.- Parameters:
x_init (
Array[*data_dims]) – The initial noisy tensor from which to start sampling (at timets[0]).zs (
Array[num_steps, *data_dims]) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.ts (
Array[num_steps+1]) – The time schedule for sampling. A sorted decreasing array of times fromt_maxtot_min.
- Returns:
The complete generated trajectory including the initial state
x_init.- Return type:
Array[num_steps+1, *data_dims]
diffusionlab.schedulers¶
- class diffusionlab.schedulers.Scheduler[source]¶
Bases:
objectBase class for time step schedulers used in diffusion, denoising, and sampling.
Allows for extensible scheduler implementations where subclasses can define their own initialization and time step generation parameters via **kwargs.
- get_ts(**ts_hparams: Any) Array[source]¶
Generate the sequence of time steps.
This is an abstract method that must be implemented by subclasses. Subclasses should define the specific keyword arguments they expect within
**ts_hparams.- Parameters:
**ts_hparams (
Dict[str, Any]) – Keyword arguments containing parameters for generating time steps.- Returns:
A tensor containing the sequence of time steps in descending order.
- Return type:
Array- Raises:
NotImplementedError – If the subclass does not implement this method.
KeyError – If a required parameter is missing in
**ts_hparams(in subclass).
- class diffusionlab.schedulers.UniformScheduler[source]¶
Bases:
SchedulerA scheduler that generates uniformly spaced time steps.
Requires
t_min,t_max, andnum_stepsto be passed to theget_tsmethod via keyword arguments. The number of points generated will benum_steps + 1.- get_ts(**ts_hparams: Any) Array[source]¶
Generate uniformly spaced time steps.
- Parameters:
**ts_hparams (
Dict[str, Any]) –Keyword arguments must contain
t_min(float): The minimum time value, typically close to 0.t_max(float): The maximum time value, typically close to 1.num_steps(int): The number of diffusion steps. The function will generatenum_steps + 1time points.
- Returns:
- A JAX array containing uniformly spaced time steps
in descending order (from
t_maxtot_min).
- Return type:
Array[num_steps+1]- Raises:
KeyError – If
t_min,t_max, ornum_stepsis not found ints_hparams.AssertionError – If
t_min/t_maxconstraints are violated ornum_steps< 1.
diffusionlab.vector_fields¶
- class diffusionlab.vector_fields.VectorFieldType(*values)[source]¶
Bases:
EnumEnum representing the type of a vector field. A vector field is a function that takes in
x_t(Array[*data_dims]) andt(Array[]) and returns a vector of the same shape asx_t(Array[*data_dims]).DiffusionLab supports the following vector field types:
VectorFieldType.SCORE: The score of the distribution.VectorFieldType.X0: The denoised state.VectorFieldType.EPS: The noise.VectorFieldType.V: The velocity field.
- diffusionlab.vector_fields.convert_vector_field_type(x: Array, f_x: Array, alpha: Array, sigma: Array, alpha_prime: Array, sigma_prime: Array, in_type: VectorFieldType, out_type: VectorFieldType) Array[source]¶
Converts the output of a vector field from one type to another.
- Parameters:
x (
Array[*data_dims]) – The input tensor.f_x (
Array[*data_dims]) – The output of the vector field f evaluated at x.alpha (
Array[]) – A scalar representing the scale parameter.sigma (
Array[]) – A scalar representing the noise level parameter.alpha_prime (
Array[]) – A scalar representing the scale derivative parameter.sigma_prime (
Array[]) – A scalar representing the noise level derivative parameter.in_type (
VectorFieldType) – The type of the input vector field (e.g.VectorFieldType.SCORE,VectorFieldType.X0,VectorFieldType.EPS,VectorFieldType.V).out_type (
VectorFieldType) – The type of the output vector field.
- Returns:
The converted output of the vector field
- Return type:
Array[*data_dims]