Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Q: "S5 ComplexWarning: Casting complex values to real discards the imaginary part" intended? #3

Open
ConstantinRuhdorfer opened this issue Apr 5, 2024 · 0 comments

Comments

@ConstantinRuhdorfer
Copy link

Hi 👋 ,

Thaks for the repo!
I am currently testing out your implementation of S5. Sadly I am not very familiar with the S5 architecture.
When I run your code I get this warning:

~/.local/lib/python3.10/site-packages/jax/_src/lax/lax.py:2652: ComplexWarning: Casting complex values to real discards the imaginary part
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)

The warning originates in the PPO loss computation and is related to the complex parameters of the S5 model.
The command I am running is below.
Is this behavior intended? I tried reading up on the literature on S4 and S5 but it was not immediately obvious to me so I have little intuition around what it means to cast complex parameters to float.

Feedback is appreciated! Thanks!

python3 -m minimax.train \
--seed=1 \
--agent_rl_algo=ppo \
--n_total_updates=30000 \
--train_runner=plr \
--n_devices=1 \
--student_model_name=default_student_cnn \
--env_name=Maze \
--verbose=False \
--log_dir=~/logs/minimax \
--log_interval=10 \
--from_last_checkpoint=True \
--checkpoint_interval=1000 \
--archive_interval=0 \
--archive_init_checkpoint=False \
--test_interval=100 \
--n_students=1 \
--n_parallel=32 \
--n_eval=1 \
--n_rollout_steps=256 \
--lr=3e-05 \
--lr_anneal_steps=0 \
--max_grad_norm=0.5 \
--adam_eps=1e-05 \
--track_env_metrics=True \
--discount=0.999 \
--n_unroll_rollout=10 \
--render=False \
--ued_score=max_mc \
--plr_replay_prob=0.5 \
--plr_buffer_size=4000 \
--plr_staleness_coef=0.3 \
--plr_temp=0.3 \
--plr_use_score_ranks=True \
--plr_min_fill_ratio=0.5 \
--plr_use_robust_plr=True \
--plr_use_parallel_eval=False \
--plr_force_unique=True \
--student_gae_lambda=0.98 \
--student_entropy_coef=0.001 \
--student_value_loss_coef=0.5 \
--student_n_unroll_update=5 \
--student_ppo_n_epochs=5 \
--student_ppo_n_minibatches=1 \
--student_ppo_clip_eps=0.2 \
--student_ppo_clip_value_loss=True \
--student_recurrent_arch=s5 \
--student_recurrent_hidden_dim=256 \
--student_hidden_dim=32 \
--student_n_hidden_layers=1 \
--student_n_conv_filters=16 \
--student_n_scalar_embeddings=4 \
--student_scalar_embed_dim=5 \
--student_s5_n_blocks=2 \
--student_s5_n_layers=2 \
--student_s5_layernorm_pos=pre \
--student_s5_activation=half_glu1 \
--maze_height=13 \
--maze_width=13 \
--maze_n_walls=60 \
--maze_replace_wall_pos=True \
--maze_sample_n_walls=False \
--maze_see_agent=False \
--maze_normalize_obs=True \
--maze_obs_agent_pos=False \
--maze_max_episode_steps=250 \
--test_n_episodes=10 \
--test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \
--maze_test_see_agent=False \
--maze_test_normalize_obs=True \
--xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.3s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr3e-05g0.999cv0.5ce0.001e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lpr_ahg1_s5_h256nb2nl2_0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant