atari_network#


layer_init(layer: Module, std: float = 1.4142135623730951, bias_const: float = 0.0) Module[source]#

TODO.

class ScaledObsInputActionReprNet(module: ActionReprNetWithVectorOutput, denom: float = 255.0)[source]#

Bases: ActionReprNetWithVectorOutput

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor | Sequence[Tensor], T | None][source]#

The main method for tianshou to compute action representations (such as actions, inputs of distributions, Q-values, etc) from env observations. Implementations will always make use of the preprocess_net as the first processing step.

Parameters:
  • obs – the observations from the environment as retrieved from ObsBatchProtocol.obs. If the environment is a dict env, this will be an instance of Batch, otherwise it will be an array (or tensor if your env returns tensors).

  • state – the hidden state of the RNN, if applicable

  • info – the info object from the environment step

Returns:

a tuple (action_repr, hidden_state), where action_repr is either an actual action for the environment or a representation from which it can be retrieved/sampled (e.g., mean and std for a Gaussian distribution), and hidden_state is the new hidden state of the RNN, if applicable.

class DQNet(c: int, h: int, w: int, action_shape: ~collections.abc.Sequence[int] | int, features_only: bool = False, output_dim_added_layer: int | None = None, layer_init: ~collections.abc.Callable[[~torch.nn.modules.module.Module], ~torch.nn.modules.module.Module] = <function DQNet.<lambda>>)[source]#

Bases: ActionReprNetWithVectorOutput[Any]

Reference: Human-level control through deep reinforcement learning.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor, T | None][source]#

Mapping: s -> Q(s, *).

For more info, see docstring of parent.

class C51Net(*, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51)[source]#

Bases: DQNet

Reference: A distributional perspective on reinforcement learning.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, T] | None = None) tuple[Tensor, T | None][source]#

Mapping: x -> Z(x, *).

class RainbowNet(*, c: int, h: int, w: int, action_shape: Sequence[int], num_atoms: int = 51, noisy_std: float = 0.5, is_dueling: bool = True, is_noisy: bool = True)[source]#

Bases: DQNet

Reference: Rainbow: Combining Improvements in Deep Reinforcement Learning.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, Any] | None = None) tuple[Tensor, T | None][source]#

Mapping: s -> Q(s, *).

For more info, see docstring of parent.

class QRDQNet(*, c: int, h: int, w: int, action_shape: Sequence[int] | int, num_quantiles: int = 200)[source]#

Bases: DQNet

Reference: Distributional Reinforcement Learning with Quantile Regression.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(obs: Tensor | ndarray | BatchProtocol, state: T | None = None, info: dict[str, Any] | None = None) tuple[Tensor, T | None][source]#

Mapping: s -> Q(s, *).

For more info, see docstring of parent.

class ActorFactoryAtariDQN(scale_obs: bool = True, features_only: bool = False, output_dim_added_layer: int | None = None)[source]#

Bases: ActorFactory

USE_SOFTMAX_OUTPUT = False#
create_module(envs: Environments, device: str | device) DiscreteActor[source]#
create_dist_fn(envs: Environments) Callable[[tuple[Tensor, Tensor]], Distribution] | Callable[[Tensor], Distribution] | None[source]#
Parameters:

envs – the environments

Returns:

the distribution function, which converts the actor’s output into a distribution, or None if the actor does not output distribution parameters

class IntermediateModuleFactoryAtariDQN(features_only: bool = False, net_only: bool = False)[source]#

Bases: IntermediateModuleFactory

create_intermediate_module(envs: Environments, device: str | device) IntermediateModule[source]#
class IntermediateModuleFactoryAtariDQNFeatures[source]#

Bases: IntermediateModuleFactoryAtariDQN