Source code for tianshou.highlevel.params.collector

from abc import ABC, abstractmethod

from tianshou.algorithm import Algorithm
from tianshou.data import BaseCollector, Collector, ReplayBuffer
from tianshou.env import BaseVectorEnv


[docs] class CollectorFactory(ABC):
[docs] @abstractmethod def create_collector( self, algorithm: Algorithm, vector_env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> BaseCollector: """ Creates a collector for the given algorithm and vectorized environment. :param algorithm: the algorithm :param vector_env: the vectorized environment :param buffer: the replay buffer to be used by the collector; if None, a new buffer will be created with default parameters :param exploration_noise: whether action shall be modified using the policy's exploration noise :return: the collector """
[docs] class CollectorFactoryDefault(CollectorFactory):
[docs] def create_collector( self, algorithm: Algorithm, vector_env: BaseVectorEnv, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> BaseCollector: return Collector( algorithm.policy, vector_env, buffer=buffer, exploration_noise=exploration_noise )