From 5edb6f759e9c99007ea6200c3f50309a4ddecb11 Mon Sep 17 00:00:00 2001 From: Neverbolt Date: Fri, 26 Jul 2024 13:09:48 +0200 Subject: [PATCH] Switches logging to Logger functionality from Database Also changes default value for DbStorage.connection_string fronm ":memory:" to "wintermute.sqlite3" --- src/hackingBuddyGPT/cli/wintermute.py | 12 +-- src/hackingBuddyGPT/usecases/agents.py | 2 +- src/hackingBuddyGPT/usecases/base.py | 89 +++++++++++++------ src/hackingBuddyGPT/usecases/minimal/agent.py | 2 +- .../usecases/privesc/common.py | 6 +- src/hackingBuddyGPT/usecases/privesc/linux.py | 16 ++-- src/hackingBuddyGPT/usecases/web/simple.py | 2 +- .../usecases/web/with_explanation.py | 16 +--- .../simple_openapi_documentation.py | 2 +- .../web_api_testing/simple_web_api_testing.py | 4 +- .../utils/db_storage/db_storage.py | 19 ++-- .../utils/openai/openai_lib.py | 11 ++- 12 files changed, 109 insertions(+), 72 deletions(-) diff --git a/src/hackingBuddyGPT/cli/wintermute.py b/src/hackingBuddyGPT/cli/wintermute.py index 4f6f0c1..c0badf1 100644 --- a/src/hackingBuddyGPT/cli/wintermute.py +++ b/src/hackingBuddyGPT/cli/wintermute.py @@ -8,15 +8,15 @@ def main(): parser = argparse.ArgumentParser() subparser = parser.add_subparsers(required=True) for name, use_case in use_cases.items(): - subb = subparser.add_parser( + use_case.build_parser(subparser.add_parser( name=use_case.name, help=use_case.description - ) - use_case.build_parser(subb) - x= sys.argv[1:] - parsed = parser.parse_args(x) + )) + + parsed = parser.parse_args(sys.argv[1:]) + configuration = {k: v for k, v in vars(parsed).items() if k != "use_case"} instance = parsed.use_case(parsed) - instance.init() + instance.init(configuration=configuration) instance.run() diff --git a/src/hackingBuddyGPT/usecases/agents.py b/src/hackingBuddyGPT/usecases/agents.py index a018b58..0985e64 100644 --- a/src/hackingBuddyGPT/usecases/agents.py +++ b/src/hackingBuddyGPT/usecases/agents.py @@ -93,7 +93,7 @@ def perform_round(self, turn:int) -> bool: result, got_root = capability(cmd) # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + self._log.add_log_query(turn, cmd, result, answer) self._state.update(capability, cmd, result) # TODO output/log new state self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) diff --git a/src/hackingBuddyGPT/usecases/base.py b/src/hackingBuddyGPT/usecases/base.py index 459db92..3bb992e 100644 --- a/src/hackingBuddyGPT/usecases/base.py +++ b/src/hackingBuddyGPT/usecases/base.py @@ -1,10 +1,12 @@ import abc +import json import argparse import typing -from dataclasses import dataclass +from dataclasses import dataclass, field from rich.panel import Panel from typing import Dict, Type +from hackingBuddyGPT.utils import LLMResult from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, transparent from hackingBuddyGPT.utils.console.console import Console from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage @@ -14,8 +16,40 @@ class Logger: log_db: DbStorage console: Console + model: str = "" tag: str = "" - run_id: int = 0 + configuration: str = "" + + run_id: int = field(init=False, default=None) + + def __post_init__(self): + self.run_id = self.log_db.create_new_run(self.model, self.tag, self.configuration) + + def add_log_query(self, turn: int, command: str, result: str, answer: LLMResult): + self.log_db.add_log_query(self.run_id, turn, command, result, answer) + + def add_log_message(self, role: str, content: str, tokens_query: int, tokens_response: int, duration: float) -> int: + return self.log_db.add_log_message(self.run_id, role, content, tokens_query, tokens_response, duration) + + def add_log_tool_call(self, message_id: int, tool_call_id: str, function_name: str, arguments: str, result_text: str, duration: float): + self.console.print(f"\n[bold green on gray3]{' ' * self.console.width}\nTOOL RESPONSE:[/bold green on gray3]") + self.console.print(result_text) + self.log_db.add_log_tool_call(self.run_id, message_id, tool_call_id, function_name, arguments, result_text, duration) + + def add_log_analyze_response(self, turn: int, command: str, result: str, answer: LLMResult): + self.log_db.add_log_analyze_response(self.run_id, turn, command, result, answer) + + def add_log_update_state(self, turn: int, command: str, result: str, answer: LLMResult): + self.log_db.add_log_update_state(self.run_id, turn, command, result, answer) + + def run_was_success(self): + self.log_db.run_was_success(self.run_id) + + def run_was_failure(self, reason: str): + self.log_db.run_was_failure(self.run_id, reason) + + def status_message(self, message: str): + self.log_db.add_log_message(self.run_id, "status", message, 0, 0, 0) @dataclass @@ -34,17 +68,18 @@ class UseCase(abc.ABC): console: Console tag: str = "" - _run_id: int = 0 _log: Logger = None - def init(self): + def init(self, configuration): """ The init method is called before the run method. It is used to initialize the UseCase, and can be used to perform any dynamic setup that is needed before the run method is called. One of the most common use cases is setting up the llm capabilities from the tools that were injected. """ - self._run_id = self.log_db.create_new_run(self.get_name(), self.tag) - self._log = Logger(self.log_db, self.console, self.tag, self._run_id) + self._log = Logger(self.log_db, self.console, self.get_name(), self.tag, self.serialize_configuration(configuration)) + + def serialize_configuration(self, configuration) -> str: + return json.dumps(configuration) @abc.abstractmethod def run(self): @@ -85,26 +120,30 @@ def run(self): self.before_run() turn = 1 - while turn <= self.max_turns and not self._got_root: - self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") + try: + while turn <= self.max_turns and not self._got_root: + self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") - self._got_root = self.perform_round(turn) + self._got_root = self.perform_round(turn) - # finish turn and commit logs to storage - self._log.log_db.commit() - turn += 1 + # finish turn and commit logs to storage + self._log.log_db.commit() + turn += 1 - self.after_run() + self.after_run() - # write the final result to the database and console - if self._got_root: - self._log.log_db.run_was_success(self._run_id, turn) - self._log.console.print(Panel("[bold green]Got Root!", title="Run finished")) - else: - self._log.log_db.run_was_failure(self._run_id, turn) - self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished")) + # write the final result to the database and console + if self._got_root: + self._log.run_was_success() + self._log.console.print(Panel("[bold green]Got Root!", title="Run finished")) + else: + self._log.run_was_failure("maximum turn number reached") + self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished")) - return self._got_root + return self._got_root + except Exception as e: + self._log.run_was_failure(f"exception occurred: {e}") + raise @dataclass @@ -149,17 +188,17 @@ def __class_getitem__(cls, item): class AutonomousAgentUseCase(AutonomousUseCase): agent: transparent(item) = None - def init(self): - super().init() + def init(self, configuration): + super().init(configuration) self.agent._log = self._log self.agent.init() def get_name(self) -> str: return self.__class__.__name__ - + def before_run(self): return self.agent.before_run() - + def after_run(self): return self.agent.after_run() diff --git a/src/hackingBuddyGPT/usecases/minimal/agent.py b/src/hackingBuddyGPT/usecases/minimal/agent.py index e7e6442..0b2e1f0 100644 --- a/src/hackingBuddyGPT/usecases/minimal/agent.py +++ b/src/hackingBuddyGPT/usecases/minimal/agent.py @@ -40,7 +40,7 @@ def perform_round(self, turn: int) -> bool: result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + self._log.add_log_query(turn, cmd, result, answer) self._sliding_history.add_command(cmd, result) self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) diff --git a/src/hackingBuddyGPT/usecases/privesc/common.py b/src/hackingBuddyGPT/usecases/privesc/common.py index 0760082..4dd20d7 100644 --- a/src/hackingBuddyGPT/usecases/privesc/common.py +++ b/src/hackingBuddyGPT/usecases/privesc/common.py @@ -73,7 +73,7 @@ def perform_round(self, turn: int) -> bool: capability, cmd, (result, got_root) = output[0] # log and output the command and its result - self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) + self._log.add_log_query(turn, cmd, result, answer) if self._sliding_history: self._sliding_history.add_command(cmd, result) @@ -83,7 +83,7 @@ def perform_round(self, turn: int) -> bool: if self.enable_explanation: with self._log.console.status("[bold green]Analyze its result..."): answer = self.analyze_result(cmd, result) - self._log.log_db.add_log_analyze_response(self._log.run_id, turn, cmd, answer.result, answer) + self._log.add_log_analyze_response(turn, cmd, answer.result, answer) # .. and let our local model update its state if self.enable_update_state: @@ -91,7 +91,7 @@ def perform_round(self, turn: int) -> bool: # status processing time in the table.. with self._log.console.status("[bold green]Updating fact list.."): state = self.update_state(cmd, result) - self._log.log_db.add_log_update_state(self._log.run_id, turn, "", state.result, state) + self._log.add_log_update_state(turn, "", state.result, state) # Output Round Data.. self._log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self._log.run_id, self._log.log_db, turn)) diff --git a/src/hackingBuddyGPT/usecases/privesc/linux.py b/src/hackingBuddyGPT/usecases/privesc/linux.py index 7d2a7c4..f3639ce 100644 --- a/src/hackingBuddyGPT/usecases/privesc/linux.py +++ b/src/hackingBuddyGPT/usecases/privesc/linux.py @@ -31,8 +31,8 @@ class LinuxPrivescUseCase(AutonomousAgentUseCase[LinuxPrivesc]): class LinuxPrivescWithHintFileUseCase(AutonomousAgentUseCase[LinuxPrivesc]): hints: str = None - def init(self): - super().init() + def init(self, configuration=None): + super().init(configuration) self.agent.hint = self.read_hint() # simple helper that reads the hints file and returns the hint @@ -64,8 +64,8 @@ class LinuxPrivescWithLSEUseCase(UseCase): # use either an use-case or an agent to perform the privesc _use_use_case: bool = False - def init(self): - super().init() + def init(self, configuration=None): + super().init(configuration) # simple helper that uses lse.sh to get hints from the system def call_lse_against_host(self): @@ -79,11 +79,11 @@ def call_lse_against_host(self): cmd = self.llm.get_response(template_lse, lse_output=result, number=3) self.console.print("[yellow]got the cmd: " + cmd.result) - return [x for x in cmd.result.splitlines() if x.strip()] + return [x for x in cmd.result.splitlines() if x.strip()] def get_name(self) -> str: return self.__class__.__name__ - + def run(self): # get the hints through running LSE on the target system hints = self.call_lse_against_host() @@ -114,7 +114,7 @@ def run_using_usecases(self, hint, turns_per_hint): ) linux_privesc.init() return linux_privesc.run() - + def run_using_agent(self, hint, turns_per_hint): # init agent agent = LinuxPrivesc( @@ -138,7 +138,7 @@ def run_using_agent(self, hint, turns_per_hint): if agent.perform_round(turn) is True: got_root = True turn += 1 - + # cleanup and finish agent.after_run() return got_root diff --git a/src/hackingBuddyGPT/usecases/web/simple.py b/src/hackingBuddyGPT/usecases/web/simple.py index 22152b5..16f0bd4 100644 --- a/src/hackingBuddyGPT/usecases/web/simple.py +++ b/src/hackingBuddyGPT/usecases/web/simple.py @@ -76,7 +76,7 @@ def perform_round(self, turn: int): self._log.console.print(Panel(result, title="tool")) self._prompt_history.append(tool_message(result, tool_call_id)) - self._log.log_db.add_log_query(self._log.run_id, turn, command, result, answer) + self._log.add_log_query(turn, command, result, answer) return self._all_flags_found diff --git a/src/hackingBuddyGPT/usecases/web/with_explanation.py b/src/hackingBuddyGPT/usecases/web/with_explanation.py index 96dd657..f2eba71 100644 --- a/src/hackingBuddyGPT/usecases/web/with_explanation.py +++ b/src/hackingBuddyGPT/usecases/web/with_explanation.py @@ -3,7 +3,6 @@ from typing import List, Any, Union, Dict from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage -from rich.panel import Panel from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest @@ -49,19 +48,15 @@ def init(self): } def all_flags_found(self): - self._log.console.print(Panel("All flags found! Congratulations!", title="system")) + self._log.status_message("All flags found! Congratulations!") self._all_flags_found = True def perform_round(self, turn: int): prompt = self._prompt_history # TODO: in the future, this should do some context truncation - result: LLMResult = None - stream = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities) - for part in stream: - result = part - + result: LLMResult = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities) message: ChatCompletionMessage = result.result - message_id = self._log.log_db.add_log_message(self._log.run_id, message.role, message.content, result.tokens_query, result.tokens_response, result.duration) + message_id = self._log.add_log_message(message.role, message.content, result.tokens_query, result.tokens_response, result.duration) self._prompt_history.append(result.result) if message.tool_calls is not None: @@ -69,11 +64,8 @@ def perform_round(self, turn: int): tic = time.perf_counter() tool_call_result = self._capabilities[tool_call.function.name].to_model().model_validate_json(tool_call.function.arguments).execute() toc = time.perf_counter() - - self._log.console.print(f"\n[bold green on gray3]{' '*self._log.console.width}\nTOOL RESPONSE:[/bold green on gray3]") - self._log.console.print(tool_call_result) self._prompt_history.append(tool_message(tool_call_result, tool_call.id)) - self._log.log_db.add_log_tool_call(self._log.run_id, message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic) + self._log.add_log_tool_call(message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic) return self._all_flags_found diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py index 84589cf..153d3e6 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py @@ -131,4 +131,4 @@ def has_no_numbers(self, path): @use_case("Minimal implementation of a web API testing use case") class SimpleWebAPIDocumentationUseCase(AutonomousAgentUseCase[SimpleWebAPIDocumentation]): - pass \ No newline at end of file + pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py index 3f8e1dd..20b3af6 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import field from typing import List, Any, Union, Dict import pydantic_core @@ -137,4 +137,4 @@ def _handle_response(self, completion, response): return self.all_http_methods_found() @use_case("Minimal implementation of a web API testing use case") class SimpleWebAPITestingUseCase(AutonomousAgentUseCase[SimpleWebAPITesting]): - pass \ No newline at end of file + pass diff --git a/src/hackingBuddyGPT/utils/db_storage/db_storage.py b/src/hackingBuddyGPT/utils/db_storage/db_storage.py index 497c023..7c30ef1 100644 --- a/src/hackingBuddyGPT/utils/db_storage/db_storage.py +++ b/src/hackingBuddyGPT/utils/db_storage/db_storage.py @@ -5,7 +5,7 @@ @configurable("db_storage", "Stores the results of the experiments in a SQLite database") class DbStorage: - def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default=":memory:")): + def __init__(self, connection_string: str = parameter(desc="sqlite3 database connection string for logs", default="wintermute.sqlite3")): self.connection_string = connection_string def init(self): @@ -37,7 +37,6 @@ def setup_db(self): tag TEXT, started_at text, stopped_at text, - rounds INTEGER, configuration TEXT )""") self.cursor.execute("""CREATE TABLE IF NOT EXISTS commands ( @@ -80,10 +79,10 @@ def setup_db(self): self.analyze_response_id = self.insert_or_select_cmd('analyze_response') self.state_update_id = self.insert_or_select_cmd('update_state') - def create_new_run(self, model, tag): + def create_new_run(self, model: str, tag: str, configuration: str) -> int: self.cursor.execute( - "INSERT INTO runs (model, state, tag, started_at) VALUES (?, ?, ?, datetime('now'))", - (model, "in progress", tag)) + "INSERT INTO runs (model, state, tag, started_at, configuration) VALUES (?, ?, ?, datetime('now'), ?)", + (model, "in progress", tag, configuration)) return self.cursor.lastrowid def add_log_query(self, run_id, round, cmd, result, answer): @@ -194,14 +193,14 @@ def get_cmd_history(self, run_id): return result - def run_was_success(self, run_id, round): - self.cursor.execute("update runs set state=?,stopped_at=datetime('now'), rounds=? where id = ?", + def run_was_success(self, run_id): + self.cursor.execute("update runs set state=?,stopped_at=datetime('now') where id = ?", ("got root", round, run_id)) self.db.commit() - def run_was_failure(self, run_id, round): - self.cursor.execute("update runs set state=?, stopped_at=datetime('now'), rounds=? where id = ?", - ("reached max runs", round, run_id)) + def run_was_failure(self, run_id: int, reason: str): + self.cursor.execute("update runs set state=?, stopped_at=datetime('now') where id = ?", + (reason, run_id)) self.db.commit() def commit(self): diff --git a/src/hackingBuddyGPT/utils/openai/openai_lib.py b/src/hackingBuddyGPT/utils/openai/openai_lib.py index 3e6f8da..4ab8ed1 100644 --- a/src/hackingBuddyGPT/utils/openai/openai_lib.py +++ b/src/hackingBuddyGPT/utils/openai/openai_lib.py @@ -75,7 +75,15 @@ def get_response(self, prompt, *, capabilities: Dict[str, Capability]=None, **kw response.usage.completion_tokens, ) - def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None) -> Iterable[Union[ChatCompletionChunk, LLMResult]]: + def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None, get_individual_updates=False) -> Union[LLMResult, Iterable[Union[ChatCompletionChunk, LLMResult]]]: + generator = self._stream_response(prompt, console, capabilities) + + if get_individual_updates: + return generator + + return list(generator)[-1] + + def _stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: Console, capabilities: Dict[str, Capability] = None) -> Iterable[Union[ChatCompletionChunk, LLMResult]]: tools = None if capabilities: tools = capabilities_to_tools(capabilities) @@ -149,7 +157,6 @@ def stream_response(self, prompt: Iterable[ChatCompletionMessageParam], console: usage.prompt_tokens, usage.completion_tokens, ) - pass def encode(self, query) -> list[int]: return tiktoken.encoding_for_model(self.model).encode(query)