Skip to content

Commit

Permalink
Save wide and deep model
Browse files Browse the repository at this point in the history
Signed-off-by: David Davó <[email protected]>
  • Loading branch information
daviddavo committed Oct 6, 2024
1 parent 5829d16 commit 3df2dfe
Showing 1 changed file with 53 additions and 8 deletions.
61 changes: 53 additions & 8 deletions recommenders/models/wide_deep/wide_deep_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.
from typing import Tuple, Dict, Optional, Any
from typing import Tuple, Dict, Optional, Any, Union
from dataclasses import dataclass, field
from pathlib import Path

import numpy as np
import pandas as pd
Expand All @@ -15,7 +16,7 @@
import recommenders.utils.python_utils as pu
import recommenders.utils.torch_utils as tu

@dataclass(kw_only=True, frozen=True)
@dataclass(frozen=True)
class WideAndDeepHyperParams:
user_dim: int = 32
item_dim: int = 32
Expand Down Expand Up @@ -182,12 +183,14 @@ def __init__(
n_items: Optional[int] = None,
epochs: int = 100,
batch_size: int = 128,
loss_fn: str | nn.Module = 'mse',
loss_fn: Union[str, nn.Module] = 'mse',
optimizer: str = 'sgd',
l1: float = 0.0001,
optimizer_params: dict[str, Any] = dict(),
disable_batch_progress: bool = False,
disable_iter_progress: bool = False,
model_dir: Optional[Union[str, Path]] = None,
save_model_iter: int = -1,
prediction_col: str = DEFAULT_PREDICTION_COL,
):
self.n_users = n_users or max(train.n_users, test.n_users)
Expand Down Expand Up @@ -230,16 +233,35 @@ def __init__(
self.current_epoch = 0
self.epochs = epochs

self.model_dir = Path(model_dir) if model_dir else None
self.save_model_iter = save_model_iter
self._check_save_model()

self.train_loss_history = list()
self.test_loss_history = list()

@property
def user_col(self) -> str:
return self.train.user_col

@property
def model_path(self) -> Path:
return self.model_dir / f'wide_deep_state_{self.current_epoch:05d}.pth'

@property
def item_col(self) -> str:
return self.train.item_col

def _check_save_model(self) -> bool:
# The two conditions should be True/False at the same time
if (self.save_model_iter == -1) != (self.model_dir is None):
raise ValueError('You should set both save_model_iter and model_dir at the same time')

if self.model_dir is not None:
# Check that save works
self.save()

return True

def fit(self):
if self.current_epoch >= self.epochs:
Expand All @@ -255,6 +277,26 @@ def fit(self):
test_loss=self.test_loss_history[-1],
)

if self.save_model_iter != -1 and self.current_epoch % self.save_model_iter == 0:
self.save()

def save(self, model_path=None):
model_path = Path(model_path) if model_path else self.model_path
model_path.parent.mkdir(exist_ok=True)

torch.save(self.model.state_dict(), model_path)

def load(self, model_path=None):
if model_path is None:
print('Model path not specified, automatically loading from model dir')
model_path = max(self.model_dir.glob('*.pth'), key=lambda f: int(f.stem.split('_')[-1]))
print(' Loading', model_path)
else:
model_path = Path(model_path)

self.model.load_state_dict(torch.load(model_path))
self.current_epoch = int(model_path.stem.split('_')[-1])

def fit_step(self):
self.model.train()

Expand Down Expand Up @@ -292,9 +334,7 @@ def fit_step(self):

self.current_epoch += 1

def recommend_k_items(
self, user_ids=None, item_ids=None, top_k=10, remove_seen=True,
):
def _get_uip_cont(self, user_ids, item_ids, remove_seen: bool):
if user_ids is None:
user_ids = np.arange(1, self.n_users)
if item_ids is None:
Expand All @@ -316,8 +356,13 @@ def recommend_k_items(
cont_features = torch.from_numpy(
np.stack(uip.map(lambda x: self.train._get_continuous_features(*x)).values)
)

uip = uip.to_frame(index=False)

return uip.to_frame(index=False), cont_features

def recommend_k_items(
self, user_ids=None, item_ids=None, top_k=10, remove_seen=True,
):
uip, cont_features = self._get_uip_cont(user_ids, item_ids, remove_seen)

with torch.no_grad():
uip[self.prediction_col] = self.model(
Expand Down

0 comments on commit 3df2dfe

Please sign in to comment.