From 0bd791befe2b7e843dd28fdf0bb4259042bb0e46 Mon Sep 17 00:00:00 2001 From: Vuong Nguyen Date: Wed, 31 Jan 2024 10:24:34 +0100 Subject: [PATCH] methods/utils/st/nnst: update utils for nnst method --- augmentare/methods/utils/style_transfer/utils_nnst.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/augmentare/methods/utils/style_transfer/utils_nnst.py b/augmentare/methods/utils/style_transfer/utils_nnst.py index eeb2365..492df0a 100644 --- a/augmentare/methods/utils/style_transfer/utils_nnst.py +++ b/augmentare/methods/utils/style_transfer/utils_nnst.py @@ -283,7 +283,8 @@ def replace_features(src, ref): """ # Move style features to gpu (necessary to mostly store on cpu for gpus w/ # < 12GB of memory) - ref_flat = to_device(flatten_grid(ref)) + device = "cuda" if torch.cuda.is_available() else "cpu" + ref_flat = flatten_grid(ref).to(device) rplc = [] for j in range(src.size(0)): # How many rows of the distance matrix to compute at once, can be