A script for inference using T5 in Java #17432
-
Hello, I am having a little trouble for building a script to perform the inference in Java using the T5 model. I am using T5 for the summarization use case. 1️⃣ ModelLuckily the 2️⃣ The code I am using currentlypublic class t5_onnx {
public static void main(String args[]) throws OrtException {
// Load the model and create InferenceSession
System.out.println("This is the model loading");
String encoder = "encoder_model.onnx";
String decoder = "decoder_model.onnx";
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession encoderSession = env.createSession(encoder);
OrtSession decoderSession = env.createSession(decoder);
String prompt = "yo!";
String generatedText = generate(prompt, env, encoderSession, decoderSession);
}
static String generate(String prompt, OrtEnvironment env, OrtSession encoderSession, OrtSession decoderSession) throws OrtException {
// Get the input and output names for the encoder and decoder
String encoderInputName = encoderSession.getInputNames().iterator().next();
String encoderOutputName = encoderSession.getOutputNames().iterator().next();
String decoderInputName = decoderSession.getInputNames().iterator().next();
String decoderOutputName = decoderSession.getOutputNames().iterator().next();
// Encoding
// INPUT IDS
long[] inputData = new long[10];
for (int i = 0; i < inputData.length; i++) {
inputData[i] = (long) Math.random() * 10;
}
long[] input_ids_shape = new long[]{1, inputData.length}; // Shape of the input data
// ATTENTION MASK FOR INPUT IDS
long[] attention_mask = new long[inputData.length];
Arrays.fill(attention_mask, 1);
long[] attention_mask_shape = new long[]{1, inputData.length}; // Shape of the input data
// OnnxTensor - INPUT IDS
OnnxTensor inputTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(inputData), input_ids_shape);
OnnxTensor attentionTensor = OnnxTensor.createTensor(env, LongBuffer.wrap(attention_mask), attention_mask_shape);
Map<String, OnnxTensor> encoder_inputs = Map.of("input_ids", inputTensor, "attention_mask", attentionTensor);
OrtSession.Result encoderOutput = encoderSession.run(encoder_inputs);
OnnxTensor encoderOutputTensor = (OnnxTensor) encoderOutput.get(encoderOutputName).get();
// OrtSession.Result decoderOutput = decoderSession.run(inputMap);
// OnnxTensor decoderOutputTensor = (OnnxTensor) decoderOutput.get(decoderOutputName).get();
return "STATIC RETURN";
} As you can see:
🙏🏻 A small requestCan you please share a correct way to go for this? Because:
Can I get a code guidance to do this? Because in CausalLM models, we have to go for a loop until the max tokens are reached, I am not sure what should I do for this seq to seq model and how the "doSample" and "top K" can be applied here. Please share a script to do this... I would appreciate your help. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
You'll need an implementation of the T5 tokenizer, which looks like it's sentencepiece, so either sentencepiece-jni, or the wrapper in DJL will work. I'm working on a pure Java implementation of sentencepiece which will live in Tribuo but it's not finished yet. The decoder expects the input ids for the decoder prompt and the hidden state output from the encoder to generate the first token, then you want to use decoder_with_past_model to generate tokens keeping the key & value cache. If you inspect the models with Then once you get logits out you can run the sampling procedure yourself. Either greedy (by taking the largest logit) or performing nucleus sampling or any other sampling procedure. It's not baked into the model, so you need to do it by hand. However there are tools which will bake in a beam search and other operations into an ONNX model (e.g. this for GPT-2 & T5) so you don't need to do the sampling in Java, but I've not used them. |
Beta Was this translation helpful? Give feedback.
For the next step the input ids should be your start of sequence token (which you can get by loading in the sentencepiece protobuf and querying it for the start of sequence id), then whatever token you sampled from the logits. The encoder_hidden_states should be the same, those won't change. You'll want the version that accepts past_key_values and supply those from the outputs of the decoder.
It's a little messy as you don't want to reallocate the buffer every time. It might be better to allocate a single direct LongBuffer (with
ByteBuffer.allocateDirect(seqLength*8).order(ByteOrder.nativeOrder()).asLongBuffer()
) then set the position to 0 and increment the limit each time you wrap it int…