API Reference¶
diffusionlab
¶
diffusionlab.distributions
¶
diffusionlab.distributions.base
¶
- class diffusionlab.distributions.base.Distribution(dist_params: Dict[str, Array], dist_hparams: Dict[str, Any])[source]¶
Bases:
object
Base class for all distributions.
This class should be subclassed by other distributions when you want to use ground truth scores, denoisers, noise predictors, or velocity estimators.
Each distribution implementation provides functions to sample from it and compute various vector fields related to a diffusion process, such as denoising (
x0
), noise prediction (eps
), velocity estimation (v
), and score estimation (score
).- dist_params¶
Dictionary containing distribution parameters as JAX arrays. Shapes depend on the specific distribution.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary containing distribution hyperparameters (non-array values).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Predict the noise component
ε
corresponding to the noisy statex_t
at timet
, given thediffusion_process
.- Parameters:
x_t (
Array[*data_dims]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step.diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The predicted noise
ε
.- Return type:
Array[*data_dims]
- get_vector_field(vector_field_type: VectorFieldType) Callable[[Array, Array, DiffusionProcess], Array] [source]¶
Get the vector field function of a given type associated with this distribution.
- Parameters:
vector_field_type (
VectorFieldType
) – The type of vector field to retrieve (e.g.,VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
).- Returns:
The requested vector field function. It takes the current state
x_t
(Array[*data_dims]
), timet
(Array[]
), and thediffusion_process
as input and returns the corresponding vector field value (Array[*data_dims]
).- Return type:
Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]
- sample(key: Array, num_samples: int) Tuple[Array, Any] [source]¶
Sample from the distribution.
- Parameters:
key (
Array
) – The JAX PRNG key to use for sampling.num_samples (
int
) – The number of samples to draw.
- Returns:
A tuple containing the samples and any additional information.
- Return type:
Tuple[Array[num_samples, *data_dims], Any]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Compute the score function (
∇_x log p_t(x)
) of the distribution at timet
, given the noisy statex_t
and thediffusion_process
.- Parameters:
x_t (
Array[*data_dims]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step.diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The score of the distribution at
(x_t, t)
.- Return type:
Array[*data_dims]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Compute the velocity field
v(x_t, t)
corresponding to the noisy statex_t
at timet
, given thediffusion_process
.- Parameters:
x_t (
Array[*data_dims]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step.diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The computed velocity field
v
.- Return type:
Array[*data_dims]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Predict the initial state
x0
(denoised sample) from the noisy statex_t
at timet
, given thediffusion_process
.- Parameters:
x_t (
Array[*data_dims]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step.diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The predicted initial state
x0
.- Return type:
Array[*data_dims]
diffusionlab.distributions.empirical
¶
- class diffusionlab.distributions.empirical.EmpiricalDistribution(labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]])[source]¶
Bases:
Distribution
An empirical distribution, i.e., the uniform distribution over a dataset. The probability measure is defined as:
μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)
where
x_i
is the ith data point in the dataset, andN
is the number of data points.This class provides methods for sampling from the empirical distribution and computing various vector fields (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing distribution parameters (currently unused).
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters. It may contain the following keys:
labeled_data
(Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]
): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can beNone
if the data is unlabelled.
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the noise field
eps(x_t, t)
for an empirical distribution w.r.t. a given diffusion process.- Parameters:
x_t (
Array[*data_dims]
) – The input tensor.t (
Array[]
) – The time tensor.diffusion_process (
DiffusionProcess
) – The diffusion process.
- Returns:
The noise field at
(x_t, t)
.- Return type:
Array[*data_dims]
- sample(key: Array, num_samples: int) Tuple[Array, Array] | Tuple[Array, None] [source]¶
Sample from the empirical distribution using reservoir sampling. Assumes all batches in
labeled_data
are consistent: either all have labels (Array
) or none have labels (None
).- Parameters:
key (
Array
) – The JAX PRNG key to use for sampling.num_samples (
int
) – The number of samples to draw.
- Returns:
A tuple
(samples, labels)
containing the samples and corresponding labels (stacked into anArray
), or(samples, None)
if the data is unlabelled.- Return type:
Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the score function (
∇_x log p_t(x)
) of the empirical distribution at timet
, given the noisy statex_t
and the diffusion process.- Parameters:
x_t (
Array[*data_dims]
) – The noisy state tensor at timet
.t (
Array[]
) – The time tensor.diffusion_process (
DiffusionProcess
) – The diffusion process.
- Returns:
The score of the empirical distribution at
(x_t, t)
.- Return type:
Array[*data_dims]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the velocity field
v(x_t, t)
for an empirical distribution w.r.t. a given diffusion process.- Parameters:
x_t (
Array[*data_dims]
) – The input tensor.t (
Array[]
) – The time tensor.diffusion_process (
DiffusionProcess
) – The diffusion process.
- Returns:
The velocity field at
(x_t, t)
.- Return type:
Array[*data_dims]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the denoiser
E[x_0 | x_t]
for an empirical distribution w.r.t. a given diffusion process.This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of
x_t
given each sample.- Parameters:
x_t (
Array[*data_dims]
) – The input tensor.t (
Array[]
) – The time tensor.diffusion_process (
DiffusionProcess
) – The diffusion process.
- Returns:
The prediction of
x_0
.- Return type:
Array[*data_dims]
diffusionlab.distributions.gmm
¶
diffusionlab.distributions.gmm.gmm
¶
- class diffusionlab.distributions.gmm.gmm.GMM(means: Array, covs: Array, priors: Array)[source]¶
Bases:
Distribution
Implements a Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], covs[i])
This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means
(Array[num_components, data_dim]
): The means of the GMM components.covs
(Array[num_components, data_dim, data_dim]
): The covariance matrices of the GMM components.priors
(Array[num_components]
): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the noise prediction
ε
for the GMM distribution.This predicts the noise that was added to the original sample
x_0
to obtainx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The noise prediction vector field
ε
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array] [source]¶
Draws samples from the GMM distribution.
- Parameters:
key (
Array
) – JAX PRNG key for random sampling.num_samples (
int
) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)
containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the score vector field
(∇_x log p_t(x_t))
for the GMM distribution.This is calculated with respect to the perturbed distribution
p_t
induced by thediffusion_process
at timet
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_t
andt
.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the velocity vector field
v
for the GMM distribution.This relates to the conditional velocity
E[dx_t/dt | x_t]
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The velocity vector field
v
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for the GMM distribution.This represents the expected original sample
x_0
given the noisy observationx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.gmm.gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, covs: Array, priors: Array) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for a GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]
wherex_t ~ N(α_t x_0, σ_t^2 I)
andx_0
follows the GMM distribution defined bymeans
,covs
, andpriors
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – Providesα(t)
andσ(t)
.means (
Array[num_components, data_dim]
) – Means of the GMM components.covs (
Array[num_components, data_dim, data_dim]
) – Covariances of the GMM components.priors (
Array[num_components]
) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.iso_gmm
¶
- class diffusionlab.distributions.gmm.iso_gmm.IsoGMM(means: Array, variances: Array, priors: Array)[source]¶
Bases:
Distribution
Implements an isotropic Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)
This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means
(Array[num_components, data_dim]
): The means of the GMM components.variances
(Array[num_components]
): The variances of the GMM components.priors
(Array[num_components]
): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the noise prediction
ε
for the isotropic GMM distribution.This predicts the noise that was added to the original sample
x_0
to obtainx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The noise prediction vector field
ε
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array] [source]¶
Draws samples from the isotropic GMM distribution.
- Parameters:
key (
Array
) – JAX PRNG key for random sampling.num_samples (
int
) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)
containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the score vector field
∇_x log p_t(x_t)
for the isotropic GMM distribution.This is calculated with respect to the perturbed distribution
p_t
induced by thediffusion_process
at timet
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_t
andt
.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the velocity vector field
v
for the isotropic GMM distribution.This is conditional velocity
E[dx_t/dt | x_t]
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The velocity vector field
v
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for the isotropic GMM distribution.This represents the expected original sample
x_0
given the noisy observationx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.iso_gmm.iso_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variances: Array, priors: Array) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for a GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]
wherex_t ~ N(α_t x_0, σ_t^2 I)
andx_0
follows the GMM distribution defined bymeans
,covs
, andpriors
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – Providesα(t)
andσ(t)
.means (
Array[num_components, data_dim]
) – Means of the GMM components.variances (
Array[num_components]
) – Covariances of the GMM components.priors (
Array[num_components]
) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.iso_hom_gmm
¶
- class diffusionlab.distributions.gmm.iso_hom_gmm.IsoHomGMM(means: Array, variance: Array, priors: Array)[source]¶
Bases:
Distribution
Implements an isotropic homoscedastic Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variance * I)
This class provides methods for sampling from the isotropic homoscedastic GMM and computing various vector fields (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core GMM parameters.
means
(Array[num_components, data_dim]
): The means of the GMM components.variance
(Array[]
): The variance of the GMM components.priors
(Array[num_components]
): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the noise prediction
ε
for the isotropic homoscedastic GMM distribution.This predicts the noise that was added to the original sample
x_0
to obtainx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The noise prediction vector field
ε
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array] [source]¶
Draws samples from the isotropic homoscedastic GMM distribution.
- Parameters:
key (
Array
) – JAX PRNG key for random sampling.num_samples (
int
) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)
containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the score vector field
∇_x log p_t(x_t)
for the isotropic homoscedastic GMM distribution.This is calculated with respect to the perturbed distribution p_t induced by the diffusion_process at time t.
- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_t
andt
.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the velocity vector field
v
for the isotropic homoscedastic GMM distribution.This is conditional velocity
E[dx_t/dt | x_t]
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The velocity vector field
v
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for the isotropic homoscedastic GMM distribution.This represents the expected original sample
x_0
given the noisy observationx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.iso_hom_gmm.iso_hom_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, variance: Array, priors: Array) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for an isotropic homoscedastic GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]
wherex_t ~ N(α_t x_0, σ_t^2 I)
andx_0
follows the GMM distribution defined bymeans
,variance
, andpriors
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – Providesα(t)
andσ(t)
.means (
Array[num_components, data_dim]
) – Means of the GMM components.variance (
Array[]
) – Covariance of the GMM components.priors (
Array[num_components]
) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.low_rank_gmm
¶
- class diffusionlab.distributions.gmm.low_rank_gmm.LowRankGMM(means: Array, cov_factors: Array, priors: Array)[source]¶
Bases:
Distribution
Implements a low-rank Gaussian Mixture Model (GMM) distribution.
The probability measure is given by:
μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)
This class provides methods for sampling from the GMM and computing various vector fields (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) related to the distribution under a given diffusion process.- dist_params¶
Dictionary containing the core low-rank GMM parameters.
means
(Array[num_components, data_dim]
): The means of the GMM components.cov_factors
(Array[num_components, data_dim, rank]
): The low-rank covariance matrix factors of the GMM components.priors
(Array[num_components]
): The prior probabilities (mixture weights) of the GMM components.
- Type:
Dict[str, Array]
- dist_hparams¶
Dictionary for storing hyperparameters (currently unused).
- Type:
Dict[str, Any]
- eps(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the noise prediction ε for the low-rank GMM distribution.
This predicts the noise that was added to the original sample
x_0
to obtainx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The noise prediction vector field
ε
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- sample(key: Array, num_samples: int) Tuple[Array, Array] [source]¶
Draws samples from the low-rank GMM distribution.
- Parameters:
key (
Array
) – JAX PRNG key for random sampling.num_samples (
int
) – The total number of samples to generate.
- Returns:
A tuple
(samples, component_indices)
containing the drawn samples and the index of the GMM component from which each sample was drawn.- Return type:
Tuple[Array[num_samples, data_dim], Array[num_samples]]
- score(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the score vector field
∇_x log p_t(x_t)
for the low-rank GMM distribution.This is calculated with respect to the perturbed distribution
p_t
induced by thediffusion_process
at timet
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The score vector field evaluated at
x_t
andt
.- Return type:
Array[data_dim]
- v(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the velocity vector field
v
for the low-rank GMM distribution.This is the conditional velocity
E[dx_t/dt | x_t]
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The velocity vector field
v
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess) Array [source]¶
Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution.
This represents the expected original sample
x_0
given the noisy observationx_t
at timet
under thediffusion_process
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – The diffusion process definition.
- Returns:
The denoised prediction vector field
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
- diffusionlab.distributions.gmm.low_rank_gmm.low_rank_gmm_x0(x_t: Array, t: Array, diffusion_process: DiffusionProcess, means: Array, cov_factors: Array, priors: Array) Array [source]¶
Computes the denoised prediction
x0 = E[x_0 | x_t]
for a low-rank GMM.This implements the closed-form solution for the conditional expectation
E[x_0 | x_t]
wherex_t ~ N(α_t x_0, σ_t^2 I)
andx_0
follows the low-rank GMM distribution defined bymeans
,cov_factors
, andpriors
.- Parameters:
x_t (
Array[data_dim]
) – The noisy state tensor at timet
.t (
Array[]
) – The time step (scalar).diffusion_process (
DiffusionProcess
) – Providesα(t)
andσ(t)
.means (
Array[num_components, data_dim]
) – Means of the GMM components.cov_factors (
Array[num_components, data_dim, rank]
) – Low-rank covariance matrices of the GMM components.priors (
Array[num_components]
) – Mixture weights of the GMM components.
- Returns:
The denoised prediction
x0
evaluated atx_t
andt
.- Return type:
Array[data_dim]
diffusionlab.distributions.gmm.utils
¶
- diffusionlab.distributions.gmm.utils.create_gmm_vector_field_fns(x0_fn: Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]) Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]] [source]¶
Factory to create eps, score, and v functions from a given x0 function.
- Parameters:
x0_fn – The specific x0 calculation function (e.g.,
gmm_x0
,iso_gmm_x0
). It must accept(x_t, t, diffusion_process, means, specific_cov, priors)
.- Returns:
A tuple containing the generated
(eps_fn, score_fn, v_fn)
. These functions will have the same signature asx0_fn
, accepting(x_t, t, diffusion_process, means, specific_cov, priors)
.- Return type:
Tuple[Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array], Callable[[Array, Array, DiffusionProcess, Array, Array, Array], Array]]
diffusionlab.dynamics
¶
- class diffusionlab.dynamics.DiffusionProcess(alpha: Callable[[Array], Array], sigma: Callable[[Array], Array])[source]¶
Bases:
object
Base class for implementing various diffusion processes.
A diffusion process defines how data evolves over time when noise is added according to specific dynamics operating on scalar time inputs. This class provides a framework to implement diffusion processes based on a schedule defined by
α(t)
andσ(t)
.The diffusion is parameterized by two scalar functions of scalar time
t
:α(t)
: Controls how much of the original signal is preserved at timet
.σ(t)
: Controls how much noise is added at timet
.
The forward process for a single data point
x_0
is defined as:x_t = α(t) * x_0 + σ(t) * ε
where:
x_0
is the original data (Array[*data_dims]
)x_t
is the noised data at timet
(Array[*data_dims]
)ε
is random noise sampled from a standard Gaussian distribution (Array[*data_dims]
)t
is the scalar diffusion time parameter (Array[]
)
- alpha¶
Function mapping scalar time
t
-> scalar signal coefficientα(t)
.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t
-> scalar noise coefficientσ(t)
.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
α
w.r.t. scalar timet
.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σ
w.r.t. scalar timet
.- Type:
Callable[[Array[]], Array[]]
- forward(x: Array, t: Array, eps: Array) Array [source]¶
Applies the forward diffusion process to a data tensor
x
at timet
using noiseε
.Computes
x_t = α(t) * x + σ(t) * ε
.- Parameters:
x (
Array[*data_dims]
) – The input data tensorx_0
.t (
Array[]
) – The scalar time parametert
.eps (
Array[*data_dims]
) – The Gaussian noise tensorε
, matching the shape ofx
.
- Returns:
The noised data tensor
x_t
at timet
.- Return type:
Array[*data_dims]
- class diffusionlab.dynamics.FlowMatchingProcess[source]¶
Bases:
DiffusionProcess
Implements a diffusion process based on Flow Matching principles.
This process defines dynamics that linearly interpolate between the data distribution at
t=0
and a noise distribution (standard Gaussian) att=1
.Uses the following scalar dynamics:
α(t) = 1 - t
σ(t) = t
Forward process:
x_t = (1 - t) * x_0 + t * ε
.- alpha¶
Function mapping scalar time
t
-> scalar signal coefficientα(t)
. Set to1 - t
.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t
-> scalar noise coefficientσ(t)
. Set tot
.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
α
w.r.t. scalar timet
. Set to-1
.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σ
w.r.t. scalar timet
. Set to1
.- Type:
Callable[[Array[]], Array[]]
- class diffusionlab.dynamics.VarianceExplodingProcess(sigma: Callable[[Array], Array])[source]¶
Bases:
DiffusionProcess
Implements a Variance Exploding (VE) diffusion process.
In this process, the signal component is constant (
α(t) = 1
), while the noise component increases over time according to the providedσ(t)
function. The variance of the noised datax_t
explodes ast
increases.Forward process:
x_t = x_0 + σ(t) * ε
.This process uses:
α(t) = 1
σ(t) =
Provided by the user
- alpha¶
Function mapping scalar time
t
-> scalar signal coefficientα(t)
. Set to 1.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t
-> scalar noise coefficientσ(t)
. Provided by the user.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
α
w.r.t. scalar timet
. Set to 0.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σ
w.r.t. scalar timet
.- Type:
Callable[[Array[]], Array[]]
- class diffusionlab.dynamics.VariancePreservingProcess[source]¶
Bases:
DiffusionProcess
Implements a Variance Preserving (VP) diffusion process, often used in DDPMs.
This process maintains the variance of the noised data
x_t
close to 1 (assumingx_0
andε
have unit variance) throughout the diffusion by scaling the signal and noise components appropriately.Uses the following scalar dynamics:
α(t) = sqrt(1 - t²)
σ(t) = t
Forward process:
x_t = sqrt(1 - t²) * x_0 + t * ε
.- alpha¶
Function mapping scalar time
t
-> scalar signal coefficientα(t)
. Set tosqrt(1 - t²)
.- Type:
Callable[[Array[]], Array[]]
- sigma¶
Function mapping scalar time
t
-> scalar noise coefficientσ(t)
. Set tot
.- Type:
Callable[[Array[]], Array[]]
- alpha_prime¶
Derivative of
α
w.r.t. scalar timet
. Set to-t / sqrt(1 - t²)
.- Type:
Callable[[Array[]], Array[]]
- sigma_prime¶
Derivative of
σ
w.r.t. scalar timet
. Set to1
.- Type:
Callable[[Array[]], Array[]]
diffusionlab.losses
¶
- class diffusionlab.losses.DiffusionLoss(diffusion_process: DiffusionProcess, vector_field_type: VectorFieldType, num_noise_draws_per_sample: int)[source]¶
Bases:
object
Loss function for training diffusion models.
This dataclass implements various loss functions for diffusion models based on the specified target type. The loss is computed as the mean squared error between the model’s prediction and the target, which depends on the chosen vector field type.
The loss supports different target types:
VectorFieldType.X0
: Learn to predict the original clean data x_0VectorFieldType.EPS
: Learn to predict the noise component epsVectorFieldType.V
: Learn to predict the velocity field vVectorFieldType.SCORE
: Not directly supported (raises ValueError)
- diffusion_process¶
The diffusion process defining the forward dynamics
- Type:
DiffusionProcess
- vector_field_type¶
The type of target to learn to estimate via minimizing the loss function.
- Type:
VectorFieldType
- num_noise_draws_per_sample¶
The number of noise draws per sample to use for the batchwise loss.
- Type:
int
- target¶
Function that computes the target based on the specified target_type.
Signature:
(x_t: Array[*data_dims], f_x_t: Array[*data_dims], x_0: Array[*data_dims], eps: Array[*data_dims], t: Array[]) -> Array[*data_dims]
- Type:
Callable[[Array, Array, Array, Array, Array], Array]
- loss(key: Array, vector_field: Callable[[Array, Array], Array], x_0: Array, t: Array) Array [source]¶
Compute the average loss over multiple noise draws for a single data point and time.
This method estimates the expected loss at a given time
t
for a clean data samplex_0
. It does this by drawingnum_noise_draws_per_sample
noise vectors (eps
), generating the corresponding noisy samplesx_t
using thediffusion_process
, predicting the target quantityf_x_t
using the providedvector_field
(vmapped internally), and then calculating theprediction_loss
for each noise sample. The final loss is the average over these samples.- Parameters:
key (
Array
) – The PRNG key for noise generation.vector_field (
Callable[[Array, Array], Array]
) –The vector field function that takes a single noisy data sample
x_t
and its corresponding timet
, and returns the model’s predictionf_x_t
. This function will be vmapped internally over the batch dimension created bynum_noise_draws_per_sample
.Signature:
(x_t: Array[*data_dims], t: Array[]) -> Array[*data_dims]
.x_0 (
Array[*data_dims]
) – The original clean data sample.t (
Array[]
) – The scalar time parameter.
- Returns:
The scalar loss value, averaged over
num_noise_draws_per_sample
noise instances.- Return type:
Array[]
- prediction_loss(x_t: Array, f_x_t: Array, x_0: Array, eps: Array, t: Array) Array [source]¶
Compute the loss given a prediction and inputs/targets.
This method calculates the mean squared error between the model’s prediction (
f_x_t
) and the target value determined by the target_type (self.target
).- Parameters:
x_t (
Array[*data_dims]
) – The noised data at timet
.f_x_t (
Array[*data_dims]
) – The model’s prediction at timet
.x_0 (
Array[*data_dims]
) – The original clean data.eps (
Array[*data_dims]
) – The noise used to generatex_t
.t (
Array[]
) – The scalar time parameter.
- Returns:
The scalar loss value for the given sample.
- Return type:
Array[]
diffusionlab.samplers
¶
- class diffusionlab.samplers.DDMSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
Sampler
Class for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM) or Denoising Diffusion Implicit Models (DDIM) sampling strategy.
This sampler first converts any given vector field type (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) provided byvector_field
into an equivalent x0 prediction using theconvert_vector_field_type
utility. Then, it applies the DDPM (ifuse_stochastic_sampler
isTrue
) or DDIM (ifuse_stochastic_sampler
isFalse
) update rule based on this x0 prediction.- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field.
- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field
.- Type:
VectorFieldType
- use_stochastic_sampler¶
If
True
, uses DDPM (stochastic); otherwise, uses DDIM (deterministic).- Type:
bool
- sample_step¶
The DDPM or DDIM step function.
- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array] [source]¶
Get the appropriate DDPM/DDIM sampling step function based on stochasticity.
- Returns:
The DDPM (stochastic) or DDIM (deterministic) step function, which has signature:
(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]
- Return type:
Callable[[int, Array, Array, Array], Array]
- class diffusionlab.samplers.EulerMaruyamaSampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
Sampler
Class for sampling from diffusion models using the first-order Euler-Maruyama integrator for the reverse process SDE/ODE.
This sampler implements the step function based on the Euler-Maruyama discretization of the reverse SDE (if
use_stochastic_sampler
is True) or the corresponding probability flow ODE (ifuse_stochastic_sampler
is False). It supports all vector field types (VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
).- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field. Takes the current state
x_t
and timet
as input.- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field
.- Type:
VectorFieldType
- use_stochastic_sampler¶
Whether to use a stochastic or deterministic reverse process.
- Type:
bool
- sample_step¶
The specific function used to perform one sampling step. Takes step index
idx
, current statex_t
, noise arrayzs
, and time schedulets
as input. Set during initialization based on the sampler type anduse_stochastic_sampler
.- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array] [source]¶
Get the appropriate Euler-Maruyama sampling step function based on the vector field type and stochasticity.
- Returns:
The specific Euler-Maruyama step function to use.
Signature:
(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]
- Return type:
Callable[[int, Array, Array, Array], Array]
- class diffusionlab.samplers.Sampler(diffusion_process: DiffusionProcess, vector_field: Callable[[Array, Array], Array], vector_field_type: VectorFieldType, use_stochastic_sampler: bool)[source]¶
Bases:
object
Base class for sampling from diffusion models using various vector field types.
A Sampler combines a diffusion process, a vector field prediction function, and a scheduler to generate samples from a trained diffusion model using the reverse process (denoising/sampling).
The sampler supports different vector field types (
VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
) and can perform both stochastic and deterministic sampling based on the subclass implementation and the use_stochastic_sampler` flag.- diffusion_process¶
The diffusion process defining the forward dynamics.
- Type:
DiffusionProcess
- vector_field¶
The function predicting the vector field. Takes the current state
x_t
and timet
as input.- Type:
Callable[[Array[*data_dims], Array[]], Array[*data_dims]]
- vector_field_type¶
The type of the vector field predicted by
vector_field
.- Type:
VectorFieldType
- use_stochastic_sampler¶
Whether to use a stochastic or deterministic reverse process.
- Type:
bool
- sample_step¶
The specific function used to perform one sampling step. Takes step index
idx
, current statex_t
, noise arrayzs
, and time schedulets
as input. Set during initialization based on the sampler type anduse_stochastic_sampler
.- Type:
Callable[[int, Array, Array, Array], Array]
- get_sample_step_function() Callable[[int, Array, Array, Array], Array] [source]¶
Abstract method to get the appropriate sampling step function.
Subclasses must implement this method to return the specific function used for performing one step of the reverse process, based on the sampler’s implementation details (e.g., integrator type) and the
use_stochastic_sampler
flag.- Returns:
The sampling step function, which has signature:
(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]
- Return type:
Callable[[int, Array, Array, Array], Array]
- sample(x_init: Array, zs: Array, ts: Array) Array [source]¶
Sample from the model using the reverse diffusion process.
This method generates a final sample by iteratively applying the
sample_step
function, starting from an initial statex_init
and using the provided noisezs
and time schedulets
.- Parameters:
x_init (
Array[*data_dims]
) – The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution atts[0]
).zs (
Array[num_steps, *data_dims]
) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.ts (
Array[num_steps+1]
) – The time schedule for sampling. A sorted decreasing array of times fromt_max
tot_min
.
- Returns:
The generated sample at the final time
ts[-1]
.- Return type:
Array[*data_dims]
- sample_trajectory(x_init: Array, zs: Array, ts: Array) Array [source]¶
Sample a trajectory from the model using the reverse diffusion process.
This method generates the entire trajectory of intermediate samples by iteratively applying the
sample_step
function.- Parameters:
x_init (
Array[*data_dims]
) – The initial noisy tensor from which to start sampling (at timets[0]
).zs (
Array[num_steps, *data_dims]
) – The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.ts (
Array[num_steps+1]
) – The time schedule for sampling. A sorted decreasing array of times fromt_max
tot_min
.
- Returns:
The complete generated trajectory including the initial state
x_init
.- Return type:
Array[num_steps+1, *data_dims]
diffusionlab.schedulers
¶
- class diffusionlab.schedulers.Scheduler[source]¶
Bases:
object
Base class for time step schedulers used in diffusion, denoising, and sampling.
Allows for extensible scheduler implementations where subclasses can define their own initialization and time step generation parameters via **kwargs.
- get_ts(**ts_hparams: Any) Array [source]¶
Generate the sequence of time steps.
This is an abstract method that must be implemented by subclasses. Subclasses should define the specific keyword arguments they expect within
**ts_hparams
.- Parameters:
**ts_hparams (
Dict[str, Any]
) – Keyword arguments containing parameters for generating time steps.- Returns:
A tensor containing the sequence of time steps in descending order.
- Return type:
Array
- Raises:
NotImplementedError – If the subclass does not implement this method.
KeyError – If a required parameter is missing in
**ts_hparams
(in subclass).
- class diffusionlab.schedulers.UniformScheduler[source]¶
Bases:
Scheduler
A scheduler that generates uniformly spaced time steps.
Requires
t_min
,t_max
, andnum_steps
to be passed to theget_ts
method via keyword arguments. The number of points generated will benum_steps + 1
.- get_ts(**ts_hparams: Any) Array [source]¶
Generate uniformly spaced time steps.
- Parameters:
**ts_hparams (
Dict[str, Any]
) –Keyword arguments must contain
t_min
(float
): The minimum time value, typically close to 0.t_max
(float
): The maximum time value, typically close to 1.num_steps
(int
): The number of diffusion steps. The function will generatenum_steps + 1
time points.
- Returns:
- A JAX array containing uniformly spaced time steps
in descending order (from
t_max
tot_min
).
- Return type:
Array[num_steps+1]
- Raises:
KeyError – If
t_min
,t_max
, ornum_steps
is not found ints_hparams
.AssertionError – If
t_min
/t_max
constraints are violated ornum_steps
< 1.
diffusionlab.vector_fields
¶
- class diffusionlab.vector_fields.VectorFieldType(*values)[source]¶
Bases:
Enum
Enum representing the type of a vector field. A vector field is a function that takes in
x_t
(Array[*data_dims]
) andt
(Array[]
) and returns a vector of the same shape asx_t
(Array[*data_dims]
).DiffusionLab supports the following vector field types:
VectorFieldType.SCORE
: The score of the distribution.VectorFieldType.X0
: The denoised state.VectorFieldType.EPS
: The noise.VectorFieldType.V
: The velocity field.
- diffusionlab.vector_fields.convert_vector_field_type(x: Array, f_x: Array, alpha: Array, sigma: Array, alpha_prime: Array, sigma_prime: Array, in_type: VectorFieldType, out_type: VectorFieldType) Array [source]¶
Converts the output of a vector field from one type to another.
- Parameters:
x (
Array[*data_dims]
) – The input tensor.f_x (
Array[*data_dims]
) – The output of the vector field f evaluated at x.alpha (
Array[]
) – A scalar representing the scale parameter.sigma (
Array[]
) – A scalar representing the noise level parameter.alpha_prime (
Array[]
) – A scalar representing the scale derivative parameter.sigma_prime (
Array[]
) – A scalar representing the noise level derivative parameter.in_type (
VectorFieldType
) – The type of the input vector field (e.g.VectorFieldType.SCORE
,VectorFieldType.X0
,VectorFieldType.EPS
,VectorFieldType.V
).out_type (
VectorFieldType
) – The type of the output vector field.
- Returns:
The converted output of the vector field
- Return type:
Array[*data_dims]