-
Notifications
You must be signed in to change notification settings - Fork 1
/
validate.py
91 lines (71 loc) · 2.14 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
import torch
from tqdm import tqdm
import wandb
from loggers.logs import log_predictions
from utils.utils import get_device
def validate_fn(loader, model, loss_fn, scheduler, global_metrics, label_metrics, config):
device = get_device(config)
logging.info("Validating results...")
loop = tqdm(
loader,
position=2,
leave=False,
postfix={"val_loss": 0.0},
desc="Validating Epoch: ",
)
vloss = 0.0
model.eval()
for idx, (data, targets) in enumerate(loop):
data, targets = data.to(device), targets.long().to(device)
# forward
with torch.no_grad():
predictions = model(data)
loss = loss_fn(predictions, targets)
# update tqdm loop
loop.set_postfix(val_loss=loss.item())
scheduler.step(loss.item())
# wandb logging
wandb.log({"batch validation loss": loss.item()})
vloss += loss.item()
# system logging
if config.project.epoch != config.project.num_epochs - 1:
continue
if (config.project.epoch * len(loader) + idx) % config.project.val_interval != 0:
continue
log_predictions(
data,
targets,
predictions,
global_metrics,
label_metrics,
config,
idx,
)
wandb.log({"validation loss": vloss / config.hyperparameters.batch_size})
loop.close()
return loss.item()
def early_stop_validation(loader, model, global_metrics, label_metrics, config):
device = get_device(config)
logging.info("Early stopping model...")
loop = tqdm(
loader,
position=2,
leave=False,
desc="Early Stopping: ",
)
model.eval()
for idx, (data, targets) in enumerate(loop):
data, targets = data.to(device), targets.long().to(device)
with torch.no_grad():
predictions = model(data)
log_predictions(
data,
targets,
predictions,
global_metrics,
label_metrics,
config,
idx,
)
loop.close()