Cross Entropy Method for Planning in Reinforcement Learning
Cross Entropy Method
The Cross Entropy Method (CEM) is a gradient-free method of optimization commonly used for planning in model-based reinforcement learning.
CEM Algorithm
- Create a Gaussian distribution $N(\mu,\sigma)$ that describes the weights $\theta$ of the neural network.
- Sample $N$ batch samples of $\theta$ from the Gaussian.
- Evaluate all $N$ samples of $\theta$ using the value function, e.g. running trials.
- Select the top % of the samples of $\theta$ and compute the new $\mu$ and $\sigma$ to parameterise the new Gaussian distribution.
- Repeat steps 1-4 until convergence.
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
import gym
# RL Gym
env = gym.make('CartPole-v0')
# Initialisation
n = 10 # number of candidate policies
top_k = 0.40 # top % selected for next iteration
mean = np.zeros((5,2)) # shape = (n_parameters, n_actions)
stddev = np.ones((5,2)) # shape = (n_parameters, n_actions)
def get_batch_weights(mean, stddev, n):
mvn = tfd.MultivariateNormalDiag(
loc=mean,
scale_diag=stddev)
return mvn.sample(n).numpy()
def policy(obs, weights):
return np.argmax(obs @ weights[:4,:] + weights[4])
def run_trial(weights, render=False):
obs = env.reset()
done = False
reward = 0
while not done:
a = policy(obs, weights)
obs, r, done, _ = env.step(a)
reward += r
if render:
env.render()
return reward
def get_new_mean_stddev(rewards, batch_weights):
idx = np.argsort(rewards)[::-1][:int(n*top_k)]
mean = np.mean(batch_weights[idx], axis=0)
stddev = np.sqrt(np.var(batch_weights[idx], axis=0))
return mean, stddev
for i in range(20):
batch_weights = get_batch_weights(mean, stddev, n)
rewards = [run_trial(weights) for weights in batch_weights]
mean, stddev = get_new_mean_stddev(rewards, batch_weights)
print(rewards)
[20.0, 10.0, 9.0, 16.0, 22.0, 10.0, 10.0, 10.0, 10.0, 9.0]
[30.0, 56.0, 26.0, 125.0, 13.0, 9.0, 9.0, 114.0, 28.0, 8.0]
[89.0, 111.0, 69.0, 9.0, 200.0, 69.0, 200.0, 105.0, 12.0, 31.0]
[94.0, 128.0, 57.0, 30.0, 122.0, 107.0, 69.0, 37.0, 37.0, 141.0]
[200.0, 200.0, 89.0, 200.0, 140.0, 91.0, 102.0, 149.0, 21.0, 81.0]
[200.0, 154.0, 10.0, 112.0, 114.0, 187.0, 200.0, 200.0, 136.0, 149.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 149.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 134.0, 200.0, 200.0, 200.0, 180.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 160.0, 131.0, 200.0]
[200.0, 152.0, 163.0, 200.0, 153.0, 200.0, 200.0, 131.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
[200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0]
mean, stddev
(array([[-0.48842902, -0.20315496],
[ 1.05925976, 1.55983425],
[-0.83255259, 1.6572544 ],
[-3.46168438, -0.27580643],
[ 0.16817479, -0.15037121]]),
array([[0.00026762, 0.00022525],
[0.00595117, 0.00055989],
[0.00042871, 0.09129609],
[0.00033094, 0.00030441],
[0.00055258, 0.00365766]]))
best_weights = get_batch_weights(mean, stddev, 1)[0]
run_trial(best_weights, render=True)
200.0