Skip to content

Commit

Permalink
feat(autofix): Support iterative user feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
jennmueng committed Jun 7, 2024
1 parent 73ad8af commit cd3887b
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 51 deletions.
6 changes: 5 additions & 1 deletion src/seer/automation/autofix/components/planner/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def format_default_msg(
</issue>
You have to break the below task into steps:
<task>
{task_str}
</task>
Think step-by-step inside the <thoughts> tag then output a concise and simple list of steps to perform in the output format provided in the system message."""
).format(
Expand Down Expand Up @@ -81,8 +83,10 @@ def format_instruction_msg(
The following changes have been made to the codebase to fix the issue:
{changes_str}
You are given the following instruction and you have to break it into steps:
You are given the following instruction in relationship to the above changes and you have to break it into steps:
<instruction>
{instruction}
</instruction>
Think step-by-step inside the <thoughts> tag then output a concise and simple list of steps to perform in the output format provided in the system message."""
).format(
Expand Down
38 changes: 18 additions & 20 deletions src/seer/automation/autofix/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,20 @@ class AutofixEventManager:
@property
def root_cause_analysis_processing_step(self) -> DefaultStep:
return DefaultStep(
id="root_cause_analysis_processing",
key="root_cause_analysis_processing",
title="Analyze Issue",
)

@property
def root_cause_analysis_step(self) -> RootCauseStep:
return RootCauseStep(
id="root_cause_analysis",
key="root_cause_analysis",
title="Root Cause Analysis",
)

@property
def indexing_step(self) -> DefaultStep:
return DefaultStep(
id="codebase_indexing",
key="codebase_indexing",
title="Codebase Indexing",
)
Expand All @@ -69,6 +66,7 @@ def user_response_step(self) -> UserResponseStep:
return UserResponseStep(
title="User",
text="",
user_id=-1,
key="user_response",
)

Expand All @@ -91,7 +89,7 @@ def send_root_cause_analysis_result(self, root_cause_output: RootCauseAnalysisOu
with self.state.update() as cur:
root_cause_processing_step = cur.find_or_add(self.root_cause_analysis_processing_step)
root_cause_processing_step.status = AutofixStatus.COMPLETED
root_cause_step = cur.find_or_add(self.root_cause_analysis_step)
root_cause_step = cur.add_step(self.root_cause_analysis_step)
if root_cause_output and root_cause_output.causes:
root_cause_step.status = AutofixStatus.COMPLETED
root_cause_step.causes = root_cause_output.causes
Expand All @@ -103,14 +101,14 @@ def send_root_cause_analysis_result(self, root_cause_output: RootCauseAnalysisOu

def send_codebase_indexing_start(self):
with self.state.update() as cur:
indexing_step = cur.find_or_add(self.indexing_step)
indexing_step = cur.add_step(self.indexing_step.model_copy())
indexing_step.status = AutofixStatus.PROCESSING

cur.status = AutofixStatus.PROCESSING

def send_codebase_indexing_complete_if_exists(self):
with self.state.update() as cur:
indexing_step = cur.find_step(id=self.indexing_step.id)
indexing_step = cur.find_step(key=self.indexing_step.key)

if indexing_step:
indexing_step.status = AutofixStatus.COMPLETED
Expand All @@ -122,23 +120,19 @@ def set_selected_root_cause(self, selection: RootCauseSelection):

cur.status = AutofixStatus.PROCESSING

def send_planning_pending(self):
def send_planning_start(self, is_update: bool = False):
with self.state.update() as cur:
root_cause_step = cur.find_or_add(self.plan_step)
root_cause_step.status = AutofixStatus.PENDING

cur.status = AutofixStatus.PROCESSING

def send_planning_start(self):
with self.state.update() as cur:
plan_step = cur.find_or_add(self.plan_step, method="key")
plan_step = cur.last_or_add(self.plan_step)
plan_step.status = AutofixStatus.PROCESSING

if is_update:
plan_step.title = "Update Fix"

cur.status = AutofixStatus.PROCESSING

def send_planning_result(self, result: PlanningOutput | None):
with self.state.update() as cur:
plan_step = cur.find_or_add(self.plan_step, method="key")
plan_step = cur.find_or_add(self.plan_step)
plan_step.status = AutofixStatus.PROCESSING if result else AutofixStatus.ERROR

if result:
Expand All @@ -157,7 +151,7 @@ def send_planning_result(self, result: PlanningOutput | None):

def send_execution_step_start(self, execution_id: int):
with self.state.update() as cur:
plan_step = cur.find_or_add(self.plan_step, method="key")
plan_step = cur.find_or_add(self.plan_step)
execution_step = plan_step.find_child(id=str(execution_id))
if execution_step:
execution_step.status = AutofixStatus.PROCESSING
Expand All @@ -167,7 +161,7 @@ def send_execution_step_result(
self, execution_id: int, status: Literal[AutofixStatus.COMPLETED, AutofixStatus.ERROR]
):
with self.state.update() as cur:
plan_step = cur.find_or_add(self.plan_step, method="key")
plan_step = cur.find_or_add(self.plan_step)
execution_step = plan_step.find_child(id=str(execution_id))
if execution_step:
execution_step.status = status
Expand All @@ -182,7 +176,7 @@ def send_execution_complete(self, codebase_changes: list[CodebaseChange]):
with self.state.update() as cur:
cur.mark_all_steps_completed()

changes_step = cur.find_or_add(self.changes_step, method="key")
changes_step = cur.add_step(self.changes_step)
changes_step.status = AutofixStatus.COMPLETED
changes_step.changes = codebase_changes

Expand All @@ -200,11 +194,15 @@ def send_pr_creation_complete(self):
changes_step.status = AutofixStatus.COMPLETED
cur.status = AutofixStatus.COMPLETED

def send_user_response_step(self, text: str):
def send_user_response_step(self, user_id: int, text: str):
with self.state.update() as cur:
step = cur.add_step(self.user_response_step)
step.user_id = user_id
step.text = text
step.status = AutofixStatus.COMPLETED

cur.actor_ids = list(set(cur.actor_ids + [user_id]))

cur.status = AutofixStatus.PROCESSING

def add_log(self, message: str):
Expand Down
52 changes: 27 additions & 25 deletions src/seer/automation/autofix/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import enum
import hashlib
import os
from typing import Annotated, Any, Literal, Optional, Union

from johen import gen
Expand Down Expand Up @@ -108,8 +109,10 @@ class StepType(str, enum.Enum):


class BaseStep(BaseModel):
id: str = Field(default_factory=lambda: hashlib.sha1().hexdigest())
key: str
id: str = Field(
default_factory=lambda: hashlib.sha1(os.urandom(16)).hexdigest()
) # Unique identifier for this step
key: str # Identifier for a type of step
title: str
type: StepType = StepType.DEFAULT

Expand Down Expand Up @@ -160,6 +163,7 @@ class UserResponseStep(BaseStep):
type: Literal[StepType.USER_RESPONSE] = StepType.USER_RESPONSE

text: str
user_id: int


Step = Union[DefaultStep, RootCauseStep, ChangesStep, UserResponseStep]
Expand All @@ -185,6 +189,7 @@ class AutofixGroupState(BaseModel):
] = None
completed_at: datetime.datetime | None = None
signals: list[str] = Field(default_factory=list)
actor_ids: list[int] = Field(default_factory=list)


class AutofixStateRequest(BaseModel):
Expand Down Expand Up @@ -273,6 +278,7 @@ class AutofixInstructionPayload(BaseModel):

class AutofixUpdateRequest(BaseModel):
run_id: int
invoking_user: AutofixUserDetails
payload: Union[
AutofixRootCauseUpdatePayload, AutofixCreatePrUpdatePayload, AutofixInstructionPayload
] = Field(discriminator="type")
Expand All @@ -281,37 +287,33 @@ class AutofixUpdateRequest(BaseModel):
class AutofixContinuation(AutofixGroupState):
request: AutofixRequest

def find_step(self, *, id: str) -> Step | None:
for step in self.steps:
if step.id == id:
def find_step(self, *, key: str | None = None, id: str | None = None) -> Step | None:
for step in reversed(self.steps):
if step.key == key or step.id == id:
return step
return None

def add_step(self, step: Step):
step.index = len(self.steps)
self.steps.append(step)

return step
def add_step(self, base_step: Step):
base_step = base_step.model_copy()
base_step.index = len(self.steps)
self.steps.append(base_step)

def find_last_with_key(self, *, key: str) -> Step | None:
for step in reversed(self.steps):
if step.key == key:
return step
return None
return base_step

def find_or_add(self, base_step: Step, method: Literal["id", "key"] = "id") -> Step:
existing = None
if method == "id":
existing = self.find_step(id=base_step.id)
elif method == "key":
existing = self.find_last_with_key(key=base_step.key)
def find_or_add(self, base_step: Step) -> Step:
existing = self.find_step(key=base_step.key)

if existing:
return existing

base_step = base_step.model_copy()
self.add_step(base_step)
return base_step
return self.add_step(base_step)

def last_or_add(self, base_step: Step) -> Step:
last_step = self.steps[-1]
if last_step and (last_step.id == base_step.id or last_step.key == base_step.key):
return last_step

return self.add_step(base_step)

def make_step_latest(self, step: Step):
if step in self.steps:
Expand Down Expand Up @@ -348,7 +350,7 @@ def set_last_step_completed_message(self, message: str):
self.steps[-1].completedMessage = message

def get_selected_root_cause_and_fix(self) -> RootCauseAnalysisItem | str | None:
root_cause_step = self.find_step(id="root_cause_analysis")
root_cause_step = self.find_step(key="root_cause_analysis")
if root_cause_step and isinstance(root_cause_step, RootCauseStep):
if root_cause_step.selection:
if isinstance(root_cause_step.selection, SuggestedFixRootCauseSelection):
Expand Down
5 changes: 1 addition & 4 deletions src/seer/automation/autofix/steps/planning_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ def get_task():
def _invoke(self, **kwargs):
self.context.event_manager.send_codebase_indexing_complete_if_exists()

if self.request.instruction:
self.context.event_manager.send_update_planning_start()
else:
self.context.event_manager.send_planning_start()
self.context.event_manager.send_planning_start(is_update=bool(self.request.instruction))

if self.context.has_missing_codebase_indexes():
raise ValueError("Codebase indexes must be created before planning")
Expand Down
6 changes: 5 additions & 1 deletion src/seer/automation/autofix/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ def run_autofix_instruction(request: AutofixUpdateRequest):
cur.mark_triggered()

event_manager = AutofixEventManager(state)

context = AutofixContext(
state=state,
sentry_client=get_sentry_client(),
event_manager=event_manager,
skip_loading_codebase=True,
)

context.event_manager.send_user_response_step(request.payload.content.text)
context.event_manager.send_user_response_step(
request.invoking_user.id, request.payload.content.text
)
event_manager.send_planning_start(is_update=True)

AutofixPlanningStep.get_signature(
AutofixPlanningStepRequest(
Expand Down

0 comments on commit cd3887b

Please sign in to comment.