• Docs >
  • Module code >
  • grl.generative_models.conditional_flow_model.independent_conditional_flow_model

Source code for grl.generative_models.conditional_flow_model.independent_conditional_flow_model

from typing import List, Tuple, Union

import torch
import torch.nn as nn
from torch.distributions import Independent, Normal
import treetensor
from easydict import EasyDict
from tensordict import TensorDict

from grl.generative_models.intrinsic_model import IntrinsicModel
from grl.generative_models.model_functions.velocity_function import VelocityFunction
from grl.generative_models.random_generator import gaussian_random_variable
from grl.generative_models.stochastic_process import StochasticProcess
from grl.generative_models.metric import compute_likelihood
from grl.numerical_methods.numerical_solvers import get_solver
from grl.numerical_methods.numerical_solvers.dpm_solver import DPMSolver
from grl.numerical_methods.numerical_solvers.ode_solver import (
from grl.numerical_methods.numerical_solvers.sde_solver import SDESolver
from grl.numerical_methods.probability_path import (
from grl.utils import find_parameters

[docs]class IndependentConditionalFlowModel(nn.Module): """ Overview: The independent conditional flow model, which is a flow model with independent conditional probability paths. Interfaces: ``__init__``, ``get_type``, ``sample``, ``sample_forward_process``, ``flow_matching_loss`` """
[docs] def __init__( self, config: EasyDict, ): """ Overview: Initialize the model. Arguments: config (:obj:`EasyDict`): The configuration of the model. """ super().__init__() self.config = config self.x_size = config.x_size self.device = config.device self.gaussian_generator = gaussian_random_variable( config.x_size, config.device, config.use_tree_tensor if hasattr(config, "use_tree_tensor") else False, ) self.path = ConditionalProbabilityPath(config.path) if hasattr(config, "reverse_path"): self.reverse_path = ConditionalProbabilityPath(config.reverse_path) else: self.reverse_path = None self.model_type = config.model.type assert self.model_type in [ "velocity_function", ], "Unknown type of model {}".format(self.model_type) self.model = IntrinsicModel(config.model.args) self.diffusion_process = StochasticProcess(self.path) self.velocity_function_ = VelocityFunction( self.model_type, self.diffusion_process ) if hasattr(config, "solver"): self.solver = get_solver(config.solver.type)(**config.solver.args)
def get_type(self): return "IndependentConditionalFlowModel"
[docs] def sample( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the model, return the final state. Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(B, N, D)`. If x_0 is not provided, the shape will be :math:`(B, D)`. If x_0 and condition are not provided, the shape will be :math:`(D)`. """ return self.sample_forward_process( t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, with_grad=with_grad, solver_config=solver_config, )[-1]
[docs] def sample_forward_process( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, ): """ Overview: Sample from the diffusion model, return all intermediate states. Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): An extra batch size used for repeated sampling with the same initial state. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. Shapes: t_span: :math:`(T)`, where :math:`T` is the number of time steps. batch_size: :math:`(B)`, where :math:`B` is the batch size of data, which could be a scalar or a tensor such as :math:`(B1, B2)`. x_0: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the state, which could be a scalar or a tensor such as :math:`(D1, D2)`. condition: :math:`(N, D)`, where :math:`N` is the batch size of data and :math:`D` is the dimension of the condition, which could be a scalar or a tensor such as :math:`(D1, D2)`. x: :math:`(T, N, D)`, if extra batch size :math:`B` is provided, the shape will be :math:`(T, B, N, D)`. If x_0 is not provided, the shape will be :math:`(T, B, D)`. If x_0 and condition are not provided, the shape will be :math:`(T, D)`. """ if t_span is not None: t_span = t_span.to(self.device) if batch_size is None: extra_batch_size = torch.tensor((1,), device=self.device) elif isinstance(batch_size, int): extra_batch_size = torch.tensor((batch_size,), device=self.device) else: if ( isinstance(batch_size, torch.Size) or isinstance(batch_size, Tuple) or isinstance(batch_size, List) ): extra_batch_size = torch.tensor(batch_size, device=self.device) else: assert False, "Invalid batch size" if x_0 is not None and condition is not None: assert ( x_0.shape[0] == condition.shape[0] ), "The batch size of x_0 and condition must be the same" data_batch_size = x_0.shape[0] elif x_0 is not None: data_batch_size = x_0.shape[0] elif condition is not None: data_batch_size = condition.shape[0] else: data_batch_size = 1 if solver_config is not None: solver = get_solver(solver_config.type)(**solver_config.args) else: assert hasattr( self, "solver" ), "solver must be specified in config or solver_config" solver = self.solver if x_0 is None: x = self.gaussian_generator( batch_size=torch.prod(extra_batch_size) * data_batch_size ) # x.shape = (B*N, D) else: if isinstance(self.x_size, int): assert ( torch.Size([self.x_size]) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): assert ( torch.Size(self.x_size) == x_0[0].shape ), "The shape of x_0 must be the same as the x_size that is specified in the config" else: assert False, "Invalid x_size" x = torch.repeat_interleave(x_0, torch.prod(extra_batch_size), dim=0) # x.shape = (B*N, D) if condition is not None: condition = torch.repeat_interleave( condition, torch.prod(extra_batch_size), dim=0 ) # condition.shape = (B*N, D) if isinstance(solver, DPMSolver): raise NotImplementedError("Not implemented") elif isinstance(solver, ODESolver): # TODO: make it compatible with TensorDict def drift(t, x): return self.model(t=t, x=x, condition=condition) if solver.library == "torchdiffeq_adjoint": if with_grad: data = solver.integrate( drift=drift, x0=x, t_span=t_span, adjoint_params=find_parameters(self.model), ) else: with torch.no_grad(): data = solver.integrate( drift=drift, x0=x, t_span=t_span, adjoint_params=find_parameters(self.model), ) else: if with_grad: data = solver.integrate( drift=drift, x0=x, t_span=t_span, ) else: with torch.no_grad(): data = solver.integrate( drift=drift, x0=x, t_span=t_span, ) elif isinstance(solver, DictTensorODESolver): # TODO: make it compatible with TensorDict def drift(t, x): return self.model(t=t, x=x, condition=condition) if with_grad: data = solver.integrate( drift=drift, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) else: with torch.no_grad(): data = solver.integrate( drift=drift, x0=x, t_span=t_span, batch_size=torch.prod(extra_batch_size) * data_batch_size, x_size=x.shape, ) elif isinstance(solver, SDESolver): raise NotImplementedError("Not implemented") else: raise NotImplementedError( "Solver type {} is not implemented".format(self.config.solver.type) ) if isinstance(data, torch.Tensor): # data.shape = (T, B*N, D) if len(extra_batch_size.shape) == 0: if isinstance(self.x_size, int): data = data.reshape( -1, extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size, int): data = data.reshape( -1, *extra_batch_size, data_batch_size, self.x_size ) elif ( isinstance(self.x_size, Tuple) or isinstance(self.x_size, List) or isinstance(self.x_size, torch.Size) ): data = data.reshape( -1, *extra_batch_size, data_batch_size, *self.x_size ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: if x_0 is None and condition is None: data = data.squeeze(1).squeeze(1) # data.shape = (T, D) else: data = data.squeeze(1) # data.shape = (T, N, D) else: if x_0 is None and condition is None: data = data.squeeze(1 + len(extra_batch_size.shape)) # data.shape = (T, B, D) else: # data.shape = (T, B, N, D) pass elif isinstance(data, TensorDict): raise NotImplementedError("Not implemented") elif isinstance(data, treetensor.torch.Tensor): for key in data.keys(): if len(extra_batch_size.shape) == 0: if isinstance(self.x_size[key], int): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, self.x_size[key] ) elif ( isinstance(self.x_size[key], Tuple) or isinstance(self.x_size[key], List) or isinstance(self.x_size[key], torch.Size) ): data[key] = data[key].reshape( -1, extra_batch_size, data_batch_size, *self.x_size[key] ) else: assert False, "Invalid x_size" else: if isinstance(self.x_size[key], int): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, self.x_size[key] ) elif ( isinstance(self.x_size[key], Tuple) or isinstance(self.x_size[key], List) or isinstance(self.x_size[key], torch.Size) ): data[key] = data[key].reshape( -1, *extra_batch_size, data_batch_size, *self.x_size[key] ) else: assert False, "Invalid x_size" # data.shape = (T, B, N, D) if batch_size is None: if x_0 is None and condition is None: data[key] = data[key].squeeze(1).squeeze(1) # data.shape = (T, D) else: data[key] = data[key].squeeze(1) # data.shape = (T, N, D) else: if x_0 is None and condition is None: data[key] = data[key].squeeze(1 + len(extra_batch_size.shape)) # data.shape = (T, B, D) else: # data.shape = (T, B, N, D) pass else: raise NotImplementedError("Not implemented") return data
[docs] def log_prob( self, x_1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], log_prob_x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, function_log_prob_x_0: Union[callable, nn.Module] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, t: torch.Tensor = None, using_Hutchinson_trace_estimator: bool = True, ) -> torch.Tensor: """ Overview: Compute the log probability of the final state given the initial state and the condition. Arguments: x_1 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The final state. log_prob_x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The log probability of the initial state. function_log_prob_x_0 (:obj:`Union[callable, nn.Module]`): The function to compute the log probability of the initial state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The condition. t (:obj:`torch.Tensor`): The time span. using_Hutchinson_trace_estimator (:obj:`bool`): Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, \ which is faster but less accurate. We recommend setting it to True for high dimensional data. Returns: log_likelihood (:obj:`torch.Tensor`): The log likelihood of the final state given the initial state and the condition. """ model_drift = lambda t, x: - self.model(1 - t, x, condition) model_params = find_parameters(self.model) def compute_trace_of_jacobian_general(dx, x): # Assuming x has shape (B, D1, ..., Dn) shape = x.shape[1:] # get the shape of a single element in the batch outputs = torch.zeros( x.shape[0], device=x.device, dtype=x.dtype ) # trace for each batch # Iterate through each index in the product of dimensions for index in torch.cartesian_prod(*(torch.arange(s) for s in shape)): if len(index.shape) > 0: index = tuple(index) else: index = (index,) grad_outputs = torch.zeros_like(x) grad_outputs[(slice(None), *index)] = ( 1 # set one at the specific index across all batches ) grads = torch.autograd.grad( outputs=dx, inputs=x, grad_outputs=grad_outputs, retain_graph=True )[0] outputs += grads[(slice(None), *index)] return outputs def compute_trace_of_jacobian_by_Hutchinson_Skilling(dx, x, eps): """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" fn_eps = torch.sum(dx * eps) grad_fn_eps = torch.autograd.grad(fn_eps, x, create_graph=True)[0] outputs = torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape)))) return outputs def composite_drift(t, x): # where x is actually x0_and_diff_logp, (x0, diff_logp), which is a tuple containing x and logp_xt_minus_logp_x0 with torch.set_grad_enabled(True): t = t.detach() x_t = x[0].detach() logp_xt_minus_logp_x0 = x[1] x_t.requires_grad = True t.requires_grad = True dx = model_drift(t, x_t) if using_Hutchinson_trace_estimator: noise = torch.randn_like(x_t, device=x_t.device) logp_drift = -compute_trace_of_jacobian_by_Hutchinson_Skilling( dx, x_t, noise ) # logp_drift = - divergence_approx(dx, x_t, noise) else: logp_drift = -compute_trace_of_jacobian_general(dx, x_t) return dx, logp_drift # x.shape = [batch_size, state_dim] x1_and_diff_logp = (x_1, torch.zeros(x_1.shape[0], device=x_1.device)) if t is None: eps = 1e-3 t_span = torch.linspace(eps, 1.0, 1000).to(x.device) else: t_span = t.to(x_1.device) solver = ODESolver(library="torchdiffeq_adjoint") x0_and_logpx0 = solver.integrate( drift=composite_drift, x0=x1_and_diff_logp, t_span=t_span, adjoint_params=model_params, ) logp_x0_minus_logp_x1 = x0_and_logpx0[1][-1] x0 = x0_and_logpx0[0][-1] if log_prob_x_0 is not None: log_likelihood = log_prob_x_0 - logp_x0_minus_logp_x1 elif function_log_prob_x_0 is not None: log_prob_x_0 = function_log_prob_x_0(x0) log_likelihood = log_prob_x_0 - logp_x0_minus_logp_x1 else: x0_1d = x0.reshape(x0.shape[0], -1) log_prob_x_0 = Independent( Normal( loc=torch.zeros_like(x0_1d, device=x0_1d.device), scale=torch.ones_like(x0_1d, device=x0_1d.device), ), 1, ).log_prob(x0_1d) log_likelihood = log_prob_x_0 - logp_x0_minus_logp_x1 return log_likelihood
[docs] def sample_with_log_prob( self, t_span: torch.Tensor = None, batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, with_grad: bool = False, solver_config: EasyDict = None, using_Hutchinson_trace_estimator: bool = True, ) -> Tuple[ Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], torch.Tensor ]: """ Overview: Sample from the model, return the final state and the log probability of the initial state. Arguments: t_span (:obj:`torch.Tensor`): The time span. batch_size (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size. x_0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state, if not provided, it will be sampled from the Gaussian distribution. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. with_grad (:obj:`bool`): Whether to return the gradient. solver_config (:obj:`EasyDict`): The configuration of the solver. using_Hutchinson_trace_estimator (:obj:`bool`): Whether to use Hutchinson trace estimator. It is an approximation of the trace of the Jacobian of the drift function, \ which is faster but less accurate. We recommend setting it to True for high dimensional data. Returns: x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The sampled result. log_prob_x_0 (:obj:`torch.Tensor`): The log probability of the initial state. """ x = self.sample( t_span=t_span, batch_size=batch_size, x_0=x_0, condition=condition, with_grad=with_grad, solver_config=solver_config, ) log_prob_x_0 = self.log_prob( x_1=x, condition=condition, t=t_span, using_Hutchinson_trace_estimator=using_Hutchinson_trace_estimator, ) return x, log_prob_x_0
[docs] def flow_matching_loss( self, x0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], x1: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, average: bool = True, ) -> torch.Tensor: """ Overview: Return the flow matching loss function of the model given the initial state and the condition. Arguments: x0 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The initial state. x1 (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The final state. condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The condition for the flow matching loss. average (:obj:`bool`): Whether to average the loss across the batch. """ return self.velocity_function_.flow_matching_loss_icfm( self.model, x0, x1, condition, average )