Source code for scitex_ml.classification.timeseries._TimeSeriesBlockingSplit

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-09-22 17:10:00 (ywatanabe)"
# File: _TimeSeriesBlockingSplit.py

import scitex_io

__FILE__ = "_TimeSeriesBlockingSplit.py"

"""
Functionalities:
  - Implements time series split with blocking for multiple subjects/groups
  - Ensures temporal integrity within each subject's timeline
  - Allows cross-subject generalization while preventing data leakage
  - Provides visualization with scatter plots and subject color coding
  - Validates that no data mixing occurs between subjects
  - Supports expanding window approach for more training data in later folds

Dependencies:
  - packages:
    - numpy
    - sklearn
    - matplotlib
    - scitex

IO:
  - input-files:
    - None (generates synthetic multi-subject data for demonstration)
  - output-files:
    - ./blocking_splits_demo.png (visualization with scatter plots)
"""

"""Imports"""
import argparse
import os
import sys
from typing import Iterator, Optional, Tuple

import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import BaseCrossValidator

import scitex_logging as logging

logger = logging.getLogger(__name__)


[docs] class TimeSeriesBlockingSplit(BaseCrossValidator): """ Time series split with blocking to handle multiple subjects/groups. This splitter ensures temporal integrity within each subject while allowing cross-subject generalization. Each subject's data is kept temporally coherent, but subjects can appear in both training and test sets at different time periods. Key Features: - Temporal order preserved within each subject - No data leakage within individual subject timelines - Expanding window approach: more training data in later folds - Cross-subject generalization: subjects can be in both train and test Use Cases: - Multiple patients with longitudinal medical data - Multiple stocks with time series financial data - Multiple sensors with temporal measurements - Any scenario with grouped time series data Parameters ---------- n_splits : int, default=5 Number of splits (folds) test_ratio : float, default=0.2 Proportion of data for test set per subject Examples -------- >>> from scitex_ml.classification import TimeSeriesBlockingSplit >>> import numpy as np >>> >>> # Create data: 100 samples, 4 subjects (25 samples each) >>> X = np.random.randn(100, 10) >>> y = np.random.randint(0, 2, 100) >>> timestamps = np.arange(100) >>> groups = np.repeat([0, 1, 2, 3], 25) # Subject IDs >>> >>> # Each subject gets temporal split: early samples → train, later → test >>> splitter = TimeSeriesBlockingSplit(n_splits=3, test_ratio=0.3) >>> for train_idx, test_idx in splitter.split(X, y, timestamps, groups): ... train_subjects = set(groups[train_idx]) ... test_subjects = set(groups[test_idx]) ... print(f"Train subjects: {train_subjects}, Test subjects: {test_subjects}") ... # Output shows same subjects in both sets but different time periods """
[docs] def __init__( self, n_splits: int = 5, test_ratio: float = 0.2, val_ratio: float = 0.0, random_state: Optional[int] = None, ): self.n_splits = n_splits self.test_ratio = test_ratio self.val_ratio = val_ratio self.random_state = random_state self.rng = np.random.default_rng(random_state)
[docs] def split( self, X: np.ndarray, y: Optional[np.ndarray] = None, timestamps: Optional[np.ndarray] = None, groups: Optional[np.ndarray] = None, ) -> Iterator[Tuple[np.ndarray, np.ndarray]]: """ Generate indices respecting group boundaries. Parameters ---------- X : array-like, shape (n_samples, n_features) Training data y : array-like, shape (n_samples,) Target variable timestamps : array-like, shape (n_samples,) Timestamps for temporal ordering (required) groups : array-like, shape (n_samples,) Group labels (e.g., patient IDs) - required Yields ------ train : ndarray Training set indices test : ndarray Test set indices """ if groups is None: raise ValueError("groups must be provided for blocking time series split") if timestamps is None: raise ValueError("timestamps must be provided") unique_groups = np.unique(groups) for i in range(self.n_splits): train_indices = [] test_indices = [] for group in unique_groups: group_mask = groups == group group_indices = np.where(group_mask)[0] group_times = timestamps[group_mask] # Sort group by time time_order = np.argsort(group_times) sorted_group_indices = group_indices[time_order] # Split this group n_group = len(sorted_group_indices) test_size = int(n_group * self.test_ratio) train_size = n_group - test_size # Expanding window for this group split_point = train_size - (self.n_splits - i - 1) * ( test_size // self.n_splits ) split_point = max(1, min(split_point, train_size)) train_indices.extend(sorted_group_indices[:split_point]) test_indices.extend( sorted_group_indices[split_point : split_point + test_size] ) yield np.array(train_indices), np.array(test_indices)
[docs] def split_with_val( self, X: np.ndarray, y: Optional[np.ndarray] = None, timestamps: Optional[np.ndarray] = None, groups: Optional[np.ndarray] = None, ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]: """ Generate indices with separate validation set respecting group boundaries. Each subject gets its own train/val/test split maintaining temporal order. Parameters ---------- X : array-like, shape (n_samples, n_features) Training data y : array-like, shape (n_samples,) Target variable timestamps : array-like, shape (n_samples,) Timestamps for temporal ordering (required) groups : array-like, shape (n_samples,) Group labels (e.g., patient IDs) - required Yields ------ train : ndarray Training set indices val : ndarray Validation set indices test : ndarray Test set indices """ if groups is None: raise ValueError("groups must be provided for blocking time series split") if timestamps is None: raise ValueError("timestamps must be provided") unique_groups = np.unique(groups) for i in range(self.n_splits): train_indices = [] val_indices = [] test_indices = [] for group in unique_groups: group_mask = groups == group group_indices = np.where(group_mask)[0] group_times = timestamps[group_mask] # Sort group by time time_order = np.argsort(group_times) sorted_group_indices = group_indices[time_order] # Split this group into train/val/test n_group = len(sorted_group_indices) test_size = int(n_group * self.test_ratio) val_size = int(n_group * self.val_ratio) if self.val_ratio > 0 else 0 train_size = n_group - test_size - val_size # Expanding window approach for training split_point_train = train_size - (self.n_splits - i - 1) * ( test_size // self.n_splits ) split_point_train = max(1, min(split_point_train, train_size)) # Define split points val_start = split_point_train test_start = val_start + val_size # Ensure we have enough data if test_start + test_size > n_group: test_size = n_group - test_start # Extract indices for this group train_indices.extend(sorted_group_indices[:split_point_train]) if val_size > 0: val_indices.extend(sorted_group_indices[val_start:test_start]) test_indices.extend( sorted_group_indices[test_start : test_start + test_size] ) yield np.array(train_indices), np.array(val_indices), np.array(test_indices)
[docs] def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations.""" return self.n_splits
[docs] def plot_splits( self, X, y=None, timestamps=None, groups=None, figsize=(12, 6), save_path=None ): """ Visualize the blocking splits showing subject separation. This visualization shows how data from different subjects/groups is allocated to training and test sets while maintaining temporal order within each subject. Color Scheme: - Rectangle border: Blue = Training set, Red = Test set - Rectangle fill: Different colors represent different subjects/groups - Each subject gets a unique color (cycling through colormap) Key Features: - No mixing: Each subject's data stays within temporal boundaries - Subject separation: Same subject can appear in both train/test but at different times - Temporal integrity: Time flows left to right for each subject Parameters ---------- X : array-like Training data y : array-like, optional Target variable (not used) timestamps : array-like, optional Timestamps (if None, uses sample indices) groups : array-like Group labels (required for blocking split) - each unique value represents a subject figsize : tuple, default (12, 6) Figure size save_path : str, optional Path to save the plot Returns ------- fig : matplotlib.figure.Figure The created figure with proper legend showing subject colors Examples -------- >>> splitter = TimeSeriesBlockingSplit(n_splits=3) >>> fig = splitter.plot_splits(X, timestamps=timestamps, groups=subject_ids) >>> fig.show() # Will show train (blue border) vs test (red border) by subject """ if groups is None: raise ValueError("groups must be provided for blocking split visualization") # Get all splits splits = list(self.split(X, y, timestamps, groups)) if not splits: raise ValueError("No splits generated") # Use sample indices if no timestamps provided if timestamps is None: timestamps = np.arange(len(X)) # Create figure fig, ax = plt.subplots(figsize=figsize) # Plot each fold for fold, (train_idx, test_idx) in enumerate(splits): y_pos = fold # Get unique groups for train and test train_groups = set(groups[train_idx]) test_groups = set(groups[test_idx]) # Train subjects (different colors for each group) colors = plt.cm.Set3(np.linspace(0, 1, len(np.unique(groups)))) for i, group in enumerate(sorted(train_groups)): group_mask = groups[train_idx] == group group_indices = train_idx[group_mask] if len(group_indices) > 0: start_idx = group_indices[0] end_idx = group_indices[-1] width = end_idx - start_idx + 1 train_rect = patches.Rectangle( (start_idx, y_pos - 0.3), width, 0.6, linewidth=1, edgecolor="blue", facecolor=colors[group % len(colors)], alpha=0.7, label=f"Train Group {group}" if fold == 0 else "", ) ax.add_patch(train_rect) # Test subjects for i, group in enumerate(sorted(test_groups)): group_mask = groups[test_idx] == group group_indices = test_idx[group_mask] if len(group_indices) > 0: start_idx = group_indices[0] end_idx = group_indices[-1] width = end_idx - start_idx + 1 test_rect = patches.Rectangle( (start_idx, y_pos - 0.3), width, 0.6, linewidth=2, edgecolor="red", facecolor="lightcoral", alpha=0.8, label=f"Test Group {group}" if fold == 0 else "", ) ax.add_patch(test_rect) # Format plot ax.set_ylim(-0.5, len(splits) - 0.5) ax.set_xlim(0, len(X)) ax.set_xlabel("Sample Index") ax.set_ylabel("Fold") ax.set_title( f"Time Series Blocking Split Visualization\\n" f"No mixing between subjects/groups" ) # Set y-ticks ax.set_yticks(range(len(splits))) ax.set_yticklabels([f"Fold {i}" for i in range(len(splits))]) # Add scatter plots of actual data points with jittering np.random.seed(42) # For reproducible jittering jitter_strength = 0.15 # Amount of vertical jittering for fold, (train_idx, test_idx) in enumerate(splits): y_pos = fold # Add jittered scatter plots for train indices if len(train_idx) > 0: train_jitter = np.random.normal(0, jitter_strength, len(train_idx)) # Color by group for group in np.unique(groups[train_idx]): group_mask = groups[train_idx] == group group_train_idx = train_idx[group_mask] group_jitter = train_jitter[group_mask] ax.scatter( group_train_idx, y_pos + group_jitter, c="darkblue", s=15, alpha=0.6, marker="o", label=( "Train points" if fold == 0 and group == np.unique(groups[train_idx])[0] else "" ), zorder=3, ) # Add jittered scatter plots for test indices if len(test_idx) > 0: test_jitter = np.random.normal(0, jitter_strength, len(test_idx)) # Color by group for group in np.unique(groups[test_idx]): group_mask = groups[test_idx] == group group_test_idx = test_idx[group_mask] group_jitter = test_jitter[group_mask] ax.scatter( group_test_idx, y_pos + group_jitter, c="darkred", s=15, alpha=0.6, marker="s", label=( "Test points" if fold == 0 and group == np.unique(groups[test_idx])[0] else "" ), zorder=3, ) # Create comprehensive legend from matplotlib.lines import Line2D from matplotlib.patches import Patch # Get unique groups and their colors unique_groups = np.unique(groups) colors = plt.cm.Set3(np.linspace(0, 1, len(unique_groups))) legend_elements = [] # Add train/test border legend legend_elements.extend( [ Line2D( [0], [0], color="blue", lw=3, alpha=0.7, label="Training Set (blue border)", ), Line2D( [0], [0], color="red", lw=3, alpha=0.8, label="Test Set (red border)", ), ] ) # Add a separator legend_elements.append(Line2D([0], [0], color="white", lw=0, label="")) # Add subject color legend for i, group in enumerate(sorted(unique_groups)): legend_elements.append( Patch( facecolor=colors[i % len(colors)], alpha=0.7, label=f"Subject/Group {group}", ) ) # Create legend with two columns if many subjects ncol = 1 if len(unique_groups) <= 3 else 2 ax.legend( handles=legend_elements, loc="center left", bbox_to_anchor=(1.02, 0.5), ncol=ncol, ) plt.tight_layout() if save_path: fig.savefig(save_path, dpi=150, bbox_inches="tight") return fig
"""Functions & Classes""" def main(args) -> int: """Demonstrate TimeSeriesBlockingSplit functionality. Args: args: Command line arguments Returns: int: Exit status """ logger.info("Demonstrating TimeSeriesBlockingSplit functionality") # Generate test data with multiple subjects np.random.seed(42) n_samples = args.n_samples n_subjects = args.n_subjects # Generate data X = np.random.randn(n_samples, 5) y = np.random.randint(0, 2, n_samples) timestamps = np.arange(n_samples) + np.random.normal(0, 0.1, n_samples) # Create subject groups samples_per_subject = n_samples // n_subjects groups = np.repeat(range(n_subjects), samples_per_subject) # Pad if necessary groups = np.pad( groups, (0, n_samples - len(groups)), mode="constant", constant_values=n_subjects - 1, ) logger.info(f"Generated test data: {n_samples} samples, {n_subjects} subjects") logger.info(f"Samples per subject: ~{samples_per_subject}") # Create blocking splitter splitter = TimeSeriesBlockingSplit( n_splits=args.n_splits, test_ratio=args.test_ratio ) logger.info(f"Blocking split configuration:") logger.info(f" Number of splits: {args.n_splits}") logger.info(f" Test ratio: {args.test_ratio}") # Test splits for fold, (train_idx, test_idx) in enumerate( splitter.split(X, y, timestamps, groups) ): train_subjects = sorted(set(groups[train_idx])) test_subjects = sorted(set(groups[test_idx])) logger.info(f"Fold {fold}:") logger.info(f" Train: {len(train_idx)} samples from subjects {train_subjects}") logger.info(f" Test: {len(test_idx)} samples from subjects {test_subjects}") # Check subject overlap overlap = set(train_subjects) & set(test_subjects) if overlap: logger.info( f" Subjects in both: {sorted(overlap)} (temporal separation maintained)" ) else: logger.info(f" No subject overlap") # Generate visualization logger.info("Generating blocking split visualization with scatter plots") fig = splitter.plot_splits(X, y, timestamps, groups) # Save using SciTeX framework scitex_io.save(fig, "./blocking_splits_demo.png", symlink_from_cwd=True) plt.close(fig) logger.info("TimeSeriesBlockingSplit demonstration completed successfully") return 0 def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Demonstrate TimeSeriesBlockingSplit for multi-subject time series" ) parser.add_argument( "--n-samples", type=int, default=300, help="Total number of samples (default: %(default)s)", ) parser.add_argument( "--n-subjects", type=int, default=4, help="Number of subjects/groups (default: %(default)s)", ) parser.add_argument( "--n-splits", type=int, default=3, help="Number of CV splits (default: %(default)s)", ) parser.add_argument( "--test-ratio", type=float, default=0.3, help="Proportion of data for test per subject (default: %(default)s)", ) args = parser.parse_args() return args def run_main() -> None: """Initialize scitex framework, run main function, and cleanup.""" global CONFIG, CC, sys, plt, rng import sys import matplotlib.pyplot as plt import scitex as stx args = parse_args() CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start( sys, plt, args=args, file=__FILE__, sdir_suffix=None, verbose=False, agg=True, ) exit_status = main(args) stx.session.close( CONFIG, verbose=False, notify=False, message="", exit_status=exit_status, ) if __name__ == "__main__": run_main() # EOF