Skip to content

Commit

Permalink
Improve whodunnit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633613807
Change-Id: I9b3223ce963f9ce2e647bc93ae9a985023a51967
  • Loading branch information
jzleibo authored and copybara-github committed May 14, 2024
1 parent b5a0be1 commit 79c2043
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def update(self) -> None:
),
max_characters=2000,
max_tokens=1000,
terminators=(),
)
what_effect_it_had = what_they_did_chain_of_thought.open_question(
question=(
Expand All @@ -143,6 +144,7 @@ def update(self) -> None:
),
max_characters=2000,
max_tokens=1000,
terminators=(),
)
# Now consider how to justify the voluntary actions for all audiences.
justification_chain_of_thought = interactive_document.InteractiveDocument(
Expand All @@ -169,6 +171,7 @@ def update(self) -> None:
),
max_characters=3000,
max_tokens=2000,
terminators=(),
)
most_salient_justification = justification_chain_of_thought.open_question(
question=(
Expand All @@ -181,6 +184,7 @@ def update(self) -> None:
answer_prefix=f'{self._agent_name}',
max_characters=2000,
max_tokens=1000,
terminators=(),
)
salient_justification = (
f'[thought] {self._agent_name} {most_salient_justification}')
Expand Down
2 changes: 1 addition & 1 deletion concordia/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def plot_df_pie(df: pd.DataFrame,
group_by: Group data by this field, plot each one in its own figure.
value: The name of the value to aggregate for the pie chart regions.
"""
cmap = mpl.colormaps['Paired']
cmap = mpl.colormaps['Paired'] # pylint: disable=unsubscriptable-object
colours = cmap(range(len(scale)))
scale_to_colour = dict(zip(scale, colours))

Expand Down
32 changes: 24 additions & 8 deletions examples/whodunnit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,13 @@
" ),\n",
" name='role playing instructions\\n')\n",
"\n",
" if agent_config.extras.get('murderer', False):\n",
" fact = generic_components.constant.ConstantComponent(\n",
" state=f'{agent_name} murdered {VICTIM}.', name='fact')\n",
" else:\n",
" fact = generic_components.constant.ConstantComponent(\n",
" state=f'{agent_name} did not kill {VICTIM}.', name='fact')\n",
"\n",
" time = generic_components.report_function.ReportFunction(\n",
" name='Current time',\n",
" function=clock.current_time_interval_str,\n",
Expand Down Expand Up @@ -434,6 +441,7 @@
" agent_config.name,\n",
" clock_now=clock.now,\n",
" components=[instructions,\n",
" fact,\n",
" initial_goal_component,\n",
" relevant_memories,\n",
" persona,\n",
Expand Down Expand Up @@ -470,6 +478,7 @@
" clock=clock,\n",
" verbose=True,\n",
" components=[instructions,\n",
" fact,\n",
" persona,\n",
" justification,\n",
" reflection,\n",
Expand All @@ -494,7 +503,7 @@
" question='What is {opining_player}\\'s opinion of {of_player}?',\n",
" )\n",
" agent.add_component(reputation_metric)\n",
" return agent\n"
" return agent"
]
},
{
Expand Down Expand Up @@ -533,6 +542,7 @@
" knowledge_of_scandal.get(WHO_FOUND_BODY, '')),\n",
" traits = make_random_big_five(),\n",
" formative_ages = sorted(random.sample(range(5, 40), 5)),\n",
" extras={'murderer': False},\n",
" ),\n",
" formative_memories.AgentConfig(\n",
" name=MURDERER,\n",
Expand All @@ -542,6 +552,7 @@
" knowledge_of_scandal.get(MURDERER, '')),\n",
" traits = make_random_big_five(),\n",
" formative_ages = sorted(random.sample(range(5, 40), 5)),\n",
" extras={'murderer': True},\n",
" ),\n",
" formative_memories.AgentConfig(\n",
" name='Donald',\n",
Expand All @@ -552,6 +563,7 @@
" knowledge_of_scandal.get('Donald', '')),\n",
" traits = make_random_big_five(),\n",
" formative_ages = sorted(random.sample(range(5, 40), 5)),\n",
" extras={'murderer': False},\n",
" ),\n",
" formative_memories.AgentConfig(\n",
" name='Ellen',\n",
Expand All @@ -562,6 +574,7 @@
" knowledge_of_scandal.get('Ellen', '')),\n",
" traits = make_random_big_five(),\n",
" formative_ages = sorted(random.sample(range(5, 40), 5)),\n",
" extras={'murderer': False},\n",
" ),\n",
"]"
]
Expand Down Expand Up @@ -682,12 +695,13 @@
"source": [
"# @title Create the game master's thought chain\n",
"account_for_agency_of_others = thought_chains_lib.AccountForAgencyOfOthers(\n",
" model=model, players=players, verbose=True)\n",
"\n",
" model=model, players=players, verbose=False)\n",
"thought_chain = [\n",
" thought_chains_lib.extract_direct_quote,\n",
" thought_chains_lib.attempt_to_most_likely_outcome,\n",
" thought_chains_lib.result_to_effect_caused_by_active_player,\n",
" account_for_agency_of_others\n",
" account_for_agency_of_others,\n",
" thought_chains_lib.restore_direct_quote,\n",
"]"
]
},
Expand Down Expand Up @@ -828,10 +842,12 @@
"fig, ax = plt.subplots(1, len(available_channels), figsize=(6, 2))\n",
"tb = [channel for channel in available_channels]\n",
"for idx, channel in enumerate(available_channels):\n",
" plotting.plot_line_measurement_channel(measurements, channel,\n",
" group_by=group_by[channel],\n",
" xaxis='time_str',\n",
" ax=ax[idx])\n",
" plotting.plot_line_measurement_channel(\n",
" measurements,\n",
" channel,\n",
" group_by=group_by[channel],\n",
" xaxis='time_str',\n",
" ax=ax[idx])\n",
" ax[idx].set_title(channel)\n",
"\n",
"fig.set_constrained_layout(constrained=True)"
Expand Down

0 comments on commit 79c2043

Please sign in to comment.