Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[timeseries] Updates timeseries README #2667

Merged
merged 1 commit into from
Jun 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 57 additions & 67 deletions extensions/timeseries/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,44 +136,33 @@ Here we define how to get `TimeSeriesData` from the dataset.
```java
public static class AirPassengers {

private Path path;
private AirPassengerData data;

public AirPassengers(Path path) {
this.path = path;
prepare();
}

public TimeSeriesData get(NDManager manager) {
LocalDateTime start =
data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
NDArray target = manager.create(data.target);
TimeSeriesData ret = new TimeSeriesData(10);
// A TimeSeriesData must contain start time and target value.
ret.setStartTime(start);
ret.setField(FieldName.TARGET, target);
return ret;
private static TimeSeriesData getTimeSeriesData(NDManager manager, URL url) throws IOException {
try (Reader reader = new InputStreamReader(url.openStream(), StandardCharsets.UTF_8)) {
AirPassengers passengers =
new GsonBuilder()
.setDateFormat("yyyy-MM")
.create()
.fromJson(reader, AirPassengers.class);

LocalDateTime start =
passengers.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
NDArray target = manager.create(passengers.target);
TimeSeriesData data = new TimeSeriesData(10);
data.setStartTime(start);
data.setField(FieldName.TARGET, target);
return data;
}
}

/** prepare the file data */
private void prepare() {
Path filePath = path.resolve("test").resolve("data.json.gz");
try {
URL url = filePath.toUri().toURL();
try (GZIPInputStream is = new GZIPInputStream(url.openStream())) {
Reader reader = new InputStreamReader(is);
data =
new GsonBuilder()
.setDateFormat("yyyy-MM")
.create()
.fromJson(reader, AirPassengerData.class);
}
} catch (IOException e) {
throw new IllegalArgumentException("Invalid url: " + filePath, e);
private static void saveNDArray(NDArray array) throws IOException {
Path path = Paths.get("build").resolve(array.getName() + ".npz");
try (OutputStream os = Files.newOutputStream(path)) {
new NDList(new NDList(array)).encode(os, true);
}
}

private static class AirPassengerData {
private static final class AirPassengers {

Date start;
float[] target;
}
Expand All @@ -185,43 +174,44 @@ public static class AirPassengers {
In djl we need to define `Translator` to help us with data pre- and post-processing.

```java
public static float[] predict() throws IOException, TranslateException, ModelException {
Map<String, Object> arguments = new ConcurrentHashMap<>();
// set parameter
arguments.put("prediction_length", 12);
arguments.put("freq", "M");
arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);

// build translator
DeepARTranslator translator = DeepARTranslator.builder(arguments).build();

// create criteria
public static float[] predict() throws IOException, TranslateException, ModelException {
Criteria<TimeSeriesData, Forecast> criteria =
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelPath(Paths.get(modelUrl))
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();

// load model
try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor()) {
NDManager manager = model.getNDManager();

AirPassengers ap = new AirPassengers(Paths.get("Not implemented"));
TimeSeriesData data = ap.get(manager);

// prediction
Forecast forecast = predictor.predict(data);

return forecast.mean().toFloatArray();
}
}
Criteria.builder()
.setTypes(TimeSeriesData.class, Forecast.class)
.optModelUrls("djl://ai.djl.mxnet/deepar/0.0.1/airpassengers")
.optEngine("MXNet")
.optTranslatorFactory(new DeferredTranslatorFactory())
.optArgument("prediction_length", 12)
.optArgument("freq", "M")
.optArgument("use_feat_dynamic_real", false)
.optArgument("use_feat_static_cat", false)
.optArgument("use_feat_static_real", false)
.optProgress(new ProgressBar())
.build();

String url = "https://resources.djl.ai/test-models/mxnet/timeseries/air_passengers.json";

try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor();
NDManager manager = NDManager.newBaseManager("MXNet")) {
TimeSeriesData input = getTimeSeriesData(manager, new URL(url));

// save data for plotting
NDArray target = input.get(FieldName.TARGET);
target.setName("target");
saveNDArray(target);

Forecast forecast = predictor.predict(input);

// save data for plotting. Please see the corresponding python script from
// https://gist.github.com/Carkham/a5162c9298bc51fec648a458a3437008
NDArray samples = ((SampleForecast) forecast).getSortedSamples();
samples.setName("samples");
saveNDArray(samples);
return forecast.mean().toFloatArray();
}
```

### Visualize

![simple_forecast](https://resources.djl.ai/images/timeseries/simple_forecast.png)
Expand Down