Skip to content

Interpretability in realistic environments #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
benchmark-environments @ git+https://github.com/HumanCompatibleAI/benchmark-environments.git
imitation @ git+https://github.com/HumanCompatibleAI/imitation.git@e99844
stable-baselines @ git+https://github.com/hill-a/stable-baselines.git
gym[mujoco]
gym[box2d,mujoco]
matplotlib
numpy
pandas
pymdptoolbox
seaborn
setuptools
scipy
#TODO(adam): upgrade to 1.15?
tensorflow>=1.13,<1.14
xarray
65 changes: 65 additions & 0 deletions runners/sparsify_point_maze.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env bash
# Copyright 2020 Adam Gleave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Script to sparsify pretrained reward models generated by `transfer_point_maze.sh`

DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
. ${DIR}/common.sh

ENV_TRAIN="imitation/PointMazeLeftVel-v0"
TRANSITION_P=0.05

if [[ ${fast} == "true" ]]; then
# intended for debugging
COMPARISON_TIMESTEPS="fast"
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze_fast
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze_fast
else
COMPARISON_TIMESTEPS=""
EVAL_TIMESTEPS=100000
PM_OUTPUT=${OUTPUT_ROOT}/transfer_point_maze
SPARSE_OUTPUT=${OUTPUT_ROOT}/sparse_point_maze
fi

MIXED_POLICY_PATH=${TRANSITION_P}:random:dummy:ppo2:${PM_OUTPUT}/expert/train/policies/final
for name in comparison_expert comparison_mixture comparison_random; do
if [[ ${name} == "comparison_expert" ]]; then
extra_flags="dataset_factory_kwargs.policy_type=ppo2 \
dataset_factory_kwargs.policy_path=${PM_OUTPUT}/expert/train/policies/final"
elif [[ ${name} == "comparison_mixture" ]]; then
extra_flags="dataset_factory_kwargs.policy_type=mixture \
dataset_factory_kwargs.policy_path=${MIXED_POLICY_PATH}"
elif [[ ${name} == "comparison_random" ]]; then
extra_flags=""
else
echo "BUG: unknown name ${name}"
exit 1
fi
parallel --header : --results ${SPARSE_OUTPUT}/parallel/${name} \
$(call_script "model_comparison" "with") \
env_name=${ENV_TRAIN} ${extra_flags} \
ellp_loss no_rescale target_reward_type=evaluating_rewards/Zero-v0 \
seed={seed} source_reward_type={source_reward_type} \
source_reward_path=${PM_OUTPUT}/reward/{source_reward_path}/{source_reward_suffix} \
${COMPARISON_TIMESTEPS} log_dir=${SPARSE_OUTPUT}/${name}/{source_reward_path}/{seed} \
::: source_reward_type evaluating_rewards/PointMazeGroundTruthWithCtrl-v0 \
evaluating_rewards/PointMazeGroundTruthNoCtrl-v0 \
evaluating_rewards/RewardModel-v0 evaluating_rewards/RewardModel-v0 \
imitation/RewardNet_unshaped-v0 imitation/RewardNet_unshaped-v0 \
:::+ source_reward_path withctrl noctrl preferences regress irl_state_only irl_state_action \
:::+ source_reward_suffix dummy dummy model model checkpoints/final/discrim/reward_net \
checkpoints/final/discrim/reward_net \
::: seed 0 1 2
done
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ignore = W503,E203
known_first_party=evaluating_rewards,tests
default_section=THIRDPARTY
multi_line_output=3
include_trailing_comma=True
force_sort_within_sections=True
line_length=100

Expand Down
8 changes: 4 additions & 4 deletions src/evaluating_rewards/analysis/plot_pm_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def default_config():

@plot_pm_reward_ex.config
def logging_config(log_root, models, reward_type, reward_path):
data_root = os.path.join(log_root, "model_comparison")
data_root = os.path.join(serialize.get_output_dir(), "model_comparison")
if models is None:
log_dir = os.path.join(
log_root, reward_type.replace("/", "_"), reward_path.replace("/", "_")
Expand Down Expand Up @@ -101,17 +101,17 @@ def dense_no_ctrl_sparsified():
pos_lim = 0.15
# Use lists of tuples rather than OrderedDict as Sacred reorders dictionaries
models = [
("Dense", "evaluating_rewards/PointMassDenseNoCtrl-v0", "dummy"),
("Dense\n(Manual)", "evaluating_rewards/PointMassDenseNoCtrl-v0", "dummy"),
(
"Sparsified",
"Sparsified\n(Learned)",
"evaluating_rewards/RewardModel-v0",
os.path.join(
"evaluating_rewards_PointMassLine-v0",
"20190921_190606_58935eb0a51849508381daf1055d0360",
"model",
),
),
("Sparse", "evaluating_rewards/PointMassSparseNoCtrl-v0", "dummy"),
("Sparse\n(Manual)", "evaluating_rewards/PointMassSparseNoCtrl-v0", "dummy"),
]
_ = locals() # quieten flake8 unused variable warning
del _
Expand Down
18 changes: 17 additions & 1 deletion src/evaluating_rewards/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import gym
import imitation.envs.examples # noqa: F401 pylint:disable=unused-import

from evaluating_rewards.envs import mujoco, point_mass # noqa: F401 pylint:disable=unused-import
from evaluating_rewards.envs import ( # noqa: F401 pylint:disable=unused-import
lunar_lander,
mujoco,
point_mass,
)

PROJECT_ROOT = "evaluating_rewards.envs"
PM_ROOT = f"{PROJECT_ROOT}.point_mass"
Expand Down Expand Up @@ -74,3 +78,15 @@ def register_mujoco():


register_mujoco()

gym.register(
id="evaluating_rewards/LunarLanderContinuous-v0",
entry_point=f"{PROJECT_ROOT}.lunar_lander:LunarLanderContinuousObservable",
reward_threshold=200,
)
gym.register(
id="evaluating_rewards/LunarLanderContinuousOriginalShaping-v0",
entry_point=f"{PROJECT_ROOT}.lunar_lander:LunarLanderContinuousObservable",
kwargs=dict(fix_shaping=False),
reward_threshold=200,
)
65 changes: 65 additions & 0 deletions src/evaluating_rewards/envs/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base classes for environment rewards."""

import abc

import gym
from imitation.util import serialize
import tensorflow as tf

from evaluating_rewards import rewards


class HardcodedReward(rewards.BasicRewardModel, serialize.LayersSerializable):
"""Hardcoded (non-trainable) reward model for a Gym environment."""

def __init__(self, observation_space: gym.Space, action_space: gym.Space, **kwargs):
"""Constructs the reward model.

Args:
observation_space: The observation space of the environment.
action_space: The action space of the environment.
**kwargs: Extra parameters to serialize and store in the instance,
accessible as attributes.
"""
rewards.BasicRewardModel.__init__(self, observation_space, action_space)
serialize.LayersSerializable.__init__(
self,
layers={},
observation_space=observation_space,
action_space=action_space,
**kwargs,
)
self._reward = self.build_reward()

def __getattr__(self, name):
try:
return self._kwargs[name]
except KeyError:
raise AttributeError(f"Attribute '{name}' not present in self._kwargs")

@abc.abstractmethod
def build_reward(self) -> tf.Tensor:
"""Computes reward from observation, action and next observation.

Returns:
A tensor containing reward, shape (batch_size,).
"""

@property
def reward(self):
"""Reward tensor, shape (batch_size,)."""
return self._reward
169 changes: 169 additions & 0 deletions src/evaluating_rewards/envs/lunar_lander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2020 Adam Gleave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Reward function for Gym LunarLander-v2 environment."""

from gym import spaces
from gym.envs.box2d import lunar_lander
from imitation.util import registry
import numpy as np
import tensorflow as tf

from evaluating_rewards import serialize as reward_serialize
from evaluating_rewards.envs import core

TERMINAL_POTENTIAL = -174 # chosen to be similar to initial potential value


class LunarLanderContinuousObservable(lunar_lander.LunarLanderContinuous):
"""LunarLander environment lightly modified from Gym to make reward a function of observation.

Adds `self.game_over` and `self.lander.awake` flags to state, which are used by Gym
internally to compute the reward. They are computed by the Box2D simulator, and cannot easily
be derived from the rest of the state.

`game_over` is set based on contact forces on the lunar lander. The `lander.awake` flag is
set when the body is not "asleep":
"When Box2D determines that a body [...] has come to rest, the body enters a sleep state"
(see https://box2d.org/documentation/md__d_1__git_hub_box2d_docs_dynamics.html).
"""

def __init__(self, time_limit: int = 1000, fix_shaping: bool = True):
# Need to set self.time_limit before super().__init__() since __init__() calls reset()
self.time_limit = time_limit
super().__init__()
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(11,), dtype=np.float32)
self.fix_shaping = fix_shaping
self.time_remaining = None

def step(self, action):
prev_shaping = self.prev_shaping
self.time_remaining -= 1
obs, rew, done, info = super().step(action)
time_up = self.time_remaining <= 0
extra_obs = [
1.0 if self.game_over else 0.0,
1.0 if self.lander.awake else 0.0,
1.0 if time_up else 0.0,
]
obs = np.concatenate((obs, extra_obs))

done = done | time_up
if done and self.fix_shaping:
# Gym does not apply shaping or control cost to final reward.
# No control cost is weird but harmless. No shaping is problematic though so we fix
# it to satisfy Ng et al (1999)'s conditions.
# Take final state to always have potential TERMINAL_POTENTIAL.
# This constant doesn't actually effect RL policy, but makes reward look less odd.
rew += TERMINAL_POTENTIAL - prev_shaping

return obs, rew, done, info

def reset(self):
self.time_remaining = self.time_limit + 1 # step() gets called once during reset
# NOTE: do not need to change observations here since super().reset() calls step()
return super().reset()


def _potential(obs: tf.Tensor) -> tf.Tensor:
"""Potential function used to compute shaping.

Based on `shaping` variable in `LunarLander.step()`.
"""
leg_contact = obs[:, 6] + obs[:, 7]
l2 = tf.sqrt(tf.math.square(obs[:, 0]) + tf.math.square(obs[:, 1]))
l2 += tf.sqrt(tf.math.square(obs[:, 2]) + tf.math.square(obs[:, 3]))
return 10 * leg_contact - 100 * l2 - 100 * tf.abs(obs[:, 4])


class LunarLanderContinuousGroundTruthReward(core.HardcodedReward):
"""Reward for LunarLanderContinuousObservable. Matches ground truth with default settings."""

def __init__(
self,
observation_space: spaces.Space,
action_space: spaces.Space,
ctrl_coef: float = 1.0,
shaping_coef: float = 1.0,
):
"""Constructs the reward model.

Args:
observation_space: The observation space of the environment.
action_space: The action space of the environment.
ctrl_coef: Multiplier for the control cost. 1.0 equals ground truth; 0.0 disables.
shaping_coef: Multiplier for potential shaping. 1.0 equals ground truth; 0.0 disables.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
ctrl_coef=ctrl_coef,
shaping_coef=shaping_coef,
)

def build_reward(self) -> tf.Tensor:
"""Intended to match the reward returned by gym.LunarLander.

Known differences:
- Will disagree on states *after* episode termination due to non-Markovian leg contact
condition in Gym.

Returns:
A Tensor containing predicted rewards.
"""
# Sparse reward
game_over = (tf.abs(self._proc_next_obs[:, 0]) >= 1.0) | (self._proc_next_obs[:, 8] > 0)
landed_safely = self._proc_next_obs[:, 9] == 0.0
time_up = self._proc_next_obs[:, 10] == 0.0
done = game_over | landed_safely | time_up
# Note time out is neither penalized nor rewarded by sparse_reward
sparse_reward = -100.0 * tf.cast(game_over, tf.float32)
sparse_reward += 100.0 * tf.cast(landed_safely, tf.float32)

# Control cost
m_thrust = self._proc_act[:, 0] > 0
m_power_when_act = 0.5 * (tf.clip_by_value(self._proc_act[:, 0], 0.0, 1.0) + 1.0)
m_power = tf.where(m_thrust, m_power_when_act, 0.0 * m_power_when_act)
abs_side_act = tf.abs(self._proc_act[:, 1])
s_thrust = abs_side_act > 0.5
s_power_when_act = tf.clip_by_value(abs_side_act, 0.5, 1.0)
s_power = tf.where(s_thrust, s_power_when_act, 0.0 * s_power_when_act)
ctrl_cost = -0.3 * m_power - 0.03 * s_power
# Gym does not apply control cost to final step. (Seems weird, but OK.)
ctrl_cost = tf.where(done, 0 * ctrl_cost, ctrl_cost)

# Shaping
# Note this assumes no discount (matching Gym implementation), which will make it
# not *quite* potential shaping for any RL algorithm using discounting.
shaping = (1 - tf.cast(done, tf.float32)) * _potential(self._proc_next_obs)
shaping += TERMINAL_POTENTIAL * tf.cast(done, tf.float32)
shaping -= _potential(self._proc_obs)

return sparse_reward + self.shaping_coef * shaping + self.ctrl_coef * ctrl_cost


def _register_rewards():
density = {"Dense": {}, "Sparse": {"shaping_coef": 0.0}}
control = {"WithCtrl": {}, "NoCtrl": {"ctrl_coef": 0.0}}
for k1, cfg1 in density.items():
for k2, cfg2 in control.items():
fn = registry.build_loader_fn_require_space(
LunarLanderContinuousGroundTruthReward, **cfg1, **cfg2,
)
reward_serialize.reward_registry.register(
key=f"evaluating_rewards/LunarLanderContinuous{k1}{k2}-v0", value=fn,
)


_register_rewards()
Loading