import logging
import platform
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any, TypeAlias, cast
import gymnasium as gym
import gymnasium.spaces
import numpy as np
from gymnasium import Env
from sensai.util.pickle import setstate
from sensai.util.string import ToStringMixin
from tianshou.env import (
BaseVectorEnv,
DummyVectorEnv,
RayVectorEnv,
SubprocVectorEnv,
)
from tianshou.highlevel.persistence import Persistence
from tianshou.utils.net.common import TActionShape
TObservationShape: TypeAlias = int | Sequence[int]
log = logging.getLogger(__name__)
[docs]
class EnvType(Enum):
"""Enumeration of environment types."""
CONTINUOUS = "continuous"
DISCRETE = "discrete"
[docs]
def is_discrete(self) -> bool:
return self == EnvType.DISCRETE
[docs]
def is_continuous(self) -> bool:
return self == EnvType.CONTINUOUS
[docs]
def assert_continuous(self, requiring_entity: Any) -> None:
if not self.is_continuous():
raise AssertionError(f"{requiring_entity} requires continuous environments")
[docs]
def assert_discrete(self, requiring_entity: Any) -> None:
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
[docs]
@staticmethod
def from_env(env: Env) -> "EnvType":
if isinstance(env.action_space, gymnasium.spaces.Discrete):
return EnvType.DISCRETE
elif isinstance(env.action_space, gymnasium.spaces.Box):
return EnvType.CONTINUOUS
else:
raise Exception(f"Unsupported environment type with action space {env.action_space}")
[docs]
class EnvMode(Enum):
"""Indicates the purpose for which an environment is created."""
TRAINING = "training"
TEST = "test"
WATCH = "watch"
[docs]
class VectorEnvType(Enum):
DUMMY = "dummy"
"""Vectorized environment without parallelization; environments are processed sequentially"""
SUBPROC = "subproc"
"""Parallelization based on `subprocess`"""
SUBPROC_SHARED_MEM_DEFAULT_CONTEXT = "shmem"
"""Parallelization based on `subprocess` with shared memory"""
SUBPROC_SHARED_MEM_FORK_CONTEXT = "shmem_fork"
"""Parallelization based on `subprocess` with shared memory and fork context (relevant for macOS, which uses `spawn`
by default https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)"""
RAY = "ray"
"""Parallelization based on the `ray` library"""
SUBPROC_SHARED_MEM_AUTO = "subproc_shared_mem_auto"
"""Parallelization based on `subprocess` with shared memory, using default context on windows and fork context otherwise"""
[docs]
def create_venv(
self,
factories: Sequence[Callable[[], gym.Env]],
) -> BaseVectorEnv:
match self:
case VectorEnvType.DUMMY:
return DummyVectorEnv(factories)
case VectorEnvType.SUBPROC:
return SubprocVectorEnv(factories)
case VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True)
case VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT:
return SubprocVectorEnv(factories, share_memory=True, context="fork")
case VectorEnvType.SUBPROC_SHARED_MEM_AUTO:
if platform.system().lower() == "windows":
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_DEFAULT_CONTEXT
else:
selected_venv_type = VectorEnvType.SUBPROC_SHARED_MEM_FORK_CONTEXT
return selected_venv_type.create_venv(factories)
case VectorEnvType.RAY:
return RayVectorEnv(factories)
case _:
raise NotImplementedError(self)
[docs]
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments for a learning process."""
def __init__(
self,
env: gym.Env,
training_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
self.env = env
self.training_envs = training_envs
self.test_envs = test_envs
self.watch_env = watch_env
self.persistence: Sequence[Persistence] = []
[docs]
@staticmethod
def from_factory_and_type(
factory_fn: Callable[[EnvMode], gym.Env],
env_type: EnvType,
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
:param factory_fn: the factory for a single environment instance
:param env_type: the type of environments created by `factory_fn`
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
training_envs = venv_type.create_venv(
[lambda: factory_fn(EnvMode.TRAINING)] * num_training_envs,
)
test_envs = venv_type.create_venv(
[lambda: factory_fn(EnvMode.TEST)] * num_test_envs,
)
if create_watch_env:
watch_env = VectorEnvType.DUMMY.create_venv([lambda: factory_fn(EnvMode.WATCH)])
else:
watch_env = None
env = factory_fn(EnvMode.TRAINING)
match env_type:
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, training_envs, test_envs, watch_env)
case EnvType.DISCRETE:
return DiscreteEnvironments(env, training_envs, test_envs, watch_env)
case _:
raise ValueError(f"Environment type {env_type} not handled")
def _tostring_includes(self) -> list[str]:
return []
def _tostring_additional_entries(self) -> dict[str, Any]:
return self.info()
[docs]
def info(self) -> dict[str, Any]:
return {
"action_shape": self.get_action_shape(),
"state_shape": self.get_observation_shape(),
}
[docs]
def set_persistence(self, *p: Persistence) -> None:
"""Associates the given persistence handlers which may persist and restore environment-specific information.
:param p: persistence handlers
"""
self.persistence = p
[docs]
@abstractmethod
def get_action_shape(self) -> TActionShape:
pass
[docs]
@abstractmethod
def get_observation_shape(self) -> TObservationShape:
pass
[docs]
def get_action_space(self) -> gym.Space:
return self.env.action_space
[docs]
def get_observation_space(self) -> gym.Space:
return self.env.observation_space
[docs]
@abstractmethod
def get_type(self) -> EnvType:
pass
[docs]
class ContinuousEnvironments(Environments):
"""Represents (vectorized) continuous environments."""
def __init__(
self,
env: gym.Env,
training_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, training_envs, test_envs, watch_env)
self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env)
[docs]
@staticmethod
def from_factory(
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
:param factory_fn: the factory for a single environment instance
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
return cast(
ContinuousEnvironments,
Environments.from_factory_and_type(
factory_fn,
EnvType.CONTINUOUS,
venv_type,
num_training_envs,
num_test_envs,
create_watch_env,
),
)
[docs]
def info(self) -> dict[str, Any]:
d = super().info()
d["max_action"] = self.max_action
return d
@staticmethod
def _get_continuous_env_info(
env: gym.Env,
) -> tuple[tuple[int, ...], tuple[int, ...], float]:
if not isinstance(env.action_space, gym.spaces.Box):
raise ValueError(
"Only environments with continuous action space are supported here. "
f"But got env with action space: {env.action_space.__class__}.",
)
state_shape = env.observation_space.shape or env.observation_space.n # type: ignore
if not state_shape:
raise ValueError("Observation space shape is not defined")
action_shape = env.action_space.shape
max_action = env.action_space.high[0]
return state_shape, action_shape, max_action
[docs]
def get_action_shape(self) -> TActionShape:
return self.action_shape
[docs]
def get_observation_shape(self) -> TObservationShape:
return self.state_shape
[docs]
def get_type(self) -> EnvType:
return EnvType.CONTINUOUS
[docs]
class DiscreteEnvironments(Environments):
"""Represents (vectorized) discrete environments."""
def __init__(
self,
env: gym.Env,
training_envs: BaseVectorEnv,
test_envs: BaseVectorEnv,
watch_env: BaseVectorEnv | None = None,
):
super().__init__(env, training_envs, test_envs, watch_env)
self.observation_shape = env.observation_space.shape or env.observation_space.n # type: ignore
self.action_shape = env.action_space.shape or env.action_space.n # type: ignore
[docs]
@staticmethod
def from_factory(
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
:param factory_fn: the factory for a single environment instance
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param create_watch_env: whether to create an environment for watching the agent
:return: the instance
"""
return cast(
DiscreteEnvironments,
Environments.from_factory_and_type(
factory_fn,
EnvType.DISCRETE,
venv_type,
num_training_envs,
num_test_envs,
create_watch_env,
),
)
[docs]
def get_action_shape(self) -> TActionShape:
return self.action_shape
[docs]
def get_observation_shape(self) -> TObservationShape:
return self.observation_shape
[docs]
def get_type(self) -> EnvType:
return EnvType.DISCRETE
[docs]
class EnvPoolFactory:
"""A factory for the creation of envpool-based vectorized environments, which can be used in conjunction
with :class:`EnvFactoryRegistered`.
"""
def _transform_task(self, task: str) -> str:
return task
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
"""Transforms gymnasium keyword arguments to be envpool-compatible.
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
:param mode: the environment mode
:return: the transformed keyword arguments
"""
kwargs = dict(kwargs)
if "render_mode" in kwargs:
del kwargs["render_mode"]
return kwargs
[docs]
def create_venv(
self,
task: str,
num_envs: int,
mode: EnvMode,
seed: int,
kwargs: dict,
) -> BaseVectorEnv:
import envpool
envpool_task = self._transform_task(task)
envpool_kwargs = self._transform_kwargs(kwargs, mode)
return envpool.make_gymnasium(
envpool_task,
num_envs=num_envs,
seed=seed,
**envpool_kwargs,
)
[docs]
class EnvFactory(ToStringMixin, ABC):
def __init__(self, venv_type: VectorEnvType):
"""Main interface for the creation of environments (in various forms).
:param venv_type: the type of vectorized environment to use for train and test environments.
`WATCH` environments are always created as `DUMMY` vector environments.
"""
self.venv_type = venv_type
@staticmethod
def _create_rng(seed: int | None) -> np.random.Generator:
"""
Creates a random number generator with the given seed.
:param seed: the seed to use; if None, a random seed will be used
:return: the random number generator
"""
return np.random.default_rng(seed=seed)
@staticmethod
def _next_seed(rng: np.random.Generator) -> int:
"""
Samples a random seed from the given random number generator.
:param rng: the random number generator
:return: the sampled random seed
"""
# int32 is needed for envpool compatibility
return int(rng.integers(0, 2**31, dtype=np.int32))
@abstractmethod
def _create_env(self, mode: EnvMode) -> Env:
"""Creates a single environment for the given mode.
:param mode: the mode
:return: an environment
"""
[docs]
def create_env(self, mode: EnvMode, seed: int | None = None) -> Env:
"""
Creates a single environment for the given mode.
:param mode: the mode
:param seed: the random seed to use for the environment; if None, the seed will not be specified,
and gymnasium will use a random seed.
:return: the environment
"""
env = self._create_env(mode)
# initialize the environment with the given seed (if any)
if seed is not None:
rng = self._create_rng(seed)
env.np_random = rng
# also set the seed member within the environment such that it can be retrieved
# (gymnasium's random seed handling is, unfortunately, broken)
if hasattr(env, "_np_random_seed"):
env._np_random_seed = seed
return env
[docs]
def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create.
In `WATCH` mode the resulting venv will always be of type `DUMMY` with a single env.
:return: the vectorized environments
"""
rng = self._create_rng(seed)
def create_factory_fn() -> Callable[[], Env]:
# create a factory function that uses a sampled random seed
return lambda random_seed=self._next_seed(rng): self.create_env(mode, seed=random_seed) # type: ignore
# create the vectorized environment, seeded appropriately
if mode == EnvMode.WATCH:
venv = VectorEnvType.DUMMY.create_venv([create_factory_fn()])
else:
venv = self.venv_type.create_venv([create_factory_fn() for _ in range(num_envs)])
# seed the action samplers
venv.seed([self._next_seed(rng) for _ in range(num_envs)])
return venv
[docs]
def create_envs(
self,
num_training_envs: int,
num_test_envs: int,
create_watch_env: bool = False,
seed: int | None = None,
) -> Environments:
"""Create environments for learning.
:param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments
:param create_watch_env: whether to create an environment for watching the agent
:param seed: the random seed to use for environment creation
:return: the environments
"""
rng = self._create_rng(seed)
env = self.create_env(EnvMode.TRAINING)
training_envs = self.create_venv(
num_training_envs, EnvMode.TRAINING, seed=self._next_seed(rng)
)
test_envs = self.create_venv(num_test_envs, EnvMode.TEST, seed=self._next_seed(rng))
watch_env = (
self.create_venv(1, EnvMode.WATCH, seed=self._next_seed(rng))
if create_watch_env
else None
)
match EnvType.from_env(env):
case EnvType.DISCRETE:
return DiscreteEnvironments(env, training_envs, test_envs, watch_env)
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, training_envs, test_envs, watch_env)
case _:
raise ValueError
[docs]
class EnvFactoryRegistered(EnvFactory):
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
(or via `envpool.make_gymnasium`).
"""
def __init__(
self,
*,
task: str,
venv_type: VectorEnvType,
envpool_factory: EnvPoolFactory | None = None,
render_mode_training: str | None = None,
render_mode_test: str | None = None,
render_mode_watch: str = "human",
**make_kwargs: Any,
):
""":param task: the gymnasium task/environment identifier
:param seed: the random seed
:param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified)
:param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed.
:param render_mode_training: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`. If envpool is used, the gymnasium parameters will be appropriately translated for use with `envpool.make_gymnasium`.
"""
super().__init__(venv_type)
self.task = task
self.envpool_factory = envpool_factory
self.render_modes = {
EnvMode.TRAINING: render_mode_training,
EnvMode.TEST: render_mode_test,
EnvMode.WATCH: render_mode_watch,
}
self.make_kwargs = make_kwargs
def __setstate__(self, state: dict) -> None:
if "seed" in state:
if "test_seed" in state or "training_seed" in state:
raise RuntimeError(
f"Cannot have both 'seed' and 'test_seed'/'training_seed' in state. "
f"Something went wrong during serialization/deserialization: "
f"{state=}",
)
state["test_seed"] = state["seed"]
state["training_seed"] = state["seed"]
del state["seed"]
if "train_seed" in state:
state["training_seed"] = state["train_seed"]
del state["train_seed"]
setstate(EnvFactoryRegistered, self, state)
def _create_kwargs(self, mode: EnvMode) -> dict:
"""Adapts the keyword arguments for the given mode.
:param mode: the mode
:return: adapted keyword arguments
"""
kwargs = dict(self.make_kwargs)
kwargs["render_mode"] = self.render_modes.get(mode)
return kwargs
def _create_env(self, mode: EnvMode) -> Env:
"""Creates a single environment for the given mode.
:param mode: the mode
:return: an environment
"""
kwargs = self._create_kwargs(mode)
return gymnasium.make(self.task, **kwargs)
[docs]
def create_venv(self, num_envs: int, mode: EnvMode, seed: int | None = None) -> BaseVectorEnv:
if self.envpool_factory is not None:
rng = self._create_rng(seed)
return self.envpool_factory.create_venv(
self.task,
num_envs,
mode,
self._next_seed(rng),
self._create_kwargs(mode),
)
else:
return super().create_venv(num_envs, mode, seed=seed)