Source code for spacr.deep_spacr

import os, torch, time, gc, datetime
torch.backends.cudnn.benchmark = True
import numpy as np
import pandas as pd
from torch.optim import Adagrad, AdamW
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.metrics import auc, precision_recall_curve
from IPython.display import display
from multiprocessing import cpu_count
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR

from sklearn.metrics import precision_recall_curve, auc, average_precision_score, confusion_matrix
    

from torchvision import transforms
from torch.utils.data import DataLoader

[docs] def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10): """ Apply a trained PyTorch model to images in a directory. The function loads a saved model, builds a dataset from the input images, runs batched inference, and saves prediction scores to a CSV file. :param src: Path to the input image directory or collection of image paths. :type src: str or sequence :param model_path: Path to the saved PyTorch model. :type model_path: str :param image_size: Final square crop size used before inference. :type image_size: int :param batch_size: Number of images processed per batch. :type batch_size: int :param normalize: Whether to normalize the image channels using mean 0.5 and standard deviation 0.5. :type normalize: bool :param n_jobs: Number of worker processes used by the DataLoader. :type n_jobs: int :return: DataFrame with image paths and predicted positive-class probabilities. :rtype: pandas.DataFrame The returned DataFrame contains the columns ``path`` and ``pred``. Results are also written to a CSV file derived from ``model_path`` and the current date. The model output is interpreted as a binary logit and converted to probabilities with ``torch.sigmoid``. """ from .io import NoClassDataset from .utils import print_progress device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if normalize: transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(size=(image_size, image_size)), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) else: transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(size=(image_size, image_size))]) model = torch.load(model_path) print(model) print(f'Loading dataset in {src} with {len(src)} images') dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs) print(f'Loaded {len(src)} images') result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv' print(f'Results wil be saved in: {result_loc}') model.eval() model = model.to(device) prediction_pos_probs = [] filenames_list = [] time_ls = [] with torch.no_grad(): for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1): start = time.time() images = batch_images.to(torch.float).to(device) outputs = model(images) batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy() prediction_pos_probs.extend(batch_prediction_pos_prob.tolist()) filenames_list.extend(filenames) stop = time.time() duration = stop - start time_ls.append(duration) files_processed = batch_idx*batch_size files_to_process = len(data_loader) print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions") data = {'path':filenames_list, 'pred':prediction_pos_probs} df = pd.DataFrame(data, index=None) df.to_csv(result_loc, index=True, header=True, mode='w') torch.cuda.empty_cache() torch.cuda.memory.empty_cache() return df
[docs] def apply_model_to_tar(settings={}): """ Apply a trained PyTorch model to images stored in a tar archive. The function loads a saved model, reads images from a tar-based dataset, performs batched inference, post-processes prediction scores, and saves the results to a CSV file. :param settings: Dictionary of inference settings. Expected keys include ``tar_path``, ``model_path``, ``image_size``, ``batch_size``, ``normalize``, ``n_jobs``, ``verbose``, and ``score_threshold``. :type settings: dict :return: DataFrame with processed prediction results. :rtype: pandas.DataFrame The returned DataFrame contains at least the columns ``path`` and ``pred``. Additional columns may be added by ``process_vision_results``. If the model output has shape ``(N, 2)``, the probability of class 1 is computed with ``torch.softmax``. Otherwise, outputs are treated as binary logits and converted with ``torch.sigmoid``. """ from .io import TarImageDataset from .utils import process_vision_results, print_progress tar_path = settings['tar_path'] model_path = settings['model_path'] device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if settings['normalize']: transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), ]) else: transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])), ]) if settings['verbose']: print(f"Loading model from {model_path}") print(f"Loading dataset from {tar_path}") # <<< key change: allow unpickling of your saved model object >>> model = torch.load(settings['model_path'], map_location=device, weights_only=False) dataset = TarImageDataset(tar_path, transform=transform) data_loader = DataLoader( dataset, batch_size=settings['batch_size'], shuffle=True, # fine for inference; set False if you want deterministic order num_workers=settings['n_jobs'], pin_memory=(device.type == 'cuda'), ) model_name = os.path.splitext(os.path.basename(model_path))[0] dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0] date_name = datetime.date.today().strftime('%y%m%d') dst = os.path.dirname(tar_path) result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv' model.eval() model = model.to(device) if settings['verbose']: print(model) print(f'Generated dataset with {len(dataset)} images') print(f'Generating loader from {len(data_loader)} batches') print(f'Results wil be saved in: {result_loc}') print(f'Model is in eval mode') print(f'Model loaded to device') prediction_pos_probs = [] filenames_list = [] time_ls = [] gc.collect() with torch.no_grad(): for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1): start = time.time() images = batch_images.to(torch.float).to(device) outputs = model(images) # robust positive-class probability handling if outputs.ndim == 2 and outputs.size(1) == 2: batch_prediction_pos_prob = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy() else: # assume single-logit binary head batch_prediction_pos_prob = torch.sigmoid(outputs).squeeze(-1).cpu().numpy() prediction_pos_probs.extend(batch_prediction_pos_prob.tolist()) filenames_list.extend(filenames) stop = time.time() duration = stop - start time_ls.append(duration) files_processed = batch_idx * settings['batch_size'] files_to_process = len(data_loader) * settings['batch_size'] print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset") df = pd.DataFrame({'path': filenames_list, 'pred': prediction_pos_probs}, index=None) df = process_vision_results(df, settings['score_threshold']) df.to_csv(result_loc, index=True, header=True, mode='w') print(f"Saved results to {result_loc}") torch.cuda.empty_cache() return df
def _to_numpy_labels(target: torch.Tensor) -> np.ndarray: """ Convert targets to integer class ids: - if 1D float/bool tensor -> round and cast to int - if shape (N, C) one-hot -> argmax - else assume already (N,) int """ t = target.detach().cpu() if t.ndim == 2 and t.size(1) > 1: return t.argmax(dim=1).numpy().astype(int) if t.dtype.is_floating_point: return t.round().numpy().astype(int) return t.numpy().astype(int) def _binary_metrics(y_true: np.ndarray, pos_probs: np.ndarray) -> dict: """Metrics for binary classification.""" if y_true.ndim != 1: y_true = y_true.reshape(-1) # Precision-Recall AUC if len(np.unique(y_true)) >= 2: precision, recall, thresholds = precision_recall_curve(y_true, pos_probs, pos_label=1) pr_auc = auc(recall, precision) # F1-optimal threshold (optional; we still report 0.5 preds below) thresholds = np.append(thresholds, 1.0) with np.errstate(divide='ignore', invalid='ignore'): f1 = 2 * (precision * recall) / (precision + recall) opt_idx = np.nanargmax(f1) opt_thr = float(thresholds[opt_idx]) else: pr_auc = np.nan opt_thr = 0.5 # Discrete preds at 0.5 threshold for stability/readability pred = (pos_probs >= 0.5).astype(int) # Accuracies acc = (pred == y_true).mean() if len(y_true) else np.nan neg_mask = y_true == 0 pos_mask = y_true == 1 acc_neg = (pred[neg_mask] == 0).mean() if neg_mask.any() else np.nan acc_pos = (pred[pos_mask] == 1).mean() if pos_mask.any() else np.nan return { "accuracy": float(acc), "neg_accuracy": float(acc_neg), "pos_accuracy": float(acc_pos), "prauc": float(pr_auc), "optimal_threshold": float(opt_thr), } def _multiclass_metrics(y_true: np.ndarray, prob_mat: np.ndarray) -> dict: """ Metrics for multiclass (single-label): - overall accuracy - per-class accuracy (weighted by support) - macro average precision (one-vs-rest) """ preds = prob_mat.argmax(axis=1) acc = (preds == y_true).mean() if len(y_true) else np.nan # Per-class (diagonal / row sum) cm = confusion_matrix(y_true, preds, labels=np.arange(prob_mat.shape[1])) with np.errstate(divide='ignore', invalid='ignore'): per_class_acc = np.diag(cm) / cm.sum(axis=1, where=(cm.sum(axis=1) != 0), initial=1) # Average precision macro (one-vs-rest) # Build one-hot y_true C = prob_mat.shape[1] y_true_oh = np.zeros((len(y_true), C), dtype=int) if len(y_true): y_true_oh[np.arange(len(y_true)), y_true] = 1 try: ap_macro = average_precision_score(y_true_oh, prob_mat, average="macro") except Exception: ap_macro = np.nan # For compatibility with your logging keys: return { "accuracy": float(acc), "neg_accuracy": np.nan, # not meaningful in multiclass "pos_accuracy": np.nan, # not meaningful in multiclass "prauc": float(ap_macro), # reuse key for macro-AP "optimal_threshold": np.nan, "per_class_accuracy": per_class_acc.tolist(), "num_classes": int(C), }
[docs] def evaluate_model_performance(model, loader, epoch, loss_type='auto', loss_fn=None, num_classes=None): """ Evaluates performance for binary or multiclass models. Returns: data_dict (dict): metrics + loss + epoch [prediction_probs, all_labels] - binary: probs shape (N,) - multiclass: probs shape (N, C) """ from .utils import calculate_loss, build_loss # build_loss only used if loss_fn is None device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval().to(device) total_loss, total_samples = 0.0, 0 all_labels = [] prob_bucket = [] head_dim = None # infer from first batch binary_mode = None with torch.no_grad(): for data, target, _ in loader: data = data.to(device) logits = model(data) # infer head size/mode once if head_dim is None: head_dim = logits.size(1) if (logits.ndim == 2) else 1 binary_mode = (head_dim == 1) # ----- target normalization for loss/metrics ----- if binary_mode: # BCE-style targets: float {0,1}, allow (N,) or (N,1) target = target.to(device).float() y_true_batch = (target.view(-1) > 0.5).long().detach().cpu().numpy() else: # CE-style: class indices (N,) if target.ndim == 2: # handle one-hot inputs robustly target = target.argmax(dim=1) target = target.to(device).long() y_true_batch = target.view(-1).detach().cpu().numpy() # ----- choose loss (prefer training's loss_fn if provided) ----- local_loss_fn = loss_fn if local_loss_fn is None: # fallback: construct something reasonable matching the head local_loss_fn = build_loss(loss_type or 'auto', num_classes=head_dim, class_counts=None, label_smoothing=0.0, focal_gamma=2.0, focal_alpha=None, logit_adjust_tau=0.0) loss = local_loss_fn(logits, target) batch_size = data.size(0) total_loss += float(loss.item()) * batch_size total_samples += batch_size all_labels.extend(y_true_batch.tolist()) # ----- probabilities for metrics ----- if binary_mode: probs = torch.sigmoid(logits.view(-1)) prob_bucket.append(probs.detach().cpu().numpy()) else: probs = torch.softmax(logits, dim=1) prob_bucket.append(probs.detach().cpu().numpy()) # aggregate mean_loss = total_loss / max(1, total_samples) y_true = np.asarray(all_labels, dtype=int) if len(prob_bucket) == 0: # empty loader: synthesize empty array with correct rank if (num_classes or head_dim or 1) == 1: probs_np = np.empty((0,)) else: c = num_classes if num_classes is not None else (head_dim if head_dim is not None else 2) probs_np = np.empty((0, c)) else: probs_np = np.concatenate(prob_bucket, axis=0) # metrics (assumes _binary_metrics / _multiclass_metrics exist) if probs_np.ndim == 1: metrics = _binary_metrics(y_true, probs_np) else: metrics = _multiclass_metrics(y_true, probs_np) metrics["loss"] = float(mean_loss) metrics["epoch"] = int(epoch) metrics["Accuracy"] = metrics["accuracy"] return metrics, [probs_np, y_true.tolist()]
[docs] def test_model_core(model, loader, loader_name, epoch, loss_type): """ Core test loop returning both summary metrics and a row-per-image dataframe, compatible with binary & multiclass. """ from .utils import calculate_loss device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.eval().to(device) total_loss = 0.0 total_samples = 0 all_labels = [] probs_rows = [] filenames = [] with torch.no_grad(): for data, target, batch_filenames in loader: data = data.to(device) target = target.to(device) logits = model(data) batch_size = data.size(0) loss = calculate_loss(logits, target, prefer_focal=True) #loss = calculate_loss(logits, target, loss_type=loss_type) total_loss += float(loss.item()) * batch_size total_samples += batch_size # labels & filenames y_true = _to_numpy_labels(target) all_labels.extend(y_true) filenames.extend(list(batch_filenames)) # probs if logits.ndim == 1 or logits.size(-1) == 1: probs = torch.sigmoid(logits.view(-1)).detach().cpu().numpy() probs_rows.append(probs.reshape(-1, 1)) # keep 2D for uniform handling else: probs = torch.softmax(logits, dim=1).detach().cpu().numpy() probs_rows.append(probs) mean_loss = total_loss / max(1, total_samples) y_true = np.asarray(all_labels, dtype=int) prob_mat = np.vstack(probs_rows) if probs_rows else np.empty((0, 1)) C = prob_mat.shape[1] # metrics if C == 1: metrics = _binary_metrics(y_true, prob_mat.ravel()) else: metrics = _multiclass_metrics(y_true, prob_mat) metrics["loss"] = float(mean_loss) metrics["epoch"] = int(epoch) metrics["Accuracy"] = metrics["accuracy"] # Build per-file results dataframe df_dict = { "filename": filenames, "true_label": y_true.tolist(), "predicted_label": prob_mat.argmax(1).tolist() if C > 1 else (prob_mat.ravel() >= 0.5).astype(int).tolist(), } if C == 1: df_dict["class_1_probability"] = prob_mat.ravel().tolist() else: # add one column per class probs: prob_class_0, prob_class_1, ... for k in range(C): df_dict[f"prob_class_{k}"] = prob_mat[:, k].tolist() results_df = pd.DataFrame(df_dict) return metrics, (prob_mat if C > 1 else prob_mat.ravel()), y_true.tolist(), results_df
[docs] def test_model_performance(loaders, model, loader_name_list, epoch, loss_type): """ Wrapper kept for API compatibility with your caller. Returns (summary_metrics_dataframe, per_file_results_dataframe) """ start_time = time.time() data_dict, _, _, results_df = test_model_core( model=model, loader=loaders, loader_name=loader_name_list, epoch=epoch, loss_type=loss_type, ) # The old function returned a DataFrame in 'result'; emulate that: result_df = pd.DataFrame([data_dict]) return result_df, results_df
[docs] def train_test_model(settings): from .io import _copy_missclassified from .utils import pick_best_model, save_settings from .io import generate_loaders from .settings import get_train_test_model_settings settings = get_train_test_model_settings(settings) torch.cuda.empty_cache() torch.cuda.memory.empty_cache() gc.collect() src = settings['src'] channels_str = ''.join(settings['train_channels']) dst = os.path.join(src, 'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}")) os.makedirs(dst, exist_ok=True) settings['dst'] = dst num_classes = len(settings.get('classes', [])) if settings.get('classes') else 0 if num_classes <= 0: raise ValueError("No classes provided in settings['classes'].") if settings.get('loss_type') in (None, 'auto'): settings['loss_type'] = 'cross_entropy' if num_classes > 1 else 'binary_cross_entropy_with_logits' if settings['train']: if settings['train'] and settings['test']: save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True) elif settings['train'] is True: save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True) elif settings['test'] is True: save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True) model = None model_path = None if settings['train']: train, val, _ = generate_loaders( src, mode='train', image_size=settings['image_size'], batch_size=settings['batch_size'], classes=settings['classes'], n_jobs=settings['n_jobs'], validation_split=settings['val_split'], pin_memory=settings['pin_memory'], normalize=settings['normalize'], channels=settings['train_channels'], augment=settings['augment'], verbose=settings['verbose'] ) model, model_path = train_model( src=src, dst=settings['dst'], model_type=settings['model_type'], train_loaders=train, epochs=settings['epochs'], learning_rate=settings['learning_rate'], init_weights=settings['init_weights'], weight_decay=settings['weight_decay'], amsgrad=settings['amsgrad'], optimizer_type=settings['optimizer_type'], use_checkpoint=settings['use_checkpoint'], dropout_rate=settings['dropout_rate'], n_jobs=settings['n_jobs'], val_loaders=val, test_loaders=None, intermedeate_save=settings['intermedeate_save'], schedule=settings['schedule'], loss_type=settings['loss_type'], gradient_accumulation=settings['gradient_accumulation'], gradient_accumulation_steps=settings['gradient_accumulation_steps'], channels=settings['train_channels'], num_classes=num_classes, early_stopping_patience=settings.get('early_stopping_patience', 0), ) if settings['test']: test, _, _ = generate_loaders( src, mode='test', image_size=settings['image_size'], batch_size=settings['batch_size'], classes=settings['classes'], n_jobs=settings['n_jobs'], validation_split=0.0, pin_memory=settings['pin_memory'], normalize=settings['normalize'], channels=settings['train_channels'], augment=False, verbose=settings['verbose'] ) if model is None: model_path = pick_best_model(src + '/model') print(f'Best model: {model_path}') model = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) model_fldr = dst time_now = datetime.date.today().strftime('%y%m%d') result_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_result.csv" acc_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_acc.csv" print(f'Results will be saved in: {result_loc}') result, accuracy = test_model_performance(loaders=test, model=model, loader_name_list='test', epoch=1, loss_type=settings['loss_type']) result.to_csv(result_loc, index=True, header=True, mode='w') accuracy.to_csv(acc_loc, index=True, header=True, mode='w') _copy_missclassified(accuracy) torch.cuda.empty_cache() torch.cuda.memory.empty_cache() gc.collect() if settings['train']: return model_path if settings['test']: return result_loc
[docs] def train_test_model_v1(settings): from .io import _copy_missclassified from .utils import pick_best_model, save_settings from .io import generate_loaders from .settings import get_train_test_model_settings settings = get_train_test_model_settings(settings) torch.cuda.empty_cache() torch.cuda.memory.empty_cache() gc.collect() src = settings['src'] channels_str = ''.join(settings['train_channels']) dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}")) os.makedirs(dst, exist_ok=True) settings['dst'] = dst # NEW: number of classes num_classes = len(settings.get('classes', [])) if settings.get('classes') else 0 if num_classes <= 0: raise ValueError("No classes provided in settings['classes'].") # NEW: pick loss automatically if desired if settings.get('loss_type') in (None, 'auto'): # Multiclass => cross entropy, Binary (1 class folder would be odd; 2 => CE works well) settings['loss_type'] = 'cross_entropy' if num_classes > 1 else 'binary_cross_entropy_with_logits' if settings['train']: if settings['train'] and settings['test']: save_settings(settings, name=f"train_test_{settings['model_type']}_{settings['epochs']}", show=True) elif settings['train'] is True: save_settings(settings, name=f"train_{settings['model_type']}_{settings['epochs']}", show=True) elif settings['test'] is True: save_settings(settings, name=f"test_{settings['model_type']}_{settings['epochs']}", show=True) if settings['train']: train, val, _ = generate_loaders( src, mode='train', image_size=settings['image_size'], batch_size=settings['batch_size'], classes=settings['classes'], n_jobs=settings['n_jobs'], validation_split=settings['val_split'], pin_memory=settings['pin_memory'], normalize=settings['normalize'], channels=settings['train_channels'], augment=settings['augment'], verbose=settings['verbose'] ) model, model_path = train_model( src=src, dst=settings['dst'], model_type=settings['model_type'], train_loaders=train, epochs=settings['epochs'], learning_rate=settings['learning_rate'], init_weights=settings['init_weights'], weight_decay=settings['weight_decay'], amsgrad=settings['amsgrad'], optimizer_type=settings['optimizer_type'], use_checkpoint=settings['use_checkpoint'], dropout_rate=settings['dropout_rate'], n_jobs=settings['n_jobs'], val_loaders=val, test_loaders=None, intermedeate_save=settings['intermedeate_save'], schedule=settings['schedule'], loss_type=settings['loss_type'], gradient_accumulation=settings['gradient_accumulation'], gradient_accumulation_steps=settings['gradient_accumulation_steps'], channels=settings['train_channels'], num_classes=num_classes # <-- NEW ) if settings['test']: test, _, _ = generate_loaders( src, mode='test', image_size=settings['image_size'], batch_size=settings['batch_size'], classes=settings['classes'], n_jobs=settings['n_jobs'], validation_split=0.0, pin_memory=settings['pin_memory'], normalize=settings['normalize'], channels=settings['train_channels'], augment=False, verbose=settings['verbose'] ) if model is None: model_path = pick_best_model(src+'/model') print(f'Best model: {model_path}') model = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=False) #model = torch.load(model_path, map_location=device, weights_only=False) model_fldr = dst time_now = datetime.date.today().strftime('%y%m%d') result_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_result.csv" acc_loc = f"{model_fldr}/{settings['model_type']}_time_{time_now}_test_acc.csv" print(f'Results wil be saved in: {result_loc}') result, accuracy = test_model_performance(loaders=test, model=model, loader_name_list='test', epoch=1, loss_type=settings['loss_type']) result.to_csv(result_loc, index=True, header=True, mode='w') accuracy.to_csv(acc_loc, index=True, header=True, mode='w') _copy_missclassified(accuracy) torch.cuda.empty_cache() torch.cuda.memory.empty_cache() gc.collect() if settings['train']: return model_path if settings['test']: return result_loc
[docs] def train_model(src,dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule=None, loss_type='auto', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r', 'g', 'b'], verbose=False, num_classes=2, # add early stopping parameters early_stopping_patience=0, # 0 = disabled; e.g. 20 = stop after 20 epochs without val improvement ): """ Trains a model (supports 2-class and >2-class via CrossEntropy). New parameters: early_stopping_patience: number of epochs with no val improvement before stopping. Set to 0 to disable (original behavior). """ from .io import _save_model, _save_progress from .utils import choose_model, suggest_training_changes, build_loss, estimate_class_counts print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders) if val_loaders else 0}') if test_loaders is not None: print(f'Test batches:{len(test_loaders)}') use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print(f'Using {device} for Torch') head_dim = max(1, int(num_classes)) #counts = estimate_class_counts(train_loaders, head_dim) if head_dim >= 2 else None train_data_dir = os.path.join(src, 'train') if os.path.isdir(train_data_dir): classes = sorted([d for d in os.listdir(train_data_dir) if os.path.isdir(os.path.join(train_data_dir, d)) and not d.startswith('.')]) else: classes = None counts = estimate_class_counts(train_loaders, head_dim, src=train_data_dir, classes=classes) if (head_dim >= 2 and classes) else None loss_fn = build_loss( loss_type=loss_type, num_classes=head_dim, class_counts=counts, label_smoothing=0.1, focal_gamma=2.0, focal_alpha=None, logit_adjust_tau=1.0 ) model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose, num_classes=head_dim) if model is None: print(f'Model {model_type} not found') return print(f'Loading Model to {device}...') model.to(device) if optimizer_type == 'adamw': optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad) elif optimizer_type == 'adagrad': optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay) else: raise ValueError(f"Unknown optimizer_type: {optimizer_type}") if schedule == 'step_lr': scheduler = StepLR(optimizer, step_size=max(1, int(epochs / 5)), gamma=0.75) elif schedule == 'reduce_lr_on_plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=10, verbose=True) elif schedule == 'cosine': # FIX: new option — cosine annealing scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-7) else: scheduler = None accumulated_train_dicts, accumulated_val_dicts, accumulated_test_dicts = [], [], [] # track the best validation accuracy and corresponding model path best_val_acc = -1.0 best_model_path = None epochs_without_improvement = 0 print('Training ...') for epoch in range(1, epochs + 1): model.train() start_time = time.time() if gradient_accumulation: optimizer.zero_grad(set_to_none=True) # record total number of batches so we can detect leftover gradients n_batches = len(train_loaders) for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1): data = data.to(device) logits = model(data) is_multiclass = (logits.ndim == 2 and logits.size(1) >= 2) if is_multiclass: if target.ndim == 2: target = target.argmax(dim=1) target = target.to(device).long() if not (logits.ndim == 2 and logits.size(1) == head_dim): raise RuntimeError( f"Expected logits (N,{head_dim}) for CE, got {tuple(logits.shape)}") else: target = target.to(device).float() loss = loss_fn(logits, target) if gradient_accumulation: loss = loss / gradient_accumulation_steps loss.backward() if (not gradient_accumulation) or (batch_idx % gradient_accumulation_steps == 0): optimizer.step() optimizer.zero_grad(set_to_none=True) # flush leftover accumulated gradients at the end of the epoch if gradient_accumulation and (n_batches % gradient_accumulation_steps != 0): optimizer.step() optimizer.zero_grad(set_to_none=True) # Epoch end: evaluate train_time = time.time() - start_time loop_loss_type = 'ce' if head_dim >= 2 else 'bce' train_dict, _ = evaluate_model_performance( model, train_loaders, epoch, loss_type=loop_loss_type, loss_fn=loss_fn, num_classes=head_dim ) train_dict['train_time'] = train_time accumulated_train_dicts.append(train_dict) # initialize val_dict to None so the variable always exists for _save_model val_dict = None if val_loaders is not None and len(val_loaders) > 0: val_dict, _ = evaluate_model_performance( model, val_loaders, epoch, loss_type=loop_loss_type, loss_fn=loss_fn, num_classes=head_dim ) accumulated_val_dicts.append(val_dict) if schedule == 'reduce_lr_on_plateau': scheduler.step(val_dict['loss']) print(f"Progress: {train_dict.get('epoch', epoch)}/{epochs}, operation_type: Training, " f"Train Loss: {train_dict.get('loss', float('nan')):.3f}, " f"Val Loss: {val_dict.get('loss', float('nan')):.3f}, " f"Train acc.: {train_dict.get('accuracy', float('nan')):.3f}, " f"Val acc.: {val_dict.get('accuracy', float('nan')):.3f}, " f"Train F1(macro): {train_dict.get('f1_macro', float('nan')):.3f}, " f"Val F1(macro): {val_dict.get('f1_macro', float('nan')):.3f}") # track best validation accuracy for early stopping and best-model selection current_val_acc = val_dict.get('accuracy', 0.0) if current_val_acc > best_val_acc: best_val_acc = current_val_acc epochs_without_improvement = 0 else: epochs_without_improvement += 1 else: print(f"Progress: {train_dict.get('epoch', epoch)}/{epochs}, operation_type: Training, " f"Train Loss: {train_dict.get('loss', float('nan')):.3f}, " f"Train acc.: {train_dict.get('accuracy', float('nan')):.3f}, " f"Train F1(macro): {train_dict.get('f1_macro', float('nan')):.3f}") if scheduler and schedule in ('step_lr', 'cosine'): # FIX: also step cosine scheduler here scheduler.step() # Save rolling CSVs if accumulated_train_dicts and accumulated_val_dicts: _save_progress(dst, pd.DataFrame(accumulated_train_dicts), pd.DataFrame(accumulated_val_dicts)) accumulated_train_dicts, accumulated_val_dicts = [], [] elif accumulated_train_dicts: _save_progress(dst, pd.DataFrame(accumulated_train_dicts), None) accumulated_train_dicts = [] elif accumulated_test_dicts: _save_progress(dst, pd.DataFrame(accumulated_test_dicts), None) accumulated_test_dicts = [] # pass val_dict to _save_model so checkpoint decisions use validation accuracy model_path = _save_model(model, model_type, train_dict, dst, epoch, epochs, intermedeate_save=[0.99, 0.98, 0.95, 0.94], channels=channels, val_dict=val_dict) # FIX: new argument # track the best model path based on validation accuracy if model_path is not None and val_dict is not None: if val_dict.get('accuracy', 0.0) >= best_val_acc: best_model_path = model_path elif model_path is not None and best_model_path is None: best_model_path = model_path # early stopping — break if val hasn't improved for `patience` epochs if early_stopping_patience > 0 and epochs_without_improvement >= early_stopping_patience: print(f"\nEarly stopping at epoch {epoch}: no val improvement for " f"{early_stopping_patience} epochs. Best val acc: {best_val_acc:.4f}") break # Periodic suggestions (every 25 epochs and final epoch) if (epoch % 25 == 0) or (epoch == epochs): try: report = suggest_training_changes(dst) print("== Summary ==") for k, v in report["summary"].items(): print(f"{k}: {v}") print("\n== Flags ==") print(", ".join(report["flags"]) or "none") print("\n== Suggestions ==") for i, s in enumerate(report["suggestions"], 1): print(f"{i}. {s}") except Exception as e: print(f"[suggest_training_changes] Skipped at epoch {epoch}: {e}") # return best_model_path if available, otherwise fall back to last model_path final_path = best_model_path if best_model_path is not None else model_path return model, final_path
[docs] def train_model_v1(src,dst, model_type, train_loaders, epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule=None, loss_type='auto', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b'], verbose=False, num_classes=2): """ Trains a model (supports 2-class and >2-class via CrossEntropy; BCE only for true single-logit binary). """ import pandas as pd # ensure pd is available for _save_progress from .io import _save_model, _save_progress from .utils import choose_model, suggest_training_changes, build_loss, estimate_class_counts print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders) if val_loaders else 0}') if test_loaders is not None: print(f'Test batches:{len(test_loaders)}') use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") print(f'Using {device} for Torch') # Head dimension equals declared num_classes (2 or >2 => softmax head) head_dim = max(1, int(num_classes)) # optional: get global label stats for weighting / logit adjustment counts = estimate_class_counts(train_loaders, head_dim) if head_dim >= 2 else None loss_fn = build_loss( loss_type=loss_type, # 'auto' | 'ce' | 'ce_smooth' | 'ce_weighted' | 'focal_ce' | 'bce' | 'focal_bce' | 'logit_adjust_ce' | 'asl' num_classes=head_dim, class_counts=counts, # required for ce_weighted / logit_adjust_ce label_smoothing=0.1, # only used by ce_smooth focal_gamma=2.0, focal_alpha=None, # e.g. tensor of per-class weights or scalar logit_adjust_tau=1.0 # >0 enables prior-aware CE ) model = choose_model(model_type, device, init_weights, dropout_rate, use_checkpoint, verbose=verbose, num_classes=head_dim) if model is None: print(f'Model {model_type} not found') return print(f'Loading Model to {device}...') model.to(device) if optimizer_type == 'adamw': optimizer = AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), weight_decay=weight_decay, amsgrad=amsgrad) elif optimizer_type == 'adagrad': optimizer = Adagrad(model.parameters(), lr=learning_rate, eps=1e-8, weight_decay=weight_decay) else: raise ValueError(f"Unknown optimizer_type: {optimizer_type}") if schedule == 'step_lr': scheduler = StepLR(optimizer, step_size=max(1, int(epochs/5)), gamma=0.75) elif schedule == 'reduce_lr_on_plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',factor=0.1, patience=10, verbose=True) else: scheduler = None accumulated_train_dicts, accumulated_val_dicts, accumulated_test_dicts = [], [], [] print('Training ...') for epoch in range(1, epochs+1): model.train() start_time = time.time() if gradient_accumulation: optimizer.zero_grad(set_to_none=True) for batch_idx, (data, target, filenames) in enumerate(train_loaders, start=1): data = data.to(device) logits = model(data) # Decide task type from head shape is_multiclass = (logits.ndim == 2 and logits.size(1) >= 2) # --- Normalize targets to match the chosen head/loss --- if is_multiclass: # If labels are one-hot (N,C), convert to indices (N,) if target.ndim == 2: target = target.argmax(dim=1) target = target.to(device).long() # CE expects Long indices # shape check if not (logits.ndim == 2 and logits.size(1) == head_dim): raise RuntimeError(f"Expected logits (N,{head_dim}) for CE, got {tuple(logits.shape)}") else: target = target.to(device).float() # BCE expects float {0,1} loss = loss_fn(logits, target) if gradient_accumulation: loss = loss / gradient_accumulation_steps loss.backward() if (not gradient_accumulation) or (batch_idx % gradient_accumulation_steps == 0): optimizer.step() optimizer.zero_grad(set_to_none=True) # Epoch end: evaluate train_time = time.time() - start_time loop_loss_type = 'ce' if head_dim >= 2 else 'bce' # tag for logging only train_dict, _ = evaluate_model_performance( model, train_loaders, epoch, loss_type=loop_loss_type, loss_fn=loss_fn, num_classes=head_dim ) train_dict['train_time'] = train_time accumulated_train_dicts.append(train_dict) if val_loaders is not None and len(val_loaders) > 0: val_dict, _ = evaluate_model_performance( model, val_loaders, epoch, loss_type=loop_loss_type, loss_fn=loss_fn, num_classes=head_dim ) accumulated_val_dicts.append(val_dict) if schedule == 'reduce_lr_on_plateau': scheduler.step(val_dict['loss']) print(f"Progress: {train_dict.get('epoch', epoch)}/{epochs}, operation_type: Training, " f"Train Loss: {train_dict.get('loss', float('nan')):.3f}, " f"Val Loss: {val_dict.get('loss', float('nan')):.3f}, " f"Train acc.: {train_dict.get('accuracy', float('nan')):.3f}, " f"Val acc.: {val_dict.get('accuracy', float('nan')):.3f}, " f"Train F1(macro): {train_dict.get('f1_macro', float('nan')):.3f}, " f"Val F1(macro): {val_dict.get('f1_macro', float('nan')):.3f}") else: print(f"Progress: {train_dict.get('epoch', epoch)}/{epochs}, operation_type: Training, " f"Train Loss: {train_dict.get('loss', float('nan')):.3f}, " f"Train acc.: {train_dict.get('accuracy', float('nan')):.3f}, " f"Train F1(macro): {train_dict.get('f1_macro', float('nan')):.3f}") if scheduler and schedule == 'step_lr': scheduler.step() # Save rolling CSVs if accumulated_train_dicts and accumulated_val_dicts: _save_progress(dst, pd.DataFrame(accumulated_train_dicts), pd.DataFrame(accumulated_val_dicts)) accumulated_train_dicts, accumulated_val_dicts = [], [] elif accumulated_train_dicts: _save_progress(dst, pd.DataFrame(accumulated_train_dicts), None) accumulated_train_dicts = [] elif accumulated_test_dicts: _save_progress(dst, pd.DataFrame(accumulated_test_dicts), None) accumulated_test_dicts = [] # Save checkpoints model_path = _save_model(model, model_type, train_dict, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels) # ---- Periodic suggestions (every 25 epochs and final epoch) ---- if (epoch % 25 == 0) or (epoch == epochs): try: report = suggest_training_changes(dst) print("== Summary ==") for k, v in report["summary"].items(): print(f"{k}: {v}") print("\n== Flags ==") print(", ".join(report["flags"]) or "none") print("\n== Suggestions ==") for i, s in enumerate(report["suggestions"], 1): print(f"{i}. {s}") except Exception as e: print(f"[suggest_training_changes] Skipped at epoch {epoch}: {e}") return model, model_path
[docs] def generate_activation_map(settings): from .utils import SaliencyMapGenerator, GradCAMGenerator, SelectChannels, activation_maps_to_database, activation_correlations_to_database from .utils import print_progress, save_settings, calculate_activation_correlations from .io import TarImageDataset from .settings import get_default_generate_activation_map_settings torch.cuda.empty_cache() gc.collect() plt.clf() use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") source_folder = os.path.dirname(os.path.dirname(settings['dataset'])) settings['src'] = source_folder settings = get_default_generate_activation_map_settings(settings) save_settings(settings, name=f"{settings['cam_type']}_settings", show=False) if settings['model_type'] == 'maxvit' and settings['target_layer'] == None: settings['target_layer'] = 'base_model.blocks.3.layers.1.layers.MBconv.layers.conv_b' if settings['cam_type'] in ['saliency_image', 'saliency_channel']: settings['target_layer'] = None # Set number of jobs for loading n_jobs = settings['n_jobs'] if n_jobs is None: n_jobs = max(1, cpu_count() - 4) # Set transforms for images transform = transforms.Compose([ transforms.ToTensor(), transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) if settings['normalize_input'] else None, SelectChannels(settings['channels']) ]) # Handle dataset path if not os.path.exists(settings['dataset']): print(f"Dataset not found at {settings['dataset']}") return # Load the model #model = torch.load(settings['model_path']) model = torch.load(settings['model_path'], map_location=device, weights_only=False) model.to(device) model.eval() # Create directory for saving activation maps if it does not exist dataset_dir = os.path.dirname(settings['dataset']) dataset_name = os.path.splitext(os.path.basename(settings['dataset']))[0] save_dir = os.path.join(dataset_dir, dataset_name, settings['cam_type']) batch_grid_fldr = os.path.join(save_dir, 'batch_grids') if settings['save']: os.makedirs(save_dir, exist_ok=True) print(f"Activation maps will be saved in: {save_dir}") if settings['plot']: os.makedirs(batch_grid_fldr, exist_ok=True) print(f"Batch grid maps will be saved in: {batch_grid_fldr}") # Load dataset dataset = TarImageDataset(settings['dataset'], transform=transform) data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=settings['shuffle'], num_workers=n_jobs, pin_memory=True) # Initialize generator based on cam_type if settings['cam_type'] in ['gradcam', 'gradcam_pp']: cam_generator = GradCAMGenerator(model, target_layer=settings['target_layer'], cam_type=settings['cam_type']) elif settings['cam_type'] in ['saliency_image', 'saliency_channel']: cam_generator = SaliencyMapGenerator(model) time_ls = [] for batch_idx, (inputs, filenames) in enumerate(data_loader): start = time.time() img_paths = [] inputs = inputs.to(device) # Compute activation maps and predictions if settings['cam_type'] in ['gradcam', 'gradcam_pp']: activation_maps, predicted_classes = cam_generator.compute_gradcam_and_predictions(inputs) elif settings['cam_type'] in ['saliency_image', 'saliency_channel']: activation_maps, predicted_classes = cam_generator.compute_saliency_and_predictions(inputs) # Move activation maps to CPU activation_maps = activation_maps.cpu() # Sum saliency maps for 'saliency_image' type if settings['cam_type'] == 'saliency_image': summed_activation_maps = [] for i in range(activation_maps.size(0)): activation_map = activation_maps[i] #print(f"1: {activation_map.shape}") activation_map_sum = activation_map.sum(dim=0, keepdim=False) #print(f"2: {activation_map.shape}") activation_map_sum = np.squeeze(activation_map_sum, axis=0) #print(f"3: {activation_map_sum.shape}") summed_activation_maps.append(activation_map_sum) activation_maps = torch.stack(summed_activation_maps) # For plotting if settings['plot']: fig = cam_generator.plot_activation_grid(inputs, activation_maps, predicted_classes, overlay=settings['overlay'], normalize=settings['normalize']) pdf_save_path = os.path.join(batch_grid_fldr,f"batch_{batch_idx}_grid.pdf") fig.savefig(pdf_save_path, format='pdf') print(f"Saved batch grid to {pdf_save_path}") #plt.show() display(fig) for i in range(inputs.size(0)): activation_map = activation_maps[i].detach().numpy() if settings['cam_type'] in ['saliency_image', 'gradcam', 'gradcam_pp']: #activation_map = activation_map.sum(axis=0) activation_map = (activation_map - activation_map.min()) / (activation_map.max() - activation_map.min()) activation_map = (activation_map * 255).astype(np.uint8) activation_image = Image.fromarray(activation_map, mode='L') elif settings['cam_type'] == 'saliency_channel': # Handle each channel separately and save as RGB rgb_activation_map = np.zeros((activation_map.shape[1], activation_map.shape[2], 3), dtype=np.uint8) for c in range(min(activation_map.shape[0], 3)): # Limit to 3 channels for RGB channel_map = activation_map[c] channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min()) rgb_activation_map[:, :, c] = (channel_map * 255).astype(np.uint8) activation_image = Image.fromarray(rgb_activation_map, mode='RGB') # Save activation maps class_pred = predicted_classes[i].item() parts = filenames[i].split('_') plate = parts[0] well = parts[1] save_class_dir = os.path.join(save_dir, f'class_{class_pred}', str(plate), str(well)) os.makedirs(save_class_dir, exist_ok=True) save_path = os.path.join(save_class_dir, f'{filenames[i]}') if settings['save']: activation_image.save(save_path) img_paths.append(save_path) if settings['save']: activation_maps_to_database(img_paths, source_folder, settings) if settings['correlation']: df = calculate_activation_correlations(inputs, activation_maps, filenames, manders_thresholds=settings['manders_thresholds']) if settings['plot']: display(df) if settings['save']: activation_correlations_to_database(df, img_paths, source_folder, settings) stop = time.time() duration = stop - start time_ls.append(duration) files_processed = batch_idx * settings['batch_size'] files_to_process = len(data_loader) * settings['batch_size'] print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Generating Activation Maps") torch.cuda.empty_cache() gc.collect() print("Activation map generation complete.")
[docs] def visualize_classes(model, dtype, class_names, **kwargs): from .utils import class_visualization for target_y in range(2): # Assuming binary classification print(f"Visualizing class: {class_names[target_y]}") visualization = class_visualization(target_y, model, dtype, **kwargs) plt.imshow(visualization) plt.title(f"Class {class_names[target_y]} Visualization") plt.axis('off') plt.show()
[docs] def visualize_integrated_gradients(src, model_path, target_label_idx=0, image_size=224, channels=[1,2,3], normalize=True, save_integrated_grads=False, save_dir='integrated_grads'): from .utils import IntegratedGradients, preprocess_image use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") #model = torch.load(model_path) model = torch.load(model_path, map_location=device, weights_only=False) model.to(device) integrated_gradients = IntegratedGradients(model) if save_integrated_grads and not os.path.exists(save_dir): os.makedirs(save_dir) images = [] filenames = [] for file in os.listdir(src): if not file.endswith('.png'): continue image_path = os.path.join(src, file) image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels) images.append(image) filenames.append(file) input_tensor = input_tensor.to(device) integrated_grads = integrated_gradients.generate_integrated_gradients(input_tensor, target_label_idx) integrated_grads = np.mean(integrated_grads, axis=1).squeeze() fig, ax = plt.subplots(1, 3, figsize=(20, 5)) ax[0].imshow(image) ax[0].axis('off') ax[0].set_title("Original Image") ax[1].imshow(integrated_grads, cmap='hot') ax[1].axis('off') ax[1].set_title("Integrated Gradients") overlay = np.array(image) overlay = overlay / overlay.max() integrated_grads_rgb = np.stack([integrated_grads] * 3, axis=-1) # Convert saliency map to RGB overlay = (overlay * 0.5 + integrated_grads_rgb * 0.5).clip(0, 1) ax[2].imshow(overlay) ax[2].axis('off') ax[2].set_title("Overlay") plt.show() if save_integrated_grads: os.makedirs(save_dir, exist_ok=True) integrated_grads_image = Image.fromarray((integrated_grads * 255).astype(np.uint8)) integrated_grads_image.save(os.path.join(save_dir, f'integrated_grads_{file}'))
[docs] class SmoothGrad: def __init__(self, model, n_samples=50, stdev_spread=0.15):
[docs] self.model = model
[docs] self.n_samples = n_samples
[docs] self.stdev_spread = stdev_spread
[docs] def compute_smooth_grad(self, input_tensor, target_class): self.model.eval() stdev = self.stdev_spread * (input_tensor.max() - input_tensor.min()) total_gradients = torch.zeros_like(input_tensor) for i in range(self.n_samples): noise = torch.normal(mean=0, std=stdev, size=input_tensor.shape).to(input_tensor.device) noisy_input = input_tensor + noise noisy_input.requires_grad_() output = self.model(noisy_input) self.model.zero_grad() output[0, target_class].backward() total_gradients += noisy_input.grad avg_gradients = total_gradients / self.n_samples return avg_gradients.abs()
[docs] def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, channels=[1,2,3], normalize=True, save_smooth_grad=False, save_dir='smooth_grad'): from .utils import preprocess_image use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") #model = torch.load(model_path) model = torch.load(model_path, map_location=device, weights_only=False) model.to(device) smooth_grad = SmoothGrad(model) if save_smooth_grad and not os.path.exists(save_dir): os.makedirs(save_dir) images = [] filenames = [] for file in os.listdir(src): if not file.endswith('.png'): continue image_path = os.path.join(src, file) image, input_tensor = preprocess_image(image_path, normalize=normalize, image_size=image_size, channels=channels) images.append(image) filenames.append(file) input_tensor = input_tensor.to(device) smooth_grad_map = smooth_grad.compute_smooth_grad(input_tensor, target_label_idx) smooth_grad_map = np.mean(smooth_grad_map.cpu().data.numpy(), axis=1).squeeze() fig, ax = plt.subplots(1, 3, figsize=(20, 5)) ax[0].imshow(image) ax[0].axis('off') ax[0].set_title("Original Image") ax[1].imshow(smooth_grad_map, cmap='hot') ax[1].axis('off') ax[1].set_title("SmoothGrad") overlay = np.array(image) overlay = overlay / overlay.max() smooth_grad_map_rgb = np.stack([smooth_grad_map] * 3, axis=-1) # Convert smooth grad map to RGB overlay = (overlay * 0.5 + smooth_grad_map_rgb * 0.5).clip(0, 1) ax[2].imshow(overlay) ax[2].axis('off') ax[2].set_title("Overlay") plt.show() if save_smooth_grad: os.makedirs(save_dir, exist_ok=True) smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8)) smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
[docs] def save_top_class_examples(df, tar_path, dst, n=20, classes=None): """ Extract the N most confident images per class from the tar and save them into class-labelled subfolders under dst. For binary classification (classes=[0, 1]): - class_0/ ← the 20 images with the LOWEST pred (closest to 0) - class_1/ ← the 20 images with the HIGHEST pred (closest to 1) Parameters ---------- df : pd.DataFrame Must contain columns 'path' (tar member name) and 'pred' (probability). tar_path : str Path to the tar archive that holds the images. dst : str Root folder where class subfolders will be created. n : int Number of top images to keep per class. classes : list or None Explicit class labels. If None, defaults to binary [0, 1]. """ import os, tarfile from io import BytesIO from PIL import Image # -- default to binary labels if the caller doesn't specify -- if classes is None: classes = [0, 1] # -- collect the tar member names we need to extract -- paths_to_extract = set() for cls in classes: # For binary: class 0 → lowest pred, class 1 → highest pred # For multiclass this would need per-class probability columns; # keeping it binary-friendly for now. if cls == 0: # sort ascending: values closest to 0 come first top = df.nsmallest(n, 'pred') else: # sort descending: values closest to 1 come first top = df.nlargest(n, 'pred') # create the class subfolder cls_dir = os.path.join(dst, f'class_{cls}') os.makedirs(cls_dir, exist_ok=True) # remember which member names belong to this class for _, row in top.iterrows(): paths_to_extract.add(row['path']) # -- build a lookup: tar member name → list of (class, dest_path) -- # (an image could theoretically appear in both extremes for a # degenerate model, so we use a list) member_destinations = {} for cls in classes: if cls == 0: top = df.nsmallest(n, 'pred') else: top = df.nlargest(n, 'pred') cls_dir = os.path.join(dst, f'class_{cls}') for _, row in top.iterrows(): fname = os.path.basename(row['path']) dest_file = os.path.join(cls_dir, fname) member_destinations.setdefault(row['path'], []).append(dest_file) # -- single pass through the tar: extract only the members we need -- extracted = 0 with tarfile.open(tar_path, 'r') as tar: for member in tar.getmembers(): if member.name in member_destinations: img_bytes = tar.extractfile(member).read() for dest_file in member_destinations[member.name]: with open(dest_file, 'wb') as f: f.write(img_bytes) extracted += 1 print(f"Saved {extracted} top-confidence example images to {dst}") return dst
[docs] def merge_predictions_into_db(df, db_path, table='png_list', pred_col='pred', class_col='cv_predictions'): """ Merge prediction scores back into the SQLite database. Matching on basename of png_path (DB) vs path (tar member name), since the tar stores relative member names while the DB stores full disk paths. """ import sqlite3, os if not os.path.exists(db_path): print(f"Database not found at {db_path}; skipping merge.") return df = df.copy() # -- derive basename from the tar member name for matching -- df['_join_key'] = df['path'].apply(os.path.basename) conn = sqlite3.connect(db_path) cur = conn.cursor() # -- add prediction columns if they don't already exist -- for col, col_type in [(pred_col, 'REAL'), (class_col, 'INTEGER')]: try: cur.execute(f"ALTER TABLE {table} ADD COLUMN {col} {col_type}") except sqlite3.OperationalError: pass # -- build a lookup dict keyed on basename for fast matching -- lookup = {} for _, row in df.iterrows(): lookup[row['_join_key']] = (row[pred_col], row[class_col]) # -- fetch png_path from the DB; match on its basename -- cur.execute(f"SELECT rowid, png_path FROM {table}") rows = cur.fetchall() updates = [] matched = 0 for rowid, png_path in rows: # png_path is the full disk path; compare basenames basename = os.path.basename(png_path) if png_path else None if basename in lookup: pred_val, class_val = lookup[basename] updates.append((pred_val, class_val, rowid)) matched += 1 cur.executemany( f"UPDATE {table} SET {pred_col} = ?, {class_col} = ? WHERE rowid = ?", updates ) conn.commit() conn.close() print(f"Merged predictions into {table}: {matched}/{len(rows)} rows matched") return matched
[docs] def deep_spacr_v1(settings={}): import os # local imports kept inside to avoid import cycles on some setups from .settings import deep_spacr_defaults from .io import generate_training_dataset, generate_dataset from .utils import save_settings # 1) expand defaults (now supports things like metadata_rules, annotation_columns, measurement_rules, etc.) settings = deep_spacr_defaults(settings) src_before = settings.get('src') # persist a snapshot of the config for reproducibility save_settings(settings, name='DL_model') # 2) dataset generation (train/test) if settings.get('train') or settings.get('test'): if settings.get('generate_training_dataset'): print("Generating train and test datasets ...") train_path, test_path = generate_training_dataset(settings) print(f'Generated Train set: {train_path}') print(f'Generated Test set: {test_path}') if train_path: settings['src'] = os.path.dirname(train_path) else: print("Training dataset generation failed; skipping model training step.") return # or raise RuntimeError if you prefer hard fail # point training to the newly created train folder by default settings['src'] = os.path.dirname(train_path) print("Training model ...") model_path = train_test_model(settings) settings['model_path'] = model_path # restore original src (so later steps like apply can use the user’s dataset if needed) settings['src'] = src_before # 4) apply model to dataset/tar if settings.get('apply_model_to_dataset'): tar_path = settings.get('tar_path') # if tar_path missing OR invalid, (re)generate it if not tar_path or not os.path.isabs(tar_path) or not os.path.exists(tar_path): print("tar_path not valid/found; generating dataset tar ...") tar_path = generate_dataset(settings) settings['tar_path'] = tar_path model_path = settings.get('model_path') if model_path and os.path.exists(model_path): apply_model_to_tar(settings) else: print(f"Model path {model_path} not found; skipping model application.")
[docs] def deep_spacr(settings={}): import os # local imports kept inside to avoid import cycles on some setups from .settings import deep_spacr_defaults from .io import generate_training_dataset, generate_dataset from .utils import save_settings # 1) expand defaults (now supports things like metadata_rules, annotation_columns, measurement_rules, etc.) settings = deep_spacr_defaults(settings) src_before = settings.get('src') # persist a snapshot of the config for reproducibility save_settings(settings, name='DL_model') # 2) dataset generation (train/test) if settings.get('train') or settings.get('test'): if settings.get('generate_training_dataset'): print("Generating train and test datasets ...") train_path, test_path = generate_training_dataset(settings) print(f'Generated Train set: {train_path}') print(f'Generated Test set: {test_path}') if train_path: settings['src'] = os.path.dirname(train_path) else: print("Training dataset generation failed; skipping model training step.") return # or raise RuntimeError if you prefer hard fail # point training to the newly created train folder by default settings['src'] = os.path.dirname(train_path) print("Training model ...") model_path = train_test_model(settings) settings['model_path'] = model_path # restore original src (so later steps like apply can use the user’s dataset if needed) settings['src'] = src_before # 4) apply model to dataset/tar if settings.get('apply_model_to_dataset'): tar_path = settings.get('tar_path') # if tar_path missing OR invalid, (re)generate it if not tar_path or not os.path.isabs(tar_path) or not os.path.exists(tar_path): print("tar_path not valid/found; generating dataset tar ...") tar_path = generate_dataset(settings) settings['tar_path'] = tar_path model_path = settings.get('model_path') if model_path and os.path.exists(model_path): # -- run inference and get the results DataFrame -- df = apply_model_to_tar(settings) # -- NEW: save the top-N most confident images per class -- # dst sits next to the tar file, in a subfolder called 'top_examples' examples_dst = os.path.join(os.path.dirname(tar_path), 'top_examples') n_examples = settings.get('n_top_examples', 20) save_top_class_examples(df, tar_path, examples_dst, n=n_examples) # -- NEW: merge predictions back into the measurements database -- # settings['src'] can be a string or list; use the first entry src_list = settings['src'] if isinstance(settings['src'], list) else [settings['src']] for src in src_list: db_path = os.path.join(src, 'measurements', 'measurements.db') merge_predictions_into_db(df, db_path) else: print(f"Model path {model_path} not found; skipping model application.")
[docs] def model_knowledge_transfer(teacher_paths, student_save_path, data_loader, device='cpu', student_model_name='maxvit_t', pretrained=True, dropout_rate=None, use_checkpoint=False, alpha=0.5, temperature=2.0, lr=1e-4, epochs=10): from .utils import TorchModel # Adjust filename to reflect knowledge-distillation if desired if student_save_path.endswith('.pth'): base, ext = os.path.splitext(student_save_path) else: base = student_save_path student_save_path = base + '_KD.pth' # -- 1. Load teacher models -- teachers = [] print("Loading teacher models:") for path in teacher_paths: print(f" Loading teacher: {path}") ckpt = torch.load(path, map_location=device) if isinstance(ckpt, TorchModel): teacher = ckpt.to(device) elif isinstance(ckpt, dict): # If it's a dict with 'model' inside # We might need to check if it has 'model_name', etc. # But let's keep it simple: same architecture as the student teacher = TorchModel( model_name=ckpt.get('model_name', student_model_name), pretrained=ckpt.get('pretrained', pretrained), dropout_rate=ckpt.get('dropout_rate', dropout_rate), use_checkpoint=ckpt.get('use_checkpoint', use_checkpoint) ).to(device) teacher.load_state_dict(ckpt['model']) else: raise ValueError(f"Unsupported checkpoint type at {path} (must be TorchModel or dict).") teacher.eval() # For consistent batchnorm, dropout teachers.append(teacher) # -- 2. Initialize the student TorchModel -- student_model = TorchModel( model_name=student_model_name, pretrained=pretrained, dropout_rate=dropout_rate, use_checkpoint=use_checkpoint ).to(device) # You could load a partial checkpoint into the student here if desired. # -- 3. Optimizer -- optimizer = optim.Adam(student_model.parameters(), lr=lr) # Distillation training loop for epoch in range(epochs): student_model.train() running_loss = 0.0 for images, labels in data_loader: images, labels = images.to(device), labels.to(device) # Forward pass student logits_s = student_model(images) # shape: (B, num_classes) logits_s_temp = logits_s / temperature # scale by T # Distillation from teachers with torch.no_grad(): # We'll average teacher probabilities teacher_probs_list = [] for tm in teachers: logits_t = tm(images) / temperature # convert to probabilities teacher_probs_list.append(F.softmax(logits_t, dim=1)) # average them teacher_probs_ensemble = torch.mean(torch.stack(teacher_probs_list), dim=0) # Student probabilities (log-softmax) student_log_probs = F.log_softmax(logits_s_temp, dim=1) # Distillation loss => KLDiv loss_distill = F.kl_div( student_log_probs, teacher_probs_ensemble, reduction='batchmean' ) * (temperature ** 2) # Real label loss => cross-entropy # We can compute this on the raw logits or scaled. Typically raw logits is standard: loss_ce = F.cross_entropy(logits_s, labels) # Weighted sum loss = alpha * loss_ce + (1 - alpha) * loss_distill optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() avg_loss = running_loss / len(data_loader) print(f"Epoch [{epoch+1}/{epochs}] - Loss: {avg_loss:.4f}") # -- 4. Save final student as a TorchModel -- torch.save(student_model, student_save_path) print(f"Knowledge-distilled student saved to: {student_save_path}") return student_model
[docs] def model_fusion(model_paths,save_path,device='cpu',model_name='maxvit_t',pretrained=True,dropout_rate=None,use_checkpoint=False,aggregator='mean'): from .utils import TorchModel if save_path.endswith('.pth'): save_path_part1, ext = os.path.splitext(save_path) else: save_path_part1 = save_path save_path = save_path_part1 + f'_{aggregator}.pth' valid_aggregators = {'mean', 'geomean', 'median', 'sum', 'max', 'min'} if aggregator not in valid_aggregators: raise ValueError(f"Invalid aggregator '{aggregator}'. " f"Must be one of {valid_aggregators}.") # --- 1. Load the first checkpoint to figure out architecture & hyperparams --- print(f"Loading the first model from: {model_paths[0]} to derive architecture") first_ckpt = torch.load(model_paths[0], map_location=device) if isinstance(first_ckpt, dict): # It's a dict with state_dict + possibly metadata # Use any stored metadata if present model_name = first_ckpt.get('model_name', model_name) pretrained = first_ckpt.get('pretrained', pretrained) dropout_rate = first_ckpt.get('dropout_rate', dropout_rate) use_checkpoint = first_ckpt.get('use_checkpoint', use_checkpoint) # Initialize the fused model fused_model = TorchModel( model_name=model_name, pretrained=pretrained, dropout_rate=dropout_rate, use_checkpoint=use_checkpoint ).to(device) # We'll collect state dicts in a list state_dicts = [first_ckpt['model']] # the actual weights elif isinstance(first_ckpt, TorchModel): # The checkpoint is directly a TorchModel instance fused_model = first_ckpt.to(device) state_dicts = [fused_model.state_dict()] else: raise ValueError("Unsupported checkpoint format. Must be a dict or a TorchModel instance.") # --- 2. Load the rest of the checkpoints --- for path in model_paths[1:]: print(f"Loading model from: {path}") ckpt = torch.load(path, map_location=device) if isinstance(ckpt, dict): state_dicts.append(ckpt['model']) # Just the state dict portion elif isinstance(ckpt, TorchModel): state_dicts.append(ckpt.state_dict()) else: raise ValueError(f"Unsupported checkpoint format in {path} (must be dict or TorchModel).") # --- 3. Verify all state dicts have the same keys --- fused_sd = fused_model.state_dict() for sd in state_dicts: if fused_sd.keys() != sd.keys(): raise ValueError("All models must have identical architecture/state_dict keys.") # --- 4. Define aggregator logic --- def combine_tensors(tensor_list, mode='mean'): """Given a list of Tensors, combine them using the chosen aggregator.""" # stack along new dimension => shape (num_models, *tensor.shape) stacked = torch.stack(tensor_list, dim=0).float() if mode == 'mean': return stacked.mean(dim=0) elif mode == 'geomean': # geometric mean = exp(mean(log(tensor))) # caution: requires all > 0 return torch.exp(torch.log(stacked).mean(dim=0)) elif mode == 'median': return stacked.median(dim=0).values elif mode == 'sum': return stacked.sum(dim=0) elif mode == 'max': return stacked.max(dim=0).values elif mode == 'min': return stacked.min(dim=0).values else: raise ValueError(f"Unsupported aggregator: {mode}") # --- 5. Combine the weights --- for key in fused_sd.keys(): # gather all versions of this tensor all_tensors = [sd[key] for sd in state_dicts] fused_sd[key] = combine_tensors(all_tensors, mode=aggregator) # Load combined weights into the fused model fused_model.load_state_dict(fused_sd) # --- 6. Save the entire TorchModel object --- torch.save(fused_model, save_path) print(f"Fused model (aggregator='{aggregator}') saved as a full TorchModel to: {save_path}") return fused_model
[docs] def annotate_filter_vision(settings): from .utils import annotate_conditions, correct_metadata def filter_csv_by_png(csv_file): """ Filters a DataFrame by removing rows that match PNG filenames in a folder. Parameters: csv_file (str): Path to the CSV file. Returns: pd.DataFrame: Filtered DataFrame. """ # Split the path to identify the datasets folder and build the training folder path before_datasets, after_datasets = csv_file.split(os.sep + "datasets" + os.sep, 1) train_fldr = os.path.join(before_datasets, 'datasets', 'training', 'train') # Paths for train/nc and train/pc nc_folder = os.path.join(train_fldr, 'nc') pc_folder = os.path.join(train_fldr, 'pc') # Load the CSV file into a DataFrame df = pd.read_csv(csv_file) # Collect PNG filenames from train/nc and train/pc png_files = set() for folder in [nc_folder, pc_folder]: if os.path.exists(folder): # Ensure the folder exists png_files.update({file for file in os.listdir(folder) if file.endswith(".png")}) # Filter the DataFrame by excluding rows where filenames match PNG files filtered_df = df[~df['path'].isin(png_files)] return filtered_df if isinstance(settings['src'], str): settings['src'] = [settings['src']] for src in settings['src']: ann_src, ext = os.path.splitext(src) output_csv = ann_src+'_annotated_filtered.csv' print(output_csv) df = pd.read_csv(src) df = correct_metadata(df) df = annotate_conditions(df, cells=settings['cells'], cell_loc=settings['cell_loc'], pathogens=settings['pathogens'], pathogen_loc=settings['pathogen_loc'], treatments=settings['treatments'], treatment_loc=settings['treatment_loc']) if not settings['filter_column'] is None: if settings['filter_column'] in df.columns: filtered_df = df[(df[settings['filter_column']] > settings['upper_threshold']) | (df[settings['filter_column']] < settings['lower_threshold'])] print(f'Filtered DataFrame with {len(df)} rows to {len(filtered_df)} rows.') else: print(f"{settings['filter_column']} not in DataFrame columns.") filtered_df = df else: filtered_df = df filtered_df.to_csv(output_csv, index=False) if settings['remove_train']: df = filter_csv_by_png(output_csv) df.to_csv(output_csv, index=False)