Source code for tianshou.highlevel.module.intermediate

from abc import ABC, abstractmethod
from dataclasses import dataclass

import torch
from sensai.util.string import ToStringMixin

from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.utils.net.common import ModuleWithVectorOutput


[docs] @dataclass class IntermediateModule: """Container for a module which computes an intermediate representation (with a known dimension).""" module: torch.nn.Module output_dim: int
[docs] def get_module_with_vector_output(self) -> ModuleWithVectorOutput: if isinstance(self.module, ModuleWithVectorOutput): return self.module else: return ModuleWithVectorOutput.from_module(self.module, self.output_dim)
[docs] class IntermediateModuleFactory(ToStringMixin, ModuleFactory, ABC): """Factory for the generation of a module which computes an intermediate representation."""
[docs] @abstractmethod def create_intermediate_module(self, envs: Environments, device: TDevice) -> IntermediateModule: pass
[docs] def create_module(self, envs: Environments, device: TDevice) -> torch.nn.Module: return self.create_intermediate_module(envs, device).module