grl.algorithms¶
QGPOCritic¶
- class grl.algorithms.QGPOCritic(config)[source]¶
- Overview:
Critic network for QGPO algorithm.
- Interfaces:
__init__,forward
- __init__(config)[source]¶
- Overview:
Initialization of QGPO critic network.
- Parameters:
config (
EasyDict) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]) – The input action.state (
Union[torch.Tensor, TensorDict]) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict])
- forward(action, state=None)[source]¶
- Overview:
Return the output of QGPO critic.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- Return type:
Tensor
- q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.reward (
torch.Tensor) – The input reward.next_state (
torch.Tensor) – The input next state.done (
torch.Tensor) – The input done.fake_next_action (
torch.Tensor) – The input fake next action.discount_factor (
float) – The discount factor.
- Return type:
Tensor
QGPOPolicy¶
- class grl.algorithms.QGPOPolicy(config)[source]¶
- Overview:
QGPO policy network.
- Interfaces:
__init__,forward,sample,behaviour_policy_sample,compute_q,behaviour_policy_loss,energy_guidance_loss,q_loss
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- behaviour_policy_loss(action, state)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.action (
Union[torch.Tensor, TensorDict]) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor)
- energy_guidance_loss(state, fake_next_action)[source]¶
- Overview:
Calculate the energy guidance loss of QGPO.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.fake_next_action (
Union[torch.Tensor, TensorDict]) – The input fake next action.
- Return type:
Tensor
- forward(state)[source]¶
- Overview:
Return the output of QGPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- q_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.reward (
torch.Tensor) – The input reward.next_state (
torch.Tensor) – The input next state.done (
torch.Tensor) – The input done.fake_next_action (
torch.Tensor) – The input fake next action.discount_factor (
float) – The discount factor.
- Return type:
Tensor
- sample(state, batch_size=None, guidance_scale=tensor(1.), solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of QGPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.guidance_scale (
Union[torch.Tensor, float]) – The guidance scale.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
QGPOAlgorithm¶
- class grl.algorithms.QGPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Q-guided policy optimization (QGPO) algorithm, which is an offline reinforcement learning algorithm that uses energy-based diffusion model for policy modeling.
- Interfaces:
__init__,train,deploy
- __init__(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Initialize the QGPO algorithm.
- Parameters:
config (
EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.simulator (
object) – The environment simulator.dataset (
QGPODataset) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.
- Interface:
__init__,train,deploy
SRPOCritic¶
SRPOPolicy¶
- class grl.algorithms.SRPOPolicy(config)[source]¶
- Overview:
The SRPO policy network.
- Interfaces:
__init__,forward,sample,behaviour_policy_loss,srpo_actor_loss
- __init__(config)[source]¶
- Overview:
Initialize the SRPO policy network.
- Parameters:
config (
EasyDict) – The configuration.
- behaviour_policy_loss(action, state)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- forward(state)[source]¶
- Overview:
Return the output of SRPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- sample(state, batch_size=None, solver_config=None, t_span=None)[source]¶
- Overview:
Return the output of SRPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- srpo_actor_loss(state)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.reward (
torch.Tensor) – The input reward.next_state (
torch.Tensor) – The input next state.done (
torch.Tensor) – The input done.fake_next_action (
torch.Tensor) – The input fake next action.discount_factor (
float) – The discount factor.
- Return type:
Tensor
SRPOAlgorithm¶
- class grl.algorithms.SRPOAlgorithm(config=None, simulator=None, dataset=None, model=None)[source]¶
- __init__(config=None, simulator=None, dataset=None, model=None)[source]¶
- Overview:
Initialize the SRPO algorithm.
- Parameters:
config (
EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.simulator (
object) – The environment simulator.dataset (
Dataset) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.
- Interface:
__init__,train,deploy
GMPOCritic¶
- class grl.algorithms.GMPOCritic(config)[source]¶
- Overview:
Critic network for GMPO algorithm.
- Interfaces:
__init__,forward
- __init__(config)[source]¶
- Overview:
Initialization of GMPO critic network.
- Parameters:
config (
EasyDict) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]) – The input action.state (
Union[torch.Tensor, TensorDict]) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict])
GMPOPolicy¶
- class grl.algorithms.GMPOPolicy(config)[source]¶
- Overview:
GMPO policy network for GMPO algorithm, which includes the base model (optinal), the guided model and the critic.
- Interfaces:
__init__,forward,sample,compute_q,behaviour_policy_loss,policy_optimization_loss_by_advantage_weighted_regression,policy_optimization_loss_by_advantage_weighted_regression_softmax
- __init__(config)[source]¶
- Overview:
Initialize the GMPO policy network.
- Parameters:
config (
EasyDict) – The configuration dict.
- behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.with_grad (
bool) – Whether to calculate the gradient.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.action (
Union[torch.Tensor, TensorDict]) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor)
- forward(state)[source]¶
- Overview:
Return the output of GMPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- policy_optimization_loss_by_advantage_weighted_regression(action, state, maximum_likelihood=False, beta=1.0, weight_clamp=100.0)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- policy_optimization_loss_by_advantage_weighted_regression_softmax(state, fake_action, maximum_likelihood=False, beta=1.0)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of GMPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
GMPOAlgorithm¶
- class grl.algorithms.GMPOAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
The Generative Model Policy Optimization(GMPO) algorithm.
- Interfaces:
__init__,train,deploy
- __init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
Initialize the GMPO && GPG algorithm.
- Parameters:
config (
EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.simulator (
object) – The environment simulator.dataset (
GPDataset) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.
- Interface:
__init__,train,deploy
GMPGCritic¶
- class grl.algorithms.GMPGCritic(config)[source]¶
- Overview:
Critic network.
- Interfaces:
__init__,forward
- __init__(config)[source]¶
- Overview:
Initialization of GPO critic network.
- Parameters:
config (
EasyDict) – The configuration dict.
- compute_double_q(action, state=None)[source]¶
- Overview:
Return the output of two Q networks.
- Parameters:
action (
Union[torch.Tensor, TensorDict]) – The input action.state (
Union[torch.Tensor, TensorDict]) – The input state.
- Returns:
The output of the first Q network. q2 (
Union[torch.Tensor, TensorDict]): The output of the second Q network.- Return type:
q1 (
Union[torch.Tensor, TensorDict])
- forward(action, state=None)[source]¶
- Overview:
Return the output of GPO critic.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- Return type:
Tensor
- in_support_ql_loss(action, state, reward, next_state, done, fake_next_action, discount_factor=1.0)[source]¶
- Overview:
Calculate the Q loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.reward (
torch.Tensor) – The input reward.next_state (
torch.Tensor) – The input next state.done (
torch.Tensor) – The input done.fake_next_action (
torch.Tensor) – The input fake next action.discount_factor (
float) – The discount factor.
- Return type:
Tensor
GMPGPolicy¶
- class grl.algorithms.GMPGPolicy(config)[source]¶
- __init__(config)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- behaviour_policy_loss(action, state, maximum_likelihood=False)[source]¶
- Overview:
Calculate the behaviour policy loss.
- Parameters:
action (
torch.Tensor) – The input action.state (
torch.Tensor) – The input state.
- behaviour_policy_sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of behaviour policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.with_grad (
bool) – Whether to calculate the gradient.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- compute_q(state, action)[source]¶
- Overview:
Calculate the Q value.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.action (
Union[torch.Tensor, TensorDict]) – The input action.
- Returns:
The Q value.
- Return type:
q (
torch.Tensor)
- forward(state)[source]¶
- Overview:
Return the output of GPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
- sample(state, batch_size=None, solver_config=None, t_span=None, with_grad=False)[source]¶
- Overview:
Return the output of GPO policy, which is the action conditioned on the state.
- Parameters:
state (
Union[torch.Tensor, TensorDict]) – The input state.batch_size (
Union[torch.Size, int, Tuple[int], List[int]]) – The batch size.solver_config (
EasyDict) – The configuration for the ODE solver.t_span (
torch.Tensor) – The time span for the ODE solver or SDE solver.
- Returns:
The output action.
- Return type:
action (
Union[torch.Tensor, TensorDict])
GMPGAlgorithm¶
- class grl.algorithms.GMPGAlgorithm(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
The Generative Model Policy Gradient(GMPG) algorithm.
- Interfaces:
__init__,train,deploy
- __init__(config=None, simulator=None, dataset=None, model=None, seed=None)[source]¶
- Overview:
Initialize algorithm.
- Parameters:
config (
EasyDict) – The configuration , which must contain the following keys: train (EasyDict): The training configuration. deploy (EasyDict): The deployment configuration.simulator (
object) – The environment simulator.dataset (
GPDataset) – The dataset.model (
Union[torch.nn.Module, torch.nn.ModuleDict]) – The model.
- Interface:
__init__,train,deploy