Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- attempt to add support for n_views >=3 #67

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@
help='Log every n steps')
parser.add_argument('--temperature', default=0.07, type=float,
help='softmax temperature (default: 0.07)')
parser.add_argument('--n-views', default=2, type=int, metavar='N',
parser.add_argument('--n-views', default=4, type=int, metavar='N',
help='Number of views for contrastive learning training.')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')


def main():
args = parser.parse_args()
assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
# assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
# check if gpu training is available
if not args.disable_cuda and torch.cuda.is_available():
args.device = torch.device('cuda')
Expand Down
69 changes: 41 additions & 28 deletions simclr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys

import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -12,48 +13,60 @@
torch.manual_seed(0)


class SimCLR(object):

def __init__(self, *args, **kwargs):
self.args = kwargs['args']
self.model = kwargs['model'].to(self.args.device)
self.optimizer = kwargs['optimizer']
self.scheduler = kwargs['scheduler']
self.writer = SummaryWriter()
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)

def info_nce_loss(self, features):
class InfoNCELoss(nn.Module):

labels = torch.cat([torch.arange(self.args.batch_size) for i in range(self.args.n_views)], dim=0)
@staticmethod
def loss_forward(features: torch.Tensor, batch_size: int, n_views: int, temperature: float):
labels = torch.cat([torch.arange(batch_size) for _ in range(n_views)], dim=0).to(features.device)
# noinspection PyUnresolvedReferences
labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
labels = labels.to(self.args.device)

features = F.normalize(features, dim=1)

similarity_matrix = torch.matmul(features, features.T)
# assert similarity_matrix.shape == (
# self.args.n_views * self.args.batch_size, self.args.n_views * self.args.batch_size)
# assert similarity_matrix.shape == labels.shape

# discard the main diagonal from both: labels and similarities matrix
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(self.args.device)
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(features.device)
labels = labels[~mask].view(labels.shape[0], -1)
similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
# assert similarity_matrix.shape == labels.shape

# select and combine multiple positives
positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
positives = similarity_matrix[labels.bool()].view(labels.shape[0] * (n_views - 1), -1)

# select only the negatives the negatives
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)
# select only the negatives
# change: copy if n_views > 2 for other positive pairs of img
negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1).repeat(n_views - 1, 1)

logits = torch.cat([positives, negatives], dim=1)
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.args.device)

logits = logits / self.args.temperature
# the idx-0 corresponding to similarity between same img from different views (positive pairs) while the
# other columns correspond to similarity between negative pairs.
# the objective is to get the feature representation such that the positive pairs have higher similarity
# (0-th column in logits) while the negative pairs (the rest of columns) have lower similairty.
# therefore the label is set to 0 and crossentropy loss is applied afterward between label and logits.
labels = torch.zeros(logits.shape[0], dtype=torch.long).to(features.device)

logits = logits / temperature
return logits, labels

def __init__(self, batch_size, n_views, temperature):
super().__init__()
self.batch_size = batch_size
self.n_views = n_views
self.temperature = temperature

def forward(self, features):
return InfoNCELoss.loss_forward(features, self.batch_size, self.n_views, self.temperature)


class SimCLR(object):

def __init__(self, *args, **kwargs):
self.args = kwargs['args']
self.model = kwargs['model'].to(self.args.device)
self.optimizer = kwargs['optimizer']
self.scheduler = kwargs['scheduler']
self.writer = SummaryWriter()
logging.basicConfig(filename=os.path.join(self.writer.log_dir, 'training.log'), level=logging.DEBUG)
self.criterion = torch.nn.CrossEntropyLoss().to(self.args.device)
self.info_nce_loss = InfoNCELoss(self.args.batch_size, self.args.n_views, self.args.temperature)

def train(self, train_loader):

scaler = GradScaler(enabled=self.args.fp16_precision)
Expand Down