psrl#
Source code: tianshou/policy/modelbased/psrl.py
- class PSRLModel(trans_count_prior: ndarray, rew_mean_prior: ndarray, rew_std_prior: ndarray, discount_factor: float, epsilon: float)[source]#
Implementation of Posterior Sampling Reinforcement Learning Model.
- Parameters:
trans_count_prior – dirichlet prior (alphas), with shape (n_state, n_action, n_state).
rew_mean_prior – means of the normal priors of rewards, with shape (n_state, n_action).
rew_std_prior – standard deviations of the normal priors of rewards, with shape (n_state, n_action).
discount_factor – in [0, 1].
epsilon – for precision control in value iteration.
lr_scheduler – a learning rate scheduler that adjusts the learning rate in optimizer in each policy.update(). Default to None (no lr_scheduler).
- observe(trans_count: ndarray, rew_sum: ndarray, rew_square_sum: ndarray, rew_count: ndarray) None [source]#
Add data into memory pool.
For rewards, we have a normal prior at first. After we observed a reward for a given state-action pair, we use the mean value of our observations instead of the prior mean as the posterior mean. The standard deviations are in inverse proportion to the number of the corresponding observations.
- Parameters:
trans_count – the number of observations, with shape (n_state, n_action, n_state).
rew_sum – total rewards, with shape (n_state, n_action).
rew_square_sum – total rewards’ squares, with shape (n_state, n_action).
rew_count – the number of rewards, with shape (n_state, n_action).
- static value_iteration(trans_prob: ndarray, rew: ndarray, discount_factor: float, eps: float, value: ndarray) tuple[ndarray, ndarray] [source]#
Value iteration solver for MDPs.
- Parameters:
trans_prob – transition probabilities, with shape (n_state, n_action, n_state).
rew – rewards, with shape (n_state, n_action).
eps – for precision control.
discount_factor – in [0, 1].
value – the initialize value of value array, with shape (n_state, ).
- Returns:
the optimal policy with shape (n_state, ).
- class PSRLPolicy(*, trans_count_prior: ndarray, rew_mean_prior: ndarray, rew_std_prior: ndarray, action_space: Discrete, discount_factor: float = 0.99, epsilon: float = 0.01, add_done_loop: bool = False, observation_space: Space | None = None, lr_scheduler: LRScheduler | MultipleLRSchedulers | None = None)[source]#
Implementation of Posterior Sampling Reinforcement Learning.
Reference: Strens M. A Bayesian framework for reinforcement learning [C] //ICML. 2000, 2000: 943-950.
- Parameters:
trans_count_prior – dirichlet prior (alphas), with shape (n_state, n_action, n_state).
rew_mean_prior – means of the normal priors of rewards, with shape (n_state, n_action).
rew_std_prior – standard deviations of the normal priors of rewards, with shape (n_state, n_action).
action_space – Env’s action_space.
discount_factor – in [0, 1].
epsilon – for precision control in value iteration.
add_done_loop – whether to add an extra self-loop for the terminal state in MDP. Default to False.
observation_space – Env’s observation space.
lr_scheduler – if not None, will be called in policy.update().
See also
Please refer to
BasePolicy
for more detailed explanation.- forward(batch: ObsBatchProtocol, state: dict | BatchProtocol | ndarray | None = None, **kwargs: Any) ActBatchProtocol [source]#
Compute action over the given batch data with PSRL model.
- Returns:
A
Batch
with “act” key containing the action.
See also
Please refer to
forward()
for more detailed explanation.
- learn(batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) TPSRLTrainingStats [source]#
Update policy with a given batch of data.
- Returns:
A dataclass object, including the data needed to be logged (e.g., loss).
Note
In order to distinguish the collecting state, updating state and testing state, you can check the policy state by
self.training
andself.updating
. Please refer to States for policy for more detailed explanation.Warning
If you use
torch.distributions.Normal
andtorch.distributions.Categorical
to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[batch_size, 1]” shape. The auto-broadcasting of numerical operation with torch tensors will amplify this error.