Source code for diffusionlab.distributions.empirical

from dataclasses import dataclass
from typing import Iterable, Tuple

import jax
from jax import Array, numpy as jnp

from diffusionlab.dynamics import DiffusionProcess
from diffusionlab.distributions.base import Distribution
from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type


[docs] @dataclass(frozen=True) class EmpiricalDistribution(Distribution): """ An empirical distribution, i.e., the uniform distribution over a dataset. The probability measure is defined as: ``μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)`` where ``x_i`` is the ith data point in the dataset, and ``N`` is the number of data points. This class provides methods for sampling from the empirical distribution and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. Attributes: dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters (currently unused). dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters. It may contain the following keys: - ``labeled_data`` (``Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]``): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can be ``None`` if the data is unlabelled. """ def __init__( self, labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]] ): super().__init__( dist_params={}, dist_hparams={"labeled_data": labeled_data}, )
[docs] def sample( self, key: Array, num_samples: int ) -> Tuple[Array, Array] | Tuple[Array, None]: """ Sample from the empirical distribution using reservoir sampling. Assumes all batches in ``labeled_data`` are consistent: either all have labels (``Array``) or none have labels (``None``). Args: key (``Array``): The JAX PRNG key to use for sampling. num_samples (``int``): The number of samples to draw. Returns: ``Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]``: A tuple ``(samples, labels)`` containing the samples and corresponding labels (stacked into an ``Array``), or ``(samples, None)`` if the data is unlabelled. """ data_iterator = iter(self.dist_hparams["labeled_data"]) # Get an iterator # Initialize reservoir reservoir_samples = [] reservoir_labels = [] # Will store labels if present, otherwise remains empty items_seen = 0 is_labeled = None # Determine based on first batch for X_batch, y_batch in data_iterator: # Determine if data is labeled based on the first batch encountered if is_labeled is None: is_labeled = y_batch is not None if is_labeled: # Basic validation for the first labeled batch if ( not isinstance(y_batch, jnp.ndarray) or y_batch.shape[0] != X_batch.shape[0] ): raise ValueError( f"First labeled batch has inconsistent shape. X shape: {X_batch.shape}, Y shape: {getattr(y_batch, 'shape', 'N/A')}" ) # else: y_batch is None, is_labeled remains False current_batch_size = X_batch.shape[0] # Reservoir sampling for i in range(current_batch_size): x = X_batch[i] y = y_batch[i] if is_labeled else None if items_seen < num_samples: reservoir_samples.append(x) if is_labeled: reservoir_labels.append(y) else: key, subkey = jax.random.split(key) j = jax.random.randint( subkey, shape=(), minval=0, maxval=items_seen + 1 ) if j < num_samples: reservoir_samples[j] = x if is_labeled: reservoir_labels[j] = y items_seen += 1 # Final checks and return if items_seen < num_samples: raise ValueError( f"Requested {num_samples} samples, but only {items_seen} items are available in the dataset." ) # Stack samples into a single array stacked_samples = jnp.stack(reservoir_samples) # Stack labels if data was labeled, otherwise return None stacked_labels = None if is_labeled: stacked_labels = jnp.stack(reservoir_labels) return stacked_samples, stacked_labels else: return stacked_samples, None
[docs] def score( self, x_t: Array, t: Array, diffusion_process: DiffusionProcess, ) -> Array: """ Computes the score function (``∇_x log p_t(x)``) of the empirical distribution at time ``t``, given the noisy state ``x_t`` and the diffusion process. Args: x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The score of the empirical distribution at ``(x_t, t)``. """ x0_x_t = self.x0(x_t, t, diffusion_process) alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) alpha_prime_t = diffusion_process.alpha_prime(t) sigma_prime_t = diffusion_process.sigma_prime(t) score_x_t = convert_vector_field_type( x_t, x0_x_t, alpha_t, sigma_t, alpha_prime_t, sigma_prime_t, VectorFieldType.X0, VectorFieldType.SCORE, ) return score_x_t
[docs] def x0( self, x_t: Array, t: Array, diffusion_process: DiffusionProcess, ) -> Array: """ Computes the denoiser ``E[x_0 | x_t]`` for an empirical distribution w.r.t. a given diffusion process. This method computes the denoiser by performing a weighted average of the dataset samples, where the weights are determined by the likelihood of ``x_t`` given each sample. Arguments: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The prediction of ``x_0``. """ data = self.dist_hparams["labeled_data"] alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) # Initialize for stable online softmax max_exponent = -jnp.inf weighted_sum_x0 = jnp.zeros_like(x_t) sum_weights = jnp.zeros(()) for X_batch, _ in data: # y_batch is unused squared_dists = jax.vmap(lambda x: jnp.sum((x_t - alpha_t * x) ** 2))( X_batch ) exponents = -squared_dists / (2 * sigma_t**2) # Rebalance by max exponent current_max_exponent = jnp.max(exponents) new_max_exponent = jnp.maximum(max_exponent, current_max_exponent) # Rescale previous sums if max exponent increased rescale_factor = jnp.exp(max_exponent - new_max_exponent) sum_weights = sum_weights * rescale_factor weighted_sum_x0 = weighted_sum_x0 * rescale_factor # (*data_dims) # Calculate current batch weights scaled by new max exponent current_weights = jnp.exp(exponents - new_max_exponent) # (batch_size, ) # Update sums sum_weights = sum_weights + jnp.sum(current_weights) # (, ) weighted_sum_x0 = weighted_sum_x0 + jnp.sum( jax.vmap(lambda xi, wi: xi * wi)( X_batch, current_weights ), # (batch_size, *data_dims) axis=0, # Sum over batch dim -> (*data_dims) ) # Update overall max exponent max_exponent = new_max_exponent # Final calculation with division-by-zero protection x0_hat = jnp.where( sum_weights == 0, jnp.zeros_like(weighted_sum_x0), weighted_sum_x0 / sum_weights, ) return x0_hat
[docs] def eps( self, x_t: Array, t: Array, diffusion_process: DiffusionProcess, ) -> Array: """ Computes the noise field ``eps(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. Args: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The noise field at ``(x_t, t)``. """ x0_x_t = self.x0(x_t, t, diffusion_process) alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) alpha_prime_t = diffusion_process.alpha_prime(t) sigma_prime_t = diffusion_process.sigma_prime(t) eps_x_t = convert_vector_field_type( x_t, x0_x_t, alpha_t, sigma_t, alpha_prime_t, sigma_prime_t, VectorFieldType.X0, VectorFieldType.EPS, ) return eps_x_t
[docs] def v( self, x_t: Array, t: Array, diffusion_process: DiffusionProcess, ) -> Array: """ Computes the velocity field ``v(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. Args: x_t (``Array[*data_dims]``): The input tensor. t (``Array[]``): The time tensor. diffusion_process (``DiffusionProcess``): The diffusion process. Returns: ``Array[*data_dims]``: The velocity field at ``(x_t, t)``. """ x0_x_t = self.x0(x_t, t, diffusion_process) alpha_t = diffusion_process.alpha(t) sigma_t = diffusion_process.sigma(t) alpha_prime_t = diffusion_process.alpha_prime(t) sigma_prime_t = diffusion_process.sigma_prime(t) v_x_t = convert_vector_field_type( x_t, x0_x_t, alpha_t, sigma_t, alpha_prime_t, sigma_prime_t, VectorFieldType.X0, VectorFieldType.V, ) return v_x_t