Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
QuanyiLi committed Feb 17, 2024
1 parent de90f4a commit 036be49
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 19 deletions.
2 changes: 1 addition & 1 deletion metadrive/engine/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def step(self, step_num: int = 1) -> None:

if self.force_fps.real_time_simulation and i < step_num - 1:
self.task_manager.step()

# Do rendering
self.task_manager.step()
if self.on_screen_message is not None:
Expand Down
10 changes: 6 additions & 4 deletions metadrive/manager/scenario_data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,19 @@ def _score(scenario_id):
id_score_scenarios = [(s_id, *_score(s_id)) for s_id in self.summary_lookup[start:end]]
id_score_scenarios = sorted(id_score_scenarios, key=lambda scenario: scenario[-2])
self.summary_lookup[start:end] = [id_score_scenario[0] for id_score_scenario in id_score_scenarios]
self.scenario_difficulty = {id_score_scenario[0]: id_score_scenario[1] for id_score_scenario in
id_score_scenarios}
self._scenarios = {i+start: id_score_scenario[-1] for i, id_score_scenario in enumerate(id_score_scenarios)}
self.scenario_difficulty = {
id_score_scenario[0]: id_score_scenario[1]
for id_score_scenario in id_score_scenarios
}
self._scenarios = {i + start: id_score_scenario[-1] for i, id_score_scenario in enumerate(id_score_scenarios)}

def clear_stored_scenarios(self):
self._scenarios = {}

@property
def current_scenario_difficulty(self):
return self.scenario_difficulty[self.summary_lookup[self.engine.global_random_seed]
] if self.scenario_difficulty is not None else 0
] if self.scenario_difficulty is not None else 0

@property
def current_scenario_id(self):
Expand Down
16 changes: 8 additions & 8 deletions metadrive/scenario/scenario_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def sanity_check(cls, scenario_dict, check_self_type=False, valid_check=False):
)
# position heading check
assert ScenarioDescription.HEADING in obj_state[ScenarioDescription.STATE
], "heading is required for an object"
], "heading is required for an object"
assert ScenarioDescription.POSITION in obj_state[ScenarioDescription.STATE
], "position is required for an object"
], "position is required for an object"

# Check dynamic_map_state
assert isinstance(scenario_dict[cls.DYNAMIC_MAP_STATES], dict)
Expand All @@ -254,7 +254,7 @@ def sanity_check(cls, scenario_dict, check_self_type=False, valid_check=False):
"You lack these keys in metadata: {}".format(
cls.METADATA_KEYS.difference(set(scenario_dict[cls.METADATA].keys()))
)
assert scenario_dict[cls.METADATA][cls.TIMESTEP].shape == (scenario_length,)
assert scenario_dict[cls.METADATA][cls.TIMESTEP].shape == (scenario_length, )

@classmethod
def _check_map_features(cls, map_feature):
Expand Down Expand Up @@ -299,7 +299,7 @@ def _check_object_state_dict(cls, obj_state, scenario_length, object_id, valid_c
assert state_array.ndim in [1, 2], "Haven't implemented test array with dim {} yet".format(state_array.ndim)
if state_array.ndim == 2:
assert state_array.shape[
1] != 0, "Please convert all state with dim 1 to a 1D array instead of 2D array."
1] != 0, "Please convert all state with dim 1 to a 1D array instead of 2D array."

if state_key == "valid" and valid_check:
assert np.sum(state_array) >= 1, "No frame valid for this object. Consider removing it"
Expand Down Expand Up @@ -488,16 +488,16 @@ def get_number_summary(scenario):
dynamic_object_states_types.add(step_state)
dynamic_object_states_counter[step_state] += 1
number_summary_dict[ScenarioDescription.SUMMARY.NUM_TRAFFIC_LIGHTS
] = len(scenario[ScenarioDescription.DYNAMIC_MAP_STATES])
] = len(scenario[ScenarioDescription.DYNAMIC_MAP_STATES])
number_summary_dict[ScenarioDescription.SUMMARY.NUM_TRAFFIC_LIGHT_TYPES] = dynamic_object_states_types
number_summary_dict[ScenarioDescription.SUMMARY.NUM_TRAFFIC_LIGHTS_EACH_STEP
] = dict(dynamic_object_states_counter)
] = dict(dynamic_object_states_counter)

# map
number_summary_dict[ScenarioDescription.SUMMARY.NUM_MAP_FEATURES
] = len(scenario[ScenarioDescription.MAP_FEATURES])
] = len(scenario[ScenarioDescription.MAP_FEATURES])
number_summary_dict[ScenarioDescription.SUMMARY.MAP_HEIGHT_DIFF
] = ScenarioDescription.map_height_diff(scenario[ScenarioDescription.MAP_FEATURES])
] = ScenarioDescription.map_height_diff(scenario[ScenarioDescription.MAP_FEATURES])
return number_summary_dict

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,9 @@ def test_export_metadrive_scenario_easy(num_scenarios=5, render_export_env=False
shutil.rmtree(dir1)

for scenario_id in scenarios_restored:
o = scenarios_restored[scenario_id]["metadata"]["history_metadata"].get("old_origin_in_current_coordinate",
np.array([0, 0]))
o = scenarios_restored[scenario_id]["metadata"]["history_metadata"].get(
"old_origin_in_current_coordinate", np.array([0, 0])
)
scenarios_restored[scenario_id] = SD.offset_scenario_with_new_origin(scenarios_restored[scenario_id], o)

assert_scenario_equal(scenarios, scenarios_restored, only_compare_sdc=False)
Expand Down Expand Up @@ -175,8 +176,9 @@ def test_export_metadrive_scenario_hard(start_seed=0, num_scenarios=3, render_ex
shutil.rmtree(dir1)

for scenario_id in scenarios_restored:
o = scenarios_restored[scenario_id]["metadata"]["history_metadata"].get("old_origin_in_current_coordinate",
np.array([0, 0]))
o = scenarios_restored[scenario_id]["metadata"]["history_metadata"].get(
"old_origin_in_current_coordinate", np.array([0, 0])
)
scenarios_restored[scenario_id] = SD.offset_scenario_with_new_origin(scenarios_restored[scenario_id], o)

assert_scenario_equal(scenarios, scenarios_restored, only_compare_sdc=False)
Expand Down Expand Up @@ -391,8 +393,9 @@ def test_waymo_export_and_original_consistency(num_scenarios=3, render_export_en
policy, scenario_index=[i for i in range(num_scenarios)], verbose=True
)
for scenario_id in scenarios:
o = scenarios[scenario_id]["metadata"]["history_metadata"].get("old_origin_in_current_coordinate",
np.array([0, 0]))
o = scenarios[scenario_id]["metadata"]["history_metadata"].get(
"old_origin_in_current_coordinate", np.array([0, 0])
)
scenarios[scenario_id] = SD.offset_scenario_with_new_origin(scenarios[scenario_id], o)
compare_exported_scenario_with_origin(scenarios, env.engine.data_manager)
finally:
Expand Down

0 comments on commit 036be49

Please sign in to comment.