[docs]classVectorFieldType(enum.Enum):""" Enum representing the type of a vector field. A vector field is a function that takes in ``x_t`` (``Array[*data_dims]``) and ``t`` (``Array[]``) and returns a vector of the same shape as ``x_t`` (``Array[*data_dims]``). DiffusionLab supports the following vector field types: - ``VectorFieldType.SCORE``: The score of the distribution. - ``VectorFieldType.X0``: The denoised state. - ``VectorFieldType.EPS``: The noise. - ``VectorFieldType.V``: The velocity field. """SCORE=enum.auto()X0=enum.auto()EPS=enum.auto()V=enum.auto()
[docs]defconvert_vector_field_type(x:Array,f_x:Array,alpha:Array,sigma:Array,alpha_prime:Array,sigma_prime:Array,in_type:VectorFieldType,out_type:VectorFieldType,)->Array:""" Converts the output of a vector field from one type to another. Arguments: x (``Array[*data_dims]``): The input tensor. f_x (``Array[*data_dims]``): The output of the vector field f evaluated at x. alpha (``Array[]``): A scalar representing the scale parameter. sigma (``Array[]``): A scalar representing the noise level parameter. alpha_prime (``Array[]``): A scalar representing the scale derivative parameter. sigma_prime (``Array[]``): A scalar representing the noise level derivative parameter. in_type (``VectorFieldType``): The type of the input vector field (e.g. ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). out_type (``VectorFieldType``): The type of the output vector field. Returns: ``Array[*data_dims]``: The converted output of the vector field """""" Derivation: ---------------------------- Define certain quantities: alpha_r = alpha' / alpha sigma_r = sigma' / sigma diff_r = sigma_r - alpha_r and note that diff_r >= 0 since alpha' < 0 and all other terms are > 0. Under the data model (1) x := alpha * x0 + sigma * eps it holds that (2) x = alpha * E[x0 | x] + sigma * E[eps | x] Therefore (3) E[x0 | x] = (x - sigma * E[eps | x]) / alpha (4) E[eps | x] = (x - alpha * E[x0 | x]) / sigma Furthermore, from (1) it holds that (5) v := x' = alpha' * x0 + sigma' * eps, or in particular (6) E[v | x] = alpha' * E[x0 | x] + sigma' * E[eps | x] Using (3), (4), (6) it holds (7) E[v | x] = alpha_r * (x - sigma * E[eps | x]) + sigma' * E[eps | x] => E[v | x] = alpha'/alpha * x + (sigma' - sigma * alpha'/alpha) * E[eps | x] => E[v | x] = alpha'/alpha * x + sigma * (sigma'/sigma - alpha'/alpha) * E[eps | x] => E[v | x] = alpha_r * x + sigma * diff_r * E[eps | x] (8) E[eps | x] = (E[v | x] - alpha_r * x) / (sigma * diff_r) and, similarly, (9) E[v | x] = alpha' * E[x0 | x] + sigma'/sigma * (x - alpha * E[x0 | x]) => E[v | x] = sigma'/sigma * x + (alpha' - alpha * sigma'/sigma) * E[x0 | x] => E[v | x] = sigma'/sigma * x + alpha * (alpha'/alpha - sigma'/sigma) * E[x0 | x] => E[v | x] = sigma_r * x - alpha * diff_r * E[x0 | x] (10) E[x0 | x] = (sigma_r * x - E[v | x]) / (alpha * diff_r) To connect the score function to the other types, we use Tweedie's formula: (11) alpha * E[x0 | x] = x + sigma^2 * score(x, alpha, sigma). Therefore, from (11): (12) E[x0 | x] = (x + sigma^2 * score(x, alpha, sigma)) / alpha From (12): (13) score(x, alpha, sigma) = (alpha * E[x0 | x] - x) / sigma^2 From (11) and (4): (14) E[eps | x] = -sigma * score(x, alpha, sigma) From (14): (15) score(x, alpha, sigma) = -E[eps | x] / sigma From (14) and (7): (16) E[v | x] = alpha_r * x - sigma^2 * diff_r * score(x, alpha, sigma) From (16): (17) score(x, alpha, sigma) = (alpha_r * x - E[v | x]) / (sigma^2 * diff_r) """alpha_ratio=alpha_prime/alphasigma_ratio=sigma_prime/sigmaratio_diff=sigma_ratio-alpha_ratioconverted_fx=f_xifin_type==VectorFieldType.SCORE:ifout_type==VectorFieldType.X0:converted_fx=(x+sigma**2*f_x)/alpha# From equation (12)elifout_type==VectorFieldType.EPS:converted_fx=-sigma*f_x# From equation (14)elifout_type==VectorFieldType.V:converted_fx=(alpha_ratio*x-sigma**2*ratio_diff*f_x)# From equation (16)elifin_type==VectorFieldType.X0:ifout_type==VectorFieldType.SCORE:converted_fx=(alpha*f_x-x)/sigma**2# From equation (13)elifout_type==VectorFieldType.EPS:converted_fx=(x-alpha*f_x)/sigma# From equation (4)elifout_type==VectorFieldType.V:converted_fx=(sigma_ratio*x-alpha*ratio_diff*f_x)# From equation (9)elifin_type==VectorFieldType.EPS:ifout_type==VectorFieldType.SCORE:converted_fx=-f_x/sigma# From equation (15)elifout_type==VectorFieldType.X0:converted_fx=(x-sigma*f_x)/alpha# From equation (3)elifout_type==VectorFieldType.V:converted_fx=(alpha_ratio*x+sigma*ratio_diff*f_x)# From equation (7)elifin_type==VectorFieldType.V:ifout_type==VectorFieldType.SCORE:converted_fx=(alpha_ratio*x-f_x)/(sigma**2*ratio_diff)# From equation (17)elifout_type==VectorFieldType.X0:converted_fx=(sigma_ratio*x-f_x)/(alpha*ratio_diff)# From equation (10)elifout_type==VectorFieldType.EPS:converted_fx=(f_x-alpha_ratio*x)/(sigma*ratio_diff)# From equation (8)returnconverted_fx