Shortcuts

Source code for grl.rl_modules.value_network.one_shot_value_function

import copy
from typing import Tuple, Union

import torch
import torch.nn as nn
from easydict import EasyDict
from tensordict import TensorDict

from grl.rl_modules.value_network.value_network import DoubleVNetwork


[docs]class OneShotValueFunction(nn.Module): """ Overview: Value network for one-shot cases, which means that no Bellman backup is needed for training. Interfaces: ``__init__``, ``forward`` """
[docs] def __init__(self, config: EasyDict): """ Overview: Initialization of one-shot value network. Arguments: config (:obj:`EasyDict`): The configuration dict. """ super().__init__() self.config = config self.v_alpha = config.v_alpha self.v = DoubleVNetwork(config.DoubleVNetwork) self.v_target = copy.deepcopy(self.v).requires_grad_(False)
[docs] def forward( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict] = None, ) -> torch.Tensor: """ Overview: Return the output of one-shot value network. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. """ return self.v(state, condition)
[docs] def compute_double_v( self, state: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Return the output of two value networks. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. """ return self.v.compute_double_v(state, condition=condition)
[docs] def v_loss( self, state: Union[torch.Tensor, TensorDict], value: Union[torch.Tensor, TensorDict], condition: Union[torch.Tensor, TensorDict] = None, ) -> torch.Tensor: """ Overview: Calculate the v loss. Arguments: state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. value (:obj:`Union[torch.Tensor, TensorDict]`): The input value. condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. Returns: v_loss (:obj:`torch.Tensor`): The v loss. """ # Update value function targets = value v0, v1 = self.v.compute_double_v(state, condition=condition) v_loss = ( torch.nn.functional.mse_loss(v0, targets) + torch.nn.functional.mse_loss(v1, targets) ) / 2 return v_loss