grl.numerical_methods

ODE

class grl.numerical_methods.ODE(drift=None)[source]
Overview:

Base class for ordinary differential equations. The ODE is defined as:

\[dx = f(x, t)dt\]

where f(x, t) is the drift term.

Interfaces:

__init__

__init__(drift=None)[source]

SDE

class grl.numerical_methods.SDE(drift=None, diffusion=None)[source]
Overview:

Base class for stochastic differential equations. The SDE is defined as:

\[dx = f(x, t)dt + g(x, t)dW\]

where f(x, t) is the drift term, g(x, t) is the diffusion term, and dW is the Wiener process.

Interfaces:

__init__

__init__(drift=None, diffusion=None)[source]

DPMSolver

class grl.numerical_methods.DPMSolver(order, device, atol=1e-05, rtol=1e-05, steps=None, type='dpm_solver', method='singlestep', solver_type='dpm_solver', skip_type='time_uniform', denoise=False)[source]
Overview:

The DPM-Solver for sampling from the diffusion process.

Interface:

__init__, integrate

__init__(order, device, atol=1e-05, rtol=1e-05, steps=None, type='dpm_solver', method='singlestep', solver_type='dpm_solver', skip_type='time_uniform', denoise=False)[source]
Overview:

Initialize the DPM-Solver.

Parameters:
  • order (int) – The order of the DPM-Solver, which should be 1, 2, or 3.

  • device (str) – The device for the computation.

  • denoise (bool) – Whether to denoise at the final step.

  • atol (float) – The absolute tolerance for the adaptive solver.

  • rtol (float) – The relative tolerance for the adaptive solver.

  • steps (int) – The total number of function evaluations (NFE).

  • type (str) – The type for the DPM-Solver, which should be ‘dpm_solver’ or ‘dpm_solver++’.

  • method (str) – The method for the DPM-Solver, which should be ‘singlestep’, ‘multistep’, ‘singlestep_fixed’, or ‘adaptive’.

  • solver_type (str) – The type for the high-order solvers, which should be ‘dpm_solver’ or ‘taylor’. The type slightly impacts the performance. We recommend to use ‘dpm_solver’ type.

  • skip_type (str) – The type for the spacing of the time steps, which should be ‘logSNR’, ‘time_uniform’, or ‘time_quadratic’.

  • denoise – Whether to denoise at the final step.

integrate(diffusion_process, noise_function, data_prediction_function, x, condition=None, steps=None, save_intermediate=False)[source]
Overview:

Integrate the diffusion process by the DPM-Solver.

Parameters:
  • diffusion_process (DiffusionProcess) – The diffusion process.

  • noise_function (Callable) – The noise prediction model.

  • data_prediction_function (Callable) – The data prediction model.

  • x (Union[torch.Tensor, TensorDict]) – The initial value at time t_start.

  • condition (Union[torch.Tensor, TensorDict]) – The condition for the data prediction model.

  • steps (int) – The total number of function evaluations (NFE).

  • save_intermediate (bool) – If true, also return the intermediate model values.

Returns:

The approximated solution at time t_end.

Return type:

x_end (torch.Tensor)

ODESolver

class grl.numerical_methods.ODESolver(ode_solver='euler', dt=0.01, atol=1e-05, rtol=1e-05, library='torchdyn', **kwargs)[source]
Overview:

The ODE solver class.

Interfaces:

__init__, integrate

__init__(ode_solver='euler', dt=0.01, atol=1e-05, rtol=1e-05, library='torchdyn', **kwargs)[source]
Overview:

Initialize the ODE solver using torchode or torchdyn library.

Parameters:
  • ode_solver (str) – The ODE solver to use.

  • dt (float) – The time step.

  • atol (float) – The absolute tolerance.

  • rtol (float) – The relative tolerance.

  • library (str) – The library to use for the ODE solver. Currently, it supports ‘torchdiffeq’, ‘torchdyn’ and ‘torchode’.

  • **kwargs – Additional arguments for the ODE solver.

integrate(drift, x0, t_span, **kwargs)[source]
Overview:

Integrate the ODE.

Parameters:
  • drift (Union[nn.Module, Callable]) – The drift term of the ODE.

  • x0 (Union[torch.Tensor, TensorDict]) – The input initial state.

  • t_span (torch.Tensor) – The time at which to evaluate the ODE. The first element is the initial time, and the last element is the final time. For example, t = torch.tensor([0.0, 1.0]).

Returns:

The output trajectory of the ODE, which has the same data type as x0 and the shape of (len(t_span), *x0.shape).

Return type:

trajectory (Union[torch.Tensor, TensorDict])

SDESolver

class grl.numerical_methods.SDESolver(sde_solver='euler', sde_noise_type='diagonal', sde_type='ito', dt=0.001, atol=1e-05, rtol=1e-05, library='torchsde', **kwargs)[source]
__init__(sde_solver='euler', sde_noise_type='diagonal', sde_type='ito', dt=0.001, atol=1e-05, rtol=1e-05, library='torchsde', **kwargs)[source]
Overview:

Initialize the SDE solver using torchsde library.

Parameters:
  • sde_solver (str) – The SDE solver to use.

  • sde_noise_type (str) – The type of noise of the SDE. It can be ‘diagonal’, ‘general’, ‘scalar’ or ‘additive’.

  • sde_type (str) – The type of the SDE. It can be ‘ito’ or ‘stratonovich’.

  • dt (float) – The time step.

  • atol (float) – The absolute tolerance.

  • rtol (float) – The relative tolerance.

  • library (str) – The library to use for the ODE solver. Currently, it supports ‘torchsde’.

  • **kwargs – Additional arguments for the ODE solver.

integrate(drift, diffusion, x0, t_span, logqp=False, adaptive=False)[source]
Overview:

Integrate the SDE.

Parameters:
  • drift (nn.Module) – The function that defines the ODE.

  • diffusion (nn.Module) – The function that defines the ODE.

GaussianConditionalProbabilityPath

class grl.numerical_methods.GaussianConditionalProbabilityPath(config)[source]
Overview:

Gaussian conditional probability path.

General case:

\[p(x(t)|x(0))=\mathcal{N}(x(t);\mu(t,x(0)),\sigma^2(t,x(0))I)\]

If written in the form of SDE:

\[\mathrm{d}x=f(x,t)\mathrm{d}t+g(t)w_{t}\]

where \(f(x,t)\) is the drift term, \(g(t)\) is the diffusion term, and \(w_{t}\) is the Wiener process.

For diffusion model:

\[p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),\sigma^2(t)I)\]

or,

\[p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),s^2(t)e^{-2\lambda(t)}I)\]

If written in the form of SDE:

\[\mathrm{d}x=\frac{s'(t)}{s(t)}x(t)\mathrm{d}t+s^2(t)\sqrt{2(\frac{s'(t)}{s(t)}-\lambda'(t))}e^{-\lambda(t)}\mathrm{d}w_{t}\]

or,

\[\mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t}\]
where \(s(t)\) is the scale factor, \(\sigma^2(t)I\) is the covariance matrix,

\(\sigma(t)\) is the standard deviation with the scale factor, \(e^{-2\lambda(t)}I\) is the covariance matrix without the scale factor, \(\lambda(t)\) is the half-logSNR, which is the difference between the log scale factor and the log standard deviation, \(\lambda(t)=\log(s(t))-\log(\sigma(t))\).

For VP SDE:

\[p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{1}{2}\int_{0}^{t}{\beta(s)\mathrm{d}s}},(1-e^{-\int_{0}^{t}{\beta(s)\mathrm{d}s}})I)\]

For Linear VP SDE:

\[p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{\beta(1)-\beta(0)}{4}t^2-\frac{\beta(0)}{2}t},(1-e^{-\frac{\beta(1)-\beta(0)}{2}t^2-\beta(0)t})I)\]

#TODO: add more details for Cosine VP SDE; General VE SDE; OPT-Flow;

HalfLogSNR(t)[source]
Overview:

Compute the half-logSNR of the Gaussian conditional probability path, which is

\[\log(s(t))-\log(\sigma(t))\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The half-logSNR.

Return type:

HalfLogSNR (torch.Tensor)

InverseHalfLogSNR(HalfLogSNR)[source]
Overview:

Compute the inverse function of the half-logSNR of the Gaussian conditional probability path. Since the half-logSNR is an invertible function, we can compute the time t from the half-logSNR. For linear VP SDE, the inverse function is

\[t(\lambda)=\frac{1}{\beta_1-\beta_0}(\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}-\beta_0)\]

or,

\[t(\lambda)=\frac{2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}{\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}+\beta_0}\]
Parameters:

HalfLogSNR (torch.Tensor) – The input half-logSNR.

Returns:

The time.

Return type:

t (torch.Tensor)

__init__(config)[source]
Overview:

Initialize the Gaussian conditional probability path.

Parameters:

config (EasyDict) – The configuration of the Gaussian conditional probability path.

covariance(t)[source]
Overview:

Compute the covariance matrix of the Gaussian conditional probability path, which is

\[\Sigma(t)\]

or

\[\sigma^2(t)I\]

or

\[s^2(t)e^{-2\lambda(t)}I\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The covariance matrix.

Return type:

covariance (torch.Tensor)

d_covariance_dt(t)[source]
Overview:

Compute the time derivative of the covariance matrix of the Gaussian conditional probability path, which is

\[\frac{\mathrm{d}\Sigma(t)}{\mathrm{d}t}\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The time derivative of the covariance matrix.

Return type:

d_covariance_dt (torch.Tensor)

d_log_scale_dt(t)[source]
Overview:

Compute the time derivative of the log scale factor of the Gaussian conditional probability path, which is

\[\log(s'(t))\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The time derivative of the log scale factor.

Return type:

d_log_scale_dt (Union[torch.Tensor, TensorDict])

d_scale_dt(t)[source]
Overview:

Compute the time derivative of the scale factor of the Gaussian conditional probability path, which is

\[s'(t)\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The time derivative of the scale factor.

Return type:

d_scale_dt (Union[torch.Tensor, TensorDict])

d_std_dt(t)[source]
Overview:

Compute the time derivative of the standard deviation of the Gaussian conditional probability path, which is

\[\frac{\mathrm{d}\sigma(t)}{\mathrm{d}t}\]
Parameters:

t (torch.Tensor) – The input time.

Return type:

Tensor

diffusion(t)[source]
Overview:

Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following:

\[g(x,t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output diffusion term.

Return type:

diffusion (Union[torch.Tensor, TensorDict])

diffusion_squared(t)[source]
Overview:

Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following:

\[g^2(x,t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output diffusion term.

Return type:

diffusion (Union[torch.Tensor, TensorDict])

drift(t, x=None)[source]
Overview:

Return the drift term of the Gaussian conditional probability path. The drift term is given by the following:

\[f(x,t)\]
Parameters:
  • t (torch.Tensor) – The input time.

  • x (Union[torch.Tensor, TensorDict]) – The input state.

Returns:

The output drift term.

Return type:

drift (Union[torch.Tensor, TensorDict])

drift_coefficient(t)[source]
Overview:

Return the drift coefficient term of the Gaussian conditional probability path. The drift term is given by the following:

\[f(t)\]

which satisfies the following SDE:

\[\mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t}\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The output drift term.

Return type:

drift (Union[torch.Tensor, TensorDict])

log_scale(t)[source]
Overview:

Compute the log scale factor of the Gaussian conditional probability path, which is

\[\log(s(t))\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The log scale factor.

Return type:

log_scale (torch.Tensor)

scale(t)[source]
Overview:

Compute the scale factor of the Gaussian conditional probability path, which is

\[s(t)\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The scale factor.

Return type:

scale (torch.Tensor)

std(t)[source]
Overview:

Compute the standard deviation of the Gaussian conditional probability path, which is

\[\sqrt{\Sigma(t)}\]

or

\[\sigma(t)\]

or

\[s(t)e^{-\lambda(t)}\]
Parameters:

t (torch.Tensor) – The input time.

Returns:

The standard deviation.

Return type:

std (torch.Tensor)