Temporal Difference Control Methods - Part 1
In monte carlo control methods, we had to wait for the episode to finish in order to make Q table updates. In temporal difference (TD) methods, we update the table as soon as the agent starts interacting with the environment (i.e every timestep)
Part 1: In this notebook, we will implement on-policy method SARSA or SARSA(0) to estimate the optimal policy of CliffWalking gym environment. Sarsa method guaranteed to converge to the optimal policy
Import Libraries
import gym
import sys
import numpy as np
import pandas as pd
from collections import defaultdict, deque
import matplotlib.pyplot as plt
%matplotlib inline
Create Gym Environment
# simple environment with discrete state and action space
env = gym.make('CliffWalking-v0')
# Explore state and action space
print('State space: {}'.format(env.observation_space))
print('Action space: {}'.format(env.action_space))
State space: Discrete(48)
Action space: Discrete(4)
Let us see how well does the random player performs
# Random player: takes random action
score = 0
while True:
action = np.random.randint(env.action_space.n)
state, reward, done, info = env.step(action)
score += reward
if done:
print('last reward: {}, total score: {}'.format(reward, score))
last reward: -1, total score: -10218
We continue using epsilon-greedy policy (same as monte carlo implementation)
def eps_greedy(eps, Q, state, nA):
rand = np.random.rand()
if rand < eps:
return np.random.randint(nA)
return np.argmax(Q[state])
Update rule ->
def update_sarsa_Q(alpha, gamma, Q, state, action, next_state=None, next_action=None):
Qsa = Q[next_state][next_action] if next_state is not None else 0
updated_q_value = Q[state][action] + alpha * (reward + gamma * Qsa - Q[state][action])
return updated_q_value
def generate_sarsa_episode(env, Q, eps, alpha, gamma):
nA = env.action_space.n
state = env.reset()
action = eps_greedy(eps, Q, state, nA)
score = 0
while True:
next_state, reward, done, info = env.step(action)
score += reward
next_action = eps_greedy(eps, Q, next_state, nA)
if not done:
Q[state][action] = update_sarsa_Q(alpha, gamma, Q, state, action, next_state, next_action)
state = next_state
action = next_action
if done:
Q[state][action] = update_sarsa_Q(alpha, gamma, Q, state, action)
return Q, score
# Play for defined number of episodes
def train(env, num_episodes, eps=1.0, eps_min=0.01, eps_decay=0.9, alpha=0.01, gamma=1.0, plot_every=100):
Q = defaultdict(lambda: np.zeros(env.action_space.n))
print_every = int(0.1 * num_episodes)
print_every = print_every if print_every > 0 else 1
all_rewards = []
tmp_rewards = deque(maxlen=plot_every)
avg_rewards = deque(maxlen=num_episodes)
for i in range(1, num_episodes + 1):
eps = max(eps * eps_decay, eps_min)
Q, score = generate_sarsa_episode(env, Q, eps, alpha, gamma)
if i % print_every == 0:
print('\rProgress: {}/{}, average: {}'.format(i, num_episodes, np.mean(score)), end='')
if i % plot_every == 0:
return Q, all_rewards, avg_rewards
Train the agent and try to recover optimal policy (or near optimal)
# Hyperparameters (RL is very susceptible to hyperparams)
eps = 1.0 # starting epsilon
eps_min = 0.01 # minimum epsilon
eps_decay = 0.9 # decay rate
alpha = 0.01 # Q value update step size
gamma = 1.0 # discount factor
Q_sarsa, score, avg_score = train(env, 5000, eps=eps, eps_min=eps_min, eps_decay=eps_decay, alpha=alpha, gamma=gamma)
Progress: 5000/5000, average: -13.0
policy = np.array([np.argmax(Q_sarsa[key]) if key in Q_sarsa else -1 for key in np.arange(12*4)])
# UP: 0, RIGHT: 1, DOWN: 2, LEFT: 3, N/A: -1
# Display policy
policy.reshape(4, 12)
array([[ 3, 1, 0, 3, 1, 1, 1, 1, 1, 1, 2, 2],
[ 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2],
[ 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0]])
Plot average rewards over the total number of episodes
plt.plot(np.linspace(0, 5000, len(avg_score)), np.asarray(avg_score))
plt.ylabel('Average Reward (100 episode)')
plt.title('SARSA Agent Training')
Try it out
# Use policy to play
state = env.reset()
score = 0
while True:
action = np.argmax(Q_sarsa[state])
state, reward, done, info = env.step(action)
score += reward
if done:
print('last reward: {}, total score: {}'.format(reward, score))
last reward: -1, total score: -13