Source code for evaluation.extractor
from enum import Enum
import pickle
from typing import Dict, Optional, Tuple
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
import os
from models.load_model import load_pretrained_model, create_feature_extractor, ModelName
from pybbbc import BBBC021, constants
from experiments.tvn import TypicalVariationNormalizer # Import TVN module
[docs]
def extract_moa_features(
model_name: ModelName,
device,
batch_size=16,
data_root: str = "/scratch/cv-course2025/group8",
compounds: list[str] = None,
tvn: bool = False
) -> None:
"""
Extract features for the BBBC021 dataset using a pretrained ResNet50 model.
Args:
model_name: Name of the model to use. Is of type ModelName.
device: Device to run the model on.
batch_size: Batch size for data loading.
data_root: Root directory where the BBBC021 dataset is stored.
compounds: List of compounds to process. If None, all compounds will be processed.
tvn: If True, apply Typical Variation Normalization to features before averaging and saving.
"""
# Load pretrained ResNet50 model
pretrained_model = load_pretrained_model(model_name)
# Create feature extractor
feature_extractor = create_feature_extractor(pretrained_model)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
])
if not compounds:
compounds = constants.COMPOUNDS
else:
for compound in compounds:
if compound not in constants.COMPOUNDS:
raise ValueError(f"Compound '{compound}' is not valid. Valid compounds: {constants.COMPOUNDS}")
# Output directory
output_dir = os.path.join(data_root, "bbbc021_features", model_name)
os.makedirs(output_dir, exist_ok=True)
# Set device
feature_extractor = feature_extractor.to(device)
feature_extractor.eval()
# Collect per-compound features
compound_features = {}
tvn_features = []
# Process each compound dynamically
for compound in compounds:
data = BBBC021(root_path=data_root, compound=compound)
print(f"Processing Compound: {compound} with {len(data.images)} images")
# Dictionary to store images grouped by (compound, concentration, moa)
image_groups: Dict[Tuple[str, float, str], list[torch.Tensor]] = {}
# Collect images for this compound
for image, metadata in data:
if metadata.compound.moa == 'null':
print(f"Skipping image with null MOA for compound {compound}.")
continue
key = (metadata.compound.compound,
metadata.compound.concentration,
metadata.compound.moa)
if key not in image_groups:
image_groups[key] = []
# Convert numpy array to tensor if needed
if isinstance(image, np.ndarray):
image = torch.from_numpy(image).float()
image = transform(image)
image_groups[key].append(image)
# Process each group for this compound immediately
for key, images in image_groups.items():
compound_name, concentration, moa = key
if len(images) == 0:
print(f"Warning: No images for group {compound_name}_{concentration}. Skipping...")
continue
print(f"Extracting features for {compound_name}@{concentration}({moa}) - {len(images)} images")
try:
# Create DataLoader for this group
dataset = torch.utils.data.TensorDataset(torch.stack(images))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# Extract features
all_features = []
with torch.no_grad():
for batch in dataloader:
batch_images = batch[0].to(device)
features = feature_extractor(batch_images)
features = features.squeeze() # Remove spatial dimensions
if len(features.shape) == 1:
features = features.unsqueeze(0)
all_features.append(features.cpu())
all_features = torch.cat(all_features, dim=0)
if tvn and compound_name == "DMSO":
tvn_features.append(all_features) # Store for TVN fitting
compound_features[key] = all_features # Store per-image features for every treatment
except Exception as e:
print(f"Error processing group {compound_name}_{concentration}: {e}. Skipping...")
continue
# TVN logic
if tvn:
tvn_output_dir = os.path.join(data_root, "bbbc021_features", model_name, "tvn")
os.makedirs(tvn_output_dir, exist_ok=True)
print("\nFitting TVN from DMSO images...")
if not tvn_features:
raise RuntimeError("No DMSO features found to fit TVN.")
dmso_concat = torch.cat(tvn_features, dim=0)
tvn = TypicalVariationNormalizer()
tvn.fit(dmso_concat)
print("Applying TVN and saving averaged features...")
for key, features in compound_features.items():
transformed = tvn.transform(features)
avg_feature = torch.mean(transformed, dim=0)
result = (key, avg_feature)
compound_name, concentration, _ = key
filename = f"{compound_name}_{concentration}_tvn.pkl".replace(" ", "_").replace("/", "_")
filepath = os.path.join(tvn_output_dir, filename)
with open(filepath, 'wb') as f:
pickle.dump(result, f)
print(f"Saved averaged TVN features to {filepath}")
else:
print("Saving non-TVN averaged features...")
for key, features in compound_features.items():
avg_feature = torch.mean(features, dim=0)
result = (key, avg_feature)
compound_name, concentration, _ = key
filename = f"{compound_name}_{concentration}.pkl".replace(" ", "_").replace("/", "_")
filepath = os.path.join(output_dir, filename)
with open(filepath, 'wb') as f:
pickle.dump(result, f)
print(f"Saved averaged features to {filepath}")
# main function to run the feature extraction
if __name__ == "__main__":
extract_moa_features(
model_name="simclr",
# model_name="wsdino",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
batch_size=256,
data_root="/scratch/cv-course2025/group8",
compounds=constants.COMPOUNDS,
tvn=False #True
)