Skip to content

interpreting-rl-behavior/models--interpreting-rl-behavior

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Understanding RL agents using generative visualisation and differentiable environment simulation

This repository contains the models and experiments that are used in the article https://interpreting-rl-behavior.github.io/. The code was originally forked from https://github.com/jbkjr/train-procgen-pytorch which contains code to run the Procgen environments (https://openai.com/blog/procgen-benchmark/) and the PPO agent we interpret in this work.

This README provides instructions for how to replicate the results in our paper.

Overview of steps:

  • Train agent on procgen task
  • Record dataset of real agent-environment rollouts
  • Train generative model on recorded dataset of real agent-environment rollouts
  • Run analyses on recorded dataset of real agent-environment rollouts
  • Record dataset of simulated agent-environment rollouts from the generative model
  • Run analyses on the recorded simulated rollouts.
  • Analysis of the prediction quality over time

All scripts should be run from the root dir.

First install the package locally, and use the following command to install in editable mode:

pip install -e .

To train the agent on coinrun:

python train.py --exp_name [agent_training_experiment_name] --env_name coinrun --param_name hard-rec --num_levels 1000000 --distribution_mode hard --num_timesteps 200000000 --num_checkpoints 500

This will save training data and a model in a directory in

logs/procgen/coinrun/[your_experiment_name]/

Each training run has a unique seed, so each seed gets its own directory in the above folder like so:

logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]

Then to plot the training curve for that training run:

python plot_training_csv.py --datapath="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]"

You can render your trained agent to see what its behaviour looks like:

python render.py --exp_name=[agent_rendering_experiment_name] --env_name="coinrun" --distribution_mode="hard" --param_name="hard-local-dev-rec" --device="cpu" --model_file="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]/[agent_name].pth"

Assuming your agent is behaviour as you'd like it to, now we can start interpreting it.

Making recordings and training the generative model

To begin interpretation, we need to record a bunch of agent-environment rollouts in order to train the generative model:

python record.py --exp_name [recording_experiment_name] --env_name coinrun --param_name hard-rec --num_levels 1000000 --distribution_mode hard --num_checkpoints 200 --model_file="logs/procgen/coinrun/[agent_training_experiment_name]/[agent_training_unique_seed]/[agent_name].pth" --logdir="[path_to_rollout_data_save_dir]" python record.py --model_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --logdir=./ --env_name coinrun --param_name hard-rec-record --num_levels 1000000 --distribution_mode hard --num_checkpoints 200

Note that --logdir should have plenty of storage space (100's of GB).

With this recorded data, we can start to train the generative model on agent-environment rollouts:

python train_gen_model.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"

That'll take a 1-4 days to train on a single GPU. Once it's trained, we'll record some agent- environment rollouts from the model. This will enable us to compare the simulations to the true rollouts and will help us understand our generative model (which includes the agent that we want to interpret) better. This is how we record samples from the generative model:

python record_gen_samples.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"

Now we're ready to start some analysis.

Analysis

The generative model is a VAE, and therefore consists of an encoder and decoder. The decoder is the part we want to interpret because it simulates agent- environment rollouts. It will be informative, therefore, to get a picture of what's going on inside the latent vector of the VAE, since this is the input to the decoder.

Analysis of agent's hidden state

We'll next analyse the agent's hidden state with a few dimensionality reduction methods. First we precompute the dimensionality reduction analyses:

python analysis/hidden_analysis_precompute.py

with 10'000 episodes (not samples). Increase request for memory and compute time to cope with more episodes.

which will save the analysis data in analysis/hx_analysis_precomp/

Next we'll make some plots from the precomputed analyses of the agent's hidden states:

python analysis/hidden_analysis_plotting.py

These depict what the agent is 'thinking' during many episodes, visualised using several different dimensionality reduction and clustering methods.

Analysis of environment hidden states

python analysis/env_h_analysis_precompute.py

with 20'000 samples of len 24. Increase request for memory and compute time to cope with more samples.

then

python analysis/env_h_analysis_plotting.py

Calculating saliency maps

Saliency maps calculate the gradient (averaged over noised samples) of some network quantity (e.g. the agent's value function output) with respect to inputs or intermediate network activations. We can thus calculate how important dimensions of the generated observations or agent hidden states are for the value function.

Say we wanted to generate saliency maps with respect to value and leftwards actions for specifically the generated samples numbered 33 39 56 84. We'd use the following command:

python saliency_experiments.py --agent_file=./logs/procgen/coinrun/trainhx_1Mlvls/seed_498_07-06-2021_23-26-27/model_80412672.pth --gen_mod_exp_name=dev --model_file="generative/results/rssm53_largepos_sim_penalty_extraconverterlayers/20220106_181406/model_epoch3_batch20000.pt"

If we wanted to generate saliency maps for the same quantities but combine those samples into one sample by taking their mean latent space vector (instead of iterating over each sample individually), we'd add the flag --combine_samples_not_iterate

If we wanted to generate saliency maps for all samples from 0 to 100, we'd replace the --sample_ids 33 39 56 84 flag with --sample_ids 0 to 100.

Identifying causal stories for behaviours

After we've calculated the saliency maps, we can use them to identify the causal structure of the control algorithm used by the agent.

First we cluster the agent-environment dynamics. These clusters correspond to behaviours.

python analysis/combined_agent_env_hx_analysis_precompute.py

(Now would be a good time to look at the interpretability panel since we've just generated everything it needs to run.)

We need to summarise the IC dynamics for each behaviour. We summarize them and plot them using xcorr plots between ICs at each timestep.

python xcorr_and_xcaus_plots.py

Then we compare the magnitude and sign of the corresponding entries in the cross-correlation and Jacobian matrices to identify where gradients are consistent with the dynamics, both with and without passing gradients through the environment.

python analysis/dynamics_grads_consistency_plots.py

Validating hypotheses by controlling the dynamics

If our hypotheses about the role of different directions in hidden-state space are correct, we should be able to make predictions about how the agent should behave when those directions are altered.

We can record the hidden states while either swapping different directions in hidden-state-space or collapsing directions into the nullspace so that the agent can't use those directions.

We can use the record_informinit_gen_samples.py script to do this.

By default, the CLI arguments for --swap_directions_from and --swap_directions_to are empty. If we want to swap the 10th hx direction with the 12th hx direction and at the same time collapse the 5th hx direction into the nullspace, we simply add the arguments

--swap_directions_from 10 5 --swap_directions_to 12 None

It's also advised to change the directory that the recordings get saved to in order not to overwrite previous data from the unaltered agent hx dynamics. To do this add something like:

--data_save_dir=generative/recorded_validations_swapping