Skip to content

Commit

Permalink
Added stopping on substring for HF Transformers.
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek Divekar committed Aug 13, 2024
1 parent e17c57c commit 4bcd8bb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 3 deletions.
79 changes: 77 additions & 2 deletions src/synthesizrr/base/algorithm/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from transformers.models.auto.modeling_auto import _BaseAutoModelClass, MODEL_MAPPING_NAMES, \
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput
from transformers.generation.utils import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, StoppingCriteria
from transformers import (
LogitsProcessorList,
MinLengthLogitsProcessor, TemperatureLogitsWarper,
Expand Down Expand Up @@ -507,6 +507,51 @@ class HFGenerativeLMTokenizerConfig(HFTokenizerConfig):
truncation_side: Literal['left', 'right'] = 'left' ## Keeps tokens at the end of the string, useful for LLMs


class HFSubstringMatchStoppingCriteria(StoppingCriteria):
def __init__(
self,
*,
stop_sequences: List[str],
tokenizer: Any,
tokenizer_decode_dict: Dict,
prompt_input_ids: Tensor,
):
self.tokenizer: PreTrainedTokenizerBase = tokenizer
self.tokenizer_decode_dict: Dict = tokenizer_decode_dict
self.stop_sequences: List[str] = as_list(stop_sequences)
self.prompt_input_ids: Tensor = prompt_input_ids

def __call__(self, input_ids, scores, **kwargs):
# Get the generated text as a string
generated_texts: List[str] = self.tokenizer.batch_decode(
input_ids[:, self.prompt_input_ids.shape[1]:],
**self.tokenizer_decode_dict,
)
# Check if the target sequence appears in ALL generated texts
should_stop_generating: List[bool] = []
for generated_text in generated_texts:
should_stop_generating.append(False)
for stop_seq in self.stop_sequences:
if stop_seq in generated_text:
should_stop_generating[-1] = True
break
if bool(all(should_stop_generating)):
# print('=' * 40)
# print(f'Stopped at this point:')
# print('=' * 40)
# for generated_text in generated_texts:
# print(generated_text, end='\n\n')
# print('=' * 40)
return True ## Stop generation
return False ## Continue generation

def __len__(self):
return len(self.stop_sequences)

def __iter__(self):
yield self


class HFPyTorchGenerativeLMMixin(GenerativeLM, HFPyTorchTextModel, ABC):
class Hyperparameters(HFPyTorchTextModel.Hyperparameters):
prompt_prefix: str = ''
Expand All @@ -529,6 +574,14 @@ def set_generative_lm_params(cls, params: Dict) -> Dict:
def max_num_generated_tokens(self) -> int:
return self.hyperparams.generation_params.max_new_tokens

@property
def tokenizer_decode_dict(self) -> Dict:
return self.hyperparams.tokenizer_decode.dict()

@property
def stop_sequences(self) -> Optional[List[str]]:
return self.hyperparams.generation_params.stop_sequences

def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts:
batch: Prompts = super(HFPyTorchGenerativeLMMixin, self)._task_preprocess(
batch,
Expand All @@ -539,12 +592,20 @@ def _task_preprocess(self, batch: Prompts, **kwargs) -> Prompts:
def forward(self, input: Dict, **kwargs) -> Dict:
## Feed the input_ids and masks to the model:
input.pop('token_type_ids', None)
input_ids: Tensor = input['input_ids']
with disable_hf_logging():
gen_kwargs: Dict = {
**input,
**self.hyperparams.generation_params.hf_dict(),
**dict(return_dict_in_generate=True), ## Always return a *DecoderOnlyOutput
}
if self.stop_sequences is not None:
gen_kwargs['stopping_criteria'] = HFSubstringMatchStoppingCriteria(
stop_sequences=self.stop_sequences,
tokenizer=self.tokenizer,
tokenizer_decode_dict=self.tokenizer_decode_dict,
prompt_input_ids=input_ids,
)
out: Union[GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput] = self.model.generate(**gen_kwargs)
return dict(out)

Expand All @@ -566,8 +627,22 @@ def prepare_predictions(self, output: Dict, input: Dict, **kwargs) -> Any:
num_generated_tokens: int = generated_sequences.shape[1]
generated_texts: List[str] = self.tokenizer.batch_decode(
generated_sequences,
**self.hyperparams.tokenizer_decode.dict(),
**self.tokenizer_decode_dict,
)
## Post process stop-sequences:
if self.stop_sequences is not None:
for gen_text_i, generated_text in enumerate(generated_texts):
earliest_stop_idx: Optional[int] = None
for stop_seq in self.stop_sequences:
stop_idx: int = generated_text.find(stop_seq)
if stop_idx != -1:
if earliest_stop_idx is None:
earliest_stop_idx: int = stop_idx
else:
earliest_stop_idx: int = min(earliest_stop_idx, stop_idx)
if earliest_stop_idx is not None:
generated_texts[gen_text_i]: str = generated_text[:earliest_stop_idx]

predictions: Dict = {
GENERATED_TEXTS_COL: generated_texts
}
Expand Down
2 changes: 1 addition & 1 deletion src/synthesizrr/base/framework/evaluator/LocalEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LocalEvaluator(Evaluator):
aliases = ['local', 'SimpleEvaluator', 'simple']

## Cache model locally for 15 mins:
cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=60 * 15)
cache_timeout: Optional[Union[Timeout, confloat(gt=0)]] = Timeout24Hr(timeout=3 * 60 * 60)

def _load_model(
self,
Expand Down
3 changes: 3 additions & 0 deletions src/synthesizrr/base/framework/task/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ def set_gen_params(cls, params: Dict) -> Dict:
params['output_scores_tolerance']: Optional[float] = None ## Do not filter out any tokens.
else:
raise NotImplementedError(f'Unsupported `output_scores_format`: "{params["output_scores_format"]}"')

if params.get('stop_sequences') is not None:
params['stop_sequences']: List[str] = as_list(params['stop_sequences'])
return params

def hf_dict(self) -> Dict:
Expand Down

0 comments on commit 4bcd8bb

Please sign in to comment.