Cross Entropy Method

The Cross Entropy Method (CEM) is a gradient-free method of optimization commonly used for planning in model-based reinforcement learning.

CEM Algorithm

  1. Create a Gaussian distribution $N(\mu,\sigma)$ that describes the weights $\theta$ of the neural network.
  2. Sample $N$ batch samples of $\theta$ from the Gaussian.
  3. Evaluate all $N$ samples of $\theta$ using the value function, e.g. running trials.
  4. Select the top % of the samples of $\theta$ and compute the new $\mu$ and $\sigma$ to parameterise the new Gaussian distribution.
  5. Repeat steps 1-4 until convergence.
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
import gym
# RL Gym
env = gym.make('CartPole-v0')

# Initialisation
n = 10  # number of candidate policies
top_k = 0.40  # top % selected for next iteration
mean = np.zeros((5,2))  # shape = (n_parameters, n_actions)
stddev = np.ones((5,2))  # shape = (n_parameters, n_actions)
def get_batch_weights(mean, stddev, n):
    mvn = tfd.MultivariateNormalDiag(
        loc=mean,
        scale_diag=stddev)
    return mvn.sample(n).numpy()

def policy(obs, weights):
    return np.argmax(obs @ weights[:4,:] + weights[4])

def run_trial(weights, render=False):
    obs = env.reset()
    done = False
    reward = 0
    while not done:
        a = policy(obs, weights)
        obs, r, done, _ = env.step(a)
        reward += r
        if render:
            env.render()
    return reward

def get_new_mean_stddev(rewards, batch_weights):
    idx = np.argsort(rewards)[::-1][:int(n*top_k)]
    mean = np.mean(batch_weights[idx], axis=0)
    stddev = np.sqrt(np.var(batch_weights[idx], axis=0))
    return mean, stddev
for i in range(20):
    batch_weights = get_batch_weights(mean, stddev, n)
    rewards = [run_trial(weights) for weights in batch_weights]
    mean, stddev = get_new_mean_stddev(rewards, batch_weights)
    print(rewards)
[20.0, 10.0, 9.0, 16.0, 22.0, 10.0, 10.0, 10.0, 10.0, 9.0]
[30.0, 56.0, 26.0, 125.0, 13.0, 9.0, 9.0, 114.0, 28.0, 8.0]
[89.0, 111.0, 69.0, 9.0, 200.0, 69.0, 200.0, 105.0, 12.0, 31.0]
[94.0, 128.0, 57.0, 30.0, 122.0, 107.0, 69.0, 37.0, 37.0, 141.0]
[200.0, 200.0, 89.0, 200.0, 140.0, 91.0, 102.0, 149.0, 21.0, 81.0]
[200.0, 154.0, 10.0, 112.0, 114.0, 187.0, 200.0, 200.0, 136.0, 149.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 149.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 134.0, 200.0, 200.0, 200.0, 180.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 160.0, 131.0, 200.0]
[200.0, 152.0, 163.0, 200.0, 153.0, 200.0, 200.0, 131.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
mean, stddev
(array([[-0.48842902, -0.20315496],
        [ 1.05925976,  1.55983425],
        [-0.83255259,  1.6572544 ],
        [-3.46168438, -0.27580643],
        [ 0.16817479, -0.15037121]]),
 array([[0.00026762, 0.00022525],
        [0.00595117, 0.00055989],
        [0.00042871, 0.09129609],
        [0.00033094, 0.00030441],
        [0.00055258, 0.00365766]]))
best_weights = get_batch_weights(mean, stddev, 1)[0]
run_trial(best_weights, render=True)
200.0