Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seq2Seq模型怎么解析输入输出 #2565

Open
ling976 opened this issue Apr 24, 2023 · 6 comments
Open

Seq2Seq模型怎么解析输入输出 #2565

ling976 opened this issue Apr 24, 2023 · 6 comments
Labels
enhancement New feature or request

Comments

@ling976
Copy link

ling976 commented Apr 24, 2023

Seq2Seq模型怎么解析输入输出

    public static class TextTranslator implements Translator<String, String>{

		private HuggingFaceTokenizer tokenizer;
	    
	    TextTranslator(HuggingFaceTokenizer tokenizer) {
	        this.tokenizer = tokenizer;
	    }

	    /** {@inheritDoc} */
	    @Override
	    public NDList processInput(TranslatorContext ctx, String input) {
	    	Encoding encoding = tokenizer.encode(input);
	        NDArray attention = ctx.getNDManager().create(encoding.getAttentionMask());
	     	NDArray inputIds = ctx.getNDManager().create(encoding.getIds());
	     	NDArray tokenTypes = ctx.getNDManager().create(encoding.getTypeIds());
	     	return new NDList(inputIds,tokenTypes,attention);
	    }

	    /** {@inheritDoc} */
	    @Override
	    public String processOutput(TranslatorContext ctx, NDList list) {
	    	return list.toString();
	    }
	}

我用这个TextTranslator 去处理输入输出,现在得到的NDList list内容如下:

        NDList size: 98
      	0 : (128, 32128) float32
     	1 : (16, 128, 64) float32
	2 : (16, 128, 64) float32
	3 : (16, 128, 64) float32
	4 : (16, 128, 64) float32
	5 : (16, 128, 64) float32
	6 : (16, 128, 64) float32
	7 : (16, 128, 64) float32
        ......
	91 : (16, 128, 64) float32
	92 : (16, 128, 64) float32
	93 : (16, 128, 64) float32
	94 : (16, 128, 64) float32
	95 : (16, 128, 64) float32
	96 : (16, 128, 64) float32
	97 : (128, 1024) float32

这里怎么把这些数据转换为正确的字符串

下面是Python中进行推理的相关代码

    def answer_fn(text, sample=False, top_k=50):

       encoding = tokenizer(text=[text], truncation=True, padding=True, max_length=256, return_tensors="pt").to(device) 
       if not sample: # 不进行采样
           out = model_trained.generate(**encoding, return_dict_in_generate=True, max_length=512,    num_beams=4,temperature=0.5,repetition_penalty=10.0,remove_invalid_values=True)
         else: # 采样(生成)
              out = model_trained.generate(**encoding, return_dict_in_generate=True,   max_length=512,temperature=0.6,do_sample=True,repetition_penalty=3.0 ,top_k=top_k)
          out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
        if out_text[0]=='':
              return '我只是个语言模型,这个问题我回答不了。'
       return postprocess(out_text[0]) 

         text_list=[]
         text = input('请输入问题:')
         result=answer_fn(text, sample=True, top_k=100)
         print("模型生成:",result)
@ling976 ling976 added the enhancement New feature or request label Apr 24, 2023
@KexinFeng
Copy link
Contributor

KexinFeng commented Apr 24, 2023

The postprocessing is done in

    /** {@inheritDoc} */
    @Override
    public String processOutput(TranslatorContext ctx, NDList list) {
    	return list.toString();
    }

First of all, list.toString() is not the command of parsing the output. It is rather equivalent to print(list) in python.

To parse the output, first make sure you understand what the output NDArrays are. Find the output token_ids. Then feed it to HuggingFaceTokenizer's decoder, something like the following.

        String tokenizerJson = "some_directory/gpt2_onnx/tokenizer.json";
        HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson));
        System.out.println(tokenizer.decode(tokenIds));

HuggingFaceTokenizer has documents in extensions/tokenizers/README.md. You can also take a look at the source code.

@ling976
Copy link
Author

ling976 commented Apr 25, 2023

    String tokenizerJson = "some_directory/gpt2_onnx/tokenizer.json";
    HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson));
    System.out.println(tokenizer.decode(tokenIds));

你这里的token_ids是怎么来的

@KexinFeng
Copy link
Contributor

To parse the output, first make sure you understand what the output NDArrays are. Find the output token_ids.

tokenIds is contained in or computed from the output NDArray.

@ling976
Copy link
Author

ling976 commented Apr 25, 2023

 public String processOutput(TranslatorContext ctx, NDList list) {}

我怎么从这里解析出来呢

@KexinFeng
Copy link
Contributor

This requires certain knowledge of the model output. For example, for gpt2
would need to know its output by the following documents:
https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2LMHeadModel
https://huggingface.co/transformers/v2.2.0/pretrained_models.html

Or just use search engine to see what the model outputs are.
NDList size: 98
0 : (128, 32128) float32
1 : (16, 128, 64) float32
This also tells what they are. Understanding the model output is necessary to parse its output.

@KexinFeng KexinFeng pinned this issue Aug 26, 2023
@KexinFeng KexinFeng unpinned this issue Aug 26, 2023
@KexinFeng
Copy link
Contributor

刚刚看了一下你的那个输出的NDArray

        NDList size: 98
      	0 : (128, 32128) float32
     	1 : (16, 128, 64) float32
	2 : (16, 128, 64) float32
	3 : (16, 128, 64) float32
	4 : (16, 128, 64) float32
	5 : (16, 128, 64) float32
	6 : (16, 128, 64) float32
	7 : (16, 128, 64) float32
        ......
	91 : (16, 128, 64) float32
	92 : (16, 128, 64) float32
	93 : (16, 128, 64) float32
	94 : (16, 128, 64) float32
	95 : (16, 128, 64) float32
	96 : (16, 128, 64) float32
	97 : (128, 1024) float32

看起来0号是logits,剩下的一直到96都是kv cache,97不太清楚。这应该只是transformer模型的一步的输出,对应一个token。要想生成一串token,还有一个autoregressive loop的过程。看一下 #2637 #2723 里面有如何调用Translator生成text的

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants