Shortcuts

Source code for grl.rl_modules.value_network.q_network

from typing import Tuple, Union

import torch
import torch.nn as nn
from easydict import EasyDict
from tensordict import TensorDict

from grl.neural_network import get_module
from grl.neural_network.encoders import get_encoder


[docs]class QNetwork(nn.Module): """ Overview: Q network, which is used to approximate the Q value. Interfaces: ``__init__``, ``forward`` """
[docs] def __init__(self, config: EasyDict): """ Overview: Initialization of Q network. Arguments: config (:obj:`EasyDict`): The configuration dict. """ super().__init__() self.config = config self.model = torch.nn.ModuleDict() if hasattr(config, "action_encoder"): self.model["action_encoder"] = get_encoder(config.action_encoder.type)( **config.action_encoder.args ) else: self.model["action_encoder"] = torch.nn.Identity() if hasattr(config, "state_encoder"): self.model["state_encoder"] = get_encoder(config.state_encoder.type)( **config.state_encoder.args ) else: self.model["state_encoder"] = torch.nn.Identity() # TODO # specific backbone network self.model["backbone"] = get_module(config.backbone.type)( **config.backbone.args )
[docs] def forward( self, action: Union[torch.Tensor, TensorDict], state: Union[torch.Tensor, TensorDict], ) -> torch.Tensor: """ Overview: Return output of Q networks. Arguments: action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: q (:obj:`Union[torch.Tensor, TensorDict]`): The output of Q network. """ action_embedding = self.model["action_encoder"](action) state_embedding = self.model["state_encoder"](state) return self.model["backbone"](action_embedding, state_embedding)
[docs]class DoubleQNetwork(nn.Module): """ Overview: Double Q network, which has two Q networks. Interfaces: ``__init__``, ``forward``, ``compute_double_q``, ``compute_mininum_q`` """
[docs] def __init__(self, config: EasyDict): super().__init__() self.model = torch.nn.ModuleDict() self.model["q1"] = QNetwork(config) self.model["q2"] = QNetwork(config)
[docs] def compute_double_q( self, action: Union[torch.Tensor, TensorDict], state: Union[torch.Tensor, TensorDict], ) -> Tuple[torch.Tensor, torch.Tensor]: """ Overview: Return the output of two Q networks. Arguments: action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. """ return self.model["q1"](action, state), self.model["q2"](action, state)
[docs] def compute_mininum_q( self, action: Union[torch.Tensor, TensorDict], state: Union[torch.Tensor, TensorDict], ) -> torch.Tensor: """ Overview: Return the minimum output of two Q networks. Arguments: action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. """ return torch.min(*self.compute_double_q(action, state))
[docs] def forward( self, action: Union[torch.Tensor, TensorDict], state: Union[torch.Tensor, TensorDict], ) -> torch.Tensor: """ Overview: Return the minimum output of two Q networks. Arguments: action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. Returns: minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. """ return self.compute_mininum_q(action, state)