Models#
Models module for compound profiling.
Model Loading#
- models.load_model.load_pretrained_model(model_name: Literal['base_resnet', 'simclr', 'wsdino'], weight_path='/scratch/cv-course2025/group8/model_weights')[source]#
Load pretrained ResNet50 model.
- models.load_model.load_pretrained_resnet50(weights: str = 'IMAGENET1K_V2') object [source]#
Load pretrained ResNet50 model.
- Parameters:
weights – Model weights to use
Note
See https://docs.pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights for available weights.
- models.load_model.load_pretrained_model_from_weights(model_name: str, weight_path: str) Module [source]#
Load pretrained ResNet50 model from custom weights.
- Parameters:
model_name – Name of the model to load
weight_path – Path to the weights file
- Returns:
Pretrained ResNet50 model
- Return type:
nn.Module
SimCLR Vanilla#
- class models.simclr_vanilla.SimCLRProjectionHead(input_dim=2048, hidden_dim=512, output_dim=128)[source]#
Bases:
Module
Projection head for SimCLR
- __init__(input_dim=2048, hidden_dim=512, output_dim=128)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.simclr_vanilla.SimCLRModel(backbone_model, projection_dim=128)[source]#
Bases:
Module
This is our model for vanilla SimCLR, with ResNet50 backbone.
- __init__(backbone_model, projection_dim=128)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.simclr_vanilla.SimCLRVanillaDataset(root_path, transform=None, compound_aware=False)[source]#
Bases:
Dataset
Dataset for (vanilla) SimCLR: returns two augmentations of the same image. Optionally returns compound labels for compound-aware training.
- models.simclr_vanilla.contrastive_loss_vanilla(z1, z2, temperature=0.5)[source]#
Standard SimCLR NT-Xent loss for vanilla training. No labels needed - positive pairs are augmentations of the same image.
- models.simclr_vanilla.contrastive_loss_vanilla_compound_aware(z1, z2, compound_labels, temperature=0.5)[source]#
Compound-aware SimCLR NT-Xent loss for vanilla training. Excludes same-compound pairs from being treated as negatives. This is my go at a less agressive WS version compared to the simclr_ws.py version as i couldn’t believe our Labels would not help at all :D.
- Parameters:
z1 – Projected features from positive pairs (augmentations of same image).
z2 – Projected features from positive pairs (augmentations of same image).
compound_labels – Compound labels for each image in the batch.
temperature – Temperature parameter for softmax.
SimCLR Weakly-Supervised#
- class models.simclr_ws.SimCLRProjectionHead(input_dim=2048, hidden_dim=512, output_dim=256)[source]#
Bases:
Module
Projection head for SimCLR
- __init__(input_dim=2048, hidden_dim=512, output_dim=256)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.simclr_ws.SimCLRModel(backbone_model, projection_dim=256)[source]#
Bases:
Module
This is our approach to WS-SimCLR where we use weak labels to compose positive pairs. In particular, we used the compound labels to create positive pairs, where each pair consists of two images from the same compound but from different wells or plates (to bring in some noise and prevent the model from learning plate/well specific features).All negative pairs are images from different compounds.
- __init__(backbone_model, projection_dim=256)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class models.simclr_ws.SimCLRDataset(root_path, transform=None)[source]#
Bases:
Dataset
Dataset for SimCLR training with compound-based positive pairs
- models.simclr_ws.contrastive_loss(z1, z2, labels, temperature=0.1)[source]#
This was our go at the NT-Xent loss function, which is the contrastive loss used in SimCLR. It computes the loss based on the similarity of projected features from positive pairs. We made a small change to the loss function so that images in the same compound are ingored as negatives, so the mdoel would not try to seperate them in the feature space. Positive pairs are images from the same compound, but from different wells/plates.
- Parameters:
z1 – Projected features from positive pairs.
z2 – Projected features from positive pairs.
labels – Compound labels for each pair in the batch.
temperature – Temperature parameter for softmax.
- class models.simclr_ws.LARS(params, lr=1.0, momentum=0.9, weight_decay=0.0001, eta=0.001)[source]#
Bases:
Optimizer
LARS optimizer implementation. Although this Optimizer is designed for large batch sizes, we still wanted to implement this since it is the optimizer used in the original SimCLR paper.
However, given our computational resources, we will use AdamW instead, which is more suitable for our setup.
WS-DINO ResNet#
- class models.wsdino_resnet.BBBC021WeakLabelDataset(bbbc021, transform=None)[source]#
Bases:
Dataset
PyTorch-compatible dataset wrapper for BBBC021 using weak labels (compound IDs).
- Each sample consists of:
An image tensor (optionally transformed)
A weak label: the compound ID (as an index)
Filters out samples with unknown MoA (‘null’).
- class models.wsdino_resnet.DINOProjectionHead(in_dim, proj_dim=256, hidden_dim=2048)[source]#
Bases:
Module
DINO-style Projection Head for Self-Supervised Learning.
- This module implements a 3-layer MLP with:
Two hidden layers of size 2048 with GELU activations
An output layer of configurable dimension (default: 256) without activation
L2 normalization applied to the output
A weight-normalized linear layer as the final projection
- Parameters:
- __init__(in_dim, proj_dim=256, hidden_dim=2048)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- models.wsdino_resnet.get_resnet50(num_classes=None, use_projection_head=True, proj_dim=256, model_type='base_resnet')[source]#
Loads a ResNet-50 backbone with the final layer replaced for classification. Optionally adds a DINO-style projection head.
- Parameters:
num_classes – Number of output classes (e.g., number of unique MoAs, used only if projection head is False)
use_projection_head – If True, attach DINO-style projection head
proj_dim – Output dimension of projection head
model_type – should be “base_resnet” or “wsdino”
- Returns:
A torch.nn.Module (ResNet-50 with custom head)
- models.wsdino_resnet.dino_loss(student_out, teacher_out, temp=0.07)[source]#
Computes the DINO distillation loss between student and teacher outputs.
- Parameters:
student_out – Logits from the student network
teacher_out – Logits from the teacher network (detached)
temp – Temperature scaling parameter
- Returns:
KL divergence loss between student and teacher probability distributions
- models.wsdino_resnet.update_teacher(student, teacher, m=0.996)[source]#
Updates teacher weights using exponential moving average of student weights.
- Parameters:
student – Student model (nn.Module)
teacher – Teacher model (nn.Module)
m – Momentum factor (closer to 1 = slower update)
- Returns:
None (teacher model is updated in place)