Skip to content

Commit 728bb1b

Browse files
committed
Merge remote-tracking branch 'origin/master' into hidim-interpretability
2 parents aab2933 + 7489905 commit 728bb1b

32 files changed

+496
-255
lines changed

runners/common.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ function call_script {
2323
function learnt_model {
2424
if [[ $# -ne 1 ]]; then
2525
echo "usage: $0 <model prefix>"
26-
echo "model prefix must be relative to ${OUTPUT_ROOT}"
26+
echo "model prefix must be relative to ${EVAL_OUTPUT_ROOT}"
2727
exit 1
2828
fi
2929

3030
model_prefix=$1
31-
learnt_model_dir=${OUTPUT_ROOT}/${model_prefix}
31+
learnt_model_dir=${EVAL_OUTPUT_ROOT}/${model_prefix}
3232

3333
case ${model_prefix} in
3434
train_adversarial)
@@ -53,4 +53,6 @@ eval "$(${ENV_REWARD_CMD} 2>/dev/null)"
5353
ENVS="${!REWARDS_BY_ENV[@]}"
5454
echo "Loaded mappings for environments ${ENVS}"
5555

56-
OUTPUT_ROOT=/mnt/eval_reward/data
56+
if [[ "${EVAL_OUTPUT_ROOT}" == "" ]]; then
57+
EVAL_OUTPUT_ROOT=$HOME/output
58+
fi

runners/comparison/hardcoded.sh

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,18 @@ for env_name in "${!REWARDS_BY_ENV[@]}"; do
2626
types=${REWARDS_BY_ENV[$env_name]}
2727
env_name_sanitized=$(echo ${env_name} | sed -e 's/\//_/g')
2828
types_sanitized=$(echo ${types} | sed -e 's/\//_/g')
29-
parallel --header : --results $HOME/output/parallel/comparison/hardcoded_mujoco \
30-
${TRAIN_CMD} env_name=${env_name} \
29+
30+
named_configs=""
31+
if [[ ${env_name} == "evaluating_rewards/PointMassLine-v0" ]]; then
32+
named_configs="dataset_random_transition"
33+
fi
34+
35+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/comparison/hardcoded_mujoco \
36+
${TRAIN_CMD} env_name=${env_name} ${named_configs} \
3137
seed={seed} \
3238
source_reward_type={source_reward_type} \
3339
target_reward_type={target_reward_type} \
34-
log_dir=${HOME}/output/comparison/hardcoded/${env_name_sanitized}/{source_reward_type_sanitized}_vs_{target_reward_type_sanitized}_seed{seed} \
40+
log_dir=${EVAL_OUTPUT_ROOT}/comparison/hardcoded/${env_name_sanitized}/{source_reward_type_sanitized}_vs_{target_reward_type_sanitized}_seed{seed} \
3541
::: source_reward_type ${types} \
3642
:::+ source_reward_type_sanitized ${types_sanitized} \
3743
::: target_reward_type ${types} \

runners/comparison/learnt.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ for env_name in ${ENVS}; do
3535
echo "Models: ${MODELS}"
3636
echo "Hardcoded rewards: ${types}"
3737

38-
parallel --header : --results ${OUTPUT_ROOT}/parallel/comparison/learnt/${env_name_sanitized} \
38+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/comparison/learnt/${env_name_sanitized} \
3939
${TRAIN_CMD} env_name=${env_name} seed={seed} \
4040
source_reward_type=${source_reward_type} \
4141
source_reward_path=${learnt_model_dir}/${env_name_sanitized}/{source_reward}/${model_name} \
4242
target_reward_type={target_reward} {named_config} \
43-
log_dir=${OUTPUT_ROOT}/comparison/${model_prefix}/${env_name_sanitized}/{source_reward}/match_{named_config}_to_{target_reward_sanitized}_seed{seed} \
43+
log_dir=${EVAL_OUTPUT_ROOT}/comparison/${model_prefix}/${env_name_sanitized}/{source_reward}/match_{named_config}_to_{target_reward_sanitized}_seed{seed} \
4444
::: source_reward ${MODELS} \
4545
::: target_reward ${types} \
4646
:::+ target_reward_sanitized ${types_sanitized} \

runners/eval/greedy_pm_hardcoded.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ GREEDY_REWARD_MODELS="PointMassGroundTruth-v0:None \
2020
PointMassSparse-v0:None \
2121
PointMassDense-v0:None"
2222

23-
parallel --header : --results $HOME/output/parallel/greedy_pm_hardcoded \
23+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/greedy_pm_hardcoded \
2424
${EVAL_POLICY_CMD} policy_type=evaluating_rewards/MCGreedy-v0 \
2525
env_name={env} policy_path={policy_path} \
2626
::: env ${PM_ENVS} \

runners/eval/greedy_pm_irl.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1818

1919
for env in ${ENVS}; do
2020
env_sanitized=$(echo ${env} | sed -e 's/\//_/g')
21-
reward_paths=$HOME/output/train_adversarial/${env_sanitized}/*/final/discrim/reward_net
21+
reward_paths=${EVAL_OUTPUT_ROOT}/train_adversarial/${env_sanitized}/*/final/discrim/reward_net
2222
policy_paths=""
2323
for rew_path in ${reward_paths}; do
2424
policy_paths="${policy_paths} BasicShapedRewardNet_shaped:${rew_path}"
2525
policy_paths="${policy_paths} BasicShapedRewardNet_unshaped:${rew_path}"
2626
done
27-
parallel --header : --results $HOME/output/parallel/greedy_pm_irl \
27+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/greedy_pm_irl \
2828
${EVAL_POLICY_CMD} env_name=${env} policy_type=evaluating_rewards/MCGreedy-v0 \
2929
policy_path={policy_path} \
3030
::: policy_path ${policy_paths}

runners/eval/learnt.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1818

1919
if [[ $# -ne 1 ]]; then
2020
echo "usage: $0 <policy prefix>"
21-
echo "policy prefix must be relative to ${OUTPUT_ROOT}"
21+
echo "policy prefix must be relative to ${EVAL_OUTPUT_ROOT}"
2222
exit 1
2323
fi
2424

2525
policy_prefix=$1
26-
policy_dir=${OUTPUT_ROOT}/${policy_prefix}
26+
policy_dir=${EVAL_OUTPUT_ROOT}/${policy_prefix}
2727
model_name="policies/final"
2828

2929
for env_name in ${ENVS}; do
@@ -38,11 +38,11 @@ for env_name in ${ENVS}; do
3838
echo "Policies: ${policies}"
3939
echo "Hardcoded rewards: ${types}"
4040

41-
parallel --header : --results $HOME/output/parallel/learnt \
41+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/learnt \
4242
${EVAL_POLICY_CMD} env_name=${env_name} policy_type=ppo2 \
4343
reward_type={reward_type} \
4444
policy_path=${policy_dir}/${env_name_sanitized}/{policy_path}/${model_name} \
45-
log_dir=${OUTPUT_ROOT}/eval/${policy_prefix}/${env_name_sanitized}/{policy_path}/eval_under_{reward_type_sanitized} \
45+
log_dir=${EVAL_OUTPUT_ROOT}/eval/${policy_prefix}/${env_name_sanitized}/{policy_path}/eval_under_{reward_type_sanitized} \
4646
::: reward_type ${types} \
4747
:::+ reward_type_sanitized ${types_sanitized} \
4848
::: policy_path ${policies}

runners/eval/static.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
1818

1919
POLICY_TYPES="random zero"
2020

21-
parallel --header : --results $HOME/output/parallel/static \
21+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/static \
2222
${EVAL_POLICY_CMD} env_name={env} policy_type={policy_type} \
2323
::: env ${ENVS} \
2424
::: policy_type ${POLICY_TYPES}

runners/irl/train_irl.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ TRAIN_CMD=$(call_script "train_adversarial" "with")
2020

2121
for env in ${ENVS}; do
2222
env_sanitized=$(echo ${env} | sed -e 's/\//_/g')
23-
parallel --header : --results $HOME/output/parallel/train_irl \
23+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/train_irl \
2424
${TRAIN_CMD} env_name=${env} seed={seed} \
2525
init_trainer_kwargs.reward_kwargs.state_only={state_only} \
2626
rollout_path={data_path}/rollouts/final.pkl \
27-
::: data_path $HOME/output/expert_demos/${env_sanitized}/* \
27+
::: data_path ${EVAL_OUTPUT_ROOT}/expert_demos/${env_sanitized}/* \
2828
::: state_only True False \
2929
::: seed 0 1 2
3030
done

runners/preferences/hyper_sweep.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ PointMassDense-v0
2323
PointMassSparse-v0
2424
"
2525

26-
parallel --header : --results $HOME/output/parallel/train_preferences_hyper \
26+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/train_preferences_hyper \
2727
${TRAIN_CMD} env_name=evaluating_rewards/PointMassLine-v0 \
2828
seed={seed} target_reward_type=evaluating_rewards/{target_reward} \
2929
batch_timesteps={batch_timesteps} trajectory_length={trajectory_length} \
3030
learning_rate={lr} total_timesteps=5e6 \
31-
log_dir=${HOME}/output/train_preferences_hyper/{target_reward}/batch{batch_timesteps}_of_{trajectory_length}_lr{lr}/{seed} \
31+
log_dir=${EVAL_OUTPUT_ROOT}/train_preferences_hyper/{target_reward}/batch{batch_timesteps}_of_{trajectory_length}_lr{lr}/{seed} \
3232
::: target_reward ${TARGET_REWARDS} \
3333
::: batch_timesteps 500 2500 10000 50000 250000 \
3434
::: trajectory_length 1 5 25 100 \

runners/preferences/train_preferences.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ for env_name in "${!REWARDS_BY_ENV[@]}"; do
2525
env_name_sanitized=$(echo ${env_name} | sed -e 's/\//_/g')
2626
types_sanitized=$(echo ${types} | sed -e 's/\//_/g')
2727

28-
parallel --header : --results $HOME/output/parallel/train_preferences/${env_name} \
28+
parallel --header : --results ${EVAL_OUTPUT_ROOT}/parallel/train_preferences/${env_name} \
2929
${TRAIN_CMD} env_name=${env_name} \
3030
seed={seed} target_reward_type={target_reward} \
31-
log_dir=${HOME}/output/train_preferences/${env_name_sanitized}/{target_reward_sanitized}/{seed} \
31+
log_dir=${EVAL_OUTPUT_ROOT}/train_preferences/${env_name_sanitized}/{target_reward_sanitized}/{seed} \
3232
::: target_reward ${types} \
3333
:::+ target_reward_sanitized ${types_sanitized} \
3434
::: seed 0 1 2

0 commit comments

Comments
 (0)