#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Timestamp: "2025-09-22 16:50:00 (ywatanabe)"
# File: _TimeSeriesStratifiedSplit.py
import scitex_io
__FILE__ = "_TimeSeriesStratifiedSplit.py"
"""
Functionalities:
- Implements time series cross-validation with stratification support
- Ensures chronological order (test data always after training data)
- Supports optional validation set between train and test
- Maintains temporal gaps to prevent data leakage
- Provides visualization with scatter plots for verification
- Validates temporal integrity in all splits
Dependencies:
- packages:
- numpy
- sklearn
- matplotlib
IO:
- input-files:
- None (generates synthetic data for demonstration)
- output-files:
- ./stratified_splits_demo.png (visualization)
"""
"""Imports"""
import argparse
import os
import sys
from typing import Iterator, Optional, Tuple
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import BaseCrossValidator
from sklearn.utils.validation import _num_samples
import scitex_logging as logging
logger = logging.getLogger(__name__)
[docs]
class TimeSeriesStratifiedSplit(BaseCrossValidator):
"""
Time series cross-validation with stratification support.
This splitter ensures:
1. Test data is always chronologically after training data
2. Optional validation set between train and test
3. Class balance preservation in splits
4. Gap period between train and test to avoid leakage
Parameters
----------
n_splits : int
Number of splits (folds)
test_ratio : float
Proportion of data for test set (default: 0.2)
val_ratio : float
Proportion of data for validation set (default: 0.1)
gap : int
Number of samples to exclude between train and test (default: 0)
stratify : bool
Whether to maintain class proportions (default: True)
random_state : int, optional
Random seed for reproducibility (default: None)
Examples
--------
>>> from scitex_ml.classification import TimeSeriesStratifiedSplit
>>> import numpy as np
>>>
>>> X = np.random.randn(100, 10)
>>> y = np.random.randint(0, 2, 100)
>>> timestamps = np.arange(100)
>>>
>>> tscv = TimeSeriesStratifiedSplit(n_splits=3)
>>> for train_idx, test_idx in tscv.split(X, y, timestamps):
... print(f"Train: {len(train_idx)}, Test: {len(test_idx)}")
"""
[docs]
def __init__(
self,
n_splits: int = 5,
test_ratio: float = 0.2,
val_ratio: float = 0.1,
gap: int = 0,
stratify: bool = True,
random_state: Optional[int] = None,
):
self.n_splits = n_splits
self.test_ratio = test_ratio
self.val_ratio = val_ratio
self.gap = gap
self.stratify = stratify
self.random_state = random_state
self.rng = np.random.default_rng(random_state)
[docs]
def split(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
timestamps: Optional[np.ndarray] = None,
groups: Optional[np.ndarray] = None,
) -> Iterator[Tuple[np.ndarray, np.ndarray]]:
"""
Generate indices to split data into training and test sets.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data
y : array-like, shape (n_samples,)
Target variable
timestamps : array-like, shape (n_samples,)
Timestamps for temporal ordering (required)
groups : array-like, shape (n_samples,), optional
Group labels for grouped CV
Yields
------
train : ndarray
Training set indices
test : ndarray
Test set indices
"""
if timestamps is None:
raise ValueError("timestamps must be provided for time series split")
n_samples = _num_samples(X)
indices = np.arange(n_samples)
# Sort by timestamp
time_order = np.argsort(timestamps)
sorted_indices = indices[time_order]
sorted_y = y[time_order] if y is not None else None
# Calculate split sizes
test_size = int(n_samples * self.test_ratio)
val_size = int(n_samples * self.val_ratio) if self.val_ratio > 0 else 0
# Generate splits with expanding training window
for i in range(self.n_splits):
# Expanding window approach
train_end = n_samples - test_size - self.gap
train_end = train_end - (self.n_splits - i - 1) * (
test_size // self.n_splits
)
train_end = max(test_size, train_end) # Ensure min training size
# Apply gap and start test set immediately after gap
test_start = train_end + self.gap
test_end = min(test_start + test_size, n_samples)
# Get indices
train_indices = sorted_indices[:train_end]
test_indices = sorted_indices[test_start:test_end]
# For time series, temporal integrity is prioritized over stratification
# Chronological order must be preserved to prevent data leakage
# Class imbalance should be handled through other methods or at dataset level
yield train_indices, test_indices
[docs]
def split_with_val(
self,
X: np.ndarray,
y: Optional[np.ndarray] = None,
timestamps: Optional[np.ndarray] = None,
groups: Optional[np.ndarray] = None,
) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
Generate indices with separate validation set.
Yields
------
train : ndarray
Training set indices
val : ndarray
Validation set indices
test : ndarray
Test set indices
"""
if timestamps is None:
raise ValueError("timestamps must be provided for time series split")
n_samples = _num_samples(X)
indices = np.arange(n_samples)
# Sort by timestamp
time_order = np.argsort(timestamps)
sorted_indices = indices[time_order]
sorted_y = y[time_order] if y is not None else None
# Calculate split sizes
test_size = int(n_samples * self.test_ratio)
val_size = int(n_samples * self.val_ratio) if self.val_ratio > 0 else 0
# Generate splits with strict temporal order
for i in range(self.n_splits):
# Calculate split points in temporal order (sorted domain)
# Work backwards from the end to ensure proper spacing
test_start_pos = n_samples - test_size
test_start_pos = test_start_pos - i * (
test_size // self.n_splits
) # Earlier for each fold
test_end_pos = min(test_start_pos + test_size, n_samples)
# Validation comes before test with gap
val_end_pos = test_start_pos - self.gap
val_start_pos = max(0, val_end_pos - val_size)
# Training comes before validation with gap
train_end_pos = val_start_pos - self.gap
train_start_pos = 0 # Always start from beginning (expanding window)
# Ensure all positions are valid
if (
train_end_pos <= train_start_pos
or val_start_pos >= val_end_pos
or test_start_pos >= test_end_pos
):
continue
# Extract indices from temporally sorted sequence
train_indices = sorted_indices[train_start_pos:train_end_pos]
val_indices = sorted_indices[val_start_pos:val_end_pos]
test_indices = sorted_indices[test_start_pos:test_end_pos]
# For split_with_val, we prioritize temporal integrity over stratification
# to ensure no overlapping between train, validation, and test sets
# Class imbalance should be handled through other methods for 3-way splits
yield train_indices, val_indices, test_indices
def _stratify_indices_temporal(
self, indices: np.ndarray, y: np.ndarray, target_size: int
) -> np.ndarray:
"""Apply stratification while preserving temporal order for time series.
This method maintains chronological order as the top priority while
attempting to balance class representation within the temporal window.
"""
# If target_size >= current size, return as-is
if target_size >= len(indices):
return indices
# Get the labels for these indices in their current temporal order
current_labels = y[indices]
unique_classes = np.unique(current_labels)
# Calculate desired samples per class based on current distribution
class_counts = {}
for cls in unique_classes:
class_counts[cls] = np.sum(current_labels == cls)
total_current = len(indices)
# Calculate target samples per class, proportional to current distribution
target_per_class = {}
remaining_target = target_size
for cls in unique_classes:
proportion = class_counts[cls] / total_current
target_count = max(1, int(target_size * proportion))
target_per_class[cls] = min(target_count, class_counts[cls])
remaining_target -= target_per_class[cls]
# Adjust if we're under/over the target
if remaining_target > 0:
# Distribute remaining samples to classes with most samples
sorted_classes = sorted(
unique_classes, key=lambda x: class_counts[x], reverse=True
)
for cls in sorted_classes:
if remaining_target <= 0:
break
if target_per_class[cls] < class_counts[cls]:
target_per_class[cls] += 1
remaining_target -= 1
# Select indices while preserving temporal order
selected_indices = []
class_taken = {cls: 0 for cls in unique_classes}
for idx in indices: # indices are already in temporal order
label = y[idx]
if class_taken[label] < target_per_class[label]:
selected_indices.append(idx)
class_taken[label] += 1
# Stop if we've reached our target
if len(selected_indices) >= target_size:
break
return np.array(selected_indices)
[docs]
def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the CV."""
return self.n_splits
def _find_contiguous_segments(self, indices):
"""Find contiguous segments in a sorted array of indices."""
if len(indices) == 0:
return []
sorted_indices = np.sort(indices)
segments = []
start = sorted_indices[0]
end = sorted_indices[0]
for i in range(1, len(sorted_indices)):
if sorted_indices[i] == end + 1:
end = sorted_indices[i]
else:
segments.append((start, end))
start = sorted_indices[i]
end = sorted_indices[i]
segments.append((start, end))
return segments
[docs]
def plot_splits(self, X, y=None, timestamps=None, figsize=(12, 6), save_path=None):
"""
Visualize the stratified time series splits.
Shows train (blue), validation (green), and test (red) sets.
When val_ratio=0, only shows train and test.
Parameters
----------
X : array-like
Training data
y : array-like, optional
Target variable
timestamps : array-like, optional
Timestamps (if None, uses sample indices)
figsize : tuple, default (12, 6)
Figure size
save_path : str, optional
Path to save the plot
Returns
-------
fig : matplotlib.figure.Figure
The created figure
"""
# Use sample indices if no timestamps provided
if timestamps is None:
timestamps = np.arange(len(X))
# Create figure
fig, ax = plt.subplots(figsize=figsize)
# Check if we have validation sets
if self.val_ratio > 0:
# Use split_with_val for 3-way splits
splits = list(self.split_with_val(X, y, timestamps))
split_type = "train-val-test"
else:
# Use regular split for 2-way splits
splits = list(self.split(X, y, timestamps))
split_type = "train-test"
if not splits:
raise ValueError("No splits generated")
# Plot each fold
for fold, split_indices in enumerate(splits):
y_pos = fold
if len(split_indices) == 3: # train, val, test
train_idx, val_idx, test_idx = split_indices
# Train set (blue) - plot as individual segments if non-contiguous
if len(train_idx) > 0:
# Find contiguous segments in train indices
train_segments = self._find_contiguous_segments(train_idx)
for start_idx, end_idx in train_segments:
train_rect = patches.Rectangle(
(start_idx, y_pos - 0.3),
end_idx - start_idx + 1,
0.6,
linewidth=1,
edgecolor="blue",
facecolor="lightblue",
alpha=0.7,
label=(
"Train"
if fold == 0 and start_idx == train_segments[0][0]
else ""
),
)
ax.add_patch(train_rect)
# Validation set (green) - plot as individual segments if non-contiguous
if len(val_idx) > 0:
val_segments = self._find_contiguous_segments(val_idx)
for start_idx, end_idx in val_segments:
val_rect = patches.Rectangle(
(start_idx, y_pos - 0.3),
end_idx - start_idx + 1,
0.6,
linewidth=1,
edgecolor="green",
facecolor="lightgreen",
alpha=0.7,
label=(
"Validation"
if fold == 0 and start_idx == val_segments[0][0]
else ""
),
)
ax.add_patch(val_rect)
# Test set (red) - plot as individual segments if non-contiguous
if len(test_idx) > 0:
test_segments = self._find_contiguous_segments(test_idx)
for start_idx, end_idx in test_segments:
test_rect = patches.Rectangle(
(start_idx, y_pos - 0.3),
end_idx - start_idx + 1,
0.6,
linewidth=1,
edgecolor="red",
facecolor="lightcoral",
alpha=0.7,
label=(
"Test"
if fold == 0 and start_idx == test_segments[0][0]
else ""
),
)
ax.add_patch(test_rect)
else: # train, test (2-way split)
train_idx, test_idx = split_indices
# Train set (blue) - plot as individual segments if non-contiguous
if len(train_idx) > 0:
train_segments = self._find_contiguous_segments(train_idx)
for start_idx, end_idx in train_segments:
train_rect = patches.Rectangle(
(start_idx, y_pos - 0.3),
end_idx - start_idx + 1,
0.6,
linewidth=1,
edgecolor="blue",
facecolor="lightblue",
alpha=0.7,
label=(
"Train"
if fold == 0 and start_idx == train_segments[0][0]
else ""
),
)
ax.add_patch(train_rect)
# Test set (red) - plot as individual segments if non-contiguous
if len(test_idx) > 0:
test_segments = self._find_contiguous_segments(test_idx)
for start_idx, end_idx in test_segments:
test_rect = patches.Rectangle(
(start_idx, y_pos - 0.3),
end_idx - start_idx + 1,
0.6,
linewidth=1,
edgecolor="red",
facecolor="lightcoral",
alpha=0.7,
label=(
"Test"
if fold == 0 and start_idx == test_segments[0][0]
else ""
),
)
ax.add_patch(test_rect)
# Add scatter plots of actual data points with jittering
np.random.seed(42) # For reproducible jittering
jitter_strength = 0.15 # Amount of vertical jittering
for fold, split_indices in enumerate(splits):
y_pos = fold
if len(split_indices) == 3: # train, val, test
train_idx, val_idx, test_idx = split_indices
# Add jittered scatter plots for 3-way split
if len(train_idx) > 0:
train_jitter = np.random.normal(0, jitter_strength, len(train_idx))
ax.scatter(
train_idx,
y_pos + train_jitter,
c="darkblue",
s=15,
alpha=0.6,
marker="o",
label="Train points" if fold == 0 else "",
zorder=3,
)
if len(val_idx) > 0:
val_jitter = np.random.normal(0, jitter_strength, len(val_idx))
ax.scatter(
val_idx,
y_pos + val_jitter,
c="darkgreen",
s=15,
alpha=0.6,
marker="^",
label="Val points" if fold == 0 else "",
zorder=3,
)
if len(test_idx) > 0:
test_jitter = np.random.normal(0, jitter_strength, len(test_idx))
ax.scatter(
test_idx,
y_pos + test_jitter,
c="darkred",
s=15,
alpha=0.6,
marker="s",
label="Test points" if fold == 0 else "",
zorder=3,
)
else: # train, test (2-way split)
train_idx, test_idx = split_indices
# Add jittered scatter plots for 2-way split
if len(train_idx) > 0:
train_jitter = np.random.normal(0, jitter_strength, len(train_idx))
ax.scatter(
train_idx,
y_pos + train_jitter,
c="darkblue",
s=15,
alpha=0.6,
marker="o",
label="Train points" if fold == 0 else "",
zorder=3,
)
if len(test_idx) > 0:
test_jitter = np.random.normal(0, jitter_strength, len(test_idx))
ax.scatter(
test_idx,
y_pos + test_jitter,
c="darkred",
s=15,
alpha=0.6,
marker="s",
label="Test points" if fold == 0 else "",
zorder=3,
)
# Format plot
ax.set_ylim(-0.5, len(splits) - 0.5)
ax.set_xlim(0, len(X))
ax.set_xlabel("Sample Index (original order)")
ax.set_ylabel("Fold")
title = f"Time Series Stratified Split Visualization ({split_type})"
if self.stratify:
title += "\nMaintains class balance across splits"
if self.gap > 0:
title += f", Gap: {self.gap} samples"
title += "\nRectangles show ranges, dots show actual data points"
ax.set_title(title)
# Set y-ticks
ax.set_yticks(range(len(splits)))
ax.set_yticklabels([f"Fold {i}" for i in range(len(splits))])
# Add legend with scatter points
ax.legend(loc="upper right")
plt.tight_layout()
if save_path:
fig.savefig(save_path, dpi=150, bbox_inches="tight")
return fig
"""Functions & Classes"""
def main(args) -> int:
"""Demonstrate TimeSeriesStratifiedSplit functionality.
Args:
args: Command line arguments
Returns:
int: Exit status
"""
logger.info("Demonstrating TimeSeriesStratifiedSplit functionality")
# Generate test data
np.random.seed(42)
n_samples = 200
X = np.random.randn(n_samples, 5)
y = np.random.randint(0, 2, n_samples)
timestamps = np.arange(n_samples) + np.random.normal(0, 0.1, n_samples)
logger.info(
f"Generated test data: {n_samples} samples, {X.shape[1]} features, {len(np.unique(y))} classes"
)
# Test regular split
logger.info("Testing regular train/test split")
splitter = TimeSeriesStratifiedSplit(n_splits=3, test_ratio=0.2, gap=5)
for fold, (train_idx, test_idx) in enumerate(splitter.split(X, y, timestamps)):
logger.info(f"Fold {fold}: Train={len(train_idx)}, Test={len(test_idx)}")
# Test split with validation
logger.info("Testing train/validation/test split")
splitter_val = TimeSeriesStratifiedSplit(
n_splits=2, test_ratio=0.2, val_ratio=0.15, gap=3
)
for fold, (train_idx, val_idx, test_idx) in enumerate(
splitter_val.split_with_val(X, y, timestamps)
):
logger.info(
f"Fold {fold}: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
)
# Check temporal order
train_times = timestamps[train_idx]
val_times = timestamps[val_idx] if len(val_idx) > 0 else np.array([])
test_times = timestamps[test_idx] if len(test_idx) > 0 else np.array([])
temporal_ok = True
if len(val_times) > 0 and len(test_times) > 0:
temporal_ok = (train_times.max() < val_times.min()) and (
val_times.max() < test_times.min()
)
elif len(test_times) > 0:
temporal_ok = train_times.max() < test_times.min()
status = "✓" if temporal_ok else "✗"
logger.info(f" Temporal order: {status}")
# Generate visualization
logger.info("Generating split visualization")
fig = splitter_val.plot_splits(X, y, timestamps)
# Save using SciTeX framework
scitex_io.save(fig, "./stratified_splits_demo.png", symlink_from_cwd=True)
plt.close(fig)
logger.info("TimeSeriesStratifiedSplit demonstration completed successfully")
return 0
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
import argparse
parser = argparse.ArgumentParser(
description="Demonstrate TimeSeriesStratifiedSplit with temporal integrity validation"
)
args = parser.parse_args()
return args
def run_main() -> None:
"""Initialize scitex framework, run main function, and cleanup."""
global CONFIG, CC, sys, plt, rng
import sys
import matplotlib.pyplot as plt
import scitex as stx
args = parse_args()
CONFIG, sys.stdout, sys.stderr, plt, CC, rng = stx.session.start(
sys,
plt,
args=args,
file=__FILE__,
sdir_suffix=None,
verbose=False,
agg=True,
)
exit_status = main(args)
stx.session.close(
CONFIG,
verbose=False,
notify=False,
message="",
exit_status=exit_status,
)
if __name__ == "__main__":
run_main()
# EOF