Cross Entropy Method The Cross Entropy Method (CEM) is a gradient-free method of optimization commonly used for planning in model-based reinforcement learning. CEM Algorithm Create a Gaussian distribution $N(\mu,\sigma)$ that describes the weights $\theta$ of the neural network. Sample $N$ batch samples of $\theta$ from the Gaussian. Evaluate all $N$ samples of $\theta$ using the value function, e.g. running trials. Select the top % of the samples of $\theta$ and compute the new $\mu$ and $\sigma$ to parameterise the new Gaussian distribution. Repeat steps 1-4 until convergence. import numpy as np import tensorflow_probability as tfp tfd = tfp.distributions import gym # RL Gym env = gym.make('CartPole-v0') # Initialisation n = 10 # number of candidate policies top_k = 0.40 # top % selected for next iteration mean = np.zeros((5,2)) # shape = (n_parameters, n_actions) stddev = np.ones((5,2)) # shape = (n_parameters, n_actions) def get_batch_weights(mean, stddev, n): mvn = tfd.MultivariateNormalDiag( loc=mean, scale_diag=stddev) return mvn.sample(n).numpy() def policy(obs, weights): return np.argmax(obs @ weights[:4,:] + weights[4]) def run_trial(weights, render=False): obs = env.reset() done = False reward = 0 while not done: a = policy(obs, weights) obs, r, done, _ = env.step(a) reward += r if render: env.render() return reward def get_new_mean_stddev(rewards, batch_weights): idx = np.argsort(rewards)[::-1][:int(n*top_k)] mean = np.mean(batch_weights[idx], axis=0) stddev = np.sqrt(np.var(batch_weights[idx], axis=0)) return mean, stddev for i in range(20): batch_weights = get_batch_weights(mean, stddev, n) rewards = [run_trial(weights) for weights in batch_weights] mean, stddev = get_new_mean_stddev(rewards, batch_weights) print(rewards) [20.0, 10.0, 9.0, 16.0, 22.0, 10.0, 10.0, 10.0, 10.0, 9.0] [30.0, 56.0, 26.0, 125.0, 13.0, 9.0, 9.0, 114.0, 28.0, 8.0] [89.0, 111.0, 69.0, 9.0, 200.0, 69.0, 200.0, 105.0, 12.0, 31.0] [94.0, 128.0, 57.0, 30.0, 122.0, 107.0, 69.0, 37.0, 37.0, 141.0] [200.0, 200.0, 89.0, 200.0, 140.0, 91.0, 102.0, 149.0, 21.0, 81.0] [200.0, 154.0, 10.0, 112.0, 114.0, 187.0, 200.0, 200.0, 136.0, 149.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 149.0, 200.0, 200.0, 200.0] [200.0, 200.0, 134.0, 200.0, 200.0, 200.0, 180.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 160.0, 131.0, 200.0] [200.0, 152.0, 163.0, 200.0, 153.0, 200.0, 200.0, 131.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] [200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0, 200.0] mean, stddev (array([[-0.48842902, -0.20315496], [ 1.05925976, 1.55983425], [-0.83255259, 1.6572544 ], [-3.46168438, -0.27580643], [ 0.16817479, -0.15037121]]), array([[0.00026762, 0.00022525], [0.00595117, 0.00055989], [0.00042871, 0.09129609], [0.00033094, 0.00030441], [0.00055258, 0.00365766]])) best_weights = get_batch_weights(mean, stddev, 1)[0] run_trial(best_weights, render=True) 200.0