Skip to content

Commit

Permalink
address the comments
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 19, 2023
1 parent 4b3267b commit 2187786
Show file tree
Hide file tree
Showing 16 changed files with 69 additions and 63 deletions.
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ public SymbolBlock newSymbolBlock(NDManager manager) {
*/
public abstract NDManager newBaseManager(Device device);

public LMBlock newLMAdapter(String languageModel, GPTConfig gptConfig, Block[] blocks)
public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks)
throws ModelNotFoundException, MalformedModelException, IOException {
throw new UnsupportedOperationException("Not supported.");
}
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ public NDList addAll(NDList other) {
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex) {
if (fromIndex > size()) {
if (fromIndex >= size()) {
return null;
}
return subList(fromIndex, size());
Expand Down
5 changes: 0 additions & 5 deletions api/src/main/java/ai/djl/nn/AbstractBaseBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

Expand Down Expand Up @@ -109,10 +108,6 @@ protected abstract NDList forwardInternal(
boolean training,
PairList<String, Object> params);

protected NativeResource<Long> forwardInternal(NativeResource<Long>[] inputs) {
return null;
}

/**
* A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after
* initialization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchList represents a search state, and the NDArrays inside are updated in each iteration of the
// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
// The SeqBatcher batch operations will operate on these two dimensions.
public abstract class BatchList {
public abstract class BatchTensorList {
// [batch, seq_past]. seq-dim-size == |past_seq| + |inputIds|. Will grow.
public NDArray pastOutputIds;

Expand All @@ -24,16 +25,16 @@ public abstract class BatchList {
// Sequence dimension order among all dimensions for each element in the batch list.
public long[] seqDimOrder;

BatchList() {}
BatchTensorList() {}

BatchList(NDList list, long[] seqDimOrder) {
BatchTensorList(NDList list, long[] seqDimOrder) {
this.seqDimOrder = seqDimOrder;
pastOutputIds = list.get(0);
pastAttentionMask = list.get(1);
pastKeyValues = list.subNDList(2);
}

BatchList(
BatchTensorList(
NDArray pastOutputIds,
NDArray pastAttentionMask,
NDList pastKeyValues,
Expand All @@ -44,7 +45,7 @@ public abstract class BatchList {
this.seqDimOrder = seqDimOrder;
}

public abstract BatchList fromList(NDList inputList, long[] seqDimOrder);
public abstract BatchTensorList fromList(NDList inputList, long[] seqDimOrder);

// The pastOutputIds has to be the first in the output list
public abstract NDList getList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public SeqBatcher initForward(NDArray inputIds, NDArray batchUids, SearchConfig
seqDimOrder[3] = -1; // -1 means no sequence dimension
Arrays.fill(seqDimOrder, 4, seqDimOrder.length, 2);

BatchList batchList =
BatchTensorList batchList =
new ContrastiveBatchList(
inputIds,
attentionMask,
Expand Down Expand Up @@ -202,7 +202,7 @@ public NDArray inferenceCall() {
}
}

class ContrastiveBatchList extends BatchList {
class ContrastiveBatchList extends BatchTensorList {
// [batch, seq_past, hiddenDim]
// The embed vector of the past seq. seq-dim-size = |past_seq|. Will grow.
public NDArray pastHiddenStates;
Expand Down
18 changes: 15 additions & 3 deletions api/src/main/java/ai/djl/translate/LMBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
*/
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
Expand Down Expand Up @@ -50,14 +52,24 @@ protected NDList forwardInternal(
CausalLMOutput output =
forward(inputs.subList(0, 3), inputs.subNDList(3), inputs.getManager());
return new NDList(output.logits)
.addAll(output.allHiddenStates)
.addAll(output.allHiddenStates) // allHiddenStates could be null
.addAll(output.pastKeyValuesList);
}

/** {@inheritDoc} */
@Override
public Shape[] getOutputShapes(Shape[] inputShapes) {
// TODO: define the outputShapes of LMBlock
return null;
try (NDManager manager = NDManager.newBaseManager()) {
NDArray inputIds = manager.ones(inputShapes[0], DataType.INT64);
NDArray positionIds =
manager.arange(0, inputIds.getShape().size(-1), 1, DataType.INT64)
.reshape(1, -1)
.repeat(0, inputIds.getShape().get(0));
NDArray attentionMask = manager.ones(positionIds.getShape(), DataType.INT64);
NDList input = new NDList(inputIds, positionIds, attentionMask);

NDList result = forwardInternal(new ParameterStore(manager, false), input, false, null);
return result.stream().map(NDArray::getShape).toArray(Shape[]::new);
}
}
}
4 changes: 3 additions & 1 deletion api/src/main/java/ai/djl/translate/LMSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ public NDArray forward(NDArray inputIds) {
case "contrastive":
return contrastiveSearch(inputIds);
default:
return greedySearch(inputIds);
throw new IllegalArgumentException(
"searchName not correctly specified. Please choose among: {greedy, beam,"
+ " contrastive}");
}
}

Expand Down
7 changes: 3 additions & 4 deletions api/src/main/java/ai/djl/translate/SeqBatchScheduler.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

// This is a scheduler, serving as an API to the consumer of the systme, allowing for three major
// This is a scheduler, serving as an API to the consumer of the system, allowing for three major
// actions: initForward, addBatch, fastForward, collectResults.
// An optimal control sequence should be solved, after considering the time consumption of each
// action, the batch size and sequence length of queueing requests. Such optimal control solver
Expand All @@ -32,7 +32,7 @@ public SeqBatchScheduler(LMBlock lmBlock) {
/**
* Initialize the iteration and SeqBatcher
*
* @return SeqBatcher. Stores the search state and operate on the BatchList.
* @return SeqBatcher. Stores the search state and operate on the BatchTensorList.
*/
public abstract SeqBatcher initForward(
NDArray inputIds, NDArray batchUids, SearchConfig config);
Expand All @@ -42,10 +42,9 @@ public void incrementForward(int count) throws NullPointerException {
int i = 0;
while (i++ < count) {
if (seqBatcher == null || seqBatcher.getData() == null) {
System.out.println(
throw new IllegalArgumentException(
"seqBatcher not set. Please call addBatch. Current inference order is "
+ i);
break;
}

NDArray intermediate = inferenceCall();
Expand Down
18 changes: 12 additions & 6 deletions api/src/main/java/ai/djl/translate/SeqBatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

// This stores the search state (BatchList), the control variables (e.g. seqLength, offSets, etc),
// and batch operations (merge, trim, exitCriteria, etc) on BatchList.
// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets,
// etc),
// and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList.
public class SeqBatcher {
NDManager manager;

Expand All @@ -29,14 +30,14 @@ public class SeqBatcher {
NDArray offSets;

// This is a struct that contains NDArrays with batch dimension
BatchList data;
BatchTensorList data;

// batchIndex -> seqEndPosition
private Map<Long, Long> exitIndexEndPosition;

static long padTokenId = 220;

SeqBatcher(BatchList data, NDArray batchUid, NDArray offSets, NDManager manager) {
SeqBatcher(BatchTensorList data, NDArray batchUid, NDArray offSets, NDManager manager) {
this.manager = manager.newSubManager();
this.data = data;
this.batchUid = batchUid.getShape().dimension() == 2 ? batchUid : batchUid.reshape(-1, 1);
Expand All @@ -46,7 +47,7 @@ public class SeqBatcher {
exitIndexEndPosition = new ConcurrentHashMap<>();
}

public BatchList getData() {
public BatchTensorList getData() {
return data;
}

Expand All @@ -62,7 +63,12 @@ public void addBatch(SeqBatcher seqBatcherNew) {
}

/** Merge two batchers together. Modify the batch dimension and the left padding. */
private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) {
private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta)
throws IllegalArgumentException {
if (seqBatcher1.seqLength < seqBatcher2.seqLength) {
throw new IllegalArgumentException(
"seqBatcher1.seqLength should >= seqBatcher2.seqLength.");
}
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ public NDManager newBaseManager(Device device) {

/** {@inheritDoc} */
@Override
public LMBlock newLMAdapter(String languageModel, GPTConfig gptConfig, Block[] blocks)
public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks)
throws ModelNotFoundException, MalformedModelException, IOException {
if ("GPT2".equals(languageModel)) {
return new GPT2OrtLMBlock(gptConfig, blocks);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public NDManager newBaseManager(Device device) {

/** {@inheritDoc} */
@Override
public LMBlock newLMAdapter(String languageModel, GPTConfig gptConfig, Block[] blocks) {
public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) {
if ("GPT2".equals(languageModel)) {
return new GPT2PtLMBlock(gptConfig, blocks);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,6 @@ protected NDList forwardInternal(
return IValueUtils.forward(this, inputs, training);
}

/** {@inheritDoc} */
@Override
protected IValue forwardInternal(NativeResource<Long>[] inputs) {
if (System.getProperty("ai.djl.pytorch.graph_optimizer") != null) {
boolean setOptimizer = Boolean.getBoolean("ai.djl.pytorch.graph_optimizer");
JniUtils.setGraphExecutorOptimize(setOptimizer);
}
inputDescriptions = new PairList<>();
outputDescriptions = new PairList<>();
inputDescriptions.add("nested IValue", null);
outputDescriptions.add("nested IValue", null);

return IValueUtils.forward(this, (IValue[]) inputs);
}

/** {@inheritDoc} */
@Override
public PairList<String, Shape> describeInput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
import java.util.LinkedList;
import java.util.List;

public final class TestLMAdapter {
public final class TestLMBlock {

private static final Logger logger = LoggerFactory.getLogger(TestLMAdapter.class);
private static final Logger logger = LoggerFactory.getLogger(TestLMBlock.class);

private TestLMAdapter() {}
private TestLMBlock() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException {
Expand All @@ -65,7 +65,7 @@ public static List<Object> getLMBlock(String[] modelUrls, String engine, String
blocks[i] = model.getBlock();
models.add(model);
}
result.add(Engine.getEngine(engine).newLMAdapter(modelName, new GPTConfig(), blocks));
result.add(Engine.getEngine(engine).newLMBlock(modelName, new GPTConfig(), blocks));
result.add(models);
return result;
}
Expand All @@ -77,7 +77,7 @@ public static void mainOnnx()
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/gpt2_onnx/decoder_model_merged.onnx"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "OnnxRuntime", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "OnnxRuntime", "GPT2");
LMBlock generator = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down Expand Up @@ -146,7 +146,7 @@ public static void mainPt()
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "PyTorch", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2");
LMBlock generator = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ private TestLMSearch() {}

public static void main(String[] args)
throws ModelNotFoundException, MalformedModelException, IOException {
// mainContrastivePt(args);
// mainGreedyPt(args);
// mainBeamPt(args);
mainContrastivePt(args);
mainGreedyPt(args);
mainBeamPt(args);
mainBeamOnnx(args);
logger.info(
"Notic: with OnnxRuntime model, it doesn't take positionId yet (only attentionMask"
Expand All @@ -51,7 +51,7 @@ public static void mainContrastivePt(String[] args)
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_init_hidden.pt",
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};
List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "PyTorch", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2");
LMBlock lmBlock = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down Expand Up @@ -89,7 +89,7 @@ public static void mainGreedyPt(String[] args)
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "PyTorch", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2");
LMBlock lmBlock = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down Expand Up @@ -123,7 +123,7 @@ public static void mainBeamPt(String[] args)
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "PyTorch", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2");
LMBlock lmBlock = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down Expand Up @@ -157,7 +157,7 @@ public static void mainBeamOnnx(String[] args)
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/gpt2_onnx/decoder_model_merged.onnx"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "OnnxRuntime", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "OnnxRuntime", "GPT2");
LMBlock lmBlock = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public static void mainContrastivePt(String[] args)
"/Users/fenkexin/Desktop/tasks/HuggingFaceQa_relavant/transformer/traced_GPT2_hidden.pt"
};

List<Object> result = TestLMAdapter.getLMBlock(modelUrls, "PyTorch", "GPT2");
List<Object> result = TestLMBlock.getLMBlock(modelUrls, "PyTorch", "GPT2");
LMBlock lmBlock = (LMBlock) result.remove(0);
Object models = result.remove(0);

Expand Down
Loading

0 comments on commit 2187786

Please sign in to comment.