Shortcuts

Source code for grl.numerical_methods.ode

from typing import Callable, Union

from torch import nn


[docs]class ODE: """ Overview: Base class for ordinary differential equations. The ODE is defined as: .. math:: dx = f(x, t)dt where f(x, t) is the drift term. Interfaces: ``__init__`` """
[docs] def __init__( self, drift: Union[nn.Module, Callable] = None, ): self.drift = drift