Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 29, 2024
1 parent 2ccd30d commit 0a0e37c
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from gym_csle_stopping_game.util.stopping_game_util import StoppingGameUtil
from csle_agents.common.objective_type import ObjectiveType
from csle_common.dao.training.random_policy import RandomPolicy
from csle_common.dao.training.multi_threshold_stopping_policy import MultiThresholdStoppingPolicy
from gym_csle_stopping_game.dao.stopping_game_config import StoppingGameConfig
from gym_csle_stopping_game.dao.stopping_game_defender_pomdp_config import StoppingGameDefenderPomdpConfig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,14 @@ def simulate(self, state: int, max_rollout_depth: int, c: float, history: List[i
# If a new node was created, then it has no children, in which case we should stop the search and
# do a Monte-Carlo rollout with a given base policy to estimate the value of the node
if not current_node.children:
import torch
import math
# since the node does not have any children, we first add them to the node
obs_vector = self.env.get_observation_from_history(current_node.history)
dist = self.rollout_policy.model.policy.get_distribution(obs=torch.tensor([obs_vector]).to(self.rollout_policy.model.device)).log_prob(
torch.tensor(self.A).to(self.rollout_policy.model.device)).cpu().detach().numpy()
dist = list(map(lambda i: (math.exp(dist[i]), self.A[i]), list(range(len(dist)))))
rollout_actions = list(map(lambda x: x[1], sorted(dist, reverse=True, key=lambda x: x[0])[:3]))
# rollout_actions = [27, 28, 29, 30, 31, 32, 33, 34, 35]
# obs_vector = self.env.get_observation_from_history(current_node.history)
# dist = self.rollout_policy.model.policy.get_distribution(
# obs=torch.tensor([obs_vector]).to(self.rollout_policy.model.device)).log_prob(
# torch.tensor(self.A).to(self.rollout_policy.model.device)).cpu().detach().numpy()
# dist = list(map(lambda i: (math.exp(dist[i]), self.A[i]), list(range(len(dist)))))
# rollout_actions = list(map(lambda x: x[1], sorted(dist, reverse=True, key=lambda x: x[0])[:3]))
rollout_actions = self.A
# for action in self.A:
for action in rollout_actions:
self.tree.add(history + [action], parent=current_node, action=action, value=self.default_node_value)
Expand All @@ -179,8 +178,8 @@ def simulate(self, state: int, max_rollout_depth: int, c: float, history: List[i
return float(R), depth
else:
self.env.set_state(state=state)
return self.rollout(state=state, history=history, depth=depth, max_rollout_depth=max_rollout_depth), \
depth
return (self.rollout(state=state, history=history, depth=depth, max_rollout_depth=max_rollout_depth),
depth)

# If we have not yet reached a new node, we select the next action according to the
# UCB strategy
Expand Down Expand Up @@ -245,8 +244,9 @@ def solve(self, max_rollout_depth: int, max_planning_depth: int) -> None:
while time.time() - begin < self.planning_time:
n += 1
state = self.tree.root.sample_state()
_, depth = self.simulate(state=state, max_rollout_depth=max_rollout_depth, history=self.tree.root.history, c=self.c,
parent=self.tree.root, max_planning_depth=max_planning_depth, depth=0)
_, depth = self.simulate(state=state, max_rollout_depth=max_rollout_depth, history=self.tree.root.history,
c=self.c,
parent=self.tree.root, max_planning_depth=max_planning_depth, depth=0)
if self.verbose:
action_values = np.zeros((len(self.A),))
best_action_idx = 0
Expand All @@ -262,7 +262,7 @@ def solve(self, max_rollout_depth: int, max_planning_depth: int) -> None:
f"value: {self.tree.root.children[best_action_idx].value}, "
f"count: {self.tree.root.children[best_action_idx].visit_count}, "
f"planning depth: {depth}")
#, 31:{action_values[31]}
# , 31:{action_values[31]}

def get_action(self) -> int:
"""
Expand Down Expand Up @@ -340,7 +340,6 @@ def update_tree_with_new_samples(self, action_sequence: List[int], observation:
Logger.__call__().get_logger().info(f"{observation in self.particle_model}, {observation}")
if self.particle_model is not None and observation in self.particle_model:
particles = self.particle_model[observation]
Logger.__call__().get_logger().info(f"Got particles from particle model")
else:
# fill particles by Monte-Carlo using reject sampling
while len(particles) < particle_slots:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def ucb_acquisition_function(action: "Node", c: float, rollout_policy: Union[Pol
prior = 1.0
# if rollout_policy is not None:
# prior = rollout_policy.probability(o=o, a=action.action)
return float(action.value + prior*prior_weight
return float(action.value + prior * prior_weight
+ c * POMCPUtil.ucb(action.parent.visit_count, action.visit_count))

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def __init__(self, model: Union[None, PPO, PPONetwork], simulation_name: str, sa
self.avg_R = avg_R
self.policy_type = PolicyType.PPO

def action(self, o: Union[List[float], List[int]], deterministic: bool = True) -> Union[int, float, npt.NDArray[Any]]:
def action(self, o: Union[List[float], List[int]], deterministic: bool = True) \
-> Union[int, float, npt.NDArray[Any]]:
"""
Multi-threshold stopping policy
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Dict, Any
from csle_common.dao.simulation_config.simulation_env_input_config import SimulationEnvInputConfig


class CSLECyborgWrapperConfig(SimulationEnvInputConfig):
"""
DTO representing the input configuration to a gym-csle-cyborg environment
"""

def __init__(self, maximum_steps: int, gym_env_name: str, save_trace: bool = False):
"""
Initializes the DTO
:param maximum_steps: the maximum number of steps in the environment
:param gym_env_name: the name of the gym environment
:param save_trace: boolean flag indicating whether traces should be saved
"""
self.maximum_steps = maximum_steps
self.gym_env_name = gym_env_name
self.save_trace = save_trace

def to_dict(self) -> Dict[str, Any]:
"""
Converts the object to a dict representation
:return: a dict representation of the object
"""
d: Dict[str, Any] = {}
d["baseline_red_agents"] = self.maximum_steps
d["gym_env_name"] = self.gym_env_name
d["save_trace"] = self.save_trace
return d

@staticmethod
def from_dict(d: Dict[str, Any]) -> "CSLECyborgWrapperConfig":
"""
Converts a dict representation to an instance
:param d: the dict to convert
:return: the created instance
"""
obj = CSLECyborgWrapperConfig(gym_env_name=d["gym_env_name"], maximum_steps=d["maximum_steps"],
save_trace=d["save_trace"])
return obj

def __str__(self) -> str:
"""
:return: a string representation of the object
"""
return f"gym_env_name: {self.gym_env_name}, maximum_steps: {self.maximum_steps}, save_trace: {self.save_trace}"

@staticmethod
def from_json_file(json_file_path: str) -> "CSLECyborgWrapperConfig":
"""
Reads a json file and converts it to a DTO
:param json_file_path: the json file path
:return: the converted DTO
"""
import io
import json
with io.open(json_file_path, 'r') as f:
json_str = f.read()
return CSLECyborgWrapperConfig.from_dict(json.loads(json_str))
Original file line number Diff line number Diff line change
Expand Up @@ -122,16 +122,14 @@ def step(self, action: int) -> Tuple[npt.NDArray[Any], float, bool, bool, Dict[s

o, r, done, _, info = self.cyborg_challenge_env.step(action=action)
if not self.config.decoy_optimization:
info, _, scan_state = CyborgScenarioTwoDefender.populate_info(info=dict(info), obs=o, trace=self.trace,
env=self.cyborg_challenge_env,
cyborg_hostnames=self.cyborg_hostnames,
scan_state=self.scan_state,
decoy_state=self.decoy_state, config=self.config,
cyborg_hostname_to_id=self.cyborg_hostname_to_id,
visited_cyborg_states=self.visited_cyborg_states,
visited_scanned_states=self.visited_scanned_states,
visited_decoy_states=self.visited_decoy_states,
reset=False)
info, _, scan_state = \
CyborgScenarioTwoDefender.populate_info(
info=dict(info), obs=o, trace=self.trace, env=self.cyborg_challenge_env,
cyborg_hostnames=self.cyborg_hostnames, scan_state=self.scan_state, decoy_state=self.decoy_state,
config=self.config, cyborg_hostname_to_id=self.cyborg_hostname_to_id,
visited_cyborg_states=self.visited_cyborg_states,
visited_scanned_states=self.visited_scanned_states, visited_decoy_states=self.visited_decoy_states,
reset=False)
self.scan_state = scan_state
o, observation_id_to_tensor = CyborgScenarioTwoDefender.encode_observation(
config=self.config, info=info, decoy_state=self.decoy_state, scan_state=self.scan_state,
Expand Down
Loading

0 comments on commit 0a0e37c

Please sign in to comment.