atari_network#
Source code: tianshou/env/atari/atari_network.py
- layer_init(layer: Module, std: float = 1.4142135623730951, bias_const: float = 0.0) Module[source]#
TODO.
- class ScaledObsInputActionReprNet(module: ActionReprNetWithVectorOutput, denom: float = 255.0)[source]#
Bases:
ActionReprNetWithVectorOutputInitializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor | Sequence[Tensor], T | None][source]#
The main method for tianshou to compute action representations (such as actions, inputs of distributions, Q-values, etc) from env observations. Implementations will always make use of the preprocess_net as the first processing step.
- Parameters:
obs – the observations from the environment as retrieved from ObsBatchProtocol.obs. If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors).
state – the hidden state of the RNN, if applicable
info – the info object from the environment step
- Returns:
a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), and hidden_state is the new hidden state of the RNN, if applicable.
- class DQNet(c: int, h: int, w: int, action_shape: ~collections.abc.Sequence[int] | int, features_only: bool = False, output_dim_added_layer: int | None = None, layer_init: ~collections.abc.Callable[[~torch.nn.modules.module.Module], ~torch.nn.modules.module.Module] = <function DQNet.<lambda>>)[source]#
Bases:
ActionReprNetWithVectorOutput[Any]Reference: Human-level control through deep reinforcement learning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor, T | None][source]#
Mapping: s -> Q(s, *).
For more info, see docstring of parent.
- class C51Net(*, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51)[source]#
Bases:
DQNetReference: A distributional perspective on reinforcement learning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor, T | None][source]#
Mapping: x -> Z(x, *).
- class RainbowNet(*, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, is_dueling: bool = True, is_noisy: bool = True)[source]#
Bases:
DQNetReference: Rainbow: Combining Improvements in Deep Reinforcement Learning.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, Any] | None = None) tuple[Tensor, T | None][source]#
Mapping: s -> Q(s, *).
For more info, see docstring of parent.
- class QRDQNet(*, c: int, h: int, w: int, action_shape: Sequence[int] | int, num_quantiles: int = 200)[source]#
Bases:
DQNetReference: Distributional Reinforcement Learning with Quantile Regression.
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, Any] | None = None) tuple[Tensor, T | None][source]#
Mapping: s -> Q(s, *).
For more info, see docstring of parent.
- class ActorFactoryAtariDQN(scale_obs: bool = True, features_only: bool = False, output_dim_added_layer: int | None = None)[source]#
Bases:
ActorFactory- USE_SOFTMAX_OUTPUT = False#
- create_module(envs: Environments, device: str | device) DiscreteActor[source]#
- create_dist_fn(envs: Environments) Callable[[tuple[Tensor, Tensor]], Distribution] | Callable[[Tensor], Distribution] | None[source]#
- Parameters:
envs – the environments
- Returns:
the distribution function, which converts the actor’s output into a distribution, or None if the actor does not output distribution parameters
- class IntermediateModuleFactoryAtariDQN(features_only: bool = False, net_only: bool = False)[source]#
Bases:
IntermediateModuleFactory- create_intermediate_module(envs: Environments, device: str | device) IntermediateModule[source]#