Source code for evaluation.evaluator
import torch
import pickle
import os
import numpy as np
from typing import Literal, Dict
import torch.nn.functional as F
from pybbbc import BBBC021, constants
from models.load_model import ModelName
from experiments.tvn import TypicalVariationNormalizer
DistanceMeasure = Literal["l1", "l2", "cosine"]
[docs]
def evaluate_model(
model_name: ModelName,
distance_measure: DistanceMeasure = "cosine",
nsc_eval = True,
tvn: bool = False) -> Dict[str, float]:
"""
Evaluate MOA prediction using 1-nearest neighbor with specified distance measure on pre-extracted features.
Args:
model_name: Name of the model to use for loading pre-computed features.
distance_measure: Distance measure to use for 1NN ("l1", "l2", or "cosine").
nsc_eval: If True, same compound (all concentrations) is not used for evaluation.
tvn: If True, apply Typical Variation Normalization to features.
Returns:
Dict[str, float]: Dictionary with per-compound accuracies and total accuracy
"""
# Load pre-computed features
if tvn:
features_dir = f"/scratch/cv-course2025/group8/bbbc021_features/{model_name}/tvn"
else:
features_dir = f"/scratch/cv-course2025/group8/bbbc021_features/{model_name}"
if not os.path.exists(features_dir):
raise FileNotFoundError(f"Features directory not found: {features_dir}")
# Load all feature files
stored_features_dict = {}
stored_keys = []
for filename in os.listdir(features_dir):
if filename.endswith('.pkl'):
filepath = os.path.join(features_dir, filename)
try:
with open(filepath, 'rb') as f:
key, features = pickle.load(f)
stored_features_dict[key] = features
stored_keys.append(key) # (compound, concentration, moa)
except Exception as e:
print(f"Error loading {filepath}: {e}")
continue
if not stored_features_dict:
raise ValueError("No valid feature files found")
# Track results
total_correct = 0
total_predictions = 0
compound_results = {}
# Process each treatment from stored features
for treatment_key in stored_features_dict.keys():
compound_name, concentration, true_moa = treatment_key
# Skip DMSO compounds
if compound_name == "DMSO":
continue
print(f"Evaluating treatment: {compound_name}@{concentration} ({true_moa})")
# Create filtered reference features (exclude current compound if NSC is enabled)
if nsc_eval:
reference_keys = [key for key in stored_keys if key[0] != compound_name]
if not reference_keys:
print(f"Warning: No reference features available for compound {compound_name} with NSC evaluation. Skipping...")
continue
else:
reference_keys = stored_keys
# Find pre-computed features for this treatment
treatment_key = (compound_name, concentration, true_moa)
if treatment_key not in stored_features_dict:
print(f"Warning: No pre-computed features found for treatment {compound_name}@{concentration}. Skipping...")
continue
# Get the stored features for this treatment
avg_treatment_features = stored_features_dict[treatment_key]
# Compute distances/similarities with average features
best_score = float('-inf') if distance_measure == "cosine" else float('inf')
best_idx = -1
for i, ref_key in enumerate(reference_keys):
ref_features = stored_features_dict[ref_key]
if distance_measure == "cosine":
features_norm = F.normalize(avg_treatment_features, p=2, dim=0)
ref_features_norm = F.normalize(ref_features, p=2, dim=0)
score = torch.dot(ref_features_norm, features_norm).item()
if score > best_score:
best_score = score
best_idx = i
elif distance_measure == "l2":
score = torch.norm(ref_features - avg_treatment_features, p=2).item()
if score < best_score:
best_score = score
best_idx = i
elif distance_measure == "l1":
score = torch.norm(ref_features - avg_treatment_features, p=1).item()
if score < best_score:
best_score = score
best_idx = i
else:
raise ValueError(f"Unknown distance measure: {distance_measure}")
# Get predicted MOA
predicted_moa = reference_keys[best_idx][2]
# Check if prediction is correct
if predicted_moa == true_moa:
total_correct += 1
total_predictions += 1
# Track per-compound results
if compound_name not in compound_results:
compound_results[compound_name] = {"correct": 0, "total": 0}
compound_results[compound_name]["total"] += 1
if predicted_moa == true_moa:
compound_results[compound_name]["correct"] += 1
print(f" True MOA: {true_moa}, Predicted MOA: {predicted_moa}, Correct: {predicted_moa == true_moa}")
# Calculate total accuracy
total_accuracy = total_correct / total_predictions if total_predictions > 0 else 0.0
# Calculate per-compound accuracies
compound_accuracies = {}
for compound, stats in compound_results.items():
accuracy = stats["correct"] / stats["total"] if stats["total"] > 0 else 0.0
compound_accuracies[compound] = accuracy
# Print summary
print(f"\n=== Evaluation Results ===")
print(f"Model: {model_name}")
print(f"Distance measure: {distance_measure}")
print(f"NSC evaluation: {nsc_eval}")
print(f"Total accuracy: {total_accuracy:.4f} ({total_correct}/{total_predictions})")
print(f"\nPer-compound accuracies:")
for compound, accuracy in sorted(compound_accuracies.items()):
print(f" {compound}: {accuracy:.4f} ({compound_results[compound]['correct']}/{compound_results[compound]['total']})")
# Prepare return dictionary
results = {
"total_accuracy": total_accuracy,
**{f"compound_{compound}": accuracy for compound, accuracy in compound_accuracies.items()}
}
return results
if __name__ == "__main__":
#model_name = "simclr"
model_name = "wsdino"
# Evaluate model
results = evaluate_model(model_name, distance_measure="cosine", nsc_eval=True, tvn=False)
print("\nEvaluation completed. Results:")
for key, value in results.items():
print(f"{key}: {value:.4f}")