Source code for training.wsdino_resnet_train

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
import os
from pybbbc import BBBC021
from models.wsdino_resnet import (
    BBBC021WeakLabelDataset,
    MultiCropTransform,
    get_resnet50,
    SimpleDINOLoss,
    update_teacher
)

[docs] def train_wsdino( root_path="/scratch/cv-course2025/group8", epochs=200, batch_size=16, lr = 4 * 10**(-6), momentum = 0.99, temperature=0.04, proj_dim=128, save_every=50 ): """ Training WS-DINO with a ResNet50 backbone on the BBBC021 dataset. Args: root_path: Path to BBBC021 dataset epochs: Number of training epochs batch_size: Batch size for training lr: Learning rate for optimizer momentum: temperature: Temperature parameter for DINO-style loss projection_dim: Output dimension of projection head save_every: Save model every N epochs """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") # Use DataParallel for multi-GPU training num_gpus = torch.cuda.device_count() print(f"Using {num_gpus} GPUs for training") # Load data #bbbc = BBBC021(root_path=os.path.join(root_path, "bbbc021")) bbbc = BBBC021(root_path) print("loaded data") global_transform = T.Compose([ T.RandomResizedCrop(224, scale=(0.4, 1.0)), T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.5), #T.ColorJitter(0.4, 0.4, 0.4, 0.1), #T.ToTensor(), #T.Resize((224, 224)), T.Normalize(mean=[0.485]*3, std=[0.229]*3) ]) local_transform = T.Compose([ T.RandomResizedCrop(96, scale=(0.05, 0.4)), T.RandomHorizontalFlip(p=0.5), T.RandomVerticalFlip(p=0.5), #T.ColorJitter(0.4, 0.4, 0.4, 0.1), T.Normalize(mean=[0.485]*3, std=[0.229]*3) ]) # Combine into a multi-crop transform transform = MultiCropTransform(global_transform, local_transform, num_local_crops=6) # Create dataset and dataloader dataset = BBBC021WeakLabelDataset(bbbc, transform=transform) print("made dataset") dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=min(10, os.cpu_count()), pin_memory=True, drop_last=True ) print("dataloader created") num_compounds = len(dataset.compound_to_idx) # Create models student = get_resnet50(num_classes=num_compounds, proj_dim=proj_dim, model_type="base_resnet") print("student initialized") teacher = get_resnet50(num_classes=num_compounds, proj_dim=proj_dim, model_type="base_resnet") print("teacher initialized") # get weights from file #student = get_resnet50(num_classes=num_compounds, proj_dim=proj_dim, model_type="wsdino") #teacher = get_resnet50(num_classes=num_compounds, proj_dim=proj_dim, model_type="wsdino") # Use DataParallel for multi-GPU training if num_gpus > 1: student = nn.DataParallel(student) teacher = nn.DataParallel(teacher) student = student.to(device) teacher = teacher.to(device) teacher.load_state_dict(student.state_dict()) for p in teacher.parameters(): p.requires_grad = False # Optimizer #optimizer = torch.optim.Adam(student.parameters(), lr=lr) optimizer = torch.optim.AdamW( student.parameters(), lr=lr, weight_decay=0.0, betas=(0.9, 0.999) ) criterion = SimpleDINOLoss( out_dim=proj_dim, ncrops=8, # 2 global + 6 local crops student_temp=0.1, teacher_temp=temperature, center_momentum=momentum ) # Create save directory save_dir = os.path.join(root_path, "model_weights") os.makedirs(save_dir, exist_ok=True) print("starting training loop") # Training loop for epoch in range(epochs): student.train() total_loss = 0 for crops1, crops2, _, _ in dataloader: crops1 = [c.to(device) for c in crops1] # teacher views crops2 = [c.to(device) for c in crops2] # student views #student_outputs = torch.cat([student(crop) for crop in crops2], dim=0) #with torch.no_grad(): # teacher_outputs = torch.cat([teacher(crop) for crop in crops1], dim=0) # Flatten all views into a single batch all_student_views = torch.cat(crops2, dim=0) # shape [num_crops * B, C, H, W] all_teacher_views = torch.cat(crops1, dim=0) # shape [2 * B, C, H, W] # Forward passes student_outputs = student(all_student_views) with torch.no_grad(): teacher_outputs = teacher(all_teacher_views) # Compute DINO loss loss = criterion(student_outputs, teacher_outputs) # Backward & optimize optimizer.zero_grad() loss.backward() optimizer.step() # Update teacher update_teacher(student, teacher, m=momentum) total_loss += loss.item() avg_loss = total_loss / len(dataloader) print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}") # Save model every save_every epochs if (epoch + 1) % save_every == 0 or epoch == 1: model_to_save = student.module if isinstance(student, nn.DataParallel) else student backbone_state = model_to_save.state_dict() save_path = os.path.join(save_dir, f"resnet50_wsdino_epoch_{epoch+1}.pth") torch.save(backbone_state, save_path) print(f"Model saved to {save_path}") torch.cuda.empty_cache() # Save model torch.save(student.state_dict(), os.path.join(save_dir, "resnet50_wsdino.pth")) print("Model saved to resnet50_wsdino.pth")
if __name__ == "__main__": train_wsdino( root_path="/scratch/cv-course2025/group8", epochs=100, batch_size=16, lr=4 * 10**(-6), momentum = 0.99, temperature=0.04, proj_dim=128, save_every=50 )