Skip to content

Commit

Permalink
fix expert bug, cover side_detector and lane_line_detector setting (#663
Browse files Browse the repository at this point in the history
)
  • Loading branch information
CarlDegio authored Mar 20, 2024
1 parent e2999f0 commit 6bf4267
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions metadrive/examples/ppo_expert/numpy_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def expert(vehicle, deterministic=False, need_obs=False):
global _expert_observation
expert_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=4, gaussian_noise=0.0, dropout_prob=0.0),
side_detector=dict(num_lasers=0, distance=50, gaussian_noise=0.0, dropout_prob=0.0),
lane_line_detector=dict(num_lasers=0, distance=20, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=0, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = vehicle.config.copy()
# TODO: some setting in origin cfg will not be covered, then they may change the obs shape

if _expert_weights is None:
_expert_weights = np.load(ckpt_path)
Expand Down
8 changes: 4 additions & 4 deletions metadrive/examples/ppo_expert/torch_expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def torch_expert(vehicle, deterministic=False, need_obs=False):
global _expert_observation
expert_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=4, gaussian_noise=0.0, dropout_prob=0.0),
side_detector=dict(num_lasers=0, distance=50, gaussian_noise=0.0, dropout_prob=0.0),
lane_line_detector=dict(num_lasers=0, distance=20, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = dict(
lidar=dict(num_lasers=240, distance=50, num_others=0, gaussian_noise=0.0, dropout_prob=0.0),
random_agent_model=False
)
origin_obs_cfg = vehicle.config.copy()
# TODO: some setting in origin cfg will not be covered, then they may change the obs shape
with torch.no_grad(): # Disable gradient computation
if _expert_weights is None:
_expert_weights = numpy_to_torch(np.load(ckpt_path), device)
Expand Down

0 comments on commit 6bf4267

Please sign in to comment.