Planning#

Planning is a method of simulating a sequence of actions in an environment model before actually taking an action in the real environment.

Concepts covered:

  1. Cross entropy method (CEM)

  2. Monte Carlo tree search (MCTS)

  3. Probabilistic ensembles with trajectory sampling (PETS)

References:

  • Deep Reinforcement Learning in a Handful of Trials using Probabilistic Dynamics Models

  • Exploring Model-based Planning with Policy Networks

  • Learning Latent Dynamics for Planning from Pixels

Cross Entropy Method#

The Cross Entropy Method (CEM) is a gradient-free optimization method commonly used for planning in model-based reinforcement learning.

CEM Algorithm

  1. Create a Gaussian distribution \(N(\mu,\sigma)\) that describes the weights \(\theta\) of the neural network.

  2. Sample \(N\) batch samples of \(\theta\) from the Gaussian.

  3. Evaluate all \(N\) samples of \(\theta\) using the value function, e.g. running trials.

  4. Select the top % of the samples of \(\theta\) and compute the new \(\mu\) and \(\sigma\) to parameterise the new Gaussian distribution.

  5. Repeat steps 1-4 until convergence.

import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
import gym
import warnings
warnings.filterwarnings("ignore")
# RL Gym
env = gym.make('CartPole-v1')

# Initialisation
n = 10  # number of candidate policies
top_k = 0.40  # top % selected for next iteration
mean = np.zeros((5,2))  # shape = (n_parameters, n_actions)
stddev = np.ones((5,2))  # shape = (n_parameters, n_actions)
def get_batch_weights(mean, stddev, n):
    mvn = tfd.MultivariateNormalDiag(
        loc=mean,
        scale_diag=stddev)
    return mvn.sample(n).numpy()

def policy(obs, weights):
    return np.argmax(obs @ weights[:4,:] + weights[4])

def run_trial(weights, render=False):
    obs = env.reset()
    done = False
    reward = 0
    while not done:
        a = policy(obs, weights)
        obs, r, done, _ = env.step(a)
        reward += r
        if render:
            env.render()
    return reward

def get_new_mean_stddev(rewards, batch_weights):
    idx = np.argsort(rewards)[::-1][:int(n*top_k)]
    mean = np.mean(batch_weights[idx], axis=0)
    stddev = np.sqrt(np.var(batch_weights[idx], axis=0))
    return mean, stddev
for i in range(20):
    batch_weights = get_batch_weights(mean, stddev, n)
    rewards = [run_trial(weights) for weights in batch_weights]
    mean, stddev = get_new_mean_stddev(rewards, batch_weights)
    print(rewards)
[18.0, 9.0, 9.0, 10.0, 8.0, 9.0, 9.0, 8.0, 9.0, 12.0]
[11.0, 33.0, 10.0, 10.0, 19.0, 11.0, 19.0, 8.0, 13.0, 15.0]
[10.0, 123.0, 9.0, 26.0, 16.0, 27.0, 48.0, 24.0, 19.0, 35.0]
[19.0, 68.0, 35.0, 72.0, 28.0, 73.0, 47.0, 42.0, 41.0, 63.0]
[43.0, 52.0, 163.0, 100.0, 33.0, 67.0, 88.0, 56.0, 40.0, 30.0]
[93.0, 112.0, 96.0, 121.0, 71.0, 59.0, 104.0, 28.0, 47.0, 74.0]
[133.0, 49.0, 67.0, 94.0, 70.0, 80.0, 69.0, 73.0, 54.0, 45.0]
[54.0, 63.0, 113.0, 113.0, 81.0, 125.0, 107.0, 84.0, 82.0, 135.0]
[54.0, 120.0, 53.0, 431.0, 93.0, 97.0, 54.0, 133.0, 140.0, 142.0]
[59.0, 63.0, 73.0, 128.0, 64.0, 82.0, 52.0, 93.0, 63.0, 87.0]
[82.0, 65.0, 66.0, 68.0, 81.0, 139.0, 73.0, 80.0, 92.0, 131.0]
[123.0, 189.0, 115.0, 81.0, 77.0, 120.0, 130.0, 62.0, 189.0, 174.0]
[81.0, 99.0, 162.0, 65.0, 167.0, 76.0, 84.0, 176.0, 85.0, 146.0]
[232.0, 120.0, 156.0, 118.0, 105.0, 108.0, 72.0, 73.0, 126.0, 114.0]
[102.0, 190.0, 89.0, 190.0, 118.0, 81.0, 139.0, 79.0, 76.0, 100.0]
[74.0, 98.0, 169.0, 82.0, 78.0, 99.0, 112.0, 160.0, 391.0, 263.0]
[182.0, 98.0, 109.0, 252.0, 124.0, 122.0, 186.0, 106.0, 241.0, 162.0]
[176.0, 76.0, 74.0, 110.0, 182.0, 85.0, 111.0, 84.0, 75.0, 75.0]
[100.0, 90.0, 201.0, 68.0, 164.0, 88.0, 119.0, 159.0, 73.0, 127.0]
[210.0, 83.0, 86.0, 104.0, 112.0, 78.0, 102.0, 122.0, 74.0, 76.0]
mean, stddev
(array([[-0.08092788, -3.90679289],
        [ 0.17187715,  1.6467033 ],
        [ 0.95738557, -0.27835679],
        [-1.69261409,  1.18980922],
        [-0.11475664, -0.16217023]]),
 array([[2.71812173e-04, 2.84475699e-03],
        [5.27674917e-04, 1.34500278e-03],
        [1.18191726e-04, 2.32261000e-03],
        [4.25802605e-03, 8.76518625e-03],
        [5.10162382e-05, 1.15022676e-03]]))
best_weights = get_batch_weights(mean, stddev, 1)[0]
run_trial(best_weights, render=False)
72.0

Probabilistic Ensembles with Trajectory Sampling#

Probabilistic ensembles with trajectory sampling (PETS) is a model-based reinforcement learning algorithm that combines probabilistic netwroks to capture aleatoric uncertainty and ensembles to capture epistemic uncertainty.

PETS is used for model predictive control (MPC), that plans and optimizes for a sequence of actions.

Instead of random shooting, PETS uses CEM to sample actions from a distribution closer to previous action samples that yielded high reward.