Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 24, 2024
1 parent da199a3 commit 90ede17
Show file tree
Hide file tree
Showing 11 changed files with 34 additions and 208 deletions.
2 changes: 1 addition & 1 deletion examples/manual_play/cyborg_action_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=True, decoy_state=True,
scanned_state=True, decoy_optimization=False)
csle_cyborg_env = CyborgScenarioTwoDefender(config=config)
for k, v in csle_cyborg_env.action_id_to_type_and_host.items():
for k, v in csle_cyborg_env.cyborg_action_id_to_type_and_host.items():
action_id = k
type, host = v
print(f"{action_id}, {BlueAgentActionType(type).name}, {host}")
41 changes: 0 additions & 41 deletions examples/manual_play/cyborg_parallel_policy_evaluation.py

This file was deleted.

62 changes: 0 additions & 62 deletions examples/manual_play/cyborg_rollout_test.py

This file was deleted.

58 changes: 0 additions & 58 deletions examples/manual_play/cyborg_rollout_three.py

This file was deleted.

7 changes: 4 additions & 3 deletions examples/manual_play/cyborg_rollout_two.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
actions = list(csle_cyborg_env.action_id_to_type_and_host.keys())
# for i in range(25):
import torch

torch.multiprocessing.set_start_method('spawn')
action_sequence = []
returns = []
Expand All @@ -39,8 +40,8 @@
R = 0
for fictitious_state, prob in belief.items():
r = csle_cyborg_env.parallel_rollout(policy_id=15, num_processes=1, num_evals_per_process=1,
max_horizon=1, state_id=fictitious_state)
R += r*prob
max_horizon=1, state_id=fictitious_state)
R += r * prob
action_values.append(R)
print(action_values)
a_idx = np.argmax(action_values)
Expand All @@ -55,4 +56,4 @@
o=o_id, env=csle_cyborg_env, action_sequence=action_sequence, num_particles=10, verbose=True)
belief = POMCPUtil.convert_samples_to_distribution(particles)
returns.append(total_R)
print(f"average return: {np.mean(returns)}")
print(f"average return: {np.mean(returns)}")
11 changes: 4 additions & 7 deletions examples/manual_play/learn_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import numpy as np
import io
from csle_common.metastore.metastore_facade import MetastoreFacade
from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig
from gym_csle_cyborg.dao.red_agent_type import RedAgentType
from gym_csle_cyborg.envs.cyborg_scenario_two_defender import CyborgScenarioTwoDefender
import csle_agents.constants.constants as constants
import json
from csle_agents.agents.pomcp.pomcp_util import POMCPUtil
import math

if __name__ == '__main__':
ppo_policy = MetastoreFacade.get_ppo_policy(id=22)
Expand Down Expand Up @@ -42,13 +39,15 @@
transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = 1
new_transitions += 1
else:
transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = transition_probabilities[",".join([str(s), str(s_prime), str(a)])] + 1
transition_probabilities[",".join([str(s), str(s_prime), str(a)])] = transition_probabilities[",".join(
[str(s), str(s_prime), str(a)])] + 1
if ",".join([str(s), str(s_prime), str(a)]) not in reward_function:
reward_function[",".join([str(s), str(s_prime), str(a)])] = r
if ",".join([str(s_prime), str(oid)]) not in observation_probabilities:
observation_probabilities[",".join([str(s_prime), str(oid)])] = 1
else:
observation_probabilities[",".join([str(s_prime), str(oid)])] = observation_probabilities[",".join([str(s_prime), str(oid)])] + 1
observation_probabilities[",".join([str(s_prime), str(oid)])] = observation_probabilities[",".join(
[str(s_prime), str(oid)])] + 1
t_count += 1
print(f"new transitions: {new_transitions}")

Expand All @@ -63,5 +62,3 @@
json_str = json.dumps(model, indent=4, sort_keys=True)
with io.open(f"/home/kim/cyborg_model_{i}.json", 'w', encoding='utf-8') as f:
f.write(json_str)


Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
descr="maximum number of negative samples when filling belief particles"),
agents_constants.POMCP.PARALLEL_ROLLOUT: HParam(
value=False, name=agents_constants.POMCP.PARALLEL_ROLLOUT, descr="boolean flag indicating whether "
"parallel rollout should be used"),
"parallel rollout should be used"),
agents_constants.POMCP.NUM_PARALLEL_PROCESSES: HParam(
value=5, name=agents_constants.POMCP.NUM_PARALLEL_PROCESSES, descr="number of parallel processes"),
agents_constants.POMCP.NUM_EVALS_PER_PROCESS: HParam(
Expand All @@ -99,6 +99,7 @@
player_type=PlayerType.DEFENDER, player_idx=0
)
import torch

torch.multiprocessing.set_start_method('spawn')
agent = POMCPAgent(emulation_env_config=emulation_env_config, simulation_env_config=simulation_env_config,
experiment_config=experiment_config, save_to_metastore=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
raise ValueError(f"Could not find a simulation with name: {simulation_name}")
experiment_config = ExperimentConfig(
output_dir=f"{constants.LOGGING.DEFAULT_LOG_DIR}ppo_test",
title="Cardiff PPO Cyborg BLine", random_seeds=[399], agent_type=AgentType.PPO,
title="Cardiff PPO Cyborg Meander", random_seeds=[399], agent_type=AgentType.PPO,
log_every=1,
hparams={
constants.NEURAL_NETWORKS.NUM_NEURONS_PER_HIDDEN_LAYER: HParam(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def ucb_acquisition_function(action: "Node", c: float, rollout_policy: Union[Pol
if action.visit_count == 0:
return np.inf
else:
return action.value + (prior_weight*prior_weight)/action.visit_count
return action.value + (prior_weight * prior_weight) / action.visit_count
# prior = 1.0
# if rollout_policy is not None:
# prior = rollout_policy.probability(o=o, a=action.action)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from csle_common.dao.simulation_config.base_env import BaseEnv
from csle_common.dao.simulation_config.simulation_trace import SimulationTrace
from csle_common.metastore.metastore_facade import MetastoreFacade
from csle_common.logging.log import Logger
import gym_csle_cyborg.constants.constants as env_constants
from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig
from gym_csle_cyborg.dao.blue_agent_action_type import BlueAgentActionType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ def update_red_agent(config: CSLECyborgConfig, current_red_agent: RedAgentType,

@staticmethod
def setup_cyborg_env(config: CSLECyborgConfig) \
-> Tuple[str, ChallengeWrapper, List[str], Dict[str, int], List[str], Dict[str, int],
Dict[int, Tuple[BlueAgentActionType, str]],
Dict[Tuple[BlueAgentActionType, str], int], RedAgentType]:
-> Tuple[str, ChallengeWrapper, List[str], Dict[str, int], List[str], Dict[str, int], Dict[
int, Tuple[BlueAgentActionType, str]], Dict[Tuple[BlueAgentActionType, str], int], RedAgentType]:
"""
Sets up the cyborg environment and associated metadata
Expand Down Expand Up @@ -346,11 +345,9 @@ def state_to_vector(state: List[List[Any]], decoy_state: List[List[BlueAgentActi
host_access = 3
host_decoy_state = len(decoy_state[host_id])
if not observation:
state_vector.append([host_access])
# state_vector.append([host_known, host_scanned, host_access, host_decoy_state])
state_vector.append([host_known, host_scanned, host_access, host_decoy_state])
else:
state_vector.append([activity, host_access])
# state_vector.append([activity, host_scanned, host_access, host_decoy_state])
state_vector.append([activity, host_scanned, host_access, host_decoy_state])
return state_vector

@staticmethod
Expand All @@ -370,10 +367,10 @@ def state_vector_to_state_id(state_vector: List[List[int]], observation: bool =
if not observation:
if i == 0:
host_binary_id_str += format(elem, '02b')
# if i == 0:
# host_binary_id_str += format(elem, '01b')
# if i == 1:
# host_binary_id_str += format(elem, '01b')
if i == 0:
host_binary_id_str += format(elem, '01b')
if i == 1:
host_binary_id_str += format(elem, '01b')
else:
if i == 0:
host_binary_id_str += format(elem, '02b')
Expand All @@ -398,33 +395,25 @@ def state_id_to_state_vector(state_id: int, observation: bool = False) -> List[L
:return: the state vector
"""
if not observation:
# binary_id_str = format(state_id, "091b")
binary_id_str = format(state_id, "026b")
host_binary_ids_str = [binary_id_str[i:i + 2] for i in range(0, len(binary_id_str), 2)]
binary_id_str = format(state_id, "091b")
host_binary_ids_str = [binary_id_str[i:i + 7] for i in range(0, len(binary_id_str), 7)]
else:
# binary_id_str = format(state_id, "0117b")
binary_id_str = format(state_id, "052b")
# host_binary_ids_str = [binary_id_str[i:i + 9] for i in range(0, len(binary_id_str), 9)]
host_binary_ids_str = [binary_id_str[i:i + 4] for i in range(0, len(binary_id_str), 4)]
binary_id_str = format(state_id, "0117b")
host_binary_ids_str = [binary_id_str[i:i + 9] for i in range(0, len(binary_id_str), 9)]
state_vector = []
for host_bin in host_binary_ids_str:
if not observation:
access = int(host_bin[0:2], 2)
# known = int(host_bin[0:1], 2)
# scanned = int(host_bin[1:2], 2)
# access = int(host_bin[2:4], 2)
# decoy = int(host_bin[4:7], 2)
host_vector = [access]
# host_vector = [known, scanned, access, decoy]
known = int(host_bin[0:1], 2)
scanned = int(host_bin[1:2], 2)
access = int(host_bin[2:4], 2)
decoy = int(host_bin[4:7], 2)
host_vector = [known, scanned, access, decoy]
else:
activity = int(host_bin[0:2], 2)
access = int(host_bin[2:4], 2)
# activity = int(host_bin[0:2], 2)
# scanned = int(host_bin[2:4], 2)
# access = int(host_bin[4:6], 2)
# decoy = int(host_bin[6:9], 2)
host_vector = [activity, access]
# host_vector = [activity, scanned, access, decoy]
scanned = int(host_bin[2:4], 2)
access = int(host_bin[4:6], 2)
decoy = int(host_bin[6:9], 2)
host_vector = [activity, scanned, access, decoy]
state_vector.append(host_vector)
return state_vector

Expand Down

0 comments on commit 90ede17

Please sign in to comment.