Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creates MultiDevice #2819

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 102 additions & 3 deletions api/src/main/java/ai/djl/Device.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

import ai.djl.engine.Engine;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
Expand All @@ -30,7 +35,7 @@
* @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter
* on GPU devices</a>
*/
public final class Device {
public class Device {

private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();

Expand All @@ -39,8 +44,8 @@ public final class Device {

private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");

private String deviceType;
private int deviceId;
protected String deviceType;
protected int deviceId;

/**
* Creates a {@code Device} with basic information.
Expand Down Expand Up @@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) {
return engine.defaultDevice();
}

if (deviceName.contains("+")) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need think of the following use cases:

  1. A specific device id (existing Device implementation)
  2. A continuous range of device: GPU[1-3]
  3. Arbitrary device list: GPU1;GPU3
  4. Number of device at any free device id exclusively: GPU{2}
  5. All available devices exclusively: GPU+
  6. All devices sharable: GPU*

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually have two device naming systems. One is the base system used in DJL Device.fromName(). The other is the system used in Serving getLoadOnDevices(). For example, * exists in Serving but not in DJL. The main idea seems to be that all the ones in DJL are absolute descriptions of a device and the ones in serving also contain relative ones. In that case and with your list: DJL would contain 1, 2, 3, 5 and Serving would contain 4, 6.

First, I want to talk about the structure of Device. Here, I changed it to represent anything "Device-like", either real, virtual, a combination of devices, or parts of devices. Device is now open for interpretation. I think this works very well with respect to how it opens possibilities throughout all of the API, even if many would not be supported for now. It helps a lot with multi-device usage, tensor parallel, device sharing, and distributed training. I would support having a clearer recognition of physical devices, though. Would it help to either add a function device.isPhysicalDevice() or a class PhysicalDevice extends Device?

Also for your list, you need to deal with both levels of lists of device considering tensor parallel. That is, you need something equivalent to "gpu0+gpu1;gpu2+gpu3". Which is, two workers of TP 2. I could also see {gpu0;gpu1};{gpu2;gpu3}. We also don't want to use , because it is used elsewhere. Then, would we want to have ranges like gpu[0-3/2] which would allow for TP? Also, with the current system we could still use a + without anything else even with the current system similarly to how we are using *. Both of these infer the device.

String[] split = deviceName.split("\\+");
List<Device> subDevices =
Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList());
return new MultiDevice(subDevices);
}

Matcher matcher = DEVICE_NAME.matcher(deviceName);
if (matcher.matches()) {
String deviceType = matcher.group(1);
Expand Down Expand Up @@ -214,4 +226,91 @@ public interface Type {
String CPU = "cpu";
String GPU = "gpu";
}

/** A combined {@link Device} representing the composition of multiple other devices. */
public static class MultiDevice extends Device {

List<Device> devices;

/**
* Constructs a {@link MultiDevice} with a range of new devices.
*
* @param deviceType the type of the sub-devices
* @param startInclusive the start (inclusive) of the devices range
* @param endExclusive the end (exclusive) of the devices range
*/
public MultiDevice(String deviceType, int startInclusive, int endExclusive) {
this(
IntStream.range(startInclusive, endExclusive)
.mapToObj(i -> Device.of(deviceType, i))
.collect(Collectors.toList()));
}

/**
* Constructs a {@link MultiDevice} from sub devices.
*
* @param devices the sub devices
*/
public MultiDevice(Device... devices) {
this(Arrays.asList(devices));
}

/**
* Constructs a {@link MultiDevice} from sub devices.
*
* @param devices the sub devices
*/
public MultiDevice(List<Device> devices) {
super(null, -1);
devices.sort(
Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER)
.thenComparingInt(Device::getDeviceId));
this.deviceType =
String.join(
"+",
(Iterable<String>)
() ->
devices.stream()
.map(d -> d.getDeviceType() + d.getDeviceId())
.iterator());
this.devices = devices;
}

/**
* Returns the sub devices.
*
* @return the sub devices
*/
public List<Device> getDevices() {
return devices;
}

/** {@inheritDoc} */
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
MultiDevice that = (MultiDevice) o;
return Objects.equals(devices, that.devices);
}

/** {@inheritDoc} */
@Override
public int hashCode() {
return Objects.hash(super.hashCode(), devices);
}

/** {@inheritDoc} */
@Override
public String toString() {
return deviceType + "()";
}
}
}
5 changes: 5 additions & 0 deletions api/src/main/java/ai/djl/training/ParameterStore.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package ai.djl.training;

import ai.djl.Device;
import ai.djl.Device.MultiDevice;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Parameter;
Expand Down Expand Up @@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices
this.parameterServer = parameterServer;
deviceMap.clear();
for (int i = 0; i < devices.length; ++i) {
if (devices[i] instanceof MultiDevice) {
throw new IllegalArgumentException(
"The parameter store does not support MultiDevices");
}
if (deviceMap.put(devices[i], i) != null) {
throw new IllegalArgumentException("Duplicated devices are not allowed.");
}
Expand Down
7 changes: 7 additions & 0 deletions api/src/test/java/ai/djl/DeviceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

package ai.djl;

import ai.djl.Device.MultiDevice;
import ai.djl.engine.Engine;

import org.testng.Assert;
Expand All @@ -37,6 +38,8 @@ public void testDevice() {

System.setProperty("test_key", "test");
Engine.debugEnvironment();

Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size());
}

@Test
Expand All @@ -54,5 +57,9 @@ public void testDeviceName() {
Device defaultDevice = Engine.getInstance().defaultDevice();
Assert.assertEquals(Device.fromName(""), defaultDevice);
Assert.assertEquals(Device.fromName(null), defaultDevice);

Assert.assertEquals(
Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1)));
Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3));
}
}
Loading