Cross Entropy Method¶
The Cross Entropy Method (CEM) is a gradient-free method of optimization commonly used for planning in model-based reinforcement learning.
CEM Algorithm¶
Create a Gaussian distribution \(N(\mu,\sigma)\) that describes the weights \(\theta\) of the neural network.
Sample \(N\) batch samples of \(\theta\) from the Gaussian.
Evaluate all \(N\) samples of \(\theta\) using the value function, e.g. running trials.
Select the top % of the samples of \(\theta\) and compute the new \(\mu\) and \(\sigma\) to parameterise the new Gaussian distribution.
Repeat steps 1-4 until convergence.
import numpy as np
import tensorflow_probability as tfp
tfd = tfp.distributions
import gym
# RL Gym
env = gym.make('CartPole-v0')
# Initialisation
n = 10 # number of candidate policies
top_k = 0.40 # top % selected for next iteration
mean = np.zeros((5,2)) # shape = (n_parameters, n_actions)
stddev = np.ones((5,2)) # shape = (n_parameters, n_actions)
def get_batch_weights(mean, stddev, n):
mvn = tfd.MultivariateNormalDiag(
loc=mean,
scale_diag=stddev)
return mvn.sample(n).numpy()
def policy(obs, weights):
return np.argmax(obs @ weights[:4,:] + weights[4])
def run_trial(weights, render=False):
obs = env.reset()
done = False
reward = 0
while not done:
a = policy(obs, weights)
obs, r, done, _ = env.step(a)
reward += r
if render:
env.render()
return reward
def get_new_mean_stddev(rewards, batch_weights):
idx = np.argsort(rewards)[::-1][:int(n*top_k)]
mean = np.mean(batch_weights[idx], axis=0)
stddev = np.sqrt(np.var(batch_weights[idx], axis=0))
return mean, stddev
for i in range(20):
batch_weights = get_batch_weights(mean, stddev, n)
rewards = [run_trial(weights) for weights in batch_weights]
mean, stddev = get_new_mean_stddev(rewards, batch_weights)
print(rewards)
[10.0, 16.0, 10.0, 28.0, 8.0, 136.0, 9.0, 9.0, 9.0, 19.0]
[43.0, 21.0, 9.0, 58.0, 51.0, 29.0, 91.0, 27.0, 8.0, 29.0]
[43.0, 42.0, 73.0, 62.0, 51.0, 107.0, 89.0, 28.0, 61.0, 28.0]
[64.0, 32.0, 78.0, 77.0, 42.0, 150.0, 45.0, 49.0, 61.0, 65.0]
[62.0, 129.0, 76.0, 62.0, 153.0, 200.0, 56.0, 46.0, 57.0, 60.0]
[81.0, 52.0, 84.0, 67.0, 60.0, 104.0, 56.0, 61.0, 200.0, 78.0]
[51.0, 51.0, 104.0, 84.0, 84.0, 197.0, 200.0, 89.0, 145.0, 92.0]
[66.0, 65.0, 126.0, 115.0, 69.0, 69.0, 85.0, 131.0, 94.0, 55.0]
[200.0, 157.0, 88.0, 110.0, 72.0, 130.0, 72.0, 82.0, 127.0, 194.0]
[118.0, 139.0, 109.0, 77.0, 148.0, 117.0, 62.0, 98.0, 110.0, 178.0]
[121.0, 123.0, 117.0, 104.0, 172.0, 89.0, 74.0, 168.0, 117.0, 200.0]
[163.0, 69.0, 163.0, 200.0, 158.0, 89.0, 128.0, 200.0, 138.0, 167.0]
[200.0, 96.0, 76.0, 88.0, 200.0, 90.0, 108.0, 108.0, 83.0, 153.0]
[179.0, 200.0, 200.0, 79.0, 111.0, 81.0, 151.0, 200.0, 147.0, 74.0]
[200.0, 124.0, 144.0, 200.0, 78.0, 83.0, 150.0, 147.0, 154.0, 93.0]
[116.0, 117.0, 200.0, 79.0, 101.0, 89.0, 105.0, 117.0, 200.0, 109.0]
[109.0, 132.0, 136.0, 200.0, 98.0, 200.0, 94.0, 106.0, 100.0, 200.0]
[147.0, 155.0, 107.0, 200.0, 123.0, 167.0, 155.0, 200.0, 126.0, 142.0]
[131.0, 73.0, 200.0, 200.0, 89.0, 91.0, 156.0, 200.0, 149.0, 154.0]
[105.0, 112.0, 98.0, 115.0, 200.0, 187.0, 94.0, 89.0, 180.0, 93.0]
mean, stddev
(array([[ 0.50994191, 0.1539059 ],
[ 0.5951551 , 0.350758 ],
[ 0.05404847, 0.39404 ],
[-2.6360502 , 0.22584736],
[ 0.16629546, 0.16006943]]),
array([[2.42495267e-04, 1.78857471e-04],
[1.01732459e-04, 2.03337057e-01],
[3.68772396e-04, 8.44550279e-04],
[1.10036801e-03, 2.23453546e-04],
[6.53310535e-05, 9.51211939e-05]]))
best_weights = get_batch_weights(mean, stddev, 1)[0]
run_trial(best_weights, render=True)
83.0