Skip to content

Commit

Permalink
greedy_and_beam
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 11, 2023
1 parent 1a8b39b commit 6df14ac
Show file tree
Hide file tree
Showing 9 changed files with 570 additions and 54 deletions.
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public ZooModel<I, O> loadModel()
}
}
throw new ModelNotFoundException(
"No matching model with specified Input/Output type found.", lastException);
"No model with the specified URI or the matching Input/Output type is found.",
lastException);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/translate/CausalLMOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class CausalLMOutput {

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, feature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
public NDList pastKeyValuesList;

public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
Expand Down
8 changes: 4 additions & 4 deletions api/src/main/java/ai/djl/translate/LMAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* range(|inputIds|). This means for each i, the output probability is conditional on the past
* sequence up to i.
*/
public interface LMAdapter extends AutoCloseable {
public interface LMAdapter {

/**
* @param input input
Expand All @@ -37,7 +37,7 @@ default CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager man
return null;
}

/** {@inheritDoc} */
@Override
void close();
// /** {@inheritDoc} */
// @Override
// void close();
}
397 changes: 383 additions & 14 deletions api/src/main/java/ai/djl/translate/LMSearch.java

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions api/src/main/java/ai/djl/translate/SearchConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ public class SearchConfig {

public int k;
public float alpha;
public int beam;
public int maxSeqLength;
public long padTokenId;
public long eosTokenId;
public boolean suffixPadding;

/** Constructs a new ContrastiveSearchConfig object with default values. */
public SearchConfig() {
this.k = 4;
this.alpha = 0.6f;
this.beam = 3;
this.maxSeqLength = 30;
this.eosTokenId = 50256;
this.padTokenId = 50256;
this.suffixPadding = true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import java.util.Collections;
import java.util.List;

public class GPT2OrtLMAdapter implements LMAdapter {
public class GPT2OrtLMAdapter implements LMAdapter, AutoCloseable {
Block[] blocks;
List<ZooModel<NDList, NDList>> models;
GPTConfig config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import java.util.Arrays;
import java.util.List;

public class GPT2PtLMAdapter implements LMAdapter {
public class GPT2PtLMAdapter implements LMAdapter, AutoCloseable {
Block[] blocks;
List<ZooModel<NDList, NDList>> models;
GPTConfig config;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public static void mainOnnx(String[] args) {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/gpt2_onnx/decoder_model_merged.onnx"
};

try (LMAdapter generator = Engine.getEngine("OnnxRuntime").newLMAdapter("GPT2", new GPTConfig(modelUrls));
NDManager manager = NDManager.newBaseManager()) {

try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter generator =
Engine.getEngine("PyTorch").newLMAdapter("GPT2", new GPTConfig(modelUrls));
/////////////////////////////////////////////
// Inference without cached key_values input
/////////////////////////////////////////////
Expand Down Expand Up @@ -93,7 +93,6 @@ public static void mainOnnx(String[] args) {

System.out.println(out.logits);
System.out.println(out.pastKeyValuesList);

} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -105,9 +104,9 @@ public static void mainPt(String[] args) {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};

try (LMAdapter generator = Engine.getEngine("PyTorch").newLMAdapter("GPT2", new GPTConfig(modelUrls));
NDManager manager = NDManager.newBaseManager()) {

try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter generator =
Engine.getEngine("PyTorch").newLMAdapter("GPT2", new GPTConfig(modelUrls));
/////////////////////////////////////////////
// Inference without cached key_values input
/////////////////////////////////////////////
Expand Down
193 changes: 168 additions & 25 deletions examples/src/main/java/ai/djl/examples/inference/TestLMSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.MalformedModelException;
import ai.djl.engine.Engine;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
Expand All @@ -26,49 +27,59 @@
import ai.djl.translate.SearchConfig;

import java.io.IOException;
import java.nio.file.Paths;

public final class TestLMSearch {

private TestLMSearch() {}

public static void main(String[] args) {
mainPt(args);
mainContrastivePt(args);
mainGreedy(args);
mainBeam(args);
mainBeamOnnx(args);
}

public static void mainPt(String[] args) {
// String[] modelUrls = {
// "/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
// "/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
// };
// GPTConfig gptConfig = new GPTConfig(modelUrls);
// gptConfig.numAttentionHeads = 20;
// gptConfig.numLayers = 36;
// gptConfig.hiddenStateDim = 768;
// gptConfig.logitsDim = 50257;
// gptConfig.kvDim = 64;
public static void mainContrastivePt(String[] args) {
// String[] modelUrls = {
//
// "/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
//
// "/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
// };
// GPTConfig gptConfig = new GPTConfig(modelUrls);
// gptConfig.numAttentionHeads = 20;
// gptConfig.numLayers = 36;
// gptConfig.hiddenStateDim = 768;
// gptConfig.logitsDim = 50257;
// gptConfig.kvDim = 64;

String[] modelUrls = {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/models/traced_GPT2_init_hidden.pt",
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/models/traced_GPT2_hidden.pt"
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};
GPTConfig gptConfig = new GPTConfig(modelUrls);

try (LMAdapter lmAdapter = Engine.getEngine("PyTorch").newLMAdapter("GPT2", gptConfig);
NDManager manager = NDManager.newBaseManager()) {
try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter lmAdapter = Engine.getEngine("PyTorch").newLMAdapter("GPT2", gptConfig);

LMSearch lmSearch;
lmSearch = new LMSearch(lmAdapter);
SearchConfig config = new SearchConfig();
config.maxSeqLength = 50;
config.maxSeqLength = 60;
config.alpha = 0.6f;
config.k = 3;

// [r'DeepMind Company is',
// r'Memories follow me left and right. I can']
NDArray inputIds = manager.create(new long[][] {
{29744, 28478, 5834, 318, 220, 220, 220, 220, 220, 220},
{13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460}
});
NDArray inputIds =
manager.create(
new long[][] {
{220, 220, 220, 220, 220, 220, 29744, 28478, 5834, 318},
{13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460}
});
config.padTokenId = 220;
config.suffixPadding = false;

int numBatch = (int) inputIds.getShape().get(0);
int initSeqSize = (int) inputIds.getShape().get(1);
Expand All @@ -83,23 +94,155 @@ public static void mainPt(String[] args) {
long[] aSequence = inputIds.get("{},:", i).toLongArray();
int idx = 0;
while (idx < initSeqSize) {
if (suffixPadding && aSequence[idx] == config.padTokenId || !suffixPadding && aSequence[idx] != config.padTokenId) {
if (suffixPadding && aSequence[idx] == config.padTokenId
|| !suffixPadding && aSequence[idx] != config.padTokenId) {
break;
}
idx++;
}
attentionMaskSlice[i][0] = suffixPadding ? idx : 0;
attentionMaskSlice[i][1] = suffixPadding ? initSeqSize : idx;
attentionMask.set(new NDIndex("{},{}:{}", i, suffixPadding ? idx : 0, suffixPadding ? initSeqSize : idx), 0);
attentionMask.set(
new NDIndex(
"{},{}:{}",
i,
suffixPadding ? idx : 0,
suffixPadding ? initSeqSize : idx),
0);
}

NDArray output =
lmSearch.contrastiveSearch(
manager, inputIds, attentionMask, attentionMaskSlice, config);
lmSearch.contrastiveSearch(manager, inputIds, attentionMaskSlice, config);
System.out.println(output.toDebugString(1000, 10, 10, 100, true));

printDecode(output);
} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException(e);
}
}

public static void mainGreedy(String[] args) {
String[] modelUrls = {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};
GPTConfig gptConfig = new GPTConfig(modelUrls);

try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter lmAdapter = Engine.getEngine("PyTorch").newLMAdapter("GPT2", gptConfig);

LMSearch lmSearch;
lmSearch = new LMSearch(lmAdapter);
SearchConfig config = new SearchConfig();
config.maxSeqLength = 60;

// [r'DeepMind Company is',
// r'Memories follow me left and right. I can']
NDArray inputIds =
manager.create(
new long[][] {
{220, 220, 220, 220, 220, 220, 29744, 28478, 5834, 318},
{13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460}
});
config.padTokenId = 220;
config.suffixPadding = false;

NDArray output = lmSearch.greedySearch(inputIds, config);
System.out.println(output.toDebugString(1000, 10, 10, 100, true));

printDecode(output);

} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException(e);
}
}

public static void mainBeam(String[] args) {
String[] modelUrls = {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};
GPTConfig gptConfig = new GPTConfig(modelUrls);

try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter lmAdapter = Engine.getEngine("PyTorch").newLMAdapter("GPT2", gptConfig);

LMSearch lmSearch;
lmSearch = new LMSearch(lmAdapter);
SearchConfig config = new SearchConfig();
config.maxSeqLength = 60;
config.beam = 3;

// [r'DeepMind Company is',
// r'Memories follow me left and right. I can']
NDArray inputIds =
manager.create(
new long[][] {
{50256, 50256, 50256, 50256, 50256, 50256, 29744, 28478, 5834, 318},
{13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460}
});
config.padTokenId = 50256;
config.suffixPadding = false;

NDArray output = lmSearch.beamSearch(inputIds, config);
System.out.println(output.toDebugString(1000, 10, 10, 100, true));

printDecode(output);

} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException(e);
}
}

public static void mainBeamOnnx(String[] args) {
String[] modelUrls = {
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/gpt2_onnx/decoder_model_merged.onnx"
};
GPTConfig gptConfig = new GPTConfig(modelUrls);

try (NDManager manager = NDManager.newBaseManager()) {
LMAdapter lmAdapter = Engine.getEngine("OnnxRuntime").newLMAdapter("GPT2", gptConfig);

LMSearch lmSearch;
lmSearch = new LMSearch(lmAdapter);
SearchConfig config = new SearchConfig();
config.maxSeqLength = 60;
config.beam = 3;

// [r'DeepMind Company is',
// r'Memories follow me left and right. I can']
NDArray inputIds =
manager.create(
new long[][] {
{220, 220, 220, 220, 220, 220, 29744, 28478, 5834, 318},
{13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460}
// {220, 29744, 28478, 5834, 318}
});
config.padTokenId = 220;
config.suffixPadding = false;
// The positionIds is not effective in onnx model traced from huggingface optimum.

NDArray output = lmSearch.beamSearch(inputIds, config);
System.out.println(output.toDebugString(1000, 10, 10, 100, true));

printDecode(output);
} catch (ModelNotFoundException | MalformedModelException | IOException e) {
throw new RuntimeException(e);
}
}

private static void printDecode(NDArray output) throws IOException {
// Decoding
String tokenizerJson =
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/gpt2_onnx/tokenizer.json";
HuggingFaceTokenizer tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson));

System.out.println('\n');
for (int i = 0; i < output.getShape().get(0); i++) {
System.out.println(i + ":");
long[] aSequence = output.get("{},:", i).toLongArray();
System.out.println(tokenizer.decode(aSequence));
}
System.out.println('\n');
}
}

0 comments on commit 6df14ac

Please sign in to comment.