Models#

Models module for compound profiling.

Model Loading#

models.load_model.load_pretrained_model(model_name: Literal['base_resnet', 'simclr', 'wsdino'], weight_path='/scratch/cv-course2025/group8/model_weights')[source]#

Load pretrained ResNet50 model.

models.load_model.load_pretrained_resnet50(weights: str = 'IMAGENET1K_V2') object[source]#

Load pretrained ResNet50 model.

Parameters:

weights – Model weights to use

models.load_model.load_pretrained_model_from_weights(model_name: str, weight_path: str) Module[source]#

Load pretrained ResNet50 model from custom weights.

Parameters:
  • model_name – Name of the model to load

  • weight_path – Path to the weights file

Returns:

Pretrained ResNet50 model

Return type:

nn.Module

models.load_model.create_feature_extractor(pretrained_model: Module) Module[source]#

Create a feature extractor from a pretrained ResNet50 model. :param pretrained_model: Pretrained ResNet50 model

Returns:

Feature extractor model without the final classification layer

Return type:

nn.Module

SimCLR Vanilla#

class models.simclr_vanilla.SimCLRProjectionHead(input_dim=2048, hidden_dim=512, output_dim=128)[source]#

Bases: Module

Projection head for SimCLR

__init__(input_dim=2048, hidden_dim=512, output_dim=128)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.simclr_vanilla.SimCLRModel(backbone_model, projection_dim=128)[source]#

Bases: Module

This is our model for vanilla SimCLR, with ResNet50 backbone.

__init__(backbone_model, projection_dim=128)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.simclr_vanilla.SimCLRVanillaDataset(root_path, transform=None, compound_aware=False)[source]#

Bases: Dataset

Dataset for (vanilla) SimCLR: returns two augmentations of the same image. Optionally returns compound labels for compound-aware training.

__init__(root_path, transform=None, compound_aware=False)[source]#
models.simclr_vanilla.contrastive_loss_vanilla(z1, z2, temperature=0.5)[source]#

Standard SimCLR NT-Xent loss for vanilla training. No labels needed - positive pairs are augmentations of the same image.

models.simclr_vanilla.contrastive_loss_vanilla_compound_aware(z1, z2, compound_labels, temperature=0.5)[source]#

Compound-aware SimCLR NT-Xent loss for vanilla training. Excludes same-compound pairs from being treated as negatives. This is my go at a less agressive WS version compared to the simclr_ws.py version as i couldn’t believe our Labels would not help at all :D.

Parameters:
  • z1 – Projected features from positive pairs (augmentations of same image).

  • z2 – Projected features from positive pairs (augmentations of same image).

  • compound_labels – Compound labels for each image in the batch.

  • temperature – Temperature parameter for softmax.

SimCLR Weakly-Supervised#

class models.simclr_ws.SimCLRProjectionHead(input_dim=2048, hidden_dim=512, output_dim=256)[source]#

Bases: Module

Projection head for SimCLR

__init__(input_dim=2048, hidden_dim=512, output_dim=256)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.simclr_ws.SimCLRModel(backbone_model, projection_dim=256)[source]#

Bases: Module

This is our approach to WS-SimCLR where we use weak labels to compose positive pairs. In particular, we used the compound labels to create positive pairs, where each pair consists of two images from the same compound but from different wells or plates (to bring in some noise and prevent the model from learning plate/well specific features).All negative pairs are images from different compounds.

__init__(backbone_model, projection_dim=256)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class models.simclr_ws.SimCLRDataset(root_path, transform=None)[source]#

Bases: Dataset

Dataset for SimCLR training with compound-based positive pairs

__init__(root_path, transform=None)[source]#
models.simclr_ws.contrastive_loss(z1, z2, labels, temperature=0.1)[source]#

This was our go at the NT-Xent loss function, which is the contrastive loss used in SimCLR. It computes the loss based on the similarity of projected features from positive pairs. We made a small change to the loss function so that images in the same compound are ingored as negatives, so the mdoel would not try to seperate them in the feature space. Positive pairs are images from the same compound, but from different wells/plates.

Parameters:
  • z1 – Projected features from positive pairs.

  • z2 – Projected features from positive pairs.

  • labels – Compound labels for each pair in the batch.

  • temperature – Temperature parameter for softmax.

class models.simclr_ws.LARS(params, lr=1.0, momentum=0.9, weight_decay=0.0001, eta=0.001)[source]#

Bases: Optimizer

LARS optimizer implementation. Although this Optimizer is designed for large batch sizes, we still wanted to implement this since it is the optimizer used in the original SimCLR paper.

However, given our computational resources, we will use AdamW instead, which is more suitable for our setup.

__init__(params, lr=1.0, momentum=0.9, weight_decay=0.0001, eta=0.001)[source]#
step(closure=None)[source]#

Perform a single optimization step to update parameter.

Parameters:

closure (Callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

WS-DINO ResNet#

class models.wsdino_resnet.BBBC021WeakLabelDataset(bbbc021, transform=None)[source]#

Bases: Dataset

PyTorch-compatible dataset wrapper for BBBC021 using weak labels (compound IDs).

Each sample consists of:
  • An image tensor (optionally transformed)

  • A weak label: the compound ID (as an index)

Filters out samples with unknown MoA (‘null’).

__init__(bbbc021, transform=None)[source]#
Parameters:
  • bbbc021 – An instance of pybbbc.BBBC021

  • transform – A torchvision transform to apply to each image

class models.wsdino_resnet.DINOProjectionHead(in_dim, proj_dim=256, hidden_dim=2048)[source]#

Bases: Module

DINO-style Projection Head for Self-Supervised Learning.

This module implements a 3-layer MLP with:
  • Two hidden layers of size 2048 with GELU activations

  • An output layer of configurable dimension (default: 256) without activation

  • L2 normalization applied to the output

  • A weight-normalized linear layer as the final projection

Parameters:
  • in_dim (int) – Input feature dimension (e.g., 2048 for ResNet-50).

  • hidden_dim (int) – Hidden layer dimension (default: 2048).

  • proj_dim (int) – Output projection dimension (default: 256).

__init__(in_dim, proj_dim=256, hidden_dim=2048)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

models.wsdino_resnet.get_resnet50(num_classes=None, use_projection_head=True, proj_dim=256, model_type='base_resnet')[source]#

Loads a ResNet-50 backbone with the final layer replaced for classification. Optionally adds a DINO-style projection head.

Parameters:
  • num_classes – Number of output classes (e.g., number of unique MoAs, used only if projection head is False)

  • use_projection_head – If True, attach DINO-style projection head

  • proj_dim – Output dimension of projection head

  • model_type – should be “base_resnet” or “wsdino”

Returns:

A torch.nn.Module (ResNet-50 with custom head)

models.wsdino_resnet.dino_loss(student_out, teacher_out, temp=0.07)[source]#

Computes the DINO distillation loss between student and teacher outputs.

Parameters:
  • student_out – Logits from the student network

  • teacher_out – Logits from the teacher network (detached)

  • temp – Temperature scaling parameter

Returns:

KL divergence loss between student and teacher probability distributions

models.wsdino_resnet.update_teacher(student, teacher, m=0.996)[source]#

Updates teacher weights using exponential moving average of student weights.

Parameters:
  • student – Student model (nn.Module)

  • teacher – Teacher model (nn.Module)

  • m – Momentum factor (closer to 1 = slower update)

Returns:

None (teacher model is updated in place)

class models.wsdino_resnet.MultiCropTransform(global_transform, local_transform, num_local_crops)[source]#

Bases: object

Generate 2 global crops and N local crops from the same image.

__init__(global_transform, local_transform, num_local_crops)[source]#