Source code for grl.numerical_methods.probability_path

from typing import Union

import torch
from easydict import EasyDict
from tensordict import TensorDict


[docs] class GaussianConditionalProbabilityPath: r""" Overview: Gaussian conditional probability path. General case: .. math:: p(x(t)|x(0))=\mathcal{N}(x(t);\mu(t,x(0)),\sigma^2(t,x(0))I) If written in the form of SDE: .. math:: \mathrm{d}x=f(x,t)\mathrm{d}t+g(t)w_{t} where :math:`f(x,t)` is the drift term, :math:`g(t)` is the diffusion term, and :math:`w_{t}` is the Wiener process. For diffusion model: .. math:: p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),\sigma^2(t)I) or, .. math:: p(x(t)|x(0))=\mathcal{N}(x(t);s(t)x(0),s^2(t)e^{-2\lambda(t)}I) If written in the form of SDE: .. math:: \mathrm{d}x=\frac{s'(t)}{s(t)}x(t)\mathrm{d}t+s^2(t)\sqrt{2(\frac{s'(t)}{s(t)}-\lambda'(t))}e^{-\lambda(t)}\mathrm{d}w_{t} or, .. math:: \mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t} where :math:`s(t)` is the scale factor, :math:`\sigma^2(t)I` is the covariance matrix, \ :math:`\sigma(t)` is the standard deviation with the scale factor, \ :math:`e^{-2\lambda(t)}I` is the covariance matrix without the scale factor, \ :math:`\lambda(t)` is the half-logSNR, which is the difference between the log scale factor and the log standard deviation, \ :math:`\lambda(t)=\log(s(t))-\log(\sigma(t))`. For VP SDE: .. math:: p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{1}{2}\int_{0}^{t}{\beta(s)\mathrm{d}s}},(1-e^{-\int_{0}^{t}{\beta(s)\mathrm{d}s}})I) For Linear VP SDE: .. math:: p(x(t)|x(0))=\mathcal{N}(x(t);x(0)e^{-\frac{\beta(1)-\beta(0)}{4}t^2-\frac{\beta(0)}{2}t},(1-e^{-\frac{\beta(1)-\beta(0)}{2}t^2-\beta(0)t})I) #TODO: add more details for Cosine VP SDE; General VE SDE; OPT-Flow; """
[docs] def __init__(self, config: EasyDict) -> None: """ Overview: Initialize the Gaussian conditional probability path. Arguments: config (:obj:`EasyDict`): The configuration of the Gaussian conditional probability path. """ self.config = config self.type = config.type self.t_max = 1.0 if not hasattr(config, "t_max") else config.t_max assert self.type in [ "diffusion", "vp_sde", "linear_vp_sde", "cosine_vp_sde", "general_ve_sde", "op_flow", "linear", "gvp", ], "Unknown type of Gaussian conditional probability path {}".format(type)
[docs] def drift_coefficient( self, t: torch.Tensor, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Return the drift coefficient term of the Gaussian conditional probability path. The drift term is given by the following: .. math:: f(t) which satisfies the following SDE: .. math:: \mathrm{d}x=f(t)x(t)\mathrm{d}t+g(t)w_{t} Arguments: t (:obj:`torch.Tensor`): The input time. Returns: drift (:obj:`Union[torch.Tensor, TensorDict]`): The output drift term. """ if self.type == "linear_vp_sde": # TODO: make it compatible with TensorDict return -0.5 * ( self.config.beta_0 + t * (self.config.beta_1 - self.config.beta_0) ) elif self.type == "linear": return -torch.ones_like(t) / (1.0000001 - t) elif self.type == "gvp": return -0.5 * torch.pi * torch.tan(torch.pi * t / 2.0) else: raise NotImplementedError( "Drift coefficient term for type {} is not implemented".format( self.type ) )
[docs] def drift( self, t: torch.Tensor, x: Union[torch.Tensor, TensorDict] = None, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Return the drift term of the Gaussian conditional probability path. The drift term is given by the following: .. math:: f(x,t) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: drift (:obj:`Union[torch.Tensor, TensorDict]`): The output drift term. """ if self.type == "linear_vp_sde": # TODO: make it compatible with TensorDict return torch.einsum("i...,i->i...", x, self.drift_coefficient(t)) elif self.type == "linear": return torch.einsum("i...,i->i...", x, self.drift_coefficient(t)) elif self.type == "gvp": return torch.einsum("i...,i->i...", x, self.drift_coefficient(t)) else: raise NotImplementedError( "Drift term for type {} is not implemented".format(self.type) )
[docs] def diffusion( self, t: torch.Tensor, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following: .. math:: g(x,t) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: diffusion (:obj:`Union[torch.Tensor, TensorDict]`): The output diffusion term. """ if self.type == "linear_vp_sde": return torch.sqrt( self.config.beta_0 + t * (self.config.beta_1 - self.config.beta_0) ) elif self.type == "linear": return torch.sqrt(2 * t + 2 * t * t / (1.0000001 - t)) elif self.type == "gvp": first = torch.pi * torch.sin(torch.pi * t * 0.5) second = ( torch.sin(torch.pi * t * 0.5) * torch.sin(torch.pi * t * 0.5) * torch.tan(torch.pi * t * 0.5) ) return torch.sqrt(first + second) else: raise NotImplementedError( "Diffusion term for type {} is not implemented".format(self.type) )
[docs] def diffusion_squared( self, t: torch.Tensor, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Return the diffusion term of the Gaussian conditional probability path. The diffusion term is given by the following: .. math:: g^2(x,t) Arguments: t (:obj:`torch.Tensor`): The input time. x (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: diffusion (:obj:`Union[torch.Tensor, TensorDict]`): The output diffusion term. """ if self.type == "linear_vp_sde": return self.config.beta_0 + t * (self.config.beta_1 - self.config.beta_0) elif self.type == "linear": return 2 * t + 2 * t * t / (1.0000001 - t) elif self.type == "gvp": first = torch.pi * torch.sin(torch.pi * t * 0.5) second = ( torch.sin(torch.pi * t * 0.5) * torch.sin(torch.pi * t * 0.5) * torch.tan(torch.pi * t * 0.5) ) return first + second else: raise NotImplementedError( "Diffusion term for type {} is not implemented".format(self.type) )
[docs] def scale(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the scale factor of the Gaussian conditional probability path, which is .. math:: s(t) Arguments: t (:obj:`torch.Tensor`): The input time. Returns: scale (:obj:`torch.Tensor`): The scale factor. """ # TODO: implement the scale factor for other Gaussian conditional probability path if self.type == "linear_vp_sde": return torch.exp( -0.25 * t**2 * (self.config.beta_1 - self.config.beta_0) - 0.5 * t * self.config.beta_0 ) elif self.type == "linear": return 1 - t elif self.type == "gvp": return torch.cos(0.5 * torch.pi * t) else: raise NotImplementedError( "Scale factor for type {} is not implemented".format(self.type) )
[docs] def log_scale(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the log scale factor of the Gaussian conditional probability path, which is .. math:: \log(s(t)) Arguments: t (:obj:`torch.Tensor`): The input time. Returns: log_scale (:obj:`torch.Tensor`): The log scale factor. """ # TODO: implement the scale factor for other Gaussian conditional probability path if self.type == "linear_vp_sde": return ( -0.25 * t**2 * (self.config.beta_1 - self.config.beta_0) - 0.5 * t * self.config.beta_0 ) elif self.type == "linear": return torch.log(1.0 - t) elif self.type == "gvp": return torch.log(torch.cos(0.5 * torch.pi * t)) else: raise NotImplementedError( "Log scale factor for type {} is not implemented".format(self.type) )
[docs] def d_log_scale_dt( self, t: torch.Tensor, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Compute the time derivative of the log scale factor of the Gaussian conditional probability path, which is .. math:: \log(s'(t)) Arguments: t (:obj:`torch.Tensor`): The input time. Returns: d_log_scale_dt (:obj:`Union[torch.Tensor, TensorDict]`): The time derivative of the log scale factor. """ if self.type == "linear_vp_sde": return -0.5 * t * (self.config.beta_1 - self.config.beta_0) elif self.type == "linear": return -1.0 / (1.0000001 - t) elif self.type == "gvp": return -0.5 * torch.pi * torch.tan(0.5 * torch.pi * t) else: raise NotImplementedError( "Time derivative of the log scale factor for type {} is not implemented".format( self.type ) )
[docs] def d_scale_dt( self, t: torch.Tensor, ) -> Union[torch.Tensor, TensorDict]: r""" Overview: Compute the time derivative of the scale factor of the Gaussian conditional probability path, which is .. math:: s'(t) Arguments: t (:obj:`torch.Tensor`): The input time. Returns: d_scale_dt (:obj:`Union[torch.Tensor, TensorDict]`): The time derivative of the scale factor. """ if self.type == "linear_vp_sde": return -0.5 * t * (self.config.beta_1 - self.config.beta_0) * self.scale(t) elif self.type == "linear": return -1.0 * torch.ones_like(t, dtype=torch.float32) elif self.type == "gvp": return -0.5 * torch.pi * torch.sin(0.5 * torch.pi * t) else: raise NotImplementedError( "Time derivative of the scale factor for type {} is not implemented".format( self.type ) )
[docs] def std(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the standard deviation of the Gaussian conditional probability path, which is .. math:: \sqrt{\Sigma(t)} or .. math:: \sigma(t) or .. math:: s(t)e^{-\lambda(t)} Arguments: t (:obj:`torch.Tensor`): The input time. Returns: std (:obj:`torch.Tensor`): The standard deviation. """ if self.type == "linear_vp_sde": return torch.sqrt(1.0 - self.scale(t) ** 2) elif self.type == "linear": return t elif self.type == "gvp": return torch.sin(0.5 * torch.pi * t) else: raise NotImplementedError( "Standard deviation for type {} is not implemented".format(self.type) )
[docs] def d_std_dt(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the time derivative of the standard deviation of the Gaussian conditional probability path, which is .. math:: \frac{\mathrm{d}\sigma(t)}{\mathrm{d}t} Arguments: t (:obj:`torch.Tensor`): The input time. """ if self.type == "linear_vp_sde": return -self.d_scale_dt(t) * self.scale(t) / self.std(t) elif self.type == "linear": return torch.ones_like(t, dtype=torch.float32) elif self.type == "gvp": return 0.5 * torch.pi * torch.cos(0.5 * torch.pi * t) else: raise NotImplementedError( "Time derivative of standard deviation for type {} is not implemented".format( self.type ) )
[docs] def covariance(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the covariance matrix of the Gaussian conditional probability path, which is .. math:: \Sigma(t) or .. math:: \sigma^2(t)I or .. math:: s^2(t)e^{-2\lambda(t)}I Arguments: t (:obj:`torch.Tensor`): The input time. Returns: covariance (:obj:`torch.Tensor`): The covariance matrix. """ if self.type == "linear_vp_sde": return 1.0 - self.scale(t) ** 2 elif self.type == "linear": return t * t elif self.type == "gvp": return torch.sin(0.5 * torch.pi * t) * torch.sin(0.5 * torch.pi * t) else: raise NotImplementedError( "Covariance for type {} is not implemented".format(self.type) )
[docs] def d_covariance_dt(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the time derivative of the covariance matrix of the Gaussian conditional probability path, which is .. math:: \frac{\mathrm{d}\Sigma(t)}{\mathrm{d}t} Arguments: t (:obj:`torch.Tensor`): The input time. Returns: d_covariance_dt (:obj:`torch.Tensor`): The time derivative of the covariance matrix. """ if self.type == "linear_vp_sde": return -2.0 * self.scale(t) * self.d_scale_dt(t) elif self.type == "linear": return 2.0 * t elif self.type == "gvp": return ( torch.pi * torch.sin(torch.pi * t * 0.5) * torch.cos(torch.pi * t * 0.5) ) else: raise NotImplementedError( "Time derivative of covariance for type {} is not implemented".format( self.type ) )
[docs] def HalfLogSNR(self, t: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the half-logSNR of the Gaussian conditional probability path, which is .. math:: \log(s(t))-\log(\sigma(t)) Arguments: t (:obj:`torch.Tensor`): The input time. Returns: HalfLogSNR (:obj:`torch.Tensor`): The half-logSNR. """ if self.type == "linear_vp_sde": return self.log_scale(t) - 0.5 * torch.log(1.0 - self.scale(t) ** 2) else: raise NotImplementedError( "Half-logSNR for type {} is not implemented".format(self.type) )
[docs] def InverseHalfLogSNR(self, HalfLogSNR: torch.Tensor) -> torch.Tensor: r""" Overview: Compute the inverse function of the half-logSNR of the Gaussian conditional probability path. Since the half-logSNR is an invertible function, we can compute the time t from the half-logSNR. For linear VP SDE, the inverse function is .. math:: t(\lambda)=\frac{1}{\beta_1-\beta_0}(\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}-\beta_0) or, .. math:: t(\lambda)=\frac{2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}{\sqrt{\beta_0^2+2(\beta_1-\beta_0)\log{(e^{-2\lambda}+1)}}+\beta_0} Arguments: HalfLogSNR (:obj:`torch.Tensor`): The input half-logSNR. Returns: t (:obj:`torch.Tensor`): The time. """ if self.type == "linear_vp_sde": numerator = 2.0 * torch.logaddexp( -2.0 * HalfLogSNR, torch.zeros((1,)).to(HalfLogSNR) ) denominator = ( torch.sqrt( self.config.beta_0**2 + (self.config.beta_1 - self.config.beta_0) * numerator ) + self.config.beta_0 ) return numerator / denominator else: raise NotImplementedError( "Inverse function of half-logSNR for type {} is not implemented".format( self.type ) )
class ConditionalProbabilityPath: """ Overview: Conditional probability path for general continuous-time normalizing flow. """ def __init__(self, config) -> None: self.config = config def std(self, t: torch.Tensor) -> torch.Tensor: return torch.tensor(self.config.sigma, device=t.device) class SchrodingerBridgePath: def __init__(self, config) -> None: self.config = config def std(self, t: torch.Tensor) -> torch.Tensor: return self.config.sigma * torch.sqrt(t * (1 - t)) def lambd(self, t: torch.Tensor) -> torch.Tensor: return 2 * self.std(t) / (self.config.sigma**2) def std_prime(self, t: torch.Tensor) -> torch.Tensor: return (1 - 2 * t) / (2 * t * (1 - t))