Skip to content

How to not throw errors with single-class batches #309

Answered by SkafteNicki
cmcmaster1 asked this question in Q&A
Discussion options

You must be logged in to vote

Hi @cmcmaster1,
Here is how you could do your validation loop:

class MyModel(pl.LightningModule):
    def __init__(self, ...):
        super().__init__()
        ...
        self.val_metric = torchmetrics.AUROC()

    def validation_step(self, batch, batch_idx):
        x, y = batch
        yhat = model(x)
        self.metric.update(yhat, y)

    def validation_epoch_end(self, outputs):
        auroc_val = self.metric.compute()
        self.log("val_auroc", auroc_val)

Hope this helps :]

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@cmcmaster1
Comment options

Answer selected by Borda
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants