Source code for evaluation.visualization.visualize_embeddings

"""
Visualize BBBC021 compound embeddings using t-SNE and UMAP.

This script loads pre-extracted feature embeddings from .pkl files,
applies t-SNE and UMAP dimensionality reduction, and visualizes the
resulting 2D projections colored by Mechanism of Action (MoA).
"""

import os
import pickle
import numpy as np
import matplotlib.pyplot as plt
import umap
from sklearn.manifold import TSNE
from collections import defaultdict
import math
# Add imports for the new function
from pybbbc import BBBC021, constants
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from evaluator import evaluate_model

[docs] def load_model_features(model_name, data_root="/scratch/cv-course2025/group8"): """ Load features for a single model. Args: model_name (str): Name of the model (e.g., 'base_resnet', 'simclr', 'wsdino') data_root (str): Root directory containing bbbc021_features Returns: tuple: (features, moas) where features is numpy array and moas is list """ feature_dir = os.path.join(data_root, "bbbc021_features", model_name) if not os.path.exists(feature_dir): raise FileNotFoundError(f"Features directory not found: {feature_dir}") features = [] moas = [] for file in os.listdir(feature_dir): if not file.endswith(".pkl") or "DMSO" in file: continue filepath = os.path.join(feature_dir, file) # Load (compound_name, concentration, moa), feature_vector with open(filepath, 'rb') as f: (compound_info, feat) = pickle.load(f) compound, conc, moa = compound_info # Skip samples with unknown MoA if moa == "null": continue features.append(feat.numpy() if hasattr(feat, 'numpy') else feat) moas.append(moa) if not features: raise ValueError(f"No valid features found for model {model_name}") return np.stack(features), moas
[docs] def plot_tsne_comparison(model_names, data_root="/scratch/cv-course2025/group8", output_dir="/scratch/cv-course2025/group8/plots", figsize=None): """ Create side-by-side t-SNE plots for multiple models. Args: model_names (list): List of model names to compare data_root (str): Root directory containing bbbc021_features output_dir (str): Directory to save plots figsize (tuple): Figure size (width, height) """ n_models = len(model_names) n_cols = 2 n_rows = math.ceil(n_models / n_cols) if figsize is None: figsize = (6 * n_cols, 5 * n_rows) fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False) axes = axes.flatten() # Get all unique MoAs across all models for consistent coloring all_moas = set() model_data = {} for model_name in model_names: features, moas = load_model_features(model_name, data_root) model_data[model_name] = (features, moas) all_moas.update(moas) # Create consistent color mapping unique_moas = sorted(all_moas) moa_colors = plt.cm.tab20(np.linspace(0, 1, len(unique_moas))) moa_to_color = {moa: moa_colors[i] for i, moa in enumerate(unique_moas)} for i, model_name in enumerate(model_names): features, moas = model_data[model_name] print(f"Running t-SNE for {model_name}...") tsne = TSNE(n_components=2, perplexity=30, random_state=42) tsne_proj = tsne.fit_transform(features) ax = axes[i] # Plot each MoA separately for proper legend for moa in unique_moas: mask = np.array([m == moa for m in moas]) if np.any(mask): ax.scatter( tsne_proj[mask, 0], tsne_proj[mask, 1], c=[moa_to_color[moa]], label=moa, alpha=0.7, s=30 ) ax.set_title(f"t-SNE: {model_name}") ax.set_xticks([]) ax.set_yticks([]) # Add legend only to the last subplot if i == n_models - 1: ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Hide any unused subplots for j in range(n_models, len(axes)): fig.delaxes(axes[j]) plt.tight_layout() # Save plot output_path = os.path.join(output_dir, f"tsne_comparison_{'_'.join(model_names)}.png") plt.savefig(output_path, bbox_inches='tight', dpi=300) plt.show() print(f"t-SNE comparison saved to {output_path}")
[docs] def plot_umap_comparison(model_names, data_root="/scratch/cv-course2025/group8", output_dir="/scratch/cv-course2025/group8/plots", figsize=None): """ Create side-by-side UMAP plots for multiple models. Args: model_names (list): List of model names to compare data_root (str): Root directory containing bbbc021_features output_dir (str): Directory to save plots figsize (tuple): Figure size (width, height) """ n_models = len(model_names) n_cols = 2 n_rows = math.ceil(n_models / n_cols) if figsize is None: figsize = (6 * n_cols, 5 * n_rows) fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize, squeeze=False) axes = axes.flatten() # Get all unique MoAs across all models for consistent coloring all_moas = set() model_data = {} for model_name in model_names: features, moas = load_model_features(model_name, data_root) model_data[model_name] = (features, moas) all_moas.update(moas) # Create consistent color mapping unique_moas = sorted(all_moas) moa_colors = plt.cm.tab20(np.linspace(0, 1, len(unique_moas))) moa_to_color = {moa: moa_colors[i] for i, moa in enumerate(unique_moas)} for i, model_name in enumerate(model_names): features, moas = model_data[model_name] print(f"Running UMAP for {model_name}...") umap_reducer = umap.UMAP(n_components=2, random_state=42) umap_proj = umap_reducer.fit_transform(features) ax = axes[i] # Plot each MoA separately for proper legend for moa in unique_moas: mask = np.array([m == moa for m in moas]) if np.any(mask): ax.scatter( umap_proj[mask, 0], umap_proj[mask, 1], c=[moa_to_color[moa]], label=moa, alpha=0.7, s=30 ) ax.set_title(f"UMAP: {model_name}") ax.set_xticks([]) ax.set_yticks([]) # Add legend only to the last subplot if i == n_models - 1: ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Hide any unused subplots for j in range(n_models, len(axes)): fig.delaxes(axes[j]) plt.tight_layout() # Save plot output_path = os.path.join(output_dir, f"umap_comparison_{'_'.join(model_names)}.png") plt.savefig(output_path, bbox_inches='tight', dpi=300) plt.show() print(f"UMAP comparison saved to {output_path}")
[docs] def plot_accuracy_vs_image_count(model_names, data_root="/scratch/cv-course2025/group8", output_dir="/scratch/cv-course2025/group8/plots", distance_measure="cosine", sort_by="accuracy", max_image_count=None): """ Create a dual-axis plot showing per-compound accuracies with image counts context. Args: model_names (list): List of model names to compare data_root (str): Root directory containing bbbc021_features and dataset output_dir (str): Directory to save plots distance_measure (str): Distance measure for evaluation ("cosine", "l1", "l2") sort_by (str): Sort compounds by "accuracy", "image_count", or "compound_name" max_image_count (int, optional): Maximum value for the image count y-axis (right side). If None, uses automatic scaling. """ # First, collect image counts per compound print("Collecting image counts per compound...") compound_image_counts = defaultdict(int) # Get all compounds (excluding DMSO) compounds = [c for c in constants.COMPOUNDS if c != "DMSO"] moas = [m for m in constants.MOA if m not in ["DMSO", "null"]] # Count images per compound for compound in compounds: try: pybbbc = BBBC021(root_path=data_root, compound=compound, moa=moas) for image, meta in pybbbc: compound_name = meta.compound.compound if meta.compound.moa != 'null': # Only count valid MoA samples compound_image_counts[compound_name] += 1 except Exception as e: print(f"Error processing compound {compound}: {e}") continue print(f"Found image counts for {len(compound_image_counts)} compounds") # Collect all model results first to determine consistent ordering all_model_results = {} all_compounds = set() for model_name in model_names: print(f"Evaluating model: {model_name}") try: results = evaluate_model(model_name, distance_measure=distance_measure, nsc_eval=True, tvn=False) # Extract compound accuracies compound_accuracies = {} for key, value in results.items(): if key.startswith("compound_"): compound = key.replace("compound_", "") if compound in compound_image_counts: # Only include compounds with image counts compound_accuracies[compound] = value all_compounds.add(compound) all_model_results[model_name] = compound_accuracies except Exception as e: print(f"Error processing model {model_name}: {e}") continue # Determine consistent ordering based on the first model's results if not all_compounds: print("No compounds found across all models") return # Create a reference ordering based on sort_by parameter using the first model first_model = model_names[0] if first_model in all_model_results: reference_compounds = list(all_compounds) reference_accuracies = [all_model_results[first_model].get(comp, 0) for comp in reference_compounds] reference_image_counts = [compound_image_counts[comp] for comp in reference_compounds] # Sort based on the reference model if sort_by == "accuracy": sorted_indices = sorted(range(len(reference_compounds)), key=lambda i: reference_accuracies[i]) elif sort_by == "image_count": sorted_indices = sorted(range(len(reference_compounds)), key=lambda i: reference_image_counts[i]) else: # sort by compound name sorted_indices = sorted(range(len(reference_compounds)), key=lambda i: reference_compounds[i]) # Apply the same ordering to get consistent compound order compound_order = [reference_compounds[i] for i in sorted_indices] else: compound_order = sorted(all_compounds) # Fallback to alphabetical # Create plots for each model n_models = len(model_names) fig, axes = plt.subplots(n_models, 1, figsize=(15, 6*n_models)) if n_models == 1: axes = [axes] for i, model_name in enumerate(model_names): if model_name not in all_model_results: continue compound_accuracies = all_model_results[model_name] # Use the consistent compound order compounds_sorted = compound_order accuracies_sorted = [compound_accuracies.get(comp, 0) for comp in compounds_sorted] image_counts_sorted = [compound_image_counts[comp] for comp in compounds_sorted] # Create dual-axis plot ax1 = axes[i] ax2 = ax1.twinx() # Plot accuracy bars x_pos = np.arange(len(compounds_sorted)) bars1 = ax1.bar(x_pos, accuracies_sorted, alpha=0.7, color='steelblue', label='Accuracy', width=0.6) # Plot image count line line2 = ax2.plot(x_pos, image_counts_sorted, color='red', marker='o', linewidth=2, markersize=4, label='Image Count') # Formatting ax1.set_xlabel('Compounds', fontsize=12) ax1.set_ylabel('Accuracy', fontsize=12, color='steelblue') ax2.set_ylabel('Number of Images', fontsize=12, color='red') ax1.set_title(f'{model_name}: Per-Compound Accuracy vs Image Count ({distance_measure} distance)', fontsize=14, fontweight='bold') # Set x-axis labels ax1.set_xticks(x_pos) ax1.set_xticklabels(compounds_sorted, rotation=45, ha='right', fontsize=10) # Color the axes ax1.tick_params(axis='y', labelcolor='steelblue') ax2.tick_params(axis='y', labelcolor='red') # Set y-axis limits for image count if specified if max_image_count is not None: ax2.set_ylim(0, max_image_count) # Add grid ax1.grid(axis='y', alpha=0.3) # Add value labels on bars for readability for bar, acc in zip(bars1, accuracies_sorted): height = bar.get_height() ax1.annotate(f'{acc:.2f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8, rotation=90) # Add correlation coefficient correlation = np.corrcoef(accuracies_sorted, image_counts_sorted)[0, 1] ax1.text(0.02, 0.98, f'Correlation: {correlation:.3f}', transform=ax1.transAxes, fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8)) # Add legend lines1, labels1 = ax1.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left') print(f"Model {model_name}: Correlation between accuracy and image count: {correlation:.3f}") plt.tight_layout() # Save plot output_path = os.path.join(output_dir, f"accuracy_vs_image_count_{'_'.join(model_names)}_{distance_measure}.png") plt.savefig(output_path, bbox_inches='tight', dpi=300) plt.show() print(f"Dual-axis plot saved to {output_path}") # Print summary statistics using the consistent compound order all_image_counts = [compound_image_counts[comp] for comp in compound_order] print("\n=== Summary Statistics ===") print(f"Total compounds analyzed: {len(compound_order)}") print(f"Image count range: {min(all_image_counts)} - {max(all_image_counts)}") print(f"Mean images per compound: {np.mean(all_image_counts):.1f}") print(f"Compounds with <50 images: {sum(1 for c in all_image_counts if c < 50)}") print(f"Compounds with >200 images: {sum(1 for c in all_image_counts if c > 200)}")
if __name__ == "__main__": # Example: Compare multiple models #model_names = ["base_resnet", "simclr_vanilla_ws"] model_names = ["base_resnet", "wsdino"] output_dir=os.path.join("/scratch/cv-course2025/group8/plots", "wsdino") os.makedirs(output_dir, exist_ok=True) plot_accuracy_vs_image_count(model_names, sort_by="image_count", max_image_count=200, output_dir=output_dir) plot_umap_comparison(model_names, output_dir=output_dir) plot_tsne_comparison(model_names, output_dir=output_dir)