Skip to content

Commit e573ea8

Browse files
committed
updated epsilon greedy method
1 parent c8ec50d commit e573ea8

File tree

2 files changed

+16
-8
lines changed
  • BlackJackMonteCarlo/bjack_src
  • CliffWalkingTemporalDifference/td_src

2 files changed

+16
-8
lines changed

BlackJackMonteCarlo/bjack_src/mc.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,18 @@ def epsilon_greedy(Q, state, nA, epsilon=0.1):
134134
With probability (1 − epsilon) choose the greedy action.
135135
With probability epsilon choose an action at random.
136136
"""
137-
A = np.ones(nA) * epsilon / float(nA)
138-
best_action = np.argmax(Q[state])
139-
A[best_action] += (1.0 - epsilon)
140-
return np.random.choice(np.arange(len(A)), p=A)
137+
# A = np.ones(nA) * epsilon / float(nA)
138+
# best_action, prob_for_best_action = np.argmax(Q[state]), max(Q[state])
139+
# if prob_for_best_action > epsilon:
140+
# A[best_action] += (1.0 - epsilon)
141+
# return np.random.choice(np.arange(len(A)), p=A)
142+
# else:
143+
# return np.random.choice(np.arange(len(A)))
144+
145+
actions = np.ones(nA) * epsilon / float(nA)
146+
best_current_action = np.argmax(Q[state])
147+
actions[best_current_action] += (1.0 - epsilon)
148+
return np.random.choice(np.arange(len(actions)), p=actions)
141149

142150

143151
def generate_random_episode_greedy(Q, nA, epsilon, env):

CliffWalkingTemporalDifference/td_src/td.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,10 @@ def epsilon_greedy(Q, state, nA, epsilon=0.1):
4545
With probability (1 − epsilon) choose the greedy action.
4646
With probability epsilon choose an action at random.
4747
"""
48-
A = np.ones(nA) * epsilon / float(nA)
49-
best_action = np.argmax(Q[state])
50-
A[best_action] += (1.0 - epsilon)
51-
return np.random.choice(np.arange(len(A)), p=A)
48+
actions = np.ones(nA) * epsilon / float(nA)
49+
best_current_action = np.argmax(Q[state])
50+
actions[best_current_action] += (1.0 - epsilon)
51+
return np.random.choice(np.arange(len(actions)), p=actions)
5252

5353

5454
def sarsa(env, n_episodes, gamma=1.0, alpha=0.5, epsilon=0.1):

0 commit comments

Comments
 (0)