-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
190 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import torch | ||
from transformers import LlamaForCausalLM, LlamaTokenizer | ||
from peft import PeftModel | ||
import argparse | ||
import json, os | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--base_model', default=None, type=str, required=True) | ||
parser.add_argument('--lora_model', default=None, type=str,help="If None, perform inference on the base model") | ||
parser.add_argument('--tokenizer_path',default=None,type=str) | ||
parser.add_argument('--data_file',default=None, type=str,help="file that contains instructions (one instruction per line).") | ||
parser.add_argument('--with_prompt',action='store_true') | ||
parser.add_argument('--interactive',action='store_true') | ||
parser.add_argument('--predictions_file', default='./predictions.json', type=str) | ||
args = parser.parse_args() | ||
|
||
generation_config = dict( | ||
temperature=0.2, | ||
top_k=40, | ||
top_p=0.9, | ||
do_sample=True, | ||
num_beams=1, | ||
repeat_penalty=1.3, | ||
max_new_tokens=400 | ||
) | ||
|
||
|
||
# The prompt template below is taken from llama.cpp | ||
# and is slightly different from the one used in training. | ||
# But we find it gives better results | ||
prompt_input = ( | ||
"Below is an instruction that describes a task. " | ||
"Write a response that appropriately completes the request.\n\n" | ||
"### Instruction:\n\n{instruction}\n\n### Response:\n\n" | ||
) | ||
|
||
sample_data = ["为什么要减少污染,保护环境?"] | ||
|
||
def generate_prompt(instruction, input=None): | ||
if input: | ||
instruction = instruction + '\n' + input | ||
return prompt_input.format_map({'instruction': instruction}) | ||
|
||
|
||
if __name__ == '__main__': | ||
load_type = torch.float16 | ||
if torch.cuda.is_available(): | ||
device = torch.device(0) | ||
else: | ||
device = torch.device('cpu') | ||
if args.tokenizer_path is None: | ||
args.tokenizer_path = args.lora_model | ||
if args.lora_model is None: | ||
args.tokenizer_path = args.base_model | ||
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path) | ||
|
||
base_model = LlamaForCausalLM.from_pretrained( | ||
args.base_model, | ||
load_in_8bit=False, | ||
torch_dtype=load_type, | ||
low_cpu_mem_usage=True, | ||
) | ||
|
||
model_vocab_size = base_model.get_input_embeddings().weight.size(0) | ||
tokenzier_vocab_size = len(tokenizer) | ||
print(f"Vocab of the base model: {model_vocab_size}") | ||
print(f"Vocab of the tokenizer: {tokenzier_vocab_size}") | ||
if model_vocab_size!=tokenzier_vocab_size: | ||
assert tokenzier_vocab_size > model_vocab_size | ||
print("Resize model embeddings to fit tokenizer") | ||
base_model.resize_token_embeddings(tokenzier_vocab_size) | ||
if args.lora_model is not None: | ||
print("loading peft model") | ||
model = PeftModel.from_pretrained(base_model, args.lora_model,torch_dtype=load_type) | ||
else: | ||
model = base_model | ||
|
||
if device==torch.device('cpu'): | ||
model.float() | ||
# test data | ||
if args.data_file is None: | ||
examples = sample_data | ||
else: | ||
with open(args.data_file,'r') as f: | ||
examples = [l.strip() for l in f.readlines()] | ||
print("first 10 examples:") | ||
for example in examples[:10]: | ||
print(example) | ||
|
||
model.to(device) | ||
model.eval() | ||
|
||
|
||
with torch.no_grad(): | ||
if args.interactive: | ||
while True: | ||
raw_input_text = input("Input:") | ||
if len(raw_input_text.strip())==0: | ||
break | ||
if args.with_prompt: | ||
input_text = generate_prompt(instruction=raw_input_text) | ||
else: | ||
input_text = raw_input_text | ||
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? | ||
generation_output = model.generate( | ||
input_ids = inputs["input_ids"].to(device), | ||
attention_mask = inputs['attention_mask'].to(device), | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id, | ||
**generation_config | ||
) | ||
s = generation_output[0] | ||
output = tokenizer.decode(s,skip_special_tokens=True) | ||
if args.with_prompt: | ||
response = output.split("### Response:")[1].strip() | ||
else: | ||
response = output | ||
print("Response: ",response) | ||
print("\n") | ||
else: | ||
results = [] | ||
for index, example in enumerate(examples): | ||
if args.with_prompt is True: | ||
input_text = generate_prompt(instruction=example) | ||
else: | ||
input_text = example | ||
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ? | ||
generation_output = model.generate( | ||
input_ids = inputs["input_ids"].to(device), | ||
attention_mask = inputs['attention_mask'].to(device), | ||
eos_token_id=tokenizer.eos_token_id, | ||
pad_token_id=tokenizer.pad_token_id, | ||
**generation_config | ||
) | ||
s = generation_output[0] | ||
output = tokenizer.decode(s,skip_special_tokens=True) | ||
if args.with_prompt: | ||
response = output.split("### Response:")[1].strip() | ||
else: | ||
response = output | ||
print(f"======={index}=======") | ||
print(f"Input: {example}\n") | ||
print(f"Output: {response}\n") | ||
|
||
results.append({"Input":input_text,"Output":response}) | ||
|
||
dirname = os.path.dirname(args.predictions_file) | ||
os.makedirs(dirname,exist_ok=True) | ||
with open(args.predictions_file,'w') as f: | ||
json.dump(results,f,ensure_ascii=False,indent=2) | ||
with open(dirname+'/generation_config.json','w') as f: | ||
json.dump(generation_config,f,ensure_ascii=False,indent=2) |