diff --git a/api/src/main/java/ai/djl/Device.java b/api/src/main/java/ai/djl/Device.java index b99e5a64826..597d7d9be02 100644 --- a/api/src/main/java/ai/djl/Device.java +++ b/api/src/main/java/ai/djl/Device.java @@ -15,6 +15,7 @@ import ai.djl.engine.Engine; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Map; @@ -162,6 +163,15 @@ public boolean isGpu() { return Type.GPU.equals(deviceType); } + /** + * Returns the sub devices if present (such as a {@link MultiDevice}), otherwise this. + * + * @return the sub devices if present (such as a {@link MultiDevice}), otherwise this. + */ + public List getDevices() { + return Collections.singletonList(this); + } + /** {@inheritDoc} */ @Override public String toString() { @@ -276,11 +286,8 @@ public MultiDevice(List devices) { this.devices = devices; } - /** - * Returns the sub devices. - * - * @return the sub devices - */ + /** {@inheritDoc} */ + @Override public List getDevices() { return devices; } diff --git a/api/src/test/java/ai/djl/DeviceTest.java b/api/src/test/java/ai/djl/DeviceTest.java index 63572810875..a69a502739b 100644 --- a/api/src/test/java/ai/djl/DeviceTest.java +++ b/api/src/test/java/ai/djl/DeviceTest.java @@ -39,6 +39,7 @@ public void testDevice() { System.setProperty("test_key", "test"); Engine.debugEnvironment(); + Assert.assertEquals(1, Device.cpu().getDevices().size()); Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size()); }