Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Greedy search and beam search #2557

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions api/src/main/java/ai/djl/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
package ai.djl.engine;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.SymbolBlock;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.translate.GPTConfig;
import ai.djl.translate.LMAdapter;
import ai.djl.util.Ec2Utils;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
Expand Down Expand Up @@ -302,6 +306,11 @@ public SymbolBlock newSymbolBlock(NDManager manager) {
*/
public abstract NDManager newBaseManager(Device device);

public LMAdapter newLMAdapter(String languageModel, GPTConfig gptConfig)
throws ModelNotFoundException, MalformedModelException, IOException {
throw new UnsupportedOperationException("Not supported.");
}

/**
* Returns a new instance of {@link GradientCollector}.
*
Expand Down
7 changes: 5 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,14 @@ public String toString() {

/** {@inheritDoc} */
@Override
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
public synchronized void attachInternal(String resourceId, AutoCloseable... resources) {
if (capped.get()) {
throw new IllegalStateException("NDManager is capped for addition of resources.");
}
attachUncappedInternal(resourceId, resource);
for (int i = 0; i < resources.length; i++) {
attachUncappedInternal(
resources.length == 1 ? resourceId : resourceId + "_" + i, resources[i]);
}
}

/** {@inheritDoc} */
Expand Down
9 changes: 8 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4137,7 +4137,8 @@ default NDArray argSort(int axis) {
* jshell&gt; NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell&gt; array.repeat(1, 2);
* ND: (6) cpu() float32
* [0., 0., 1., 1., 2., 2.]
* [[0., 0., 1., 1.],
* [2., 2., 3., 3.]]
* </pre>
*
* @param axis the axis to repeat
Expand Down Expand Up @@ -4544,6 +4545,12 @@ default NDArray broadcast(long... shape) {
*/
NDArray argMax(int axis);

default NDList topK(int k, int axis) {
return topK(k, axis, true, true);
}

NDList topK(int k, int axis, boolean largest, boolean sorted);

/**
* Returns the indices of the minimum values into the flattened {@code NDArray}.
*
Expand Down
6 changes: 6 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDArrayAdapter.java
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,12 @@ public NDArray argMax(int axis) {
return getAlternativeArray().argMax(axis);
}

/** {@inheritDoc} */
@Override
public NDList topK(int k, int axis, boolean largest, boolean sorted) {
throw new UnsupportedOperationException("Not implemented yet.");
}

/** {@inheritDoc} */
@Override
public NDArray argMin() {
Expand Down
21 changes: 20 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,26 @@ public NDList addAll(NDList other) {
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex) {
return new NDList(subList(fromIndex, size()));
if (fromIndex > size()) {
return null;
}
return subList(fromIndex, size());
}

/** {@inheritDoc} */
@Override
public NDArray get(int index) {
index = index + (index < 0 ? size() : 0);
return super.get(index);
}

/** {@inheritDoc} */
@Override
public NDList subList(int fromIndex, int toIndex) {
List<NDArray> subList = super.subList(fromIndex, toIndex);
NDList output = new NDList();
output.addAll(subList);
return output;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ default NDArray hanningWindow(long numPoints) {
* @param resourceId the unique resourceId
* @param resource the {@link AutoCloseable} resource to be attached
*/
void attachInternal(String resourceId, AutoCloseable resource);
void attachInternal(String resourceId, AutoCloseable... resource);

/**
* Attaches a resource to this {@code NDManager} circumventing any cap protection.
Expand Down
5 changes: 2 additions & 3 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ public long[] getShape() {
* @return the shape in the given dimension
*/
public long get(int dimension) {
dimension = dimension + (dimension < 0 ? shape.length : 0);
return shape[dimension];
}

Expand All @@ -158,9 +159,7 @@ public LayoutType getLayoutType(int dimension) {
public long size(int... dimensions) {
long total = 1;
for (long d : dimensions) {
if (d < 0 || d >= shape.length) {
throw new IllegalArgumentException("Invalid dimension " + d);
}
d = d + (d < 0 ? shape.length : 0);
if (shape[Math.toIntExact(d)] == -1) {
return -1;
}
Expand Down
13 changes: 11 additions & 2 deletions api/src/main/java/ai/djl/nn/AbstractBaseBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.Pair;
import ai.djl.util.PairList;

Expand Down Expand Up @@ -72,13 +73,17 @@ public final NDList forward(
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDManager paramsManager = parameterStore.getManager();
if (training && !isInitialized()) {
initialize(paramsManager, DataType.FLOAT32, inputs.getShapes());
initialize(parameterStore.getManager(), DataType.FLOAT32, inputs.getShapes());
}
return forwardInternal(parameterStore, inputs, training, params);
}

@Override
public final NativeResource<Long> forward(NativeResource<Long>[] inputs) {
return forwardInternal(inputs);
}

/** {@inheritDoc} */
@Override
public NDList forward(
Expand Down Expand Up @@ -109,6 +114,10 @@ protected abstract NDList forwardInternal(
boolean training,
PairList<String, Object> params);

protected NativeResource<Long> forwardInternal(NativeResource<Long>[] inputs) {
return null;
}

/**
* A helper for {@link Block#forward(ParameterStore, NDList, NDList, PairList)} after
* initialization.
Expand Down
3 changes: 3 additions & 0 deletions api/src/main/java/ai/djl/nn/Block.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.NativeResource;
import ai.djl.util.PairList;

import java.io.DataInputStream;
Expand Down Expand Up @@ -127,6 +128,8 @@ default NDList forward(ParameterStore parameterStore, NDList inputs, boolean tra
return forward(parameterStore, inputs, training, null);
}

NativeResource<Long> forward(NativeResource<Long>[] inputs);

/**
* Applies the operating function of the block once. This method should be called only on blocks
* that are initialized.
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public ZooModel<I, O> loadModel()
}
}
throw new ModelNotFoundException(
"No matching model with specified Input/Output type found.", lastException);
"No model with the specified URI or the matching Input/Output type is found.",
lastException);
}

/**
Expand Down
33 changes: 33 additions & 0 deletions api/src/main/java/ai/djl/translate/CausalLMOutput.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package ai.djl.translate;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/** CausalLMOuput is used to contain multiple output of a language model. */
public class CausalLMOutput {

// [batch, seq, feature]
// The prob. conditional on a sequence that ends at an element in seq-dim. seq-dim-size =
// |inputIds|
public NDArray logits;

// [batch, seq, dim] * (layers+1) -> take -1
// The vec. rep. of a sequence that ends at an element in seq-dim. seq-dim-size = |inputIds|
public NDList allHiddenStates;

// (k, v) * numLayer,
// kv: [batch, heads, seq_past, feature]
// The cache of past sequence. seq-dim-size == |seq_past| + |inputIds|
public NDList pastKeyValuesList;

public CausalLMOutput(NDArray logits, NDList pastKeyValues) {
this.logits = logits;
this.pastKeyValuesList = pastKeyValues;
}

public CausalLMOutput(NDArray logits, NDList... optionalOutput) {
this.logits = logits;
this.pastKeyValuesList = optionalOutput[0];
this.allHiddenStates = optionalOutput[1];
}
}
20 changes: 20 additions & 0 deletions api/src/main/java/ai/djl/translate/GPTConfig.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package ai.djl.translate;

/** GPTConfig is used to store the GPT parameters used to select different versions of GPT. */
public class GPTConfig {
public String[] modelUrls;
public int numAttentionHeads;
public int numLayers;
public long hiddenStateDim;
public long logitsDim;
public long kvDim;

public GPTConfig(String[] modelUrls) {
this.modelUrls = modelUrls;
numAttentionHeads = 12;
numLayers = 12;
hiddenStateDim = 768;
logitsDim = 50257;
kvDim = 64;
}
}
43 changes: 43 additions & 0 deletions api/src/main/java/ai/djl/translate/LMAdapter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.translate;

import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;

/**
* This is a wrapper over the model files from different sources, e.g. gpt2.pt, gpt2.onnx, etc. This
* interface is an abstraction of the causal language model, which in essence is a conditional
* probability function: p_\theta(v_t | x_{< t})}, v_t \in V, i.e. given the past tokens up to a
* certain time x_{< t}, the probability that the next token is v, taken from a vocabulary set V.
* \theta is the model's weight. This function can take an input sequence `inputIds`, whose length
* can be greater than one. In this case, the output is still p_\theta(v_i | x_{< i})}, i in
* range(|inputIds|). This means for each i, the output probability is conditional on the past
* sequence up to i.
*/
public interface LMAdapter {

/**
* @param input input
* @param pastKeyValues past_key_values
* @param manager manager
* @return CausalLMOutput
*/
default CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager manager) {
return null;
}

// /** {@inheritDoc} */
// @Override
// void close();
}
Loading