diff --git a/scripts/inference/hf_chat.py b/scripts/inference/hf_chat.py index e692d4d2b6..963733d7d4 100644 --- a/scripts/inference/hf_chat.py +++ b/scripts/inference/hf_chat.py @@ -334,9 +334,9 @@ def main(args: Namespace) -> None: # Chat format model_name = model.config.model_type - if 'llama2' in model_name: + if 'llama2' in model_name.lower(): chat_format = Llama2ChatFormatter(system=args.system_prompt) - elif 'mpt' in model_name: + elif 'mpt' in model_name.lower(): chat_format = ChatMLFormatter(system=args.system_prompt) else: chat_format = ChatFormatter(system=args.system_prompt,