Planning
Contents
Planning#
Planning is a method of simulating a sequence of actions in an environment model before actually taking an action in the real environment.
Concepts covered:
Cross entropy method (CEM)
Monte Carlo tree search (MCTS)
Probabilistic ensembles with trajectory sampling (PETS)
References:
Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamics Models
Exploring Model-based Planning with Policy Networks
Learning Latent Dynamics for Planning from Pixels
Cross Entropy Method#
The Cross Entropy Method (CEM) is a gradient-free optimization method commonly used for planning in model-based reinforcement learning.
CEM Algorithm
Create a Gaussian distribution \(N(\mu,\sigma)\) that describes the weights \(\theta\) of the neural network.
Sample \(N\) batch samples of \(\theta\) from the Gaussian.
Evaluate all \(N\) samples of \(\theta\) using the value function, e.g. running trials.
Select the top % of the samples of \(\theta\) and compute the new \(\mu\) and \(\sigma\) to parameterise the new Gaussian distribution.
Repeat steps 1-4 until convergence.
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
import gym
import warnings
warnings.filterwarnings("ignore")
# RL Gym
env = gym.make('CartPole-v1')
# Initialisation
n = 10 # number of candidate policies
top_k = 0.40 # top % selected for next iteration
mean = np.zeros((5,2)) # shape = (n_parameters, n_actions)
stddev = np.ones((5,2)) # shape = (n_parameters, n_actions)
def get_batch_weights(mean, stddev, n):
mvn = tfd.MultivariateNormalDiag(
loc=mean,
scale_diag=stddev)
return mvn.sample(n).numpy()
def policy(obs, weights):
return np.argmax(obs @ weights[:4,:] + weights[4])
def run_trial(weights, render=False):
obs = env.reset()
done = False
reward = 0
while not done:
a = policy(obs, weights)
obs, r, done, _ = env.step(a)
reward += r
if render:
env.render()
return reward
def get_new_mean_stddev(rewards, batch_weights):
idx = np.argsort(rewards)[::-1][:int(n*top_k)]
mean = np.mean(batch_weights[idx], axis=0)
stddev = np.sqrt(np.var(batch_weights[idx], axis=0))
return mean, stddev
for i in range(20):
batch_weights = get_batch_weights(mean, stddev, n)
rewards = [run_trial(weights) for weights in batch_weights]
mean, stddev = get_new_mean_stddev(rewards, batch_weights)
print(rewards)
[18.0, 9.0, 9.0, 10.0, 8.0, 9.0, 9.0, 8.0, 9.0, 12.0]
[11.0, 33.0, 10.0, 10.0, 19.0, 11.0, 19.0, 8.0, 13.0, 15.0]
[10.0, 123.0, 9.0, 26.0, 16.0, 27.0, 48.0, 24.0, 19.0, 35.0]
[19.0, 68.0, 35.0, 72.0, 28.0, 73.0, 47.0, 42.0, 41.0, 63.0]
[43.0, 52.0, 163.0, 100.0, 33.0, 67.0, 88.0, 56.0, 40.0, 30.0]
[93.0, 112.0, 96.0, 121.0, 71.0, 59.0, 104.0, 28.0, 47.0, 74.0]
[133.0, 49.0, 67.0, 94.0, 70.0, 80.0, 69.0, 73.0, 54.0, 45.0]
[54.0, 63.0, 113.0, 113.0, 81.0, 125.0, 107.0, 84.0, 82.0, 135.0]
[54.0, 120.0, 53.0, 431.0, 93.0, 97.0, 54.0, 133.0, 140.0, 142.0]
[59.0, 63.0, 73.0, 128.0, 64.0, 82.0, 52.0, 93.0, 63.0, 87.0]
[82.0, 65.0, 66.0, 68.0, 81.0, 139.0, 73.0, 80.0, 92.0, 131.0]
[123.0, 189.0, 115.0, 81.0, 77.0, 120.0, 130.0, 62.0, 189.0, 174.0]
[81.0, 99.0, 162.0, 65.0, 167.0, 76.0, 84.0, 176.0, 85.0, 146.0]
[232.0, 120.0, 156.0, 118.0, 105.0, 108.0, 72.0, 73.0, 126.0, 114.0]
[102.0, 190.0, 89.0, 190.0, 118.0, 81.0, 139.0, 79.0, 76.0, 100.0]
[74.0, 98.0, 169.0, 82.0, 78.0, 99.0, 112.0, 160.0, 391.0, 263.0]
[182.0, 98.0, 109.0, 252.0, 124.0, 122.0, 186.0, 106.0, 241.0, 162.0]
[176.0, 76.0, 74.0, 110.0, 182.0, 85.0, 111.0, 84.0, 75.0, 75.0]
[100.0, 90.0, 201.0, 68.0, 164.0, 88.0, 119.0, 159.0, 73.0, 127.0]
[210.0, 83.0, 86.0, 104.0, 112.0, 78.0, 102.0, 122.0, 74.0, 76.0]
mean, stddev
(array([[-0.08092788, -3.90679289],
[ 0.17187715, 1.6467033 ],
[ 0.95738557, -0.27835679],
[-1.69261409, 1.18980922],
[-0.11475664, -0.16217023]]),
array([[2.71812173e-04, 2.84475699e-03],
[5.27674917e-04, 1.34500278e-03],
[1.18191726e-04, 2.32261000e-03],
[4.25802605e-03, 8.76518625e-03],
[5.10162382e-05, 1.15022676e-03]]))
best_weights = get_batch_weights(mean, stddev, 1)[0]
run_trial(best_weights, render=False)
72.0
Monte Carlo Tree Search#
Upcoming
Probabilistic Ensembles with Trajectory Sampling#
Probabilistic ensembles with trajectory sampling (PETS) is a model-based reinforcement learning algorithm that combines probabilistic netwroks to capture aleatoric uncertainty and ensembles to capture epistemic uncertainty.
PETS is used for model predictive control (MPC), that plans and optimizes for a sequence of actions.
Instead of random shooting, PETS uses CEM to sample actions from a distribution closer to previous action samples that yielded high reward.