Skip to content

Commit

Permalink
POMCP [WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Jan 17, 2024
1 parent 2ffca0a commit c9de2d7
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 62 deletions.
137 changes: 98 additions & 39 deletions examples/manual_play/cyborg_test.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,106 @@
import random
from gym_csle_cyborg.dao.csle_cyborg_config import CSLECyborgConfig
from gym_csle_cyborg.dao.red_agent_type import RedAgentType
from gym_csle_cyborg.envs.cyborg_scenario_two_defender import CyborgScenarioTwoDefender


def info_to_vec(info, decoy_state, hosts):
"""
Creates the state vector
:param info: the info
:param decoy_state: the decoy state
:param hosts: the host list
:return: the state vector
"""
state_vec = []
for host in hosts:
known = info[host][3]
known = int(known)
scanned = info[host][4]
scanned = int(scanned)
access = info[host][5]
if access == "None":
access = 0
elif access == "User":
access = 1
else:
access = 2
d_state = len(decoy_state[host])
state_vec.append([known, scanned, access, d_state])
return state_vec


def state_vec_to_id(state_vec):
"""
Converts a state vector to an id
:param state_vec: the state vector to convert
:return: the id
"""
bin_id = ""
for host_vec in state_vec:
host_bin_str = ""
for i, elem in enumerate(host_vec):
if i == 0:
host_bin_str += format(elem, '01b')
if i == 1:
host_bin_str += format(elem, '01b')
if i == 2:
host_bin_str += format(elem, '02b')
if i == 3:
host_bin_str += format(elem, '03b')
bin_id += host_bin_str
id = int(bin_id, 2)
return id


def id_to_state_vec(id: int):
"""
Converts an id to a state vector
:param id: the id to convert
:return: the state vector
"""
bin_str = format(id, "091b")
host_bins = [bin_str[i:i + 7] for i in range(0, len(bin_str), 7)]
state_vec = []
for host_bin in host_bins:
known = int(host_bin[0:1], 2)
scanned = int(host_bin[1:2], 2)
access = int(host_bin[2:4], 2)
decoy = int(host_bin[4:7], 2)
host_vec = [known, scanned, access, decoy]
state_vec.append(host_vec)
return state_vec


if __name__ == '__main__':
config = CSLECyborgConfig(
gym_env_name="csle-cyborg-scenario-two-v1", scenario=2, baseline_red_agents=[RedAgentType.B_LINE_AGENT],
maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=False, decoy_state=False,
scanned_state=False, decoy_optimization=False)
maximum_steps=100, red_agent_distribution=[1.0], reduced_action_space=True, decoy_state=True,
scanned_state=True, decoy_optimization=False)
csle_cyborg_env = CyborgScenarioTwoDefender(config=config)
print(csle_cyborg_env.get_table())
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[24])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[14])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[22])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[141])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[133])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[144])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[43])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[131])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[38])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[28])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[119])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[55])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[107])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[120])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[29])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[43])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[44])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[61])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[35])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[113])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[126])
# csle_cyborg_env.step(1)
# csle_cyborg_env.step(1)
# print(csle_cyborg_env.get_table())
# print(csle_cyborg_env.get_true_table())
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[44])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[107])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[104])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[126])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[120])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[50])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[57])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[2])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[23])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[6])
# print(csle_cyborg_env.cyborg_action_id_to_type_and_host[133])
str_info = str(csle_cyborg_env.cyborg_challenge_env.env.env.env.info)
states = {}
state_idx = 0
host_state_lookup = host_state_to_id(hostnames=csle_cyborg_env.cyborg_hostnames)
host_ids = list(csle_cyborg_env.cyborg_hostname_to_id.values())

for i in range(100000):
done = False
csle_cyborg_env.reset()
actions = list(csle_cyborg_env.action_id_to_type_and_host.keys())
state_key = str(csle_cyborg_env.cyborg_challenge_env.env.env.env.info)
if state_key not in states:
states[state_key] = state_idx
state_idx += 1

while not done:
a = random.choice(actions)
o, r, done, _, info = csle_cyborg_env.step(a)
state_vec = info_to_vec(csle_cyborg_env.get_true_table().rows, csle_cyborg_env.decoy_state,
host_state_lookup, host_ids)
state_key = state_vec_to_id(state_vec=state_vec)
stv = id_to_state_vec(id=state_key)
assert stv == state_vec
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
value=rollout_policy, name=agents_constants.POMCP.ROLLOUT_POLICY,
descr="the policy to use for rollouts"),
agents_constants.POMCP.VALUE_FUNCTION: HParam(
value=None, name=agents_constants.POMCP.VALUE_FUNCTION,
value=lambda x: 0, name=agents_constants.POMCP.VALUE_FUNCTION,
descr="the value function to use for truncated rollouts"),
agents_constants.POMCP.S: HParam(value=S, name=agents_constants.POMCP.S, descr="the state space"),
agents_constants.POMCP.O: HParam(value=O, name=agents_constants.POMCP.O, descr="the observation space"),
Expand All @@ -82,13 +82,13 @@
descr="the discount factor"),
agents_constants.POMCP.INITIAL_BELIEF: HParam(value=b1, name=agents_constants.POMCP.INITIAL_BELIEF,
descr="the initial belief"),
agents_constants.POMCP.PLANNING_TIME: HParam(value=60, name=agents_constants.POMCP.PLANNING_TIME,
agents_constants.POMCP.PLANNING_TIME: HParam(value=120, name=agents_constants.POMCP.PLANNING_TIME,
descr="the planning time"),
agents_constants.POMCP.MAX_PARTICLES: HParam(value=100, name=agents_constants.POMCP.MAX_PARTICLES,
descr="the maximum number of belief particles"),
agents_constants.POMCP.MAX_DEPTH: HParam(value=500, name=agents_constants.POMCP.MAX_DEPTH,
descr="the maximum depth for planning"),
agents_constants.POMCP.C: HParam(value=0.2, name=agents_constants.POMCP.C,
agents_constants.POMCP.C: HParam(value=0.35, name=agents_constants.POMCP.C,
descr="the weighting factor for UCB exploration"),
agents_constants.POMCP.LOG_STEP_FREQUENCY: HParam(
value=1, name=agents_constants.POMCP.LOG_STEP_FREQUENCY, descr="frequency of logging time-steps"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def sample_state(self) -> int:
:return: the sampled state
"""
return POMCPUtil.rand_choice(self.particles)
sample = POMCPUtil.rand_choice(self.particles)
return int(sample)

def add_particle(self, particle: Union[int, List[int]]) -> None:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def add(self, history: List[int], parent: Union[ActionNode, BeliefNode, None], a
new_node: Node = ActionNode(self.tree_size, history, parent=parent, action=action)
else:
if observation is None:
raise ValueError("Invalid observation")
observation = 0
new_node = BeliefNode(self.tree_size, history, parent=parent, observation=observation)

if particle is not None and isinstance(new_node, BeliefNode):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Union, Callable, Any
import time
import numpy as np
from csle_common.dao.simulation_config.base_env import BaseEnv
Expand All @@ -17,7 +17,8 @@ class POMCP:

def __init__(self, S: List[int], O: List[int], A: List[int], gamma: float, env: BaseEnv, c: float,
initial_belief: List[float], planning_time: float = 0.5, max_particles: int = 350,
reinvigorated_particles_ratio: float = 0.1, rollout_policy: Union[Policy, None] = None) -> None:
reinvigorated_particles_ratio: float = 0.1, rollout_policy: Union[Policy, None] = None,
value_function: Union[Callable[[Any], float], None] = None) -> None:
"""
Initializes the solver
Expand All @@ -43,6 +44,7 @@ def __init__(self, S: List[int], O: List[int], A: List[int], gamma: float, env:
self.max_particles = max_particles
self.reinvigorated_particles_ratio = reinvigorated_particles_ratio
self.rollout_policy = rollout_policy
self.value_function = value_function
root_particles = POMCPUtil.generate_particles(
states=self.S, num_particles=self.max_particles, probability_vector=initial_belief)
self.tree = BeliefTree(root_particles=root_particles)
Expand Down Expand Up @@ -71,7 +73,11 @@ def rollout(self, state: int, history: List[int], depth: int, max_depth: int) ->
:return: the estimated value of the root node
"""
if depth > max_depth:
return 0
if self.value_function is not None:
o = self.env.get_observation_from_history(history=history)
return self.value_function(o)
else:
return 0
if self.rollout_policy is None or self.env.is_state_terminal(state):
a = POMCPUtil.rand_choice(self.A)
else:
Expand Down Expand Up @@ -99,10 +105,16 @@ def simulate(self, state: int, max_depth: int, c: float, history: List[int], dep

# Check if we have reached the maximum depth of the tree
if depth > max_depth:
return 0
if len(history) > 0 and self.value_function is not None:
o = self.env.get_observation_from_history(history=history)
return self.value_function(o)
else:
return 0

# Check if the new history has already been visited in the past of should be added as a new node to the tree
observation = history[-1]
observation = -1
if len(history) > 0:
observation = history[-1]
current_node = self.tree.find_or_create(history=history, parent=parent, observation=observation)

# If a new node was created, then it has no children, in which case we should stop the search and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,8 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
:return: the updated experiment result and the trained policy
"""
start: float = time.time()
# objective_type = self.experiment_config.hparams[agents_constants.POMCP.OBJECTIVE_TYPE].value
rollout_policy = self.experiment_config.hparams[agents_constants.POMCP.ROLLOUT_POLICY].value
# value_function = self.experiment_config.hparams[agents_constants.POMCP.VALUE_FUNCTION].value
value_function = self.experiment_config.hparams[agents_constants.POMCP.VALUE_FUNCTION].value
log_steps_frequency = self.experiment_config.hparams[agents_constants.POMCP.LOG_STEP_FREQUENCY].value
max_env_steps = self.experiment_config.hparams[agents_constants.COMMON.MAX_ENV_STEPS].value
N = self.experiment_config.hparams[agents_constants.POMCP.N].value
Expand All @@ -210,7 +209,8 @@ def pomcp(self, exp_result: ExperimentResult, seed: int,
train_env.reset()
belief = b1.copy()
pomcp = POMCP(S=S, O=O, A=A, gamma=gamma, env=train_env, c=c, initial_belief=belief,
planning_time=planning_time, max_particles=max_particles, rollout_policy=rollout_policy)
planning_time=planning_time, max_particles=max_particles, rollout_policy=rollout_policy,
value_function=value_function)
R = 0
t = 1
if t % log_steps_frequency == 0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Union
from typing import List, Dict, Union, Any
import numpy as np
from csle_agents.agents.pomcp.node import Node
from collections import Counter
Expand All @@ -23,14 +23,14 @@ def sample_from_distribution(probability_vector: List[float]) -> int:
return int(sample)

@staticmethod
def rand_choice(candidates: List[int]) -> int:
def rand_choice(candidates: List[int]) -> Any:
"""
Selects an element from a given list uniformly at random
:param candidates: the list to sample from
:return: the sample
"""
return int(np.random.choice(candidates))
return np.random.choice(candidates)

@staticmethod
def convert_samples_to_distribution(samples) -> Dict[int, float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ class EmulationAttackerActionId(IntEnum):
MYSQL_SAME_USER_PASS_DICTIONARY_HOST = 16
SMTP_SAME_USER_PASS_DICTIONARY_HOST = 17
POSTGRES_SAME_USER_PASS_DICTIONARY_HOST = 18

TCP_SYN_STEALTH_SCAN_ALL = 19
PING_SCAN_ALL = 20
UDP_PORT_SCAN_ALL = 21
Expand All @@ -44,10 +43,8 @@ class EmulationAttackerActionId(IntEnum):
MYSQL_SAME_USER_PASS_DICTIONARY_ALL = 35
SMTP_SAME_USER_PASS_DICTIONARY_ALL = 36
POSTGRES_SAME_USER_PASS_DICTIONARY_ALL = 37

NETWORK_SERVICE_LOGIN = 38
FIND_FLAG = 39

NIKTO_WEB_HOST_SCAN = 40
MASSCAN_HOST_SCAN = 41
MASSCAN_ALL_SCAN = 42
Expand All @@ -59,10 +56,8 @@ class EmulationAttackerActionId(IntEnum):
HTTP_GREP_ALL = 48
FINGER_HOST = 49
FINGER_ALL = 50

INSTALL_TOOLS = 51
SSH_BACKDOOR = 52

SAMBACRY_EXPLOIT = 53
SHELLSHOCK_EXPLOIT = 54
DVWA_SQL_INJECTION = 55
Expand All @@ -71,6 +66,5 @@ class EmulationAttackerActionId(IntEnum):
CVE_2016_10033_EXPLOIT = 58
CVE_2010_0426_PRIV_ESC = 59
CVE_2015_5602_PRIV_ESC = 60

STOP = 61
CONTINUE = 62
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,15 @@ def set_state(self, state: Any) -> None:
"""
raise NotImplementedError("This environment does not support the set_state method")

def get_observation_from_history(self, history: List[int]) -> List[Any]:
"""
Utility function to get a defender observation from a history
:param history: the history to get the observation form
:return: the observation
"""
raise NotImplementedError("This environment does not support the get_observation_from_history method")

def manual_play(self) -> None:
"""
An interactive loop to test the environment manually
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def get_observation_from_history(self, history: List[int], pi2: npt.NDArray[Any]
a1 = history[t]
o = history[t + 1]
b = StoppingGameUtil.next_belief(o=o, a1=a1, b=b, pi2=pi2, config=self.config, l=l, a2=0)
l = l - a1
l = max(l - a1, 0)
t += 2
return [l, b[1]]

Expand Down

0 comments on commit c9de2d7

Please sign in to comment.