Skip to content

Commit

Permalink
Switches logging to Logger functionality from Database
Browse files Browse the repository at this point in the history
Also changes default value for DbStorage.connection_string fronm ":memory:" to "wintermute.sqlite3"
  • Loading branch information
Neverbolt committed Aug 6, 2024
1 parent 3e52a55 commit 5edb6f7
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 72 deletions.
12 changes: 6 additions & 6 deletions src/hackingBuddyGPT/cli/wintermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ def main():
parser = argparse.ArgumentParser()
subparser = parser.add_subparsers(required=True)
for name, use_case in use_cases.items():
subb = subparser.add_parser(
use_case.build_parser(subparser.add_parser(
name=use_case.name,
help=use_case.description
)
use_case.build_parser(subb)
x= sys.argv[1:]
parsed = parser.parse_args(x)
))

parsed = parser.parse_args(sys.argv[1:])
configuration = {k: v for k, v in vars(parsed).items() if k != "use_case"}
instance = parsed.use_case(parsed)
instance.init()
instance.init(configuration=configuration)
instance.run()


Expand Down
2 changes: 1 addition & 1 deletion src/hackingBuddyGPT/usecases/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def perform_round(self, turn:int) -> bool:
result, got_root = capability(cmd)

# log and output the command and its result
self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer)
self._log.add_log_query(turn, cmd, result, answer)
self._state.update(capability, cmd, result)
# TODO output/log new state
self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}"))
Expand Down
89 changes: 64 additions & 25 deletions src/hackingBuddyGPT/usecases/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import abc
import json
import argparse
import typing
from dataclasses import dataclass
from dataclasses import dataclass, field
from rich.panel import Panel
from typing import Dict, Type

from hackingBuddyGPT.utils import LLMResult
from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, transparent
from hackingBuddyGPT.utils.console.console import Console
from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage
Expand All @@ -14,8 +16,40 @@
class Logger:
log_db: DbStorage
console: Console
model: str = ""
tag: str = ""
run_id: int = 0
configuration: str = ""

run_id: int = field(init=False, default=None)

def __post_init__(self):
self.run_id = self.log_db.create_new_run(self.model, self.tag, self.configuration)

def add_log_query(self, turn: int, command: str, result: str, answer: LLMResult):
self.log_db.add_log_query(self.run_id, turn, command, result, answer)

def add_log_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: float) -> int:
return self.log_db.add_log_message(self.run_id, role, content, tokens_query, tokens_response, duration)

def add_log_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: float):
self.console.print(f"\n[bold green on gray3]{' ' * self.console.width}\nTOOL RESPONSE:[/bold green on gray3]")
self.console.print(result_text)
self.log_db.add_log_tool_call(self.run_id, message_id, tool_call_id, function_name, arguments, result_text, duration)

def add_log_analyze_response(self, turn: int, command: str, result: str, answer: LLMResult):
self.log_db.add_log_analyze_response(self.run_id, turn, command, result, answer)

def add_log_update_state(self, turn: int, command: str, result: str, answer: LLMResult):
self.log_db.add_log_update_state(self.run_id, turn, command, result, answer)

def run_was_success(self):
self.log_db.run_was_success(self.run_id)

def run_was_failure(self, reason: str):
self.log_db.run_was_failure(self.run_id, reason)

def status_message(self, message: str):
self.log_db.add_log_message(self.run_id, "status", message, 0, 0, 0)


@dataclass
Expand All @@ -34,17 +68,18 @@ class UseCase(abc.ABC):
console: Console
tag: str = ""

_run_id: int = 0
_log: Logger = None

def init(self):
def init(self, configuration):
"""
The init method is called before the run method. It is used to initialize the UseCase, and can be used to
perform any dynamic setup that is needed before the run method is called. One of the most common use cases is
setting up the llm capabilities from the tools that were injected.
"""
self._run_id = self.log_db.create_new_run(self.get_name(), self.tag)
self._log = Logger(self.log_db, self.console, self.tag, self._run_id)
self._log = Logger(self.log_db, self.console, self.get_name(), self.tag, self.serialize_configuration(configuration))

def serialize_configuration(self, configuration) -> str:
return json.dumps(configuration)

@abc.abstractmethod
def run(self):
Expand Down Expand Up @@ -85,26 +120,30 @@ def run(self):
self.before_run()

turn = 1
while turn <= self.max_turns and not self._got_root:
self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}")
try:
while turn <= self.max_turns and not self._got_root:
self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}")

self._got_root = self.perform_round(turn)
self._got_root = self.perform_round(turn)

# finish turn and commit logs to storage
self._log.log_db.commit()
turn += 1
# finish turn and commit logs to storage
self._log.log_db.commit()
turn += 1

self.after_run()
self.after_run()

# write the final result to the database and console
if self._got_root:
self._log.log_db.run_was_success(self._run_id, turn)
self._log.console.print(Panel("[bold green]Got Root!", title="Run finished"))
else:
self._log.log_db.run_was_failure(self._run_id, turn)
self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished"))
# write the final result to the database and console
if self._got_root:
self._log.run_was_success()
self._log.console.print(Panel("[bold green]Got Root!", title="Run finished"))
else:
self._log.run_was_failure("maximum turn number reached")
self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished"))

return self._got_root
return self._got_root
except Exception as e:
self._log.run_was_failure(f"exception occurred: {e}")
raise


@dataclass
Expand Down Expand Up @@ -149,17 +188,17 @@ def __class_getitem__(cls, item):
class AutonomousAgentUseCase(AutonomousUseCase):
agent: transparent(item) = None

def init(self):
super().init()
def init(self, configuration):
super().init(configuration)
self.agent._log = self._log
self.agent.init()

def get_name(self) -> str:
return self.__class__.__name__

def before_run(self):
return self.agent.before_run()

def after_run(self):
return self.agent.after_run()

Expand Down
2 changes: 1 addition & 1 deletion src/hackingBuddyGPT/usecases/minimal/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def perform_round(self, turn: int) -> bool:
result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd)

# log and output the command and its result
self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer)
self._log.add_log_query(turn, cmd, result, answer)
self._sliding_history.add_command(cmd, result)
self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}"))

Expand Down
6 changes: 3 additions & 3 deletions src/hackingBuddyGPT/usecases/privesc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def perform_round(self, turn: int) -> bool:
capability, cmd, (result, got_root) = output[0]

# log and output the command and its result
self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer)
self._log.add_log_query(turn, cmd, result, answer)
if self._sliding_history:
self._sliding_history.add_command(cmd, result)

Expand All @@ -83,15 +83,15 @@ def perform_round(self, turn: int) -> bool:
if self.enable_explanation:
with self._log.console.status("[bold green]Analyze its result..."):
answer = self.analyze_result(cmd, result)
self._log.log_db.add_log_analyze_response(self._log.run_id, turn, cmd, answer.result, answer)
self._log.add_log_analyze_response(turn, cmd, answer.result, answer)

# .. and let our local model update its state
if self.enable_update_state:
# this must happen before the table output as we might include the
# status processing time in the table..
with self._log.console.status("[bold green]Updating fact list.."):
state = self.update_state(cmd, result)
self._log.log_db.add_log_update_state(self._log.run_id, turn, "", state.result, state)
self._log.add_log_update_state(turn, "", state.result, state)

# Output Round Data..
self._log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self._log.run_id, self._log.log_db, turn))
Expand Down
16 changes: 8 additions & 8 deletions src/hackingBuddyGPT/usecases/privesc/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class LinuxPrivescUseCase(AutonomousAgentUseCase[LinuxPrivesc]):
class LinuxPrivescWithHintFileUseCase(AutonomousAgentUseCase[LinuxPrivesc]):
hints: str = None

def init(self):
super().init()
def init(self, configuration=None):
super().init(configuration)
self.agent.hint = self.read_hint()

# simple helper that reads the hints file and returns the hint
Expand Down Expand Up @@ -64,8 +64,8 @@ class LinuxPrivescWithLSEUseCase(UseCase):
# use either an use-case or an agent to perform the privesc
_use_use_case: bool = False

def init(self):
super().init()
def init(self, configuration=None):
super().init(configuration)

# simple helper that uses lse.sh to get hints from the system
def call_lse_against_host(self):
Expand All @@ -79,11 +79,11 @@ def call_lse_against_host(self):
cmd = self.llm.get_response(template_lse, lse_output=result, number=3)
self.console.print("[yellow]got the cmd: " + cmd.result)

return [x for x in cmd.result.splitlines() if x.strip()]
return [x for x in cmd.result.splitlines() if x.strip()]

def get_name(self) -> str:
return self.__class__.__name__

def run(self):
# get the hints through running LSE on the target system
hints = self.call_lse_against_host()
Expand Down Expand Up @@ -114,7 +114,7 @@ def run_using_usecases(self, hint, turns_per_hint):
)
linux_privesc.init()
return linux_privesc.run()

def run_using_agent(self, hint, turns_per_hint):
# init agent
agent = LinuxPrivesc(
Expand All @@ -138,7 +138,7 @@ def run_using_agent(self, hint, turns_per_hint):
if agent.perform_round(turn) is True:
got_root = True
turn += 1

# cleanup and finish
agent.after_run()
return got_root
2 changes: 1 addition & 1 deletion src/hackingBuddyGPT/usecases/web/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def perform_round(self, turn: int):
self._log.console.print(Panel(result, title="tool"))
self._prompt_history.append(tool_message(result, tool_call_id))

self._log.log_db.add_log_query(self._log.run_id, turn, command, result, answer)
self._log.add_log_query(turn, command, result, answer)
return self._all_flags_found


Expand Down
16 changes: 4 additions & 12 deletions src/hackingBuddyGPT/usecases/web/with_explanation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Any, Union, Dict

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage
from rich.panel import Panel

from hackingBuddyGPT.capabilities import Capability
from hackingBuddyGPT.capabilities.http_request import HTTPRequest
Expand Down Expand Up @@ -49,31 +48,24 @@ def init(self):
}

def all_flags_found(self):
self._log.console.print(Panel("All flags found! Congratulations!", title="system"))
self._log.status_message("All flags found! Congratulations!")
self._all_flags_found = True

def perform_round(self, turn: int):
prompt = self._prompt_history # TODO: in the future, this should do some context truncation

result: LLMResult = None
stream = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities)
for part in stream:
result = part

result: LLMResult = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities)
message: ChatCompletionMessage = result.result
message_id = self._log.log_db.add_log_message(self._log.run_id, message.role, message.content, result.tokens_query, result.tokens_response, result.duration)
message_id = self._log.add_log_message(message.role, message.content, result.tokens_query, result.tokens_response, result.duration)
self._prompt_history.append(result.result)

if message.tool_calls is not None:
for tool_call in message.tool_calls:
tic = time.perf_counter()
tool_call_result = self._capabilities[tool_call.function.name].to_model().model_validate_json(tool_call.function.arguments).execute()
toc = time.perf_counter()

self._log.console.print(f"\n[bold green on gray3]{' '*self._log.console.width}\nTOOL RESPONSE:[/bold green on gray3]")
self._log.console.print(tool_call_result)
self._prompt_history.append(tool_message(tool_call_result, tool_call.id))
self._log.log_db.add_log_tool_call(self._log.run_id, message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic)
self._log.add_log_tool_call(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic)

return self._all_flags_found

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ def has_no_numbers(self, path):

@use_case("Minimal implementation of a web API testing use case")
class SimpleWebAPIDocumentationUseCase(AutonomousAgentUseCase[SimpleWebAPIDocumentation]):
pass
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import field
from typing import List, Any, Union, Dict

import pydantic_core
Expand Down Expand Up @@ -137,4 +137,4 @@ def _handle_response(self, completion, response):
return self.all_http_methods_found()
@use_case("Minimal implementation of a web API testing use case")
class SimpleWebAPITestingUseCase(AutonomousAgentUseCase[SimpleWebAPITesting]):
pass
pass
19 changes: 9 additions & 10 deletions src/hackingBuddyGPT/utils/db_storage/db_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

@configurable("db_storage", "Stores the results of the experiments in a SQLite database")
class DbStorage:
def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")):
def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default="wintermute.sqlite3")):
self.connection_string = connection_string

def init(self):
Expand Down Expand Up @@ -37,7 +37,6 @@ def setup_db(self):
tag TEXT,
started_at text,
stopped_at text,
rounds INTEGER,
configuration TEXT
)""")
self.cursor.execute("""CREATE TABLE IF NOT EXISTS commands (
Expand Down Expand Up @@ -80,10 +79,10 @@ def setup_db(self):
self.analyze_response_id = self.insert_or_select_cmd('analyze_response')
self.state_update_id = self.insert_or_select_cmd('update_state')

def create_new_run(self, model, tag):
def create_new_run(self, model: str, tag: str, configuration: str) -> int:
self.cursor.execute(
"INSERT INTO runs (model, state, tag, started_at) VALUES (?, ?, ?, datetime('now'))",
(model, "in progress", tag))
"INSERT INTO runs (model, state, tag, started_at, configuration) VALUES (?, ?, ?, datetime('now'), ?)",
(model, "in progress", tag, configuration))
return self.cursor.lastrowid

def add_log_query(self, run_id, round, cmd, result, answer):
Expand Down Expand Up @@ -194,14 +193,14 @@ def get_cmd_history(self, run_id):

return result

def run_was_success(self, run_id, round):
self.cursor.execute("update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?",
def run_was_success(self, run_id):
self.cursor.execute("update runs set state=?,stopped_at=datetime('now') where id = ?",
("got root", round, run_id))
self.db.commit()

def run_was_failure(self, run_id, round):
self.cursor.execute("update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?",
("reached max runs", round, run_id))
def run_was_failure(self, run_id: int, reason: str):
self.cursor.execute("update runs set state=?, stopped_at=datetime('now') where id = ?",
(reason, run_id))
self.db.commit()

def commit(self):
Expand Down
Loading

0 comments on commit 5edb6f7

Please sign in to comment.