Skip to content

Commit

Permalink
LMSearchOnPt
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 9, 2023
1 parent 9d7737c commit a5f322f
Show file tree
Hide file tree
Showing 18 changed files with 1,979 additions and 0 deletions.
7 changes: 7 additions & 0 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.modality.nlp.generate.GPTConfig;
import ai.djl.modality.nlp.generate.LMBlock;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
Expand Down Expand Up @@ -302,6 +305,10 @@ public SymbolBlock newSymbolBlock(NDManager manager) {
*/
public abstract NDManager newBaseManager(Device device);

public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) {
throw new UnsupportedOperationException("Not supported.");
}

/**
* Returns a new instance of {@link GradientCollector}.
*
Expand Down
111 changes: 111 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/BatchTensorList.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
// The SeqBatcher batch operations will operate on these two dimensions.
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastOutputIds;

// [batch, seq_past]
// The cache of past attentionMask. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDArray pastAttentionMask;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, kvfeature]
// The cache of past sequence. seq-dim-size == |past_seq| + |inputIds|. Will grow.
private NDList pastKeyValues;

// Sequence dimension order among all dimensions for each element in the batch list.
private long[] seqDimOrder;

BatchTensorList() {}

BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
long[] seqDimOrder) {
this.pastKeyValues = pastKeyValues;
this.pastOutputIds = pastOutputIds;
this.pastAttentionMask = pastAttentionMask;
this.seqDimOrder = seqDimOrder;
}

public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

// The pastOutputIds has to be the first in the output list
public abstract NDList getList();

public long[] getSeqDimOrder() {
return seqDimOrder;
}

/**
* Gets the value of the pastOutputIds.
*
* @return the value of pastOutputIds
*/
public NDArray getPastOutputIds() {
return pastOutputIds;
}

public void setPastOutputIds(NDArray pastOutputIds) {
this.pastOutputIds = pastOutputIds;
}

/**
* Gets the value of the pastAttentionMask.
*
* @return the value of pastAttentionMask
*/
public NDArray getPastAttentionMask() {
return pastAttentionMask;
}

public void setPastAttentionMask(NDArray pastAttentionMask) {
this.pastAttentionMask = pastAttentionMask;
}

/**
* Gets the value of the pastKeyValues.
*
* @return the value of pastKeyValues
*/
public NDList getPastKeyValues() {
return pastKeyValues;
}

public void setPastKeyValues(NDList pastKeyValues) {
this.pastKeyValues = pastKeyValues;
}

public void setSeqDimOrder(long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
}
}
76 changes: 76 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/CausalLMOutput.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** CausalLMOuput is used to contain multiple output of a language model. */
public class CausalLMOutput {

// [batch, seq, feature]
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size =
// |inputIds|
private NDArray logits;

// [batch, seq, dim] * (layers+1) -> take -1
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds|
private NDArray hiddenStates;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, feature]
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
private NDList pastKeyValuesList;

public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValues;
}

public CausalLMOutput(NDArray logits, NDArray hiddenState, NDList pastKeyValueList) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValueList;
this.hiddenStates = hiddenState;
}

/**
* Gets the value of the logits.
*
* @return the value of logits
*/
public NDArray getLogits() {
return logits;
}

public void setLogits(NDArray logits) {
this.logits = logits;
}

/**
* Gets the value of the allHiddenStates.
*
* @return the value of allHiddenStates
*/
public NDArray getHiddenState() {
return hiddenStates;
}

/**
* Gets the value of the pastKeyValuesList.
*
* @return the value of pastKeyValuesList
*/
public NDList getPastKeyValuesList() {
return pastKeyValuesList;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

class ContrastiveBatchTensorList extends BatchTensorList {
// [batch, seq_past, hiddenDim]
// The embed vector of the past seq. seq-dim-size = |past_seq|. Will grow.
private NDArray pastHiddenStates;

// [batch, vacabSize]. Only the last logits, used to recall candidate token.
private NDArray logits;

ContrastiveBatchTensorList(NDList list, long[] seqDimOrder) {
super(list.get(0), list.get(1), list.subNDList(4), seqDimOrder);
pastHiddenStates = list.get(2);
logits = list.get(3);
}

ContrastiveBatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDArray pastHiddenStates,
NDArray logits,
NDList pastKeyValues,
long[] seqDimOrder) {
super(pastOutputIds, pastAttentionMask, pastKeyValues, seqDimOrder);
this.pastHiddenStates = pastHiddenStates;
this.logits = logits;
}

public ContrastiveBatchTensorList() {}

@Override
public ContrastiveBatchTensorList fromList(NDList inputList, long[] seqDimOrder) {
return new ContrastiveBatchTensorList(inputList, seqDimOrder);
}

@Override
public NDList getList() {
// The pastOutputIds has to be the first in the output list
return new NDList(
getPastOutputIds(),
getPastAttentionMask(),
getPastHiddenStates(),
getLogits())
.addAll(getPastKeyValues());
}

/**
* Gets the value of the pastHiddenStates.
*
* @return the value of pastHiddenStates
*/
public NDArray getPastHiddenStates() {
return pastHiddenStates;
}

public void setPastHiddenStates(NDArray pastHiddenStates) {
this.pastHiddenStates = pastHiddenStates;
}

/**
* Gets the value of the logits.
*
* @return the value of logits
*/
public NDArray getLogits() {
return logits;
}

public void setLogits(NDArray logits) {
this.logits = logits;
}
}
57 changes: 57 additions & 0 deletions api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.modality.nlp.generate;

/** GPTConfig is used to store the GPT parameters used to select different versions of GPT. */
public class GPTConfig {
private int numAttentionHeads;
private int numLayers;
private long kvDim;

public GPTConfig() {
numAttentionHeads = 12;
numLayers = 12;
kvDim = 64;
}

/**
* Gets the value of the numAttentionHeads.
*
* @return the value of numAttentionHeads
*/
public int getNumAttentionHeads() {
return numAttentionHeads;
}

/**
* Gets the value of the numLayers.
*
* @return the value of numLayers
*/
public int getNumLayers() {
return numLayers;
}

public void setNumLayers(int numLayers) {
this.numLayers = numLayers;
}

/**
* Gets the value of the kvDim.
*
* @return the value of kvDim
*/
public long getKvDim() {
return kvDim;
}
}
Loading

0 comments on commit a5f322f

Please sign in to comment.