Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running TimeSeries Demo, got the TranslateException: ai.djl.engine.EngineException #2641

Closed
mengzhizihu opened this issue Jun 8, 2023 · 4 comments · Fixed by #2667
Closed
Labels
bug Something isn't working

Comments

@mengzhizihu
Copy link

mengzhizihu commented Jun 8, 2023

Description

When I run the TimeSeries Demo, I got an exception called "Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(torch.gluonts.torch.model.deepar.module.DeepARModel self, Tensor feat_static_cat, Tensor feat_static_real, Tensor past_time_feat, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor"

Expected Behavior

(what's the expected behavior?)
Expected that I can get a successful running result

Error Message

(Paste the complete error message, including stack trace.)
Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(torch.gluonts.torch.model.deepar.module.DeepARModel self, Tensor feat_static_cat, Tensor feat_static_real, Tensor past_time_feat, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor
at ai.djl.inference.Predictor.batchPredict(Predictor.java:189)
at ai.djl.inference.Predictor.predict(Predictor.java:126)
at com.demo.djl.TimeSeriesTest.predict(TimeSeriesTest.java:68)
at com.demo.djl.TimeSeriesTest.main(TimeSeriesTest.java:34)
Caused by: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(torch.gluonts.torch.model.deepar.module.DeepARModel self, Tensor feat_static_cat, Tensor feat_static_real, Tensor past_time_feat, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor
at ai.djl.pytorch.jni.PyTorchLibrary.moduleRunMethod(Native Method)
at ai.djl.pytorch.jni.IValueUtils.forward(IValueUtils.java:53)
at ai.djl.pytorch.engine.PtSymbolBlock.forwardInternal(PtSymbolBlock.java:145)
at ai.djl.nn.AbstractBaseBlock.forward(AbstractBaseBlock.java:79)
at ai.djl.nn.Block.forward(Block.java:127)
at ai.djl.inference.Predictor.predictInternal(Predictor.java:140)
at ai.djl.inference.Predictor.batchPredict(Predictor.java:180)
... 3 more

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

package com.demo.djl;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.translator.DeepARTranslator;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.google.gson.GsonBuilder;

import java.io.IOException;
import java.io.InputStreamReader;
import java.io.Reader;
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Date;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.GZIPInputStream;

public class TimeSeriesTest {

    public static void main(String[] args) throws TranslateException, ModelException, IOException {
        predict();
    }

    public static float[] predict() throws IOException, TranslateException, ModelException {
        Map<String, Object> arguments = new ConcurrentHashMap<>();
        // set parameter
        arguments.put("prediction_length", 12);
        arguments.put("freq", "M");
        arguments.put("use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
        arguments.put("use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
        arguments.put("use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);

        // build translator
        DeepARTranslator translator = DeepARTranslator.builder(arguments).build();

        // create criteria
        Criteria<TimeSeriesData, Forecast> criteria =
                Criteria.builder()
                        .optApplication(Application.TimeSeries.FORECASTING)
                        .setTypes(TimeSeriesData.class, Forecast.class)
//                        .optModelPath(Paths.get(modelUrl))
                        .optTranslator(translator)
                        .optProgress(new ProgressBar())
                        .build();

        // load model
        try (ZooModel<TimeSeriesData, Forecast> model = criteria.loadModel();
             Predictor<TimeSeriesData, Forecast> predictor = model.newPredictor()) {
            NDManager manager = model.getNDManager();

            AirPassengers ap = new AirPassengers(Paths.get("Not implemented"));
            TimeSeriesData data = ap.get(manager);

            // prediction
            Forecast forecast = predictor.predict(data);

            return forecast.mean().toFloatArray();
        }
    }

    public static class AirPassengers {

        private Path path;
        private AirPassengerData data;

        public AirPassengers(Path path) {
            this.path = path;
            prepare();
        }

        public TimeSeriesData get(NDManager manager) {
            LocalDateTime start =
                    data.start.toInstant().atZone(ZoneId.systemDefault()).toLocalDateTime();
            NDArray target = manager.create(data.target);
            TimeSeriesData ret = new TimeSeriesData(10);
            // A TimeSeriesData must contain start time and target value.
            ret.setStartTime(start);
            ret.setField(FieldName.TARGET, target);
            return ret;
        }


        private void prepare(){
            String airPassangers = "{\"start\":\"1949-01\",\"target\":[112.0,118.0,132.0,129.0,121.0,135.0,148.0,148.0,136.0,119.0,104.0,118.0,115.0,126.0,141.0,135.0,125.0,149.0,170.0,170.0,158.0,133.0,114.0,140.0,145.0,150.0,178.0,163.0,172.0,178.0,199.0,199.0,184.0,162.0,146.0,166.0,171.0,180.0,193.0,181.0,183.0,218.0,230.0,242.0,209.0,191.0,172.0,194.0,196.0,196.0,236.0,235.0,229.0,243.0,264.0,272.0,237.0,211.0,180.0,201.0,204.0,188.0,235.0,227.0,234.0,264.0,302.0,293.0,259.0,229.0,203.0,229.0,242.0,233.0,267.0,269.0,270.0,315.0,364.0,347.0,312.0,274.0,237.0,278.0,284.0,277.0,317.0,313.0,318.0,374.0,413.0,405.0,355.0,306.0,271.0,306.0,315.0,301.0,356.0,348.0,355.0,422.0,465.0,467.0,404.0,347.0,305.0,336.0,340.0,318.0,362.0,348.0,363.0,435.0,491.0,505.0,404.0,359.0,310.0,337.0,360.0,342.0,406.0,396.0,420.0,472.0,548.0,559.0,463.0,407.0,362.0,405.0,417.0,391.0,419.0,461.0,472.0,535.0,622.0,606.0,508.0,461.0,390.0,432.0]}";
            data =
                    new GsonBuilder()
                            .setDateFormat("yyyy-MM")
                            .create()
                            .fromJson(airPassangers, AirPassengerData.class);

        }

        private static class AirPassengerData {
            Date start;
            float[] target;
        }
    }

}

Steps to reproduce

(Paste the commands you ran that produced the error.)

1.run TimeSeriesTest.main()

What have you tried to solve it?

1.upgrade the pytorch version to 1.12.1

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <djl.version>0.21.0</djl.version>
    </properties>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>11</source>
                    <target>11</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>

    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-model-zoo</artifactId>
            <version>${djl.version}</version>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
            <version>${djl.version}</version>
            <scope>runtime</scope>
        </dependency>
<!--        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.9.1</version>
        </dependency>-->
        <dependency>
            <groupId>ai.djl.timeseries</groupId>
            <artifactId>timeseries</artifactId>
            <version>0.22.1</version>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>1.2.6</version>
        </dependency>

        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-core</artifactId>
            <version>1.2.6</version>
        </dependency>

        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>

        <dependency>
            <groupId>commons-cli</groupId>
            <artifactId>commons-cli</artifactId>
            <version>1.5.0</version>
        </dependency>

    </dependencies>
@mengzhizihu mengzhizihu added the bug Something isn't working label Jun 8, 2023
@mengzhizihu mengzhizihu changed the title Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: Expected at most 7 argument(s) for operator 'forward', but received 8 argument(s). Declaration: forward(__torch__.gluonts.torch.model.deepar.module.DeepARModel self, Tensor feat_static_cat, Tensor feat_static_real, Tensor past_time_feat, Tensor past_target, Tensor past_observed_values, Tensor future_time_feat) -> Tensor Running TimeSeries Demo, got the TranslateException: ai.djl.engine.EngineException Jun 8, 2023
@mengzhizihu
Copy link
Author

mengzhizihu commented Jun 12, 2023

I just replace the json resolving code because I don't have the "data.json.gz" file.

origin:

 /** prepare the file data */
//        private void prepare() {
//            Path filePath = path.resolve("test").resolve("data.json.gz");
//            try {
//                URL url = filePath.toUri().toURL();
//                try (GZIPInputStream is = new GZIPInputStream(url.openStream())) {
//                    Reader reader = new InputStreamReader(is);
//                    data =
//                            new GsonBuilder()
//                                    .setDateFormat("yyyy-MM")
//                                    .create()
//                                    .fromJson(reader, AirPassengerData.class);
//                }
//            } catch (IOException e) {
//                throw new IllegalArgumentException("Invalid url: " + filePath, e);
//            }
//        }

replace:

        private void prepare(){
            String airPassangers = "{\"start\":\"1949-01\",\"target\":[112.0,118.0,132.0,129.0,121.0,135.0,148.0,148.0,136.0,119.0,104.0,118.0,115.0,126.0,141.0,135.0,125.0,149.0,170.0,170.0,158.0,133.0,114.0,140.0,145.0,150.0,178.0,163.0,172.0,178.0,199.0,199.0,184.0,162.0,146.0,166.0,171.0,180.0,193.0,181.0,183.0,218.0,230.0,242.0,209.0,191.0,172.0,194.0,196.0,196.0,236.0,235.0,229.0,243.0,264.0,272.0,237.0,211.0,180.0,201.0,204.0,188.0,235.0,227.0,234.0,264.0,302.0,293.0,259.0,229.0,203.0,229.0,242.0,233.0,267.0,269.0,270.0,315.0,364.0,347.0,312.0,274.0,237.0,278.0,284.0,277.0,317.0,313.0,318.0,374.0,413.0,405.0,355.0,306.0,271.0,306.0,315.0,301.0,356.0,348.0,355.0,422.0,465.0,467.0,404.0,347.0,305.0,336.0,340.0,318.0,362.0,348.0,363.0,435.0,491.0,505.0,404.0,359.0,310.0,337.0,360.0,342.0,406.0,396.0,420.0,472.0,548.0,559.0,463.0,407.0,362.0,405.0,417.0,391.0,419.0,461.0,472.0,535.0,622.0,606.0,508.0,461.0,390.0,432.0]}";
            data =
                    new GsonBuilder()
                            .setDateFormat("yyyy-MM")
                            .create()
                            .fromJson(airPassangers, AirPassengerData.class);

        }

@KexinFeng
Copy link
Contributor

@mengzhizihu
Copy link
Author

That works. Thxs a lot. It would be great if you could update the official demos on your website because it really confused me a bit time.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants