Source code for training.simclr_vanilla_train

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import os
import gc

from models.load_model import load_pretrained_model
from models.simclr_vanilla import (SimCLRModel, 
                                   SimCLRVanillaDataset,
                                   contrastive_loss_vanilla,
                                   contrastive_loss_vanilla_compound_aware)

[docs] def train_simclr_vanilla( root_path="/scratch/cv-course2025/group8", epochs=200, batch_size=256, learning_rate=0.0003, temperature=0.5, projection_dim=128, device=None, save_every=50, save_dir="/scratch/cv-course2025/group8/model_weights/vanilla", compound_aware=False ): """ Train vanilla SimCLR model using two augmentations of the same image, optionally with compound-aware loss that excludes same-compound negatives. Args: root_path: Path to BBBC021 dataset epochs: Number of training epochs batch_size: Batch size for training learning_rate: Learning rate for optimizer temperature: Temperature parameter for contrastive loss projection_dim: Output dimension of projection head device: Device to train on save_every: Save model every N epochs save_dir: Directory to save model weights compound_aware: If True, uses compound-aware loss that excludes same-compound negatives """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") mode_str = "compound-aware" if compound_aware else "vanilla" print(f"Training {mode_str} SimCLR on device: {device}") # Strong augmentations for SimCLR # Note: Resize is now handled in the dataset initialization # We only apply augmentations here (no resize needed) simclr_transform = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(0.6, 1.0)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomApply([ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) ], p=0.8), transforms.RandomGrayscale(p=0.2), transforms.RandomApply([ transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)) ], p=0.5), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Create dataset and dataloader dataset = SimCLRVanillaDataset(root_path=root_path, transform=simclr_transform, compound_aware=compound_aware) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=min(4, os.cpu_count()), drop_last=True ) # Load pretrained ResNet50 and create SimCLR model backbone = load_pretrained_model("base_resnet") model = SimCLRModel(backbone, projection_dim=projection_dim) # Use DataParallel for multi-GPU training num_gpus = torch.cuda.device_count() if num_gpus > 1: print(f"Using {num_gpus} GPUs for training") model = nn.DataParallel(model) model = model.to(device) # Optimizer and scheduler optimizer = optim.AdamW( model.parameters(), lr=learning_rate, weight_decay=1e-4, betas=(0.9, 0.999) ) # Create save directory os.makedirs(save_dir, exist_ok=True) # Training loop model.train() for epoch in range(epochs): epoch_loss = 0.0 num_batches = 0 progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") for batch_data in progress_bar: if compound_aware: aug1, aug2, compound_labels = batch_data else: aug1, aug2 = batch_data compound_labels = None aug1, aug2 = aug1.to(device), aug2.to(device) optimizer.zero_grad() # Forward pass _, z1 = model(aug1) # Features and projections for augmentation 1 _, z2 = model(aug2) # Features and projections for augmentation 2 # Compute contrastive loss if compound_aware: loss = contrastive_loss_vanilla_compound_aware(z1, z2, compound_labels, temperature=temperature) else: loss = contrastive_loss_vanilla(z1, z2, temperature=temperature) # Backward pass loss.backward() optimizer.step() epoch_loss += loss.item() num_batches += 1 # Update progress bar progress_bar.set_postfix({ 'Loss': f'{loss.item():.4f}', }) avg_loss = epoch_loss / num_batches print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}") # Feature collapse monitoring (every 10 epochs) if (epoch + 1) % 10 == 0: model.eval() with torch.no_grad(): # Test with small batch test_batch_data = next(iter(dataloader)) if compound_aware: aug1, _, _ = test_batch_data else: aug1, _ = test_batch_data aug1 = aug1[:16].to(device) features1, _ = model(aug1) # Feature similarity check features1_norm = F.normalize(features1, dim=1) # Pairwise similarities within batch sim_matrix = torch.mm(features1_norm, features1_norm.t()) off_diag = sim_matrix[torch.eye(sim_matrix.size(0), device=device) == 0] print(f" Feature similarity check:") print(f" Mean off-diagonal similarity: {off_diag.mean():.4f}") print(f" Std off-diagonal similarity: {off_diag.std():.4f}") if off_diag.mean() > 0.95: print(" ⚠️ WARNING: High feature similarity detected!") # Check per-dimension standard deviation feature_std_by_dim = features1.std(dim=0) low_var_dims = (feature_std_by_dim < 0.01).sum().item() print(f"Dimensions with very low variance: {low_var_dims}/{features1.shape[1]}") model.train() # Save model every save_every epochs if (epoch + 1) % save_every == 0: if isinstance(model, nn.DataParallel): backbone_state = model.module.backbone.state_dict() else: backbone_state = model.backbone.state_dict() save_path = os.path.join(save_dir, f"resnet50_simclr_vanilla_epoch_{epoch+1}.pth") torch.save(backbone_state, save_path) print(f"Model saved to {save_path}") # Save final model if isinstance(model, nn.DataParallel): backbone_state = model.module.backbone.state_dict() else: backbone_state = model.backbone.state_dict() final_save_path = os.path.join(save_dir, "resnet50_simclr_vanilla.pth") torch.save(backbone_state, final_save_path) print(f"Final model saved to {final_save_path}") return model
if __name__ == "__main__": # Train vanilla SimCLR model gc.collect() torch.cuda.empty_cache() model = train_simclr_vanilla( root_path="/scratch/cv-course2025/group8", epochs=200, batch_size=512, learning_rate=0.0006, temperature=0.3, projection_dim=128, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), save_every=100, compound_aware=True )