Skip to content

Commit

Permalink
Fix pytype errors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661197728
Change-Id: I9c4bd108442ec4a2d368093c4183b56593127012
  • Loading branch information
jagapiou authored and copybara-github committed Aug 9, 2024
1 parent 59ad795 commit 5514bad
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
15 changes: 7 additions & 8 deletions concordia/components/agent/v2/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ class ObservationSummary(action_spec_ignored.ActionSpecIgnored):

def __init__(
self,
*,
model: language_model.LanguageModel,
clock_now: Callable[[], datetime.datetime],
timeframe_delta_from: datetime.timedelta,
timeframe_delta_until: datetime.timedelta,
memory_component_name: str = (
memory_component.DEFAULT_MEMORY_COMPONENT_NAME),
components: Mapping[str, action_spec_ignored.ActionSpecIgnored] = (
types.MappingProxyType({})
memory_component.DEFAULT_MEMORY_COMPONENT_NAME
),
component_labels: Mapping[str, str] = types.MappingProxyType({}),
prompt: str | None = None,
display_timeframe: bool = True,
pre_act_key: str = DEFAULT_OBSERVATION_SUMMARY_PRE_ACT_KEY,
Expand All @@ -127,7 +127,7 @@ def __init__(
segment to summarise.
memory_component_name: Name of the memory component from which to retrieve
observations to summarize.
components: Components to summarise along with the observations.
component_labels: Mapping from component name to the label to give it.
prompt: Language prompt for summarising memories and components.
display_timeframe: Whether to display the time interval as text.
pre_act_key: Prefix to add to the output of the component when called
Expand All @@ -140,7 +140,7 @@ def __init__(
self._timeframe_delta_from = timeframe_delta_from
self._timeframe_delta_until = timeframe_delta_until
self._memory_component_name = memory_component_name
self._components = dict(components)
self._component_labels = dict(component_labels)

self._prompt = prompt or (
'Summarize the observations above into one or two sentences.'
Expand All @@ -153,8 +153,8 @@ def _make_pre_act_value(self) -> str:
agent_name = self.get_entity().name
context = '\n'.join([
f"{agent_name}'s"
f' {prefix}:\n{self.get_named_component_pre_act_value(key)}'
for key, prefix in self._components.items()
f' {label}:\n{self.get_named_component_pre_act_value(key)}'
for key, label in self._component_labels.items()
])

segment_start = self._clock_now() - self._timeframe_delta_from
Expand Down Expand Up @@ -206,4 +206,3 @@ def _make_pre_act_value(self) -> str:
})

return result

5 changes: 3 additions & 2 deletions concordia/components/game_master/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,11 @@ def _get_nonplayer_characters(

return nonplayer_characters

def _generate_convo_summary(self, convo: list[str]):
def _generate_convo_summary(self, convo: Sequence[str]):
summary = self._model.sample_text(
'\n'.join(
convo + ['Summarize the conversation above in one sentence.'],
*convo,
'Summarize the conversation above in one sentence.',
),
max_tokens=2000,
terminators=(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def build_agent(
clock_now=clock.now,
timeframe_delta_from=datetime.timedelta(hours=4),
timeframe_delta_until=datetime.timedelta(hours=0),
components={_get_class_name(somatic_state): somatic_state_label},
component_labels={_get_class_name(somatic_state): somatic_state_label},
pre_act_key=observation_summary_label,
logging_channel=measurements.get_channel('ObservationSummary').on_next,
)
Expand Down

0 comments on commit 5514bad

Please sign in to comment.