diff --git a/extensions/timeseries/README.md b/extensions/timeseries/README.md index 101c40c8f20..e5547833d74 100644 --- a/extensions/timeseries/README.md +++ b/extensions/timeseries/README.md @@ -136,44 +136,33 @@ Here we define how to get `TimeSeriesData` from the dataset. ```java public static class AirPassengers { - private Path path; - private AirPassengerData data; - - public AirPassengers(Path path) { - this.path = path; - prepare(); - } - - public TimeSeriesData get(NDManager manager) { - LocalDateTime start = - data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime(); - NDArray target = manager.create(data.target); - TimeSeriesData ret = new TimeSeriesData(10); - // A TimeSeriesData must contain start time and target value. - ret.setStartTime(start); - ret.setField(FieldName.TARGET, target); - return ret; + private static TimeSeriesData getTimeSeriesData(NDManager manager, URL url) throws IOException { + try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) { + AirPassengers passengers = + new GsonBuilder() + .setDateFormat("yyyy-MM") + .create() + .fromJson(reader, AirPassengers.class); + + LocalDateTime start = + passengers.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime(); + NDArray target = manager.create(passengers.target); + TimeSeriesData data = new TimeSeriesData(10); + data.setStartTime(start); + data.setField(FieldName.TARGET, target); + return data; + } } - /** prepare the file data */ - private void prepare() { - Path filePath = path.resolve("test").resolve("data.json.gz"); - try { - URL url = filePath.toUri().toURL(); - try (GZIPInputStream is = new GZIPInputStream(url.openStream())) { - Reader reader = new InputStreamReader(is); - data = - new GsonBuilder() - .setDateFormat("yyyy-MM") - .create() - .fromJson(reader, AirPassengerData.class); - } - } catch (IOException e) { - throw new IllegalArgumentException("Invalid url: " + filePath, e); + private static void saveNDArray(NDArray array) throws IOException { + Path path = Paths.get("build").resolve(array.getName() + ".npz"); + try (OutputStream os = Files.newOutputStream(path)) { + new NDList(new NDList(array)).encode(os, true); } } - private static class AirPassengerData { + private static final class AirPassengers { + Date start; float[] target; } @@ -185,43 +174,44 @@ public static class AirPassengers { In djl we need to define `Translator` to help us with data pre- and post-processing. ```java -public static float[] predict() throws IOException, TranslateException, ModelException { - Map arguments = new ConcurrentHashMap<>(); - // set parameter - arguments.put("prediction_length", 12); - arguments.put("freq", "M"); - arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false); - arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false); - arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false); - - // build translator - DeepARTranslator translator = DeepARTranslator.builder(arguments).build(); - - // create criteria + public static float[] predict() throws IOException, TranslateException, ModelException { Criteria criteria = - Criteria.builder() - .setTypes(TimeSeriesData.class, Forecast.class) - .optModelPath(Paths.get(modelUrl)) - .optTranslator(translator) - .optProgress(new ProgressBar()) - .build(); - - // load model - try (ZooModel model = criteria.loadModel(); - Predictor predictor = model.newPredictor()) { - NDManager manager = model.getNDManager(); - - AirPassengers ap = new AirPassengers(Paths.get("Not implemented")); - TimeSeriesData data = ap.get(manager); - - // prediction - Forecast forecast = predictor.predict(data); - - return forecast.mean().toFloatArray(); - } -} + Criteria.builder() + .setTypes(TimeSeriesData.class, Forecast.class) + .optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/airpassengers") + .optEngine("MXNet") + .optTranslatorFactory(new DeferredTranslatorFactory()) + .optArgument("prediction_length", 12) + .optArgument("freq", "M") + .optArgument("use_feat_dynamic_real", false) + .optArgument("use_feat_static_cat", false) + .optArgument("use_feat_static_real", false) + .optProgress(new ProgressBar()) + .build(); + + String url = "https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json"; + try (ZooModel model = criteria.loadModel(); + Predictor predictor = model.newPredictor(); + NDManager manager = NDManager.newBaseManager("MXNet")) { + TimeSeriesData input = getTimeSeriesData(manager, new URL(url)); + + // save data for plotting + NDArray target = input.get(FieldName.TARGET); + target.setName("target"); + saveNDArray(target); + + Forecast forecast = predictor.predict(input); + + // save data for plotting. Please see the corresponding python script from + // https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008 + NDArray samples = ((SampleForecast) forecast).getSortedSamples(); + samples.setName("samples"); + saveNDArray(samples); + return forecast.mean().toFloatArray(); + } ``` + ### Visualize ![simple_forecast](https://resources.djl.ai/images/timeseries/simple_forecast.png)