Skip to content

Commit

Permalink
Improve importance model.
Browse files Browse the repository at this point in the history
The importance model now can be contextualized by retrieving other memories to provide scale for the decision of how to assign importance to each new memory.

This change also removes a parameter on the formative memories generator which overrode the importance value assigned by memory.add. If that behavior is desirable, a more consistent way to achieve it would be to use a Constant importance model.

PiperOrigin-RevId: 657156830
Change-Id: Icf5425fed4c3519d80c50ba6f4b10ede7df84865
  • Loading branch information
jzleibo authored and copybara-github committed Jul 29, 2024
1 parent abb1375 commit b6b9e7e
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 62 deletions.
104 changes: 92 additions & 12 deletions concordia/associative_memory/associative_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
preprint arXiv:2304.03442.
"""

from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Sequence
import datetime
import random
import threading

from concordia.associative_memory import importance_function
import numpy as np
import pandas as pd

_NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE = 25


def _check_date_in_range(timestamp: datetime.datetime) -> None:
if timestamp < pd.Timestamp.min:
Expand All @@ -44,9 +47,11 @@ class AssociativeMemory:
def __init__(
self,
sentence_embedder: Callable[[str], np.ndarray],
importance: Callable[[str], float] | None = None,
importance: Callable[[str, Sequence[tuple[str, float]]],
float] | None = None,
clock: Callable[[], datetime.datetime] = datetime.datetime.now,
clock_step_size: datetime.timedelta | None = None,
seed: int | None = None,
):
"""Constructor.
Expand All @@ -57,9 +62,17 @@ def __init__(
clock: a callable to get time when adding memories
clock_step_size: sets the step size of the clock. If None, assumes precise
time
seed: the seed to use for the random number generator if None then use the
current time
"""
self._memory_bank_lock = threading.Lock()
if seed is None:
self._seed = random.seed(int(datetime.datetime.now().timestamp()))
else:
self._seed = seed
self._embedder = sentence_embedder
self._num_to_retrieve_to_contextualize_importance = (
_NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE)
self._importance = (
importance or importance_function.ConstantImportanceModel().importance)

Expand All @@ -77,7 +90,7 @@ def add(
timestamp: datetime.datetime | None = None,
tags: Iterable[str] = (),
importance: float | None = None,
):
) -> None:
"""Adds nonduplicated entries (time, text, tags, importance) to the memory.
Args:
Expand All @@ -87,7 +100,13 @@ def add(
importance: optionally set the importance of the memory.
"""
if importance is None:
importance = self._importance(text)
with self._memory_bank_lock:
memory_size = len(self._memory_bank)
num_to_retrieve = self._num_to_retrieve_to_contextualize_importance
if memory_size < num_to_retrieve:
num_to_retrieve = memory_size
context = self.retrieve_random_with_importance(k=num_to_retrieve)
importance = self._importance(text, context)

if timestamp is None:
timestamp = self._clock_now()
Expand Down Expand Up @@ -119,7 +138,7 @@ def extend(
self,
texts: Iterable[str],
**kwargs,
):
) -> None:
"""Adds the texts to the memory.
Args:
Expand All @@ -129,7 +148,7 @@ def extend(
for text in texts:
self.add(text, **kwargs)

def get_data_frame(self):
def get_data_frame(self) -> pd.DataFrame:
with self._memory_bank_lock:
return self._memory_bank.copy()

Expand Down Expand Up @@ -202,7 +221,7 @@ def _pd_to_text(
data: pd.DataFrame,
add_time: bool = False,
sort_by_time: bool = True,
):
) -> Sequence[str]:
"""Formats a dataframe into list of strings.
Args:
Expand Down Expand Up @@ -240,7 +259,7 @@ def retrieve_associative(
use_importance: bool = True,
add_time: bool = True,
sort_by_time: bool = True,
):
) -> Sequence[str]:
"""Retrieve memories associatively.
Args:
Expand Down Expand Up @@ -270,7 +289,7 @@ def retrieve_by_regex(
regex: str,
add_time: bool = True,
sort_by_time: bool = True,
):
) -> Sequence[str]:
"""Retrieve memories matching a regex.
Args:
Expand All @@ -291,7 +310,7 @@ def retrieve_time_interval(
time_from: datetime.datetime,
time_until: datetime.datetime,
add_time: bool = False,
):
) -> Sequence[str]:
"""Retrieve memories within a time interval.
Args:
Expand All @@ -315,7 +334,7 @@ def retrieve_recent(
self,
k: int = 1,
add_time: bool = False,
):
) -> Sequence[str]:
"""Retrieve memories by recency.
Args:
Expand All @@ -333,7 +352,7 @@ def retrieve_recent_with_importance(
self,
k: int = 1,
add_time: bool = False,
):
) -> tuple[Sequence[str], Sequence[float]]:
"""Retrieve memories by recency and return importance alongside.
Args:
Expand All @@ -350,6 +369,40 @@ def retrieve_recent_with_importance(
list(data['importance']),
)

def retrieve_random(
self,
k: int = 1,
add_time: bool = False,
) -> Sequence[str]:
"""Retrieve random memories.
Args:
k: number of entries to retrieve
add_time: whether to add time stamp to the output
Returns:
List of strings corresponding to memories
"""
with self._memory_bank_lock:
data = self._memory_bank.sample(k, random_state=self._seed)
return self._pd_to_text(data, add_time=add_time, sort_by_time=True)

def retrieve_random_with_importance(
self,
k: int = 1,
) -> Sequence[tuple[str, float]]:
"""Retrieve random memories and return importance alongside.
Args:
k: number of entries to retrieve
Returns:
List of tuples of (memory, importance)
"""
with self._memory_bank_lock:
data = self._memory_bank.sample(k, random_state=self._seed)
return tuple(zip(list(data['text']), list(data['importance'])))

def __len__(self):
"""Returns the number of entries in the memory bank.
Expand All @@ -358,3 +411,30 @@ def __len__(self):
"""
with self._memory_bank_lock:
return len(self._memory_bank)

def get_mean_importance(self) -> float:
"""Returns the mean importance of the memories in the memory bank."""
with self._memory_bank_lock:
return self._memory_bank['importance'].mean()

def get_max_importance(self) -> float:
"""Returns the max importance of the memories in the memory bank."""
with self._memory_bank_lock:
return self._memory_bank['importance'].max()

def get_min_importance(self) -> float:
"""Returns the min importance of the memories in the memory bank."""
with self._memory_bank_lock:
return self._memory_bank['importance'].min()

def set_num_to_retrieve_to_contextualize_importance(
self, num_to_retrieve: int) -> None:
"""Sets the number of memories to retrieve for contextualizing importance.
Set this to 0 if you want to disable contextualization of importance.
Args:
num_to_retrieve: the number of memories to retrieve for contextualizing
importance.
"""
self._num_to_retrieve_to_contextualize_importance = num_to_retrieve
10 changes: 2 additions & 8 deletions concordia/associative_memory/formative_memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import re
from typing import Any
from concordia.associative_memory import associative_memory
from concordia.associative_memory import importance_function
from concordia.document import interactive_document
from concordia.language_model import language_model
from dateutil.relativedelta import relativedelta # pylint: disable=g-importing-member
Expand All @@ -31,7 +30,6 @@

DEFAULT_DOB = datetime.datetime(year=1984, month=7, day=3, hour=0, minute=0)
DEFAULT_FORMATIVE_AGES = (6, 9, 13, 16, 19, 21, 23)
DEFAULT_IMPORTANT_MODEL = importance_function.ConstantImportanceModel()


@dataclasses.dataclass(frozen=True, kw_only=True)
Expand All @@ -49,7 +47,6 @@ class AgentConfig:
goal: defines agents goal. Can be left blank if not used.
date_of_birth: the date of birth for the agent.
formative_ages: ages at which the formative episodes will be created
formative_memory_importance: the importance value of formative memories.
extras: a field for the user to keep any experiment specific data they need
to define an agent
"""
Expand All @@ -62,7 +59,6 @@ class AgentConfig:
goal: str = ''
date_of_birth: datetime.datetime = DEFAULT_DOB
formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES
formative_memory_importance: float = 1.0
extras: dict[str, Any] = dataclasses.field(default_factory=dict)


Expand Down Expand Up @@ -229,7 +225,6 @@ def add_memories(
tags=['episode'],
timestamp=(
agent_config.date_of_birth + relativedelta(years=episode_age)),
importance=agent_config.formative_memory_importance,
)

if self._current_date:
Expand All @@ -238,7 +233,6 @@ def add_memories(
f'{agent_config.name} is {age} years old.',
tags=['info'],
timestamp=self._current_date,
importance=agent_config.formative_memory_importance,
)

def make_memories(
Expand All @@ -262,12 +256,12 @@ def make_memories(
context_items = context.split('\n')
for item in context_items:
if item:
mem.add(item, importance=agent_config.formative_memory_importance)
mem.add(item)

if agent_config.specific_memories:
specific_memories = agent_config.specific_memories.split('\n')
for item in specific_memories:
if item:
mem.add(item, importance=agent_config.formative_memory_importance)
mem.add(item)

return mem
Loading

0 comments on commit b6b9e7e

Please sign in to comment.