Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed May 19, 2023
1 parent 34a8105 commit bc9597f
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/translate/BatchTensorList.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration of the
// BatchTensorList represents a search state, and the NDArrays inside are updated in each iteration
// of the
// autoregressive loop.
// It is a struct consisting of NDArrays, whose first dimension is batch, and also contains
// sequence dimension (whose position in tensor's shape is specified by seqDimOrder).
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/translate/LMBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected NDList forwardInternal(
CausalLMOutput output =
forward(inputs.subList(0, 3), inputs.subNDList(3), inputs.getManager());
return new NDList(output.logits)
.addAll(output.allHiddenStates) // allHiddenStates could be null
.addAll(output.allHiddenStates) // allHiddenStates could be null
.addAll(output.pastKeyValuesList);
}

Expand Down
4 changes: 3 additions & 1 deletion api/src/main/java/ai/djl/translate/LMSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,9 @@ public NDArray forward(NDArray inputIds) {
case "contrastive":
return contrastiveSearch(inputIds);
default:
throw new IllegalArgumentException("searchName not correctly specified. Please choose among: {greedy, beam, contrastive}");
throw new IllegalArgumentException(
"searchName not correctly specified. Please choose among: {greedy, beam,"
+ " contrastive}");
}
}

Expand Down
9 changes: 6 additions & 3 deletions api/src/main/java/ai/djl/translate/SeqBatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets, etc),
// This stores the search state (BatchTensorList), the control variables (e.g. seqLength, offSets,
// etc),
// and batch operations (merge, trim, exitCriteria, etc) on BatchTensorList.
public class SeqBatcher {
NDManager manager;
Expand Down Expand Up @@ -62,9 +63,11 @@ public void addBatch(SeqBatcher seqBatcherNew) {
}

/** Merge two batchers together. Modify the batch dimension and the left padding. */
private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta) throws IllegalArgumentException {
private void merge(SeqBatcher seqBatcher1, SeqBatcher seqBatcher2, long seqDelta)
throws IllegalArgumentException {
if (seqBatcher1.seqLength < seqBatcher2.seqLength) {
throw new IllegalArgumentException("seqBatcher1.seqLength should >= seqBatcher2.seqLength.");
throw new IllegalArgumentException(
"seqBatcher1.seqLength should >= seqBatcher2.seqLength.");
}
try (NDScope scope = new NDScope()) {
scope.suppressNotUsedWarning();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import ai.djl.MalformedModelException;
import ai.djl.repository.zoo.ModelNotFoundException;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.annotations.Test;
Expand All @@ -25,7 +26,8 @@ public class TextGenerationTest {
private static final Logger logger = LoggerFactory.getLogger(TextGenerationTest.class);

@Test
public void testTextGeneration() throws ModelNotFoundException, MalformedModelException, IOException {
public void testTextGeneration()
throws ModelNotFoundException, MalformedModelException, IOException {
String[] args = new String[] {};

// LMBlock
Expand Down

0 comments on commit bc9597f

Please sign in to comment.