Source code for diffusionlab.schedulers

from dataclasses import dataclass
from typing import Any
from jax import Array, numpy as jnp


[docs] @dataclass(frozen=True) class Scheduler: """ Base class for time step schedulers used in diffusion, denoising, and sampling. Allows for extensible scheduler implementations where subclasses can define their own initialization and time step generation parameters via **kwargs. """
[docs] def get_ts(self, **ts_hparams: Any) -> Array: """ Generate the sequence of time steps. This is an abstract method that must be implemented by subclasses. Subclasses should define the specific keyword arguments they expect within ``**ts_hparams``. Args: **ts_hparams (``Dict[str, Any]``): Keyword arguments containing parameters for generating time steps. Returns: ``Array``: A tensor containing the sequence of time steps in descending order. Raises: NotImplementedError: If the subclass does not implement this method. KeyError: If a required parameter is missing in ``**ts_hparams`` (in subclass). """ raise NotImplementedError
[docs] @dataclass(frozen=True) class UniformScheduler(Scheduler): """ A scheduler that generates uniformly spaced time steps. Requires ``t_min``, ``t_max``, and ``num_steps`` to be passed to the ``get_ts`` method via keyword arguments. The number of points generated will be ``num_steps + 1``. """
[docs] def get_ts(self, **ts_hparams: Any) -> Array: """ Generate uniformly spaced time steps. Args: **ts_hparams (``Dict[str, Any]``): Keyword arguments must contain - ``t_min`` (``float``): The minimum time value, typically close to 0. - ``t_max`` (``float``): The maximum time value, typically close to 1. - ``num_steps`` (``int``): The number of diffusion steps. The function will generate ``num_steps + 1`` time points. Returns: ``Array[num_steps+1]``: A JAX array containing uniformly spaced time steps in descending order (from ``t_max`` to ``t_min``). Raises: KeyError: If ``t_min``, ``t_max``, or ``num_steps`` is not found in ``ts_hparams``. AssertionError: If ``t_min``/``t_max`` constraints are violated or ``num_steps`` < 1. """ try: t_min = ts_hparams["t_min"] t_max = ts_hparams["t_max"] num_steps = ts_hparams["num_steps"] except KeyError as e: raise KeyError( f"Missing required parameter for UniformScheduler.get_ts: {e}" ) from e assert 0 <= t_min <= t_max <= 1, "t_min and t_max must be in the range [0, 1]" assert num_steps >= 1, "num_steps must be at least 1" ts = jnp.linspace(t_min, t_max, num_steps + 1)[::-1] return ts