Skip to content

Commit

Permalink
Refactor TextGenerator API
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 16, 2023
1 parent d23bec9 commit a9d516c
Show file tree
Hide file tree
Showing 19 changed files with 330 additions and 1,161 deletions.
7 changes: 0 additions & 7 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.modality.nlp.generate.GPTConfig;
import ai.djl.modality.nlp.generate.LMBlock;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
Expand Down Expand Up @@ -308,10 +305,6 @@ public SymbolBlock newSymbolBlock(NDManager manager) {
*/
public abstract NDManager newBaseManager(Device device);

public LMBlock newLMBlock(String languageModel, GPTConfig gptConfig, Block[] blocks) {
throw new UnsupportedOperationException("Not supported.");
}

/**
* Returns a new instance of {@link GradientCollector}.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,28 @@
*/
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDScope;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;

import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.Collectors;

public class ContrastiveSeqBatchScheduler extends SeqBatchScheduler {

public ContrastiveSeqBatchScheduler(LMBlock lmBlock, SearchConfig config) {
public ContrastiveSeqBatchScheduler(
Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) {
super(lmBlock, config);
}

@Override
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) {
public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) throws TranslateException {
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
manager = inputIds.getManager();
Expand All @@ -39,8 +42,7 @@ public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) {
NDArray positionIds = computePositionIds(inputIds, initOffSets, 0, 1);

CausalLMOutput output =
lmBlock.forward(
new NDList(inputIds, positionIds, attentionMask), null, manager);
predictor.predict(new NDList(inputIds, positionIds, attentionMask));
NDArray lastLogits = output.getLogits().get(":, -1, :");

// Used to mark the sequence dimension's ordinal number for each tensor in the
Expand Down Expand Up @@ -71,7 +73,7 @@ public SeqBatcher initForward(NDArray inputIds, NDArray batchUids) {
}

@Override
public NDArray inferenceCall() {
public NDArray inferenceCall() throws TranslateException {
NDArray outputIds;
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
Expand Down Expand Up @@ -116,14 +118,10 @@ public NDArray inferenceCall() {
seqBatcher.offSets,
searchState.getPastOutputIds().getShape().getLastDimension(),
config.getK());
CausalLMOutput candidateOutput =
lmBlock.forward(
new NDList(
candidateInputIds,
candidatePositionIds,
kCopyPastAttentionMask),
kCopyPastKeyValues,
manager);
NDList modelInputs =
new NDList(candidateInputIds, candidatePositionIds, kCopyPastAttentionMask);
modelInputs.addAll(kCopyPastKeyValues);
CausalLMOutput candidateOutput = predictor.predict(modelInputs);

NDList generatedOutput =
StepGeneration.constrastiveStepGeneration(
Expand Down
57 changes: 0 additions & 57 deletions api/src/main/java/ai/djl/modality/nlp/generate/GPTConfig.java

This file was deleted.

89 changes: 0 additions & 89 deletions api/src/main/java/ai/djl/modality/nlp/generate/LMBlock.java

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
*/
package ai.djl.modality.nlp.generate;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.TranslateException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -32,7 +35,7 @@
public abstract class SeqBatchScheduler {
private static final Logger logger = LoggerFactory.getLogger(SeqBatchScheduler.class);

LMBlock lmBlock;
Predictor<NDList, CausalLMOutput> predictor;
SeqBatcher seqBatcher;

NDManager manager;
Expand All @@ -41,8 +44,8 @@ public abstract class SeqBatchScheduler {

Map<Long, NDArray> results;

public SeqBatchScheduler(LMBlock lmBlock, SearchConfig config) {
this.lmBlock = lmBlock;
public SeqBatchScheduler(Predictor<NDList, CausalLMOutput> lmBlock, SearchConfig config) {
this.predictor = lmBlock;
this.config = config;
results = new ConcurrentHashMap<>();
}
Expand All @@ -52,14 +55,15 @@ public SeqBatchScheduler(LMBlock lmBlock, SearchConfig config) {
*
* @return SeqBatcher. Stores the search state and operate on the BatchTensorList.
*/
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids);
public abstract SeqBatcher initForward(NDArray inputIds, NDArray batchUids)
throws TranslateException;

/**
* Go forward for a given number of iterations.
*
* @return boolean. Indicate whether the Batch is empty.
*/
public boolean incrementForward(int count) {
public boolean incrementForward(int count) throws TranslateException {
int i = 0;
while (i++ < count) {
if (seqBatcher == null || seqBatcher.getData() == null) {
Expand All @@ -78,10 +82,10 @@ public boolean incrementForward(int count) {
return false;
}

abstract NDArray inferenceCall();
abstract NDArray inferenceCall() throws TranslateException;

/** Add new batch. */
public void addRequest(NDArray inputIds, NDArray batchUids) {
public void addRequest(NDArray inputIds, NDArray batchUids) throws TranslateException {
SeqBatcher seqBatcherNew = initForward(inputIds, batchUids);
if (seqBatcher == null) {
seqBatcher = seqBatcherNew;
Expand Down
Loading

0 comments on commit a9d516c

Please sign in to comment.