Skip to content

Commit

Permalink
improve dass questionnaire metric
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588388291
Change-Id: I4c31c50b08646a40b843971c077c596f801e5183
  • Loading branch information
jzleibo authored and copybara-github committed Dec 6, 2023
1 parent 4d2e457 commit 941ab4f
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions concordia/metrics/dass_questionnaire.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""

from collections.abc import Callable
import concurrent
from typing import Any

from concordia.document import interactive_document
Expand All @@ -33,6 +34,7 @@
from concordia.typing import component
from concordia.utils import measurements as measurements_lib
import numpy as np
import termcolor


AGREEMENT_SCALE_CHOICES = [
Expand All @@ -57,6 +59,7 @@ def __init__(
verbose: bool = False,
measurements: measurements_lib.Measurements | None = None,
channel: str = 'unspecified_subscale',
log_color='green',
):
"""Initializes the metric.
Expand All @@ -72,6 +75,7 @@ def __init__(
verbose: Whether to print the metric.
measurements: The measurements to use.
channel: The name of the channel to push data
log_color: color for debug logging
"""
self._model = model
self._name = name
Expand All @@ -81,9 +85,11 @@ def __init__(
self._player_name = player_name
self._measurements = measurements
self._channel = channel
self._log_color = log_color

self._timestep = 0

# Note: the DASS questionnaire normally asks about the previous week.
self._preprompt = (
'Please indicate the extent to which the following statement applied ' +
f'to {self._player_name} over the past week:\n')
Expand All @@ -95,15 +101,17 @@ def name(
"""See base class."""
return self._name

def _log(self, entry: str):
print(termcolor.colored(entry, self._log_color), end='')

def update(self) -> None:
"""See base class."""

prompt = interactive_document.InteractiveDocument(self._model)
parent_state = self._context_fn()
prompt.statement(parent_state)

numeric_results = []
for item in self._questionnaire:

def respond(item: dict[str, Any]) -> None:
prompt = interactive_document.InteractiveDocument(self._model)
prompt.statement(parent_state)
prompt.statement(self._preprompt)

answer = prompt.multiple_choice_question(
Expand All @@ -115,8 +123,15 @@ def update(self) -> None:
else:
reversed_choices = item['choices'].reverse()
numeric_result = float(answer) / float(len(reversed_choices) - 1)

numeric_results.append(numeric_result)

if self._verbose:
self._log('\n' + prompt.view().text() + '\n')

with concurrent.futures.ThreadPoolExecutor() as executor:
executor.map(respond, self._questionnaire)

final_result = np.mean(numeric_results)
datum = {
'time_str': self._clock.now().strftime('%H:%M:%S'),
Expand Down Expand Up @@ -316,7 +331,7 @@ def __init__(
'ascending_scale': True},

{'statement': (
'I had a feeling of faintness .'),
'I had a feeling of faintness.'),
'choices': AGREEMENT_SCALE_CHOICES,
'ascending_scale': True},

Expand Down

0 comments on commit 941ab4f

Please sign in to comment.