Skip to content

Commit

Permalink
Added the MC Joint CoT adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu committed Oct 25, 2024
1 parent e38f69e commit 3d82b96
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/helm/benchmark/adaptation/adapters/adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
ADAPT_GENERATION_MULTIMODAL,
ADAPT_LANGUAGE_MODELING,
ADAPT_MULTIPLE_CHOICE_JOINT,
ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT,
ADAPT_MULTIPLE_CHOICE_JOINT_MULTIMODAL,
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
Expand All @@ -19,6 +20,9 @@
)
from helm.benchmark.adaptation.adapters.multiple_choice_calibrated_adapter import MultipleChoiceCalibratedAdapter
from helm.benchmark.adaptation.adapters.multiple_choice_joint_adapter import MultipleChoiceJointAdapter
from helm.benchmark.adaptation.adapters.multiple_choice_joint_chain_of_thought_adapter import (
MultipleChoiceJointChainOfThoughtAdapter,
)
from helm.benchmark.adaptation.adapters.multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter
from helm.benchmark.window_services.tokenizer_service import TokenizerService

Expand All @@ -38,6 +42,8 @@ def get_adapter(adapter_spec: AdapterSpec, tokenizer_service: TokenizerService)
adapter = LanguageModelingAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_MULTIPLE_CHOICE_JOINT:
adapter = MultipleChoiceJointAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT:
adapter = MultipleChoiceJointChainOfThoughtAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL:
adapter = MultipleChoiceSeparateAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class MultipleChoiceJointAdapter(InContextLearningAdapter):

@staticmethod
def get_prefix_char(prefix: str) -> str:
return prefix.lstrip()[0]
return [char for char in prefix if char.isalnum()][0]

@staticmethod
def get_reference_prefix(prefix: str, i: int) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional

from helm.benchmark.scenarios.scenario import Instance
from helm.benchmark.adaptation.adapters.multiple_choice_joint_adapter import MultipleChoiceJointAdapter


class MultipleChoiceJointChainOfThoughtAdapter(MultipleChoiceJointAdapter):
"""
Each `Instance` in a `Scenario` looks like this:
<input> -> <reference1>
<reference2>
<reference3> [correct]
<reference4>
<instance_chain_of_thought>
We can define a label (e.g., letter) for each reference:
<global_prefix>
<instructions>
<input_prefix>
<input> # train
<input_suffix>
A. <reference1>
B. <reference2>
C. <reference3>
D. <reference4>
<output_prefix>
<chain_of_thought_prefix>
<instance_chain_of_thought>
<chain_of_thought_suffix>
<output>
<output_suffix>
<input_prefix>
<input> # test
<input_suffix>
A. <reference1>
B. <reference2>
C. <reference3>
D. <reference4>
<output_prefix>
<chain_of_thought_prefix>
<instance_chain_of_thought>
<chain_of_thought_suffix>
<output>
<output_suffix>
<global_suffix>
In general, each example is:
<input_prefix><input><input_suffix><reference_prefixes[index]><reference> \
<output_prefix><chain_of_thought_prefix><chain_of_thought><chain_of_thought_suffix><output><output_suffix>
"""

def construct_example_prompt(self, instance: Instance, include_output: bool, reference_index: Optional[int]) -> str:
"""Return a list of lines corresponding to this example (part of the prompt)."""
# Input
result: str = self.adapter_spec.input_prefix + instance.input.text + self.adapter_spec.input_suffix

# Include the references
delimiter = ", "
no_correct_references = "n/a"
output = no_correct_references
for reference_index, reference in enumerate(instance.references):
prefix = self.get_reference_prefix(self.adapter_spec.reference_prefix, reference_index)
result += prefix + reference.output.text + self.adapter_spec.reference_suffix
if reference.is_correct:
if output == no_correct_references:
output = self.get_reference_prefix(self.adapter_spec.reference_prefix, reference_index)
elif self.adapter_spec.multi_label:
output += delimiter
output += self.get_reference_prefix(self.adapter_spec.reference_prefix, reference_index)

if include_output:
chain_of_thought = (
self.adapter_spec.chain_of_thought_prefix
+ instance.extra_data.get("chain_of_thought", "")
+ self.adapter_spec.chain_of_thought_suffix
)
result += self.adapter_spec.output_prefix + chain_of_thought + output + self.adapter_spec.output_suffix
else:
result += self.adapter_spec.output_prefix.rstrip()

return result

0 comments on commit 3d82b96

Please sign in to comment.