diff --git a/api/src/main/java/ai/djl/engine/Engine.java b/api/src/main/java/ai/djl/engine/Engine.java index 3b29a9295ab..5cc21617db7 100644 --- a/api/src/main/java/ai/djl/engine/Engine.java +++ b/api/src/main/java/ai/djl/engine/Engine.java @@ -14,7 +14,10 @@ import ai.djl.Device; import ai.djl.Model; +import ai.djl.modality.nlp.generate.GPTConfig; +import ai.djl.modality.nlp.generate.LMBlock; import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; import ai.djl.nn.SymbolBlock; import ai.djl.training.GradientCollector; import ai.djl.training.LocalParameterServer; @@ -302,6 +305,10 @@ public SymbolBlock newSymbolBlock(NDManager manager) { */ public abstract NDManager newBaseManager(Device device); + public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) { + throw new UnsupportedOperationException("Not supported."); + } + /** * Returns a new instance of {@link GradientCollector}. * diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java b/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java new file mode 100644 index 00000000000..f7eedaf5533 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration +// of the +// autoregressive loop. +// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains +// sequence dimension (whose position in tensor's shape is specified by seqDimOrder). +// The SeqBatcher batch operations will operate on these two dimensions. +public abstract class BatchTensorList { + // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastOutputIds; + + // [batch, seq_past] + // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastAttentionMask; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, kvfeature] + // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDList pastKeyValues; + + // Sequence dimension order among all dimensions for each element in the batch list. + private long[] seqDimOrder; + + BatchTensorList() {} + + BatchTensorList(NDList list, long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + pastOutputIds = list.get(0); + pastAttentionMask = list.get(1); + pastKeyValues = list.subNDList(2); + } + + BatchTensorList( + NDArray pastOutputIds, + NDArray pastAttentionMask, + NDList pastKeyValues, + long[] seqDimOrder) { + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.pastAttentionMask = pastAttentionMask; + this.seqDimOrder = seqDimOrder; + } + + public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder); + + // The pastOutputIds has to be the first in the output list + public abstract NDList getList(); + + public long[] getSeqDimOrder() { + return seqDimOrder; + } + + /** + * Gets the value of the pastOutputIds. + * + * @return the value of pastOutputIds + */ + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + /** + * Gets the value of the pastAttentionMask. + * + * @return the value of pastAttentionMask + */ + public NDArray getPastAttentionMask() { + return pastAttentionMask; + } + + public void setPastAttentionMask(NDArray pastAttentionMask) { + this.pastAttentionMask = pastAttentionMask; + } + + /** + * Gets the value of the pastKeyValues. + * + * @return the value of pastKeyValues + */ + public NDList getPastKeyValues() { + return pastKeyValues; + } + + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } + + public void setSeqDimOrder(long[] seqDimOrder) { + this.seqDimOrder = seqDimOrder; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java b/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java new file mode 100644 index 00000000000..1e72e9293f4 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +/** CausalLMOuput is used to contain multiple output of a language model. */ +public class CausalLMOutput { + + // [batch, seq, feature] + // The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size = + // |inputIds| + private NDArray logits; + + // [batch, seq, dim] * (layers+1) -> take -1 + // The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds| + private NDArray hiddenStates; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, feature] + // The cache of past sequence. seq-dim-size == |seq_past| + |inputIds| + private NDList pastKeyValuesList; + + public CausalLMOutput(NDArray logits, NDList pastKeyValues) { + this.logits = logits; + this.pastKeyValuesList = pastKeyValues; + } + + public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) { + this.logits = logits; + this.pastKeyValuesList = pastKeyValueList; + this.hiddenStates = hiddenState; + } + + /** + * Gets the value of the logits. + * + * @return the value of logits + */ + public NDArray getLogits() { + return logits; + } + + public void setLogits(NDArray logits) { + this.logits = logits; + } + + /** + * Gets the value of the allHiddenStates. + * + * @return the value of allHiddenStates + */ + public NDArray getHiddenState() { + return hiddenStates; + } + + /** + * Gets the value of the pastKeyValuesList. + * + * @return the value of pastKeyValuesList + */ + public NDList getPastKeyValuesList() { + return pastKeyValuesList; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveBatchTensorList.java b/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveBatchTensorList.java new file mode 100644 index 00000000000..602d0025cf1 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/ContrastiveBatchTensorList.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +class ContrastiveBatchTensorList extends BatchTensorList { + // [batch, seq_past, hiddenDim] + // The embed vector of the past seq. seq-dim-size = |past_seq|. Will grow. + private NDArray pastHiddenStates; + + // [batch, vacabSize]. Only the last logits, used to recall candidate token. + private NDArray logits; + + ContrastiveBatchTensorList(NDList list, long[] seqDimOrder) { + super(list.get(0), list.get(1), list.subNDList(4), seqDimOrder); + pastHiddenStates = list.get(2); + logits = list.get(3); + } + + ContrastiveBatchTensorList( + NDArray pastOutputIds, + NDArray pastAttentionMask, + NDArray pastHiddenStates, + NDArray logits, + NDList pastKeyValues, + long[] seqDimOrder) { + super(pastOutputIds, pastAttentionMask, pastKeyValues, seqDimOrder); + this.pastHiddenStates = pastHiddenStates; + this.logits = logits; + } + + public ContrastiveBatchTensorList() {} + + @Override + public ContrastiveBatchTensorList fromList(NDList inputList, long[] seqDimOrder) { + return new ContrastiveBatchTensorList(inputList, seqDimOrder); + } + + @Override + public NDList getList() { + // The pastOutputIds has to be the first in the output list + return new NDList( + getPastOutputIds(), + getPastAttentionMask(), + getPastHiddenStates(), + getLogits()) + .addAll(getPastKeyValues()); + } + + /** + * Gets the value of the pastHiddenStates. + * + * @return the value of pastHiddenStates + */ + public NDArray getPastHiddenStates() { + return pastHiddenStates; + } + + public void setPastHiddenStates(NDArray pastHiddenStates) { + this.pastHiddenStates = pastHiddenStates; + } + + /** + * Gets the value of the logits. + * + * @return the value of logits + */ + public NDArray getLogits() { + return logits; + } + + public void setLogits(NDArray logits) { + this.logits = logits; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java b/api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java new file mode 100644 index 00000000000..ff4d3fa3b62 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +/** GPTConfig is used to store the GPT parameters used to select different versions of GPT. */ +public class GPTConfig { + private int numAttentionHeads; + private int numLayers; + private long kvDim; + + public GPTConfig() { + numAttentionHeads = 12; + numLayers = 12; + kvDim = 64; + } + + /** + * Gets the value of the numAttentionHeads. + * + * @return the value of numAttentionHeads + */ + public int getNumAttentionHeads() { + return numAttentionHeads; + } + + /** + * Gets the value of the numLayers. + * + * @return the value of numLayers + */ + public int getNumLayers() { + return numLayers; + } + + public void setNumLayers(int numLayers) { + this.numLayers = numLayers; + } + + /** + * Gets the value of the kvDim. + * + * @return the value of kvDim + */ + public long getKvDim() { + return kvDim; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/GreedyBatchTensorList.java b/api/src/main/java/ai/djl/modality/nlp/generate/GreedyBatchTensorList.java new file mode 100644 index 00000000000..cce02f671cb --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/GreedyBatchTensorList.java @@ -0,0 +1,116 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; + +class GreedyBatchTensorList extends BatchTensorList { + // [batch, 1] + private NDArray nextInputIds; + + // [batch, seq_past + new_seq] + // The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastAttentionMask; + + /* Variables below are one time step behind the above state variables. Ie, they contain all the past sequence but excludes the time step that corresponds to the above input. */ + + // [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDArray pastOutputIds; + + // (k, v) * numLayer, + // kv: [batch, heads, seq_past, kvfeature] + // The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow. + private NDList pastKeyValues; + + GreedyBatchTensorList( + NDArray nextInputIds, + NDArray pastOutputIds, + NDList pastKeyValues, + NDArray pastAttentionMask) { + this.nextInputIds = nextInputIds; + this.pastKeyValues = pastKeyValues; + this.pastOutputIds = pastOutputIds; + this.pastAttentionMask = pastAttentionMask; + } + + public GreedyBatchTensorList() {} + + @Override + public BatchTensorList fromList(NDList inputList, long[] seqDimOrder) { + return new GreedyBatchTensorList(); + } + + @Override + public NDList getList() { + return new NDList(); + } + + /** + * Gets the value of the nextInputIds. + * + * @return the value of nextInputIds + */ + public NDArray getNextInputIds() { + return nextInputIds; + } + + public void setNextInputIds(NDArray nextInputIds) { + this.nextInputIds = nextInputIds; + } + + /** + * Gets the value of the pastAttentionMask. + * + * @return the value of pastAttentionMask + */ + @Override + public NDArray getPastAttentionMask() { + return pastAttentionMask; + } + + @Override + public void setPastAttentionMask(NDArray pastAttentionMask) { + this.pastAttentionMask = pastAttentionMask; + } + + /** + * Gets the value of the pastOutputIds. + * + * @return the value of pastOutputIds + */ + @Override + public NDArray getPastOutputIds() { + return pastOutputIds; + } + + @Override + public void setPastOutputIds(NDArray pastOutputIds) { + this.pastOutputIds = pastOutputIds; + } + + /** + * Gets the value of the pastKeyValues. + * + * @return the value of pastKeyValues + */ + @Override + public NDList getPastKeyValues() { + return pastKeyValues; + } + + @Override + public void setPastKeyValues(NDList pastKeyValues) { + this.pastKeyValues = pastKeyValues; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/LMBlock.java b/api/src/main/java/ai/djl/modality/nlp/generate/LMBlock.java new file mode 100644 index 00000000000..c58704def65 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/LMBlock.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +/** + * This is a wrapper over the model files from different sources, e.g. gpt2.pt, gpt2.onnx, etc. This + * interface is an abstraction of the causal language model, which in essence is a conditional + * probability function: p_\theta(v_t | x_{< t})}, v_t \in V, i.e. given the past tokens up to a + * certain time x_{< t}, the probability that the next token is v, taken from a vocabulary set V. + * \theta is the model's weight. This function can take an input sequence `inputIds`, whose length + * can be greater than one. In this case, the output is still p_\theta(v_i | x_{< i})}, i in + * range(|inputIds|). This means for each i, the output probability is conditional on the past + * sequence up to i. + */ +public abstract class LMBlock extends AbstractBlock { + + /** + * @param input input + * @param pastKeyValues past_key_values + * @param manager manager + * @return CausalLMOutput + */ + public abstract CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager manager); + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + // inputIds, positionIds, attentionMask + CausalLMOutput output = + forward(inputs.subNDList(0, 3), inputs.subNDList(3), inputs.getManager()); + return new NDList(output.getLogits()) + .addAll(new NDList(output.getHiddenState())) // allHiddenStates could be null + .addAll(output.getPastKeyValuesList()); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + try (NDManager manager = NDManager.newBaseManager()) { + NDArray inputIds = manager.ones(inputShapes[0], DataType.INT64); + NDArray positionIds = + manager.arange(0, inputIds.getShape().size(-1), 1, DataType.INT64) + .reshape(1, -1) + .repeat(0, inputIds.getShape().get(0)); + NDArray attentionMask = manager.ones(positionIds.getShape(), DataType.INT64); + NDList input = new NDList(inputIds, positionIds, attentionMask); + + NDList result = forwardInternal(new ParameterStore(manager, false), input, false, null); + return result.stream().map(NDArray::getShape).toArray(Shape[]::new); + } + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/LMSearch.java b/api/src/main/java/ai/djl/modality/nlp/generate/LMSearch.java new file mode 100644 index 00000000000..999b6e41990 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/LMSearch.java @@ -0,0 +1,368 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.AbstractBlock; +import ai.djl.training.ParameterStore; +import ai.djl.util.PairList; + +import java.util.function.Function; +import java.util.stream.Collectors; + +public class LMSearch extends AbstractBlock { + + private String searchName; + private SearchConfig config; + private LMBlock lmBlock; + + private NDArray positionOffset; + + public LMSearch(LMBlock lmBlock, String searchName, SearchConfig searchConfig) { + this.lmBlock = lmBlock; + this.searchName = searchName; + this.config = searchConfig; + } + + public NDArray greedySearch(NDArray inputIds) { + NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); + NDManager manager = inputIds.getManager(); + GreedyBatchTensorList searchState = + new GreedyBatchTensorList(inputIds, null, null, attentionMask); + while (true) { + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + + long pastSeqLength = + searchState.getPastOutputIds() == null + ? 0 + : searchState.getPastOutputIds().getShape().getLastDimension(); + NDList modelInput = + prepareInput( + searchState.getNextInputIds(), + searchState.getPastAttentionMask(), + pastSeqLength, + 1); + CausalLMOutput modelOutput = + lmBlock.forward(modelInput, searchState.getPastKeyValues(), manager); + + NDArray outputIds = StepGeneration.greedyStepGen(modelOutput.getLogits()); + + // Update searchState + if (searchState.getPastOutputIds() == null) { + searchState.setPastOutputIds(searchState.getNextInputIds()); + } else { + searchState.setPastOutputIds( + searchState + .getPastOutputIds() + .concat(searchState.getNextInputIds(), 1)); + } + searchState.setNextInputIds(outputIds); + searchState.setPastKeyValues(modelOutput.getPastKeyValuesList()); + searchState.setPastAttentionMask( + searchState + .getPastAttentionMask() + .concat( + manager.ones( + new Shape(inputIds.getShape().get(0), 1), + DataType.INT64), + 1)); + + // memory management + NDScope.unregister( + searchState.getNextInputIds(), + searchState.getPastAttentionMask(), + searchState.getPastOutputIds()); + NDScope.unregister(searchState.getPastKeyValues()); + } + + // Termination Criteria + // TODO: , delete the sentence and add it to result. + if (searchState.getPastOutputIds().getShape().get(1) + 1 >= config.getMaxSeqLength()) { + break; + } + } + return searchState.getPastOutputIds().concat(searchState.getNextInputIds(), 1); + } + + // https://huggingface.co/blog/introducing-csearch + public NDArray contrastiveSearch(NDArray inputIds) { + // inputIds: [batchSize, seqLength: t_init] + // attentionMask: [batchSize, pastSeq]. seq-dim-size = |past_seq| + |inputIds|. + + NDManager manager = inputIds.getManager(); + NDArray attentionMask = prepareAttentionMaskOffset(inputIds, config); + ContrastiveBatchTensorList searchState = new ContrastiveBatchTensorList(); + while (true) { + if (searchState.getPastKeyValues() == null) { + NDList modelInput = prepareInput(inputIds, attentionMask, 0, 1); + CausalLMOutput output = lmBlock.forward(modelInput, null, manager); + NDArray lastLogits = output.getLogits().get(":, -1, :"); + searchState = + new ContrastiveBatchTensorList( + inputIds, + attentionMask, + output.getHiddenState(), + lastLogits, + output.getPastKeyValuesList(), + new long[] {}); + } + + /* Contrastive search loop main part */ + // (1) candidate tokens recall; + // (2) candidate re-rank by degeneration penalty + + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + + NDArray topKIds = + searchState + .getLogits() + .topK(config.getK(), -1, true, false) + .get(1); // [batch, topK] + + // Generate model inputs and put candidates together into batch + // [batch, topK] -> [batch * [topK]] -> [[batch * [topK]], seqLength=1] + NDArray candidateInputIds = topKIds.flatten().reshape(-1, 1); + assert candidateInputIds.getDataType() == DataType.INT64 + : "inputIds datatype should be int64"; + assert candidateInputIds.getShape().getShape().length == 2 : "shape not right"; + + // [batch, heads, seq_past, feature] -> [batch * topK, head, seq_past, feature] + NDList kCopyPastKeyValues = + new NDList( + searchState.getPastKeyValues().stream() + .map(ndarray -> ndarray.repeat(0, config.getK())) + .collect(Collectors.toList())); + assert kCopyPastKeyValues.get(0).getDataType() == DataType.FLOAT32 + : "inputIds datatype should be Float32"; + + // [batch, seq_past] -> [batch * topK, seq_past] -> [batch * topK, seq_past + 1] + long numBatch = topKIds.getShape().get(0); + NDArray kCopyPastAttentionMask = + searchState.getPastAttentionMask().repeat(0, config.getK()); + kCopyPastAttentionMask = + kCopyPastAttentionMask.concat( + manager.ones( + new Shape(numBatch * config.getK(), 1), DataType.INT64), + 1); + assert kCopyPastKeyValues.get(0).getShape().get(2) + 1 + == kCopyPastAttentionMask.getShape().getLastDimension() + : "attentionMask_seq = past_seq + new_input_seq"; + + // Forward with candidates in batch input + NDList candidateModelInput = + prepareInput( + candidateInputIds, + kCopyPastAttentionMask, + searchState.getPastOutputIds().getShape().getLastDimension(), + config.getK()); + CausalLMOutput candidateOutput = + lmBlock.forward(candidateModelInput, kCopyPastKeyValues, manager); + + NDList generatedOutput = + StepGeneration.constrastiveStepGeneration( + topKIds, + searchState.getLogits(), + searchState.getPastHiddenStates(), + candidateOutput.getHiddenState(), + positionOffset, + config.getAlpha()); + + // Update searchState for next loop + searchState = + updateSearchState(searchState, candidateOutput, generatedOutput, manager); + + // Memory + NDScope.unregister( + searchState.getPastOutputIds(), + searchState.getPastAttentionMask(), + searchState.getLogits(), + searchState.getPastHiddenStates()); + NDScope.unregister(searchState.getPastKeyValues()); + } + + // TODO: , delete the sentence and add it to result. + if (searchState.getPastOutputIds().getShape().get(1) >= config.getMaxSeqLength()) { + break; + } + } + + return searchState.getPastOutputIds(); + } + + private static ContrastiveBatchTensorList updateSearchState( + ContrastiveBatchTensorList searchState, + CausalLMOutput candidateOutput, + NDList generatedOutput, + NDManager manager) { + // Update searchState for next iteration + assert candidateOutput.getLogits().getShape().get(1) == 1 + : "dimension check: here, outputLogits corresponds to inputSeq == 1"; + long numBatch = searchState.getLogits().getShape().get(0); + long logitsDim = searchState.getLogits().getShape().get(1); + long pastSeqLengthPriorUpdate = searchState.getPastOutputIds().getShape().get(1); + long numHeads = searchState.getPastKeyValues().get(0).getShape().get(1); + long kvDim = searchState.getPastKeyValues().get(0).getShape().get(3); + long hiddenDim = searchState.getPastHiddenStates().getShape().get(2); + long k = candidateOutput.getLogits().getShape().get(0) / numBatch; + + // [batch, 1] + NDArray select = generatedOutput.get(1); + NDIndex selectIndex = + new NDIndex( + "{}, {}, ...", + manager.arange(0, numBatch, 1, DataType.INT64), + select.flatten()); + + // Take from candidateOutput + // [batch, k, inputSeq=1, logitsDim] --select--> [batch, logitDim] + NDArray nextLogits = + candidateOutput.getLogits().reshape(numBatch, k, logitsDim).get(selectIndex); + + // Take from candidateOutput + // [batch * k, heads, seq_past, feature] --select--> [batch, heads, seq_past, feature] + Function fn = + ndarray -> + ndarray.reshape(numBatch, k, numHeads, pastSeqLengthPriorUpdate + 1, kvDim) + .get(selectIndex); + NDList nextPastKeyValue = + new NDList( + candidateOutput.getPastKeyValuesList().stream() + .map(fn) + .collect(Collectors.toList())); + + // To be concatenated into searchState.pastHiddenStates + // [batch * k, inputSeq=1, hiddenDim] + NDArray newHiddenState = candidateOutput.getHiddenState(); + assert newHiddenState.getManager() == manager : "possible leaky memory"; + NDArray nextPastHiddenStates = + searchState + .getPastHiddenStates() + .concat( + newHiddenState.reshape(numBatch, k, 1, hiddenDim).get(selectIndex), + 1); + + // To be concatenated into searchState.outputIds + // [batch, seq_past] + NDArray outputIds = generatedOutput.get(0); + NDArray nextOutputIds = searchState.getPastOutputIds().concat(outputIds, 1); + + // [batch, seq_past] + NDArray nextPastAttentionMask = + searchState + .getPastAttentionMask() + .concat(manager.ones(new Shape(numBatch, 1), DataType.INT64), 1); + + return new ContrastiveBatchTensorList( + nextOutputIds, + nextPastAttentionMask, + nextPastHiddenStates, + nextLogits, + nextPastKeyValue, + new long[] {}); + } + + private NDArray prepareAttentionMaskOffset(NDArray inputIds, SearchConfig config) { + // prepare attentionMask and positionOffset + // Used to initialize the search + boolean suffixPadding = config.isSuffixPadding(); + NDManager manager = inputIds.getManager(); + int numBatch = Math.toIntExact(inputIds.getShape().get(0)); + int initSeqSize = Math.toIntExact(inputIds.getShape().get(1)); + NDArray attentionMask = + manager.ones(new Shape(1, inputIds.getShape().getLastDimension()), DataType.INT64) + .reshape(1, -1) + .repeat(0, numBatch); + + // Linear search from left to find the first position that's not padTokenId. + long[][] offset = new long[numBatch][1]; + for (int i = 0; i < numBatch; i++) { + long[] aSequence = inputIds.get("{},:", i).toLongArray(); + int idx = 0; + while (idx < initSeqSize) { + if (suffixPadding && aSequence[idx] == config.getPadTokenId() + || !suffixPadding && aSequence[idx] != config.getPadTokenId()) { + break; + } + idx++; + } + attentionMask.set( + new NDIndex( + "{},{}:{}", + i, + suffixPadding ? idx : 0, + suffixPadding ? initSeqSize : idx), + 0); + if (!suffixPadding) { + offset[i][0] = idx; + } + } + positionOffset = manager.create(offset); + return attentionMask; + } + + private NDList prepareInput( + NDArray inputIds, NDArray attentionMask, long pastSeqLength, int repeat) { + // Pack the model input + NDArray positionIds = + inputIds.getManager() + .arange( + pastSeqLength, + pastSeqLength + inputIds.getShape().getLastDimension(), + 1, + DataType.INT64) + .expandDims(0) + .repeat(0, inputIds.getShape().get(0)); + + NDArray positionIdsShifted = positionIds.subi(positionOffset.repeat(0, repeat)); + positionIds = positionIdsShifted.maximum(positionIdsShifted.zerosLike()); + + return new NDList(inputIds, positionIds, attentionMask); + } + + public NDArray forward(NDArray inputIds) { + switch (searchName) { + case "greedy": + return greedySearch(inputIds); + case "contrastive": + return contrastiveSearch(inputIds); + default: + throw new IllegalArgumentException( + "searchName not correctly specified. Please choose among: {greedy, beam," + + " contrastive}"); + } + } + + /** {@inheritDoc} */ + @Override + protected NDList forwardInternal( + ParameterStore parameterStore, + NDList inputs, + boolean training, + PairList params) { + return new NDList(forward(inputs.get(0))); + } + + /** {@inheritDoc} */ + @Override + public Shape[] getOutputShapes(Shape[] inputShapes) { + return new Shape[] {}; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java b/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java new file mode 100644 index 00000000000..8169f223c63 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SearchConfig.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +public class SearchConfig { + + private int k; + private float alpha; + private int beam; + private int maxSeqLength; + private long padTokenId; + private long eosTokenId; + private boolean suffixPadding; + + /** Constructs a new ContrastiveSearchConfig object with default values. */ + public SearchConfig() { + this.k = 4; + this.alpha = 0.6f; + this.beam = 3; + this.maxSeqLength = 30; + this.eosTokenId = 50256; + this.padTokenId = 50256; + this.suffixPadding = true; + } + + /** + * Gets the value of the k. + * + * @return the value of k + */ + public int getK() { + return k; + } + + public void setK(int k) { + this.k = k; + } + + /** + * Gets the value of the alpha. + * + * @return the value of alpha + */ + public float getAlpha() { + return alpha; + } + + public void setAlpha(float alpha) { + this.alpha = alpha; + } + + /** + * Gets the value of the beam. + * + * @return the value of beam + */ + public int getBeam() { + return beam; + } + + public void setBeam(int beam) { + this.beam = beam; + } + + /** + * Gets the value of the maxSeqLength. + * + * @return the value of maxSeqLength + */ + public int getMaxSeqLength() { + return maxSeqLength; + } + + public void setMaxSeqLength(int maxSeqLength) { + this.maxSeqLength = maxSeqLength; + } + + /** + * Gets the value of the padTokenId. + * + * @return the value of padTokenId + */ + public long getPadTokenId() { + return padTokenId; + } + + public void setPadTokenId(long padTokenId) { + this.padTokenId = padTokenId; + } + + /** + * Gets the value of the eosTokenId. + * + * @return the value of eosTokenId + */ + public long getEosTokenId() { + return eosTokenId; + } + + /** + * Gets the value of the suffixPadding. + * + * @return the value of suffixPadding + */ + public boolean isSuffixPadding() { + return suffixPadding; + } + + public void setSuffixPadding(boolean suffixPadding) { + this.suffixPadding = suffixPadding; + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java new file mode 100644 index 00000000000..04d8ae0e947 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/SeqBatcher.java @@ -0,0 +1,267 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.NDScope; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.Shape; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets, +// etc), and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList. +public class SeqBatcher { + NDManager manager; + + long batchSize; + + long seqLength; + + // [batch] stores the uid and is trimmed or enhanced correspondingly to the batch. + NDArray batchUid; + + // Minheap with lazy removal + // map: batchIdx -> offset + NDArray offSets; + + // This is a struct that contains NDArrays with batch dimension + BatchTensorList data; + + // batchIndex -> seqEndPosition + private Map exitIndexEndPosition; + + SeqBatcher(BatchTensorList data, NDArray batchUid, NDArray offSets, NDManager manager) { + this.manager = manager.newSubManager(); + this.data = data; + this.batchUid = batchUid.getShape().dimension() == 2 ? batchUid : batchUid.reshape(-1, 1); + this.offSets = offSets.getShape().hashCode() == 2 ? offSets : offSets.reshape(-1, 1); + batchSize = data.getPastOutputIds().getShape().get(0); + seqLength = data.getPastOutputIds().getShape().get(1); + exitIndexEndPosition = new ConcurrentHashMap<>(); + } + + public BatchTensorList getData() { + return data; + } + + /** Add new batch. Modify the batch dimension and the left padding. */ + public void addBatch(SeqBatcher seqBatcherNew) { + merge(this, seqBatcherNew, seqLength - seqBatcherNew.seqLength); + // manager and finishedSequences stay the same; + } + + /** Merge two batchers together. Modify the batch dimension and the left padding. */ + private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) { + if (seqDelta < 0) { + SeqBatcher swapTmp = seqBatcher1; + seqBatcher1 = seqBatcher2; + seqBatcher2 = swapTmp; + seqDelta = -seqDelta; + } + + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + + NDList list1 = seqBatcher1.data.getList(); + NDList list2 = seqBatcher2.data.getList(); + NDList merged = new NDList(list1.size()); + long[] seqDimOrder = seqBatcher1.data.getSeqDimOrder(); + for (int i = 0; i < list1.size(); i++) { + NDArray batch1 = list1.get(i); + NDArray batch2 = list2.get(i); + if (seqDelta == 0) { + // no need to pad + batch1 = batch1.concat(batch2, 0); + merged.add(batch1); + continue; + } + + long[] shape1 = batch1.getShape().getShape(); + long[] shape2 = batch2.getShape().getShape(); + long padTokenId = 220; + + // Augment the larger, batch1 + long[] shapeDelta = batch1.getShape().getShape(); + shapeDelta[0] = shape2[0]; + NDArray deltaArray; + if (i == 0) { + // The outputTokenIds is padded with padTokenId + deltaArray = + manager.full(new Shape(shapeDelta), padTokenId, batch1.getDataType()); + } else { + // The rest e.g. attentionMask, kvCache, hiddenStates are padded with 0 + deltaArray = manager.zeros(new Shape(shapeDelta), batch1.getDataType()); + } + batch1 = batch1.concat(deltaArray, 0); + + // Get the ndIndex used to set the extended part of batch1 to be batch2. + NDIndex ndIndex; + // Find the ordinal number of the sequence dimension + if (seqDimOrder[i] > 0) { + // Has a valid sequence dimension + ndIndex = new NDIndex("{}:", seqBatcher1.batchSize); + int order = 1; + while (order < seqDimOrder[i]) { + ndIndex = ndIndex.addAllDim(); + order++; + } + assert seqDelta + shape2[order] == shape1[order] + : "Wrong shapes. batch1 and batch2 are not mergable"; + ndIndex = ndIndex.addSliceDim(seqDelta, shape1[order]).addEllipseDim(); + } else { + // Only batch dimension, no valid sequence dimension + ndIndex = new NDIndex("{}:, ...", seqBatcher1.batchSize); + } + + // Copy batch2 to the extended part in batch1 + batch1.set(ndIndex, batch2); + merged.add(batch1); + } + data = data.fromList(merged, data.getSeqDimOrder()); + + batchSize = seqBatcher1.batchSize + seqBatcher2.batchSize; + batchUid = seqBatcher1.batchUid.concat(seqBatcher2.batchUid, 0); + offSets = seqBatcher1.offSets.concat(seqBatcher2.offSets.addi(seqDelta), 0); + seqLength = seqBatcher1.seqLength; + + // memory + NDScope.unregister(batchUid, offSets); + NDScope.unregister(merged); + } + } + + /** + * Check which batch needs to exit, according certain criteria like EOS or maxLength. It is an + * iteration over batch and is thus also considered as batch operation. + */ + public void exitCriteria(NDArray outputIds, long maxLength, long eosTokenId) { + long[] outputIdsArray = outputIds.toLongArray(); + long[] offSetsArray = offSets.toLongArray(); + for (int i = 0; i < outputIdsArray.length; i++) { + if (seqLength - offSetsArray[i] >= maxLength || outputIdsArray[i] == eosTokenId) { + if (!exitIndexEndPosition.containsKey((long) i)) { + exitIndexEndPosition.put((long) i, seqLength); + } + } + } + } + + /** Collect the finished sequences and trim the left padding. */ + public Map collectAndTrim() { + if (exitIndexEndPosition.isEmpty()) { + return new ConcurrentHashMap<>(); + } + Map finishedSequences = new ConcurrentHashMap<>(); + + try (NDScope scope = new NDScope()) { + scope.suppressNotUsedWarning(); + // Collect the results into finishedSequences + Set exitIndices = new HashSet<>(); + for (Map.Entry entry : exitIndexEndPosition.entrySet()) { + // batchIndex -> seqEndPosition + long batchIndex = entry.getKey(); + long seqEndPosition = entry.getValue(); + long uid = batchUid.getLong(batchIndex); + long offSet = offSets.getLong(batchIndex); + NDArray output = + data.getPastOutputIds() + .get("{}, {}:{}", batchIndex, offSet, seqEndPosition); + finishedSequences.put(uid, output); + exitIndices.add(batchIndex); + + NDScope.unregister(output); + } + + // Find the batch indices of the non-finished sequences. + long[] keepIndices = new long[Math.toIntExact(batchSize) - exitIndices.size()]; + int j = 0; + for (long i = 0; i < batchSize; i++) { + if (!exitIndices.contains(i)) { + keepIndices[j++] = i; + } + } + + if (keepIndices.length == 0) { + batchUid = manager.create(new Shape(0, 1), batchUid.getDataType()); + offSets = manager.create(new Shape(0, 1), offSets.getDataType()); + data = null; + batchSize = 0; + seqLength = 0; + exitIndexEndPosition = new ConcurrentHashMap<>(); + + NDScope.unregister(batchUid, offSets); + return finishedSequences; + } + + NDIndex ndIndex = new NDIndex("{}", manager.create(keepIndices)); + batchUid = batchUid.get(ndIndex).reshape(-1, 1); + offSets = offSets.get(ndIndex).reshape(-1, 1); + long trimSeq = offSets.min(new int[] {0}).toLongArray()[0]; + offSets = offSets.subi(trimSeq); + + // Trim batch, and sequence dim if needed + NDList list = data.getList(); + NDList newList = new NDList(list.size()); + long[] seqDimOrder = data.getSeqDimOrder(); + for (int i = 0; i < list.size(); i++) { + NDArray batch = list.get(i); + if (trimSeq == 0) { + // no need to trim + ndIndex = new NDIndex("{}, ...", manager.create(keepIndices)); + newList.add(batch.get(ndIndex)); + continue; + } + + // Get the ndIndex used to keep the entries and trim the rest + // Find the ordinal number of the sequence dimension + if (seqDimOrder[i] > 0) { + // Has a valid sequence dimension + ndIndex = new NDIndex("{}", manager.create(keepIndices)); + int order = 1; + while (order < seqDimOrder[i]) { + ndIndex = ndIndex.addAllDim(); + order++; + } + ndIndex = ndIndex.addSliceDim(trimSeq, seqLength).addEllipseDim(); + } else { + // Only batch dimension, no valid sequence dimension + ndIndex = new NDIndex("{}, ...", manager.create(keepIndices)); + } + // Keep the indexed entries and trim the rest + newList.add(batch.get(ndIndex)); + } + data = data.fromList(newList, data.getSeqDimOrder()); + + batchSize -= exitIndexEndPosition.size(); + seqLength -= trimSeq; + + exitIndexEndPosition = new ConcurrentHashMap<>(); + + // memory + NDScope.unregister(newList); + NDScope.unregister(batchUid, offSets); + + return finishedSequences; + } + } + + public boolean sequenceComplete() { + return !exitIndexEndPosition.isEmpty(); + } +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java b/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java new file mode 100644 index 00000000000..bfafa822eb5 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/StepGeneration.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.modality.nlp.generate; + +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; + +public final class StepGeneration { + + private StepGeneration() {} + + public static NDList constrastiveStepGeneration( + NDArray topKIds, + NDArray logits, + NDArray contextHiddenStates, + NDArray topkHiddenStates, + NDArray offSets, + float alpha) { + /* + topKIds: [batch, topK] + attentionMask: [batch, past_seq] + logits: [batch, vocabSize] + contextHiddenStates: [batch, past_seq, dim] + topkHiddenStates: [batch*topK, seq=1, dim] + attentionMaskSlice: [batch, 2]: (startPosition, endPosition) + */ + + long batch = topKIds.getShape().get(0); + long topK = topKIds.getShape().get(1); + long hiddenDim = topkHiddenStates.getShape().getLastDimension(); + + // [batch*topK, seq=1, dim] -> [batch, topK, dim] + topkHiddenStates = topkHiddenStates.reshape(batch, topK, hiddenDim); + + // [batch, topK, dim] * [batch, past_seq, dim] -> [batch, topK, past_seq] + topkHiddenStates = topkHiddenStates.normalize(2, 2); + contextHiddenStates = contextHiddenStates.normalize(2, 2); + NDArray cosSimilarity = + topkHiddenStates.batchMatMul(contextHiddenStates.transpose(0, 2, 1)); + + // Deactivate entries (batch_idx, :, zero_attention_idx_slice) in max{cosSim} step + long[] offSetsArray = offSets.toLongArray(); + for (int i = 0; i < offSetsArray.length; i++) { + cosSimilarity.set(new NDIndex("{}, :, {}:{}", i, 0, offSetsArray[i]), -1); + } + + // [batch, topK, past_seq] -> [batch, topK] + NDArray topkScorePart1 = cosSimilarity.max(new int[] {2}); + assert topkScorePart1.getShape().getShape().length == 2 : "Wrong output size"; + // [batch, logitDim].gather([batch, topK) -> [batch, topK] + NDArray topkScorePart2 = logits.softmax(1).gather(topKIds, 1); + NDArray topkScore = topkScorePart2.muli(1 - alpha).subi(topkScorePart1.muli(alpha)); + + // [batch, topK] => [batch, 1] + NDArray select = topkScore.argMax(1); + NDIndex selectIndex = + new NDIndex( + "{}, {}, ...", + logits.getManager().arange(0, topKIds.getShape().get(0), 1, DataType.INT64), + select); + NDArray outputIds = topKIds.get(selectIndex).reshape(-1, 1); + return new NDList(outputIds, select); + } + + // TODO: add support of Einstein summation: + // a = torch.randn(batch, past_seq, dim) + // b = torch.randn(batch, topK, dim) + // result = torch.einsum('bik,bjk->bij', a, b) + + public static NDArray greedyStepGen(NDArray logits) { + // logits: [batch, seq, probDim] + assert logits.getShape().getShape().length == 3 : "unexpected input"; + logits = logits.get(":, -1, :"); + return logits.argMax(-1).expandDims(1); // [batch, vacDim] + } + + public static NDList beamStepGeneration( + NDArray lastProbs, NDArray logits, long numBatch, long numBeam) { + // [batch * beamSource, seq, probDim] -> [batch, beamSource, probDim] + NDArray allProbs = logits.get(":, -1, :").softmax(1).reshape(numBatch, numBeam, -1); + + // Argmax over the probs in the prob dimension. + // [batch, beamSource, probDim] -> [batch, beamSource, beamChild] + NDList topK = allProbs.topK(Math.toIntExact(numBeam), -1, true, false); + NDArray outputIs = topK.get(1); + NDArray stepProbs = topK.get(0); + + // Chain the probability + // [batch, beamSource] -> [batch, beamSource, 1] + lastProbs = lastProbs.reshape(numBatch, numBeam, 1); + // [batch, beamSource, beamChild] + NDArray newProbs = stepProbs.muli(lastProbs); + + // Argmax over the (beamSource * beamChild) dimension + topK = + newProbs.reshape(numBatch, numBeam * numBeam) + .topK(Math.toIntExact(numBeam), -1, true, false); + + // The select indices act on (beamSource, beamChild) dimension. Decides how the new + // generated tokenIds correspond to the past tokenIds. + // [batch, beamNew]. + NDArray select = topK.get(1); + // Act on [batch, beam, ...] dimension and the output will be [batch, beam, ...] + NDIndex selectIndex = + new NDIndex( + "{}, {}, ...", + logits.getManager() + .arange(0, numBatch, 1, DataType.INT64) + .expandDims(1) + .repeat(1, numBeam), + select); + + // [batch, beamNew] + outputIs = outputIs.reshape(numBatch, numBeam * numBeam).get(selectIndex).expandDims(2); + // [batch, beamNew] + newProbs = newProbs.reshape(numBatch, numBeam * numBeam).get(selectIndex).normalize(1, 1); + + /* During the beam selection process, some source beams are selected several times while + some source beams are not selected even once. The pastOutputs should be reselected to + have the right correspondence to the newInputIds. + */ + // [batch, beamNew] + assert select.getDataType() == DataType.INT64 : "Wrong output! Expect integer division"; + assert select.getShape().getShape().length == 2 : "Wrong size. Expect [batch, beamNew]"; + // For each batch, convert the index1 in beamSource*beamChild dimension to its index2 in + // beamSource dimension: index2 = index1 / numBeam. + long[] index = select.toLongArray(); + for (int i = 0; i < index.length; i++) { + index[i] = Math.floorDiv(index[i], numBeam); + } + NDArray sourceBeamSelected = + logits.getManager().create(index, new Shape(numBatch, numBeam)); + + return new NDList(outputIs, newProbs, sourceBeamSelected); + } + // TODO: implement pytorch floor_divide. +} diff --git a/api/src/main/java/ai/djl/modality/nlp/generate/package-info.java b/api/src/main/java/ai/djl/modality/nlp/generate/package-info.java new file mode 100644 index 00000000000..aa997f6a2b7 --- /dev/null +++ b/api/src/main/java/ai/djl/modality/nlp/generate/package-info.java @@ -0,0 +1,15 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ + +/** Contains utility classes for image manipulation. */ +package ai.djl.modality.nlp.generate; diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/GPT2PtLMBlock.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/GPT2PtLMBlock.java new file mode 100644 index 00000000000..bd4284923b5 --- /dev/null +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/GPT2PtLMBlock.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.pytorch.engine; + +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.modality.nlp.generate.GPTConfig; +import ai.djl.modality.nlp.generate.LMBlock; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.pytorch.jni.IValue; +import ai.djl.pytorch.jni.IValueUtils; + +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.Collectors; + +public class GPT2PtLMBlock extends LMBlock { + Block[] blocks; + GPTConfig config; + + public GPT2PtLMBlock(GPTConfig gptConfig, Block[] blocks) { + config = gptConfig; + this.blocks = blocks; + } + + private NDList dummyPastKeyValues(NDArray inputIds, NDManager manager) { + long numBatch = inputIds.getShape().get(0); + long kvDim = config.getKvDim(); + int numAttentionHeads = config.getNumAttentionHeads(); + int numLayers = config.getNumLayers(); + + NDArray keyOrValue = manager.zeros(new Shape(numBatch, numAttentionHeads, 1, kvDim)); + NDList output = new NDList(); + output.addAll(Collections.nCopies(2 * numLayers, keyOrValue)); + return output; + } + + /** {@inheritDoc} */ + @Override + public CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager manager) { + // inputIds, positionIds, attentionMask + long batchSize = input.get(0).getShape().get(0); + boolean flagDummyKvCach = pastKeyValues == null; + if (flagDummyKvCach) { + pastKeyValues = dummyPastKeyValues(input.get(0), manager); + NDArray attentionMask = input.get(2); + attentionMask = + manager.zeros(new Shape(batchSize, 1), DataType.INT64) + .concat(attentionMask, -1); + input = new NDList(input.get(0), input.get(1), attentionMask); + } + + IValue[] inputNative = + input.stream() + .map(object -> IValue.from((PtNDArray) object)) + .toArray(IValue[]::new); + IValue resultIValue = + ((PtSymbolBlock) blocks[0]) + .forward( + inputNative[0], + inputNative[1], + inputNative[2], + IValueUtils.toTupleIValue( + pastKeyValues, new long[] {config.getNumLayers(), 2})); + + NDList output = resultIValue.toNDList((PtNDManager) manager); + Arrays.stream(inputNative).forEach(IValue::close); + + NDArray logitsOutput = output.get(0); + NDList pastKeyValuesOutput = output.subNDList(1, config.getNumLayers() * 2 + 1); + NDArray hiddenStatesOutput = manager.zeros(new Shape(1)); + if (output.size() > config.getNumLayers() * 2 + 2) { + hiddenStatesOutput = output.subNDList(config.getNumLayers() * 2 + 2).get(0); + } + + if (flagDummyKvCach) { + NDIndex index2 = new NDIndex(":, :, 1:, ..."); + pastKeyValuesOutput = + new NDList( + pastKeyValuesOutput.stream() + .map(object -> object.get(index2)) + .collect(Collectors.toList())); + } + + return new CausalLMOutput(logitsOutput, hiddenStatesOutput, pastKeyValuesOutput); + } +} diff --git a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java index 7f48547b4d3..30d2631c0a8 100644 --- a/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java +++ b/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/engine/PtEngine.java @@ -16,7 +16,10 @@ import ai.djl.Model; import ai.djl.engine.Engine; import ai.djl.engine.EngineException; +import ai.djl.modality.nlp.generate.GPTConfig; +import ai.djl.modality.nlp.generate.LMBlock; import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; import ai.djl.nn.SymbolBlock; import ai.djl.pytorch.jni.JniUtils; import ai.djl.pytorch.jni.LibUtils; @@ -145,6 +148,16 @@ public NDManager newBaseManager(Device device) { return PtNDManager.getSystemManager().newSubManager(device); } + /** {@inheritDoc} */ + @Override + public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) { + if ("GPT2".equals(languageModel)) { + return new GPT2PtLMBlock(gptConfig, blocks); + } else { + throw new UnsupportedOperationException("Not supported."); + } + } + /** {@inheritDoc} */ @Override public GradientCollector newGradientCollector() { diff --git a/examples/src/main/java/ai/djl/examples/inference/AutoRegressiveSearch.java b/examples/src/main/java/ai/djl/examples/inference/AutoRegressiveSearch.java new file mode 100644 index 00000000000..c58dc6a43e4 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/AutoRegressiveSearch.java @@ -0,0 +1,115 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.MalformedModelException; +import ai.djl.Model; +import ai.djl.modality.nlp.generate.LMBlock; +import ai.djl.modality.nlp.generate.LMSearch; +import ai.djl.modality.nlp.generate.SearchConfig; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDManager; +import ai.djl.nn.Block; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.util.Pair; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +public final class AutoRegressiveSearch { + + LMBlock lmBlockPt; + + List modelsPt; + + public AutoRegressiveSearch() + throws ModelNotFoundException, MalformedModelException, IOException { + String[] modelUrls = {"https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2.pt.zip"}; + Pair> result = LLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2"); + lmBlockPt = (LMBlock) result.getKey(); + modelsPt = result.getValue(); + } + + public void main(String[] args) + throws ModelNotFoundException, MalformedModelException, IOException { + mainContrastivePt(args); + mainGreedyPt(args); + } + + public boolean mainContrastivePt(String[] args) { + LMBlock lmBlock = lmBlockPt; + try (NDManager manager = NDManager.newBaseManager()) { + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + config.setAlpha(0.6f); + config.setK(3); + + // [r'DeepMind Company is', + // r'Memories follow me left and right. I can'] + NDArray inputIds = + manager.create( + new long[][] { + {220, 220, 220, 220, 220, 220, 29744, 28478, 5834, 318}, + {13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460} + }); + config.setPadTokenId(220); + config.setSuffixPadding(false); + + LMSearch lmSearch; + lmSearch = new LMSearch(lmBlock, "constrastive", config); + + NDArray output = lmSearch.contrastiveSearch(inputIds); + NDArray expected = + manager.create( + new long[][] { + {284, 8494, 3716, 2761, 11, 884, 355, 1692, 1535, 11}, + {4436, 329, 257, 2910, 1332, 13, 632, 373, 257, 3487} + }); + return output.get(":, -10:").equals(expected); + } + } + + public boolean mainGreedyPt(String[] args) { + LMBlock lmBlock = lmBlockPt; + try (NDManager manager = NDManager.newBaseManager()) { + + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + + // [r'DeepMind Company is', + // r'Memories follow me left and right. I can'] + NDArray inputIds = + manager.create( + new long[][] { + {220, 220, 220, 220, 220, 220, 29744, 28478, 5834, 318}, + {13579, 1749, 1061, 502, 1364, 290, 826, 13, 314, 460} + }); + config.setPadTokenId(220); + config.setSuffixPadding(false); + + LMSearch lmSearch = new LMSearch(lmBlock, "greedy", config); + + NDArray output = lmSearch.greedySearch(inputIds); + NDArray expected = + manager.create( + new long[][] { + {389, 635, 257, 3756, 10131, 286, 6190, 9552, 8136, 329}, + {257, 6576, 13, 314, 460, 470, 3505, 262, 938, 640} + }); + return output.get(":, -10:").equals(expected); + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/GPTInference.java b/examples/src/main/java/ai/djl/examples/inference/GPTInference.java new file mode 100644 index 00000000000..7ebc66c9c7c --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/GPTInference.java @@ -0,0 +1,129 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.MalformedModelException; +import ai.djl.Model; +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.inference.Predictor; +import ai.djl.modality.nlp.generate.GPTConfig; +import ai.djl.modality.nlp.generate.LMBlock; +import ai.djl.modality.nlp.generate.LMSearch; +import ai.djl.modality.nlp.generate.SearchConfig; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.index.NDIndex; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.repository.zoo.ZooModel; +import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.TranslateException; +import ai.djl.translate.TranslatorContext; +import ai.djl.util.Pair; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +public final class GPTInference { + + private static final Logger logger = LoggerFactory.getLogger(ImageClassification.class); + + private GPTInference() { + } + + public static void main(String[] args) + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + testPt(); + } + + private static void testPt() + throws ModelNotFoundException, MalformedModelException, IOException, + TranslateException { + String[] modelUrls = { + "https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2_init.pt.zip", + }; + Pair> result = LLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2"); + LMBlock lmBlock = (LMBlock) result.getKey(); + // An adapter class lmBlock along with the lmBlock.forward call is inevitable, because, as + // shown + // in comments in L168-170, the searching code should be general rather than specific to a + // certain model. + + SearchConfig config = new SearchConfig(); + config.setMaxSeqLength(60); + + String[] input = new String[]{"DeepMind Company is"}; + try (Model model = Model.newInstance("GPT2PtGreedy")) { + // Change "greedy" to "contrastive", it will call greedy search + model.setBlock(new LMSearch(lmBlock, "greedy", config)); + + try (Predictor predictor = + model.newPredictor(new GPTTranslator());) { + // According to the last code review meeting, the translator's pre/post process only + // takes care of the tokenizer's encoding and decoding part. It's also why Zach + // proposed + // to make LMSearch inherit AbstractBlock, so that it will be wrapped in a Model and + // utilizes the translator + String output = predictor.predict(input); + + String expected = + "DeepMind Company is a global leader in the field of artificial" + + " intelligence and artificial intelligence. We are a leading provider" + + " of advanced AI solutions for the automotive industry, including the" + + " latest in advanced AI solutions for the automotive industry. We are" + + " also a leading provider of advanced AI solutions for the automotive" + + " industry, including the"; + + logger.info("{}", expected.equals(output)); + } + } + result.getValue().forEach(Model::close); + } + + private static class GPTTranslator implements NoBatchifyTranslator { + + HuggingFaceTokenizer tokenizer; + + public GPTTranslator() { + tokenizer = HuggingFaceTokenizer.newInstance("gpt2"); + } + + /** + * {@inheritDoc} + */ + @Override + public String processOutput(TranslatorContext ctx, NDList list) { + long[] output = list.singletonOrThrow().toLongArray(); + return tokenizer.decode(output); + } + + /** + * {@inheritDoc} + */ + @Override + public NDList processInput(TranslatorContext ctx, String[] input) { + Encoding encoding = tokenizer.encode(input); + long[] inputIdsLong = encoding.getIds(); + NDArray inputIds = ctx.getNDManager().create(inputIdsLong); + return new NDList(inputIds.expandDims(0)); + } + } +} diff --git a/examples/src/main/java/ai/djl/examples/inference/LLMBlock.java b/examples/src/main/java/ai/djl/examples/inference/LLMBlock.java new file mode 100644 index 00000000000..4bdf31470e0 --- /dev/null +++ b/examples/src/main/java/ai/djl/examples/inference/LLMBlock.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.MalformedModelException; +import ai.djl.Model; +import ai.djl.engine.Engine; +import ai.djl.modality.nlp.generate.CausalLMOutput; +import ai.djl.modality.nlp.generate.GPTConfig; +import ai.djl.modality.nlp.generate.LMBlock; +import ai.djl.ndarray.NDArray; +import ai.djl.ndarray.NDList; +import ai.djl.ndarray.NDManager; +import ai.djl.ndarray.types.DataType; +import ai.djl.ndarray.types.Shape; +import ai.djl.nn.Block; +import ai.djl.repository.zoo.Criteria; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.training.util.ProgressBar; +import ai.djl.util.Pair; + +import java.io.IOException; +import java.util.LinkedList; +import java.util.List; + +public final class LLMBlock { + + private LLMBlock() {} + + public static int main(String[] args) + throws ModelNotFoundException, MalformedModelException, IOException { + mainPt(); + return 0; + } + + public static Pair> getLMBlock( + String[] modelUrls, String engine, String modelName) + throws ModelNotFoundException, MalformedModelException, IOException { + Block[] blocks; + List models = new LinkedList<>(); + // modelUrl can be replaced to local model file + blocks = new Block[modelUrls.length]; + for (int i = 0; i < modelUrls.length; i++) { + Criteria criteria = + Criteria.builder() + .setTypes(NDList.class, NDList.class) + .optModelUrls(modelUrls[i]) + .optEngine(engine) + .optProgress(new ProgressBar()) + .build(); + Model model = criteria.loadModel(); + blocks[i] = model.getBlock(); + models.add(model); + } + + return new Pair<>( + // Creating a LMBlock calls GPT2PtLMBlock.java which is engine specific, whose + // package + // `pytorch-engines.main` cannot be loaded here. + Engine.getEngine(engine).newLMBlock(modelName, new GPTConfig(), blocks), models); + } + + public static void mainPt() + throws ModelNotFoundException, MalformedModelException, IOException { + String[] modelUrls = {"https://djl-misc.s3.amazonaws.com/test/models/gpt2/gpt2.pt.zip"}; + + Pair> result = LLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2"); + LMBlock generator = (LMBlock) result.getKey(); + List models = result.getValue(); + + try (NDManager manager = NDManager.newBaseManager()) { + ///////////////////////////////////////////// + // Inference without cached key_values input + ///////////////////////////////////////////// + + int[] inputArray = {40, 2883, 6155, 351, 616, 13779}; + int numBatch = 2; + + NDArray inputIds = manager.create(inputArray, new Shape(2, inputArray.length / 2)); + + NDArray positionIds = + manager.arange(0, inputIds.getShape().size(1), 1, DataType.INT64) + .reshape(1, -1) + .repeat(0, numBatch); + + NDArray attentionMask = manager.ones(positionIds.getShape()); + + CausalLMOutput outInit = + generator.forward( + new NDList(inputIds, positionIds, attentionMask), null, manager); + + ///////////////////////////////////////////// + // Inference with cached key_values input + ///////////////////////////////////////////// + + long pastSeqLen = outInit.getPastKeyValuesList().get(0).getShape().size(2); + inputIds = manager.create(new int[] {404, 403, 402, 401}, new Shape(numBatch, 2)); + positionIds = + manager.arange( + pastSeqLen, + pastSeqLen + inputIds.getShape().getLastDimension(), + 1, + DataType.INT64) + .reshape(1, -1) + .repeat(0, numBatch); + attentionMask = + manager.ones(new Shape(1, pastSeqLen + inputIds.getShape().getLastDimension())) + .reshape(1, -1) + .repeat(0, numBatch); + + generator.forward( + new NDList(inputIds, positionIds, attentionMask), + outInit.getPastKeyValuesList(), + manager); + } + models.forEach(Model::close); + } +} diff --git a/examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java b/examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java new file mode 100644 index 00000000000..166bfa333e1 --- /dev/null +++ b/examples/src/test/java/ai/djl/examples/inference/TextGenerationTest.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.examples.inference; + +import ai.djl.MalformedModelException; +import ai.djl.repository.zoo.ModelNotFoundException; +import ai.djl.testing.TestRequirements; + +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +public class TextGenerationTest { + + @Test + public void testTextGeneration() + throws ModelNotFoundException, MalformedModelException, IOException { + TestRequirements.engine("PyTorch"); + + String[] args = new String[] {}; + + // LMBlock + Assert.assertEquals(LLMBlock.main(args), 0); + + // LMSearch + AutoRegressiveSearch search = new AutoRegressiveSearch(); + Assert.assertTrue(search.mainContrastivePt(args)); + Assert.assertTrue(search.mainGreedyPt(args)); + } +}