Skip to content

Commit

Permalink
[timeseries] Updates timeseries README (#2667)
Browse files Browse the repository at this point in the history
Fixes: #2641
  • Loading branch information
frankfliu authored Jun 20, 2023
1 parent df06528 commit 892d8a5
Showing 1 changed file with 57 additions and 67 deletions.
124 changes: 57 additions & 67 deletions extensions/timeseries/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,44 +136,33 @@ Here we define how to get `TimeSeriesData` from the dataset.
```java
public static class AirPassengers {

private Path path;
private AirPassengerData data;

public AirPassengers(Path path) {
this.path = path;
prepare();
}

public TimeSeriesData get(NDManager manager) {
LocalDateTime start =
data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
NDArray target = manager.create(data.target);
TimeSeriesData ret = new TimeSeriesData(10);
// A TimeSeriesData must contain start time and target value.
ret.setStartTime(start);
ret.setField(FieldName.TARGET, target);
return ret;
private static TimeSeriesData getTimeSeriesData(NDManager manager, URL url) throws IOException {
try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
AirPassengers passengers =
new GsonBuilder()
.setDateFormat("yyyy-MM")
.create()
.fromJson(reader, AirPassengers.class);

LocalDateTime start =
passengers.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
NDArray target = manager.create(passengers.target);
TimeSeriesData data = new TimeSeriesData(10);
data.setStartTime(start);
data.setField(FieldName.TARGET, target);
return data;
}
}

/** prepare the file data */
private void prepare() {
Path filePath = path.resolve("test").resolve("data.json.gz");
try {
URL url = filePath.toUri().toURL();
try (GZIPInputStream is = new GZIPInputStream(url.openStream())) {
Reader reader = new InputStreamReader(is);
data =
new GsonBuilder()
.setDateFormat("yyyy-MM")
.create()
.fromJson(reader, AirPassengerData.class);
}
} catch (IOException e) {
throw new IllegalArgumentException("Invalid url: " + filePath, e);
private static void saveNDArray(NDArray array) throws IOException {
Path path = Paths.get("build").resolve(array.getName() + ".npz");
try (OutputStream os = Files.newOutputStream(path)) {
new NDList(new NDList(array)).encode(os, true);
}
}

private static class AirPassengerData {
private static final class AirPassengers {

Date start;
float[] target;
}
Expand All @@ -185,43 +174,44 @@ public static class AirPassengers {
In djl we need to define `Translator` to help us with data pre- and post-processing.

```java
public static float[] predict() throws IOException, TranslateException, ModelException {
Map<String, Object> arguments = new ConcurrentHashMap<>();
// set parameter
arguments.put("prediction_length", 12);
arguments.put("freq", "M");
arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);

// build translator
DeepARTranslator translator = DeepARTranslator.builder(arguments).build();

// create criteria
public static float[] predict() throws IOException, TranslateException, ModelException {
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelPath(Paths.get(modelUrl))
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();

// load model
try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor()) {
NDManager manager = model.getNDManager();

AirPassengers ap = new AirPassengers(Paths.get("Not implemented"));
TimeSeriesData data = ap.get(manager);

// prediction
Forecast forecast = predictor.predict(data);

return forecast.mean().toFloatArray();
}
}
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/airpassengers")
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", 12)
.optArgument("freq", "M")
.optArgument("use_feat_dynamic_real", false)
.optArgument("use_feat_static_cat", false)
.optArgument("use_feat_static_real", false)
.optProgress(new ProgressBar())
.build();

String url = "https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json";

try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager("MXNet")) {
TimeSeriesData input = getTimeSeriesData(manager, new URL(url));

// save data for plotting
NDArray target = input.get(FieldName.TARGET);
target.setName("target");
saveNDArray(target);

Forecast forecast = predictor.predict(input);

// save data for plotting. Please see the corresponding python script from
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
NDArray samples = ((SampleForecast) forecast).getSortedSamples();
samples.setName("samples");
saveNDArray(samples);
return forecast.mean().toFloatArray();
}
```

### Visualize

![simple_forecast](https://resources.djl.ai/images/timeseries/simple_forecast.png)
Expand Down

0 comments on commit 892d8a5

Please sign in to comment.