Replies: 3 comments 2 replies
-
您好,我这边没有复现出这个报错,我的环境如下:
请您参考Issue: Bug report模板给出复现环境及步骤: Describe the bug(问题描述) To Reproduce(复现步骤)
Operating environment(运行环境):
Additional context |
Beta Was this translation helpful? Give feedback.
-
哦我可能知道了,您是直接用的 文档里这个只是个示例,实际使用时是需要输入linear_feature_columns和dnn_feature_columns的,您可以参考下run_classification_criteo.py里初始化模型的用法。 |
Beta Was this translation helpful? Give feedback.
-
感谢我已按照格式更新 - 跟随着run_classification_criteo.py exmaple 同样还是报错 |
Beta Was this translation helpful? Give feedback.
-
首先非常感谢这个deepctr torch这个package 可以非常快速的试各种model 但是我在读取train好的model 使用predict这个function会报错
Operating environment(运行环境):
python version 3.8.8
torch version 1.8.1
deepctr-torch version 0.2.7
请您参考Issue: Bug report模板给出复现环境及步骤:
Describe the bug(问题描述)
使用读取存储的模型用predict这个function的时候会有error
具体的error 信息:
NotImplementedError Traceback (most recent call last)
in
----> 1 reload_model.predict(train_model_input)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/models/basemodel.py in predict(self, x, batch_size)
340 x = x_test[0].to(self.device).float()
341
--> 342 y_pred = model(x).cpu().data.numpy() # .squeeze()
343 pred_ans.append(y_pred)
344
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/models/deepfm.py in forward(self, X)
76
77 if self.use_dnn:
---> 78 dnn_input = combined_dnn_input(
79 sparse_embedding_list, dense_value_list)
80 dnn_output = self.dnn(dnn_input)
/databricks/python/lib/python3.8/site-packages/deepctr_torch/inputs.py in combined_dnn_input(sparse_embedding_list, dense_value_list)
136 return torch.flatten(torch.cat(dense_value_list, dim=-1), start_dim=1)
137 else:
--> 138 raise NotImplementedError
139
140
NotImplementedError:
To Reproduce(复现步骤)
跟着run_classification_criteo.py的example
在这之后加入save/load model的步骤 以下是在example code
新增加的那部分,其他都保持一致
torch.save(model, "test.h5")
reload_model = torch.load("test.h5")
reload_model.predict(train_model_input) ## 此处报错
Additional context
同时也试过
都是同样的报错
Beta Was this translation helpful? Give feedback.
All reactions