From dc06b415f9f053708c0cae997475834ef2878fa7 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Mon, 16 Sep 2024 02:12:16 -0700 Subject: [PATCH] Improve info printed on error in the together_ai model wrapper. PiperOrigin-RevId: 675060263 Change-Id: Id04b343c49c9c186d1177141b58bef52486cb5c4 --- concordia/language_model/together_ai.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/concordia/language_model/together_ai.py b/concordia/language_model/together_ai.py index b5e94fa8..03f80f31 100644 --- a/concordia/language_model/together_ai.py +++ b/concordia/language_model/together_ai.py @@ -154,6 +154,7 @@ def sample_choice( ) -> tuple[int, str, dict[str, float]]: def _sample_choice(response: str) -> float: + augmented_prompt = prompt + response messages = [ { 'role': 'system', @@ -176,7 +177,7 @@ def _sample_choice(response: str) -> float: ), }, {'role': 'assistant', 'content': 'sleeping.'}, - {'role': 'user', 'content': prompt + response}, + {'role': 'user', 'content': augmented_prompt}, ] result = None @@ -198,6 +199,7 @@ def _sample_choice(response: str) -> float: except together.error.RateLimitError as err: if attempts >= _NUM_SILENT_ATTEMPTS: print(f' Exception: {err}') + print(f' Exception prompt: {augmented_prompt}') continue else: break @@ -205,7 +207,8 @@ def _sample_choice(response: str) -> float: if result: lp = sum(result.choices[0].logprobs.token_logprobs) else: - raise ValueError('Failed to get logprobs.') + raise ValueError( + f'Failed to get logprobs.\nException prompt: {augmented_prompt}') return lp