diff --git a/reinforcement_learning/reinforce.py b/reinforcement_learning/reinforce.py index a222ff804c..9b6cea3345 100644 --- a/reinforcement_learning/reinforce.py +++ b/reinforcement_learning/reinforce.py @@ -68,10 +68,9 @@ def finish_episode(): returns.insert(0, R) returns = torch.tensor(returns) returns = (returns - returns.mean()) / (returns.std() + eps) - for log_prob, R in zip(policy.saved_log_probs, returns): - policy_loss.append(-log_prob * R) - optimizer.zero_grad() - policy_loss = torch.cat(policy_loss).sum() + l_probs = policy.saved_log_probs.sum() + rews = returns.sum() + policy_loss = rews + l_probs policy_loss.backward() optimizer.step() del policy.rewards[:]