Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[api] Enhancement features for LMSearch #2642

Merged
merged 3 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions api/src/main/java/ai/djl/ndarray/BaseNDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,14 @@ public String toString() {

/** {@inheritDoc} */
@Override
public synchronized void attachInternal(String resourceId, AutoCloseable resource) {
public synchronized void attachInternal(String resourceId, AutoCloseable... resources) {
if (capped.get()) {
throw new IllegalStateException("NDManager is capped for addition of resources.");
}
attachUncappedInternal(resourceId, resource);
for (int i = 0; i < resources.length; i++) {
attachUncappedInternal(
resources.length == 1 ? resourceId : resourceId + "_" + i, resources[i]);
}
}

/** {@inheritDoc} */
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDArray.java
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,8 @@ default NDArray argSort(int axis) {
* jshell&gt; NDArray array = manager.create(new float[] {0f, 1f, 2f, 3f}, new Shape(2, 2));
* jshell&gt; array.repeat(1, 2);
* ND: (6) cpu() float32
* [0., 0., 1., 1., 2., 2.]
* [[0., 0., 1., 1.],
* [2., 2., 3., 3.]]
* </pre>
*
* @param axis the axis to repeat
Expand Down
14 changes: 13 additions & 1 deletion api/src/main/java/ai/djl/ndarray/NDList.java
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,19 @@ public NDList addAll(NDList other) {
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex) {
return new NDList(subList(fromIndex, size()));
return subNDList(fromIndex, size());
}

/**
* Returns a view of the portion of this NDList between the specified fromIndex, inclusive, and
* toIndex, exclusive.
*
* @param fromIndex the start index (inclusive)
* @param toIndex the end index (exclusive)
* @return a view of the portion of this NDList
*/
public NDList subNDList(int fromIndex, int toIndex) {
return new NDList(subList(fromIndex, toIndex));
}

/**
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/ndarray/NDManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -1589,7 +1589,7 @@ default NDArray hanningWindow(long numPoints) {
* @param resourceId the unique resourceId
* @param resource the {@link AutoCloseable} resource to be attached
*/
void attachInternal(String resourceId, AutoCloseable resource);
void attachInternal(String resourceId, AutoCloseable... resource);

/**
* Attaches a resource to this {@code NDManager} circumventing any cap protection.
Expand Down
20 changes: 20 additions & 0 deletions api/src/main/java/ai/djl/ndarray/NDScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,26 @@ public static void unregister(NDArray array) {
queue.getLast().resources.remove(array);
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param arrays the array of {@link NDArray} object
*/
public static void unregister(NDArray... arrays) {
for (NDArray array : arrays) {
unregister(array);
}
}

/**
* Unregisters {@link NDArray} object from this scope.
*
* @param ndlist the {@link NDList} object
*/
public static void unregister(NDList ndlist) {
ndlist.forEach(NDScope::unregister);
}

/** {@inheritDoc} */
@Override
public void close() {
Expand Down
10 changes: 10 additions & 0 deletions api/src/main/java/ai/djl/ndarray/index/NDIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ public NDIndex addBooleanIndex(NDArray index) {
return this;
}

/**
* Appends ellipse index in the current dimension.
*
* @return the updated {@link NDIndex}
*/
public NDIndex addEllipseDim() {
ellipsisIndex = indices.size();
return this;
}

/**
* Appends a new index to get all values in the dimension.
*
Expand Down
11 changes: 11 additions & 0 deletions api/src/main/java/ai/djl/ndarray/types/Shape.java
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,17 @@ public long get(int dimension) {
return shape[dimension];
}

/**
* Returns the shape in the given dimension with possible index wrapping.
*
* @param dimension the dimension to get the shape in
* @return the shape in the given dimension
*/
public long getWrap(int dimension) {
KexinFeng marked this conversation as resolved.
Show resolved Hide resolved
dimension = dimension + (dimension < 0 ? shape.length : 0);
KexinFeng marked this conversation as resolved.
Show resolved Hide resolved
return shape[dimension];
}

/**
* Returns the layout type in the given dimension.
*
Expand Down
2 changes: 1 addition & 1 deletion api/src/main/java/ai/djl/nn/AbstractBaseBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ public final NDList forward(
NDList inputs,
boolean training,
PairList<String, Object> params) {
NDManager paramsManager = parameterStore.getManager();
KexinFeng marked this conversation as resolved.
Show resolved Hide resolved
if (training && !isInitialized()) {
NDManager paramsManager = parameterStore.getManager();
initialize(paramsManager, DataType.FLOAT32, inputs.getShapes());
}
return forwardInternal(parameterStore, inputs, training, params);
Expand Down
3 changes: 2 additions & 1 deletion api/src/main/java/ai/djl/repository/zoo/Criteria.java
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ public ZooModel<I, O> loadModel()
}
}
throw new ModelNotFoundException(
"No matching model with specified Input/Output type found.", lastException);
"No model with the specified URI or the matching Input/Output type is found.",
lastException);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ public List<NDArray> getManagedArrays() {

/** {@inheritDoc} */
@Override
public void attachInternal(String resourceId, AutoCloseable resource) {}
public void attachInternal(String resourceId, AutoCloseable... resource) {}

/** {@inheritDoc} */
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public void testPassthrough() {
Assert.assertEquals(manager.getName(), "PassthroughNDManager");
Assert.assertTrue(manager.isOpen());
Assert.assertNotNull(manager.getParentManager());
manager.attachInternal(null, null);
manager.attachInternal(null, (AutoCloseable) null);
manager.attachUncappedInternal(null, null);
manager.tempAttachInternal(null, null, null);
manager.detachInternal(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static NDList forward(PtSymbolBlock block, NDList inputs, boolean isTrain
* @param inputs the input {@link IValue}
* @return the result {@link IValue}
*/
public static IValue forward(PtSymbolBlock block, IValue... inputs) {
public static IValue forward(PtSymbolBlock block, IValue[] inputs) {
return runMethod(block, "forward", inputs);
}

Expand All @@ -79,9 +79,10 @@ public static IValue forward(PtSymbolBlock block, IValue... inputs) {
* @return the result {@link IValue}
*/
public static IValue runMethod(PtSymbolBlock block, String methodName, IValue... inputs) {
long[] handles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
KexinFeng marked this conversation as resolved.
Show resolved Hide resolved
long[] iValueHandles = Arrays.stream(inputs).mapToLong(IValue::getHandle).toArray();
return new IValue(
PyTorchLibrary.LIB.moduleRunMethod(block.getHandle(), methodName, handles, false));
PyTorchLibrary.LIB.moduleRunMethod(
block.getHandle(), methodName, iValueHandles, false));
}

private static int addToMap(
Expand Down Expand Up @@ -146,4 +147,48 @@ static Pair<IValue[], String> getInputs(NDList ndList) {
}
return new Pair<>(ret, methodName);
}

/**
* Converts ndList to IValue.
*
* @param ndList the NDList to convert
* @param dims the shape of the output
* @return the result {@link IValue}
*/
public static IValue toTupleIValue(NDList ndList, long[] dims) {
return toTupleIValueRecur(ndList, dims, 0, 0).getKey();
}

/**
* Helper function.
*
* @param ndList the NDList to convert
* @param dims the shape of the output
* @param startCount the start index of the current recursion level
* @param level the recursion level
* @return the result
*/
private static Pair<IValue, Integer> toTupleIValueRecur(
NDList ndList, long[] dims, int startCount, int level) {
if (startCount > ndList.size()) {
throw new IllegalArgumentException("startCount illegal");
}
if (dims.length - 1 == level) {
long dim = dims[level];
List<PtNDArray> vector = new ArrayList<>();
for (int i = startCount; i < startCount + dim; i++) {
vector.add((PtNDArray) ndList.get(i));
}
IValue[] output = vector.stream().map(IValue::from).toArray(IValue[]::new);
return new Pair<>(IValue.tupleFrom(output), Math.toIntExact((startCount + dim)));
}

IValue[] output = new IValue[Math.toIntExact(dims[0])];
for (int j = 0; j < dims[level]; j++) {
Pair<IValue, Integer> p = toTupleIValueRecur(ndList, dims, startCount, level + 1);
startCount = p.getValue();
output[j] = p.getKey();
}
return new Pair<>(IValue.tupleFrom(output), startCount);
}
}