Source code for scitex_core.repro._RandomStateManager

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File: ./src/scitex_core/repro/_RandomStateManager.py

"""
Clean, simple RandomStateManager for scientific reproducibility.

Manages random number generators across multiple libraries (numpy, torch,
tensorflow, jax) with deterministic seeding and verification capabilities.

Main API:
    rng_manager = RandomStateManager(seed=42)  # Create instance
    gen = rng_manager("name")                  # Get named generator
    rng_manager.verify(obj, "name")            # Verify reproducibility
"""

import hashlib
import json
import os
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union


# Global singleton instance
_GLOBAL_INSTANCE = None


[docs] class RandomStateManager: """ Simple, robust random state manager for scientific computing. Provides centralized management of random number generators with deterministic seeding across multiple ML/scientific libraries. Parameters ---------- seed : int, optional Master seed for all random number generators (default: 42) verbose : bool, optional Print status messages (default: False) Examples -------- >>> from scitex_core.repro import RandomStateManager >>> >>> # Direct usage >>> rng_manager = RandomStateManager(seed=42) >>> gen = rng_manager("data") >>> data = gen.random(100) >>> >>> # Verify reproducibility >>> rng_manager.verify(data, "my_data") >>> >>> # Named generators for different purposes >>> data_gen = rng_manager("data") >>> model_gen = rng_manager("model") >>> augment_gen = rng_manager("augment") Notes ----- - Automatically detects and seeds available libraries (numpy, torch, tf, jax) - Creates independent named generators for different experiment components - Verification cache stored in ~/.scitex/rng/ """
[docs] def __init__(self, seed: int = 42, verbose: bool = False): """Initialize with automatic module detection.""" self.seed = seed self.verbose = verbose self._generators = {} self._cache_dir = Path.home() / ".scitex" / "rng" self._cache_dir.mkdir(parents=True, exist_ok=True) self._jax_key = None self._np = None if verbose: print(f"RandomStateManager initialized with seed {seed}") # Auto-fix all available seeds self._auto_fix_seeds(verbose=verbose)
[docs] def _auto_fix_seeds(self, verbose: Optional[bool] = None): """Automatically detect and fix ALL available random modules.""" # Use instance verbose if not specified if verbose is None: verbose = self.verbose # OS environment os.environ["PYTHONHASHSEED"] = str(self.seed) os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" fixed_modules = [] # Python random try: import random random.seed(self.seed) fixed_modules.append("random") except ImportError: pass # NumPy try: import numpy as np np.random.seed(self.seed) self._np = np self._np_default_rng_manager = np.random.default_rng(self.seed) fixed_modules.append("numpy") except ImportError: self._np = None # PyTorch try: import torch torch.manual_seed(self.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False fixed_modules.append("torch+cuda") else: fixed_modules.append("torch") except ImportError: pass # TensorFlow try: import tensorflow as tf tf.random.set_seed(self.seed) fixed_modules.append("tensorflow") except ImportError: pass # JAX try: import jax self._jax_key = jax.random.PRNGKey(self.seed) fixed_modules.append("jax") except (ImportError, AttributeError, RuntimeError): self._jax_key = None if verbose and fixed_modules: print(f"Fixed random seeds for: {', '.join(fixed_modules)}")
[docs] def get_np_generator(self, name: str): """ Get or create a named NumPy random generator. Parameters ---------- name : str Generator name (e.g., "data", "model", "augment") Returns ------- numpy.random.Generator Independent NumPy random generator Examples -------- >>> rng_manager = RandomStateManager(42) >>> gen = rng_manager.get_np_generator("data") >>> values = gen.random(100) >>> perm = gen.permutation(100) """ if self._np is None: raise ImportError("NumPy required for random generators") if name not in self._generators: # Create deterministic seed from name name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16) seed = (self.seed + name_hash) % (2**32) self._generators[name] = self._np.random.default_rng(seed) return self._generators[name]
[docs] def __call__(self, name: str, verbose: bool = None): """ Get or create a named NumPy random generator. This is a convenience wrapper for get_np_generator(). Parameters ---------- name : str Generator name verbose : bool, optional Whether to show deprecation warning Returns ------- numpy.random.Generator NumPy random generator with deterministic seed """ if verbose: print( f"Note: rng('{name}') is deprecated. " f"Use rng.get_np_generator('{name}') instead." ) return self.get_np_generator(name)
[docs] def verify(self, obj: Any, name: Optional[str] = None, verbose: bool = True) -> bool: """ Verify object matches cached hash (detects broken reproducibility). First call: caches the object's hash Later calls: verifies object matches cached hash Parameters ---------- obj : Any Object to verify (array, tensor, data, model weights, etc.) Supports: numpy arrays, torch tensors, tf tensors, jax arrays, lists, dicts, pandas dataframes, and basic types name : str, optional Cache name. Auto-generated from caller location if not provided. verbose : bool, optional Print verification results (default: True) Returns ------- bool True if matches cache (or first call), False if different Raises ------ ValueError If verification fails (object doesn't match cached hash) Examples -------- >>> data = generate_data() >>> rng_manager.verify(data, "train_data") # First run: caches >>> # Next run: >>> rng_manager.verify(data, "train_data") # Verifies match """ import numpy as np # Auto-generate name if needed if name is None: import inspect frame = inspect.currentframe().f_back filename = Path(frame.f_code.co_filename).stem lineno = frame.f_lineno name = f"{filename}_L{lineno}" # Sanitize name safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in name) cache_file = self._cache_dir / f"{safe_name}.json" # Compute hash based on object type obj_hash = self._compute_hash(obj) # Use instance verbose if not specified if verbose is None: verbose = self.verbose # Check cache if cache_file.exists(): with open(cache_file, "r") as f: cached = json.load(f) matches = cached["hash"] == obj_hash if not matches and verbose: print(f"⚠️ Reproducibility broken for '{name}'!") print(f" Expected: {cached['hash'][:16]}...") print(f" Got: {obj_hash[:16]}...") raise ValueError(f"Reproducibility verification failed for '{name}'") elif matches and verbose: print(f"✓ Reproducibility verified for '{name}'") return matches else: # First call - cache it with open(cache_file, "w") as f: json.dump({"name": name, "hash": obj_hash, "seed": self.seed}, f) return True
[docs] def _compute_hash(self, obj: Any) -> str: """ Compute hash for various object types. Supports: - NumPy arrays - PyTorch tensors - TensorFlow tensors - JAX arrays - Pandas DataFrames/Series - Lists, tuples, dicts - Basic types (int, float, str, bool) """ import numpy as np # NumPy array if isinstance(obj, np.ndarray): return hashlib.sha256(obj.tobytes()).hexdigest()[:32] # PyTorch tensor try: import torch if isinstance(obj, torch.Tensor): obj_np = obj.detach().cpu().numpy() return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32] except ImportError: pass # TensorFlow tensor try: import tensorflow as tf if isinstance(obj, (tf.Tensor, tf.Variable)): obj_np = obj.numpy() return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32] except ImportError: pass # JAX array try: import jax.numpy as jnp if isinstance(obj, jnp.ndarray): obj_np = np.array(obj) return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32] except ImportError: pass # Pandas DataFrame/Series try: import pandas as pd if isinstance(obj, (pd.DataFrame, pd.Series)): obj_str = obj.to_json(orient="split", date_format="iso") return hashlib.sha256(obj_str.encode()).hexdigest()[:32] except ImportError: pass # Lists and tuples - convert to numpy array if numeric if isinstance(obj, (list, tuple)): try: obj_np = np.array(obj) if obj_np.dtype != object: # Numeric array return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32] except: pass # Dictionaries - serialize to JSON if isinstance(obj, dict): try: obj_str = json.dumps(obj, sort_keys=True, default=str) return hashlib.sha256(obj_str.encode()).hexdigest()[:32] except: pass # Default: convert to string obj_str = str(obj) return hashlib.sha256(obj_str.encode()).hexdigest()[:32]
[docs] def checkpoint(self, name: str = "checkpoint"): """ Save current state of all generators. Parameters ---------- name : str, optional Checkpoint name (default: "checkpoint") Returns ------- Path Path to checkpoint file """ checkpoint_file = self._cache_dir / f"{name}.pkl" state = { "seed": self.seed, "generators": { k: v.bit_generator.state for k, v in self._generators.items() }, } with open(checkpoint_file, "wb") as f: pickle.dump(state, f) return checkpoint_file
[docs] def restore(self, checkpoint: Union[str, Path]): """ Restore from checkpoint. Parameters ---------- checkpoint : str or Path Path to checkpoint file """ if isinstance(checkpoint, str): checkpoint = Path(checkpoint) with open(checkpoint, "rb") as f: state = pickle.load(f) self.seed = state["seed"] self._auto_fix_seeds() # Restore generator states for name, gen_state in state["generators"].items(): gen = self(name) gen.bit_generator.state = gen_state
[docs] @contextmanager def temporary_seed(self, seed: int): """ Context manager for temporary seed change. Parameters ---------- seed : int Temporary seed value Examples -------- >>> rng_manager = RandomStateManager(42) >>> with rng_manager.temporary_seed(123): ... data = np.random.random(10) """ import random import numpy as np # Save current states old_random_state = random.getstate() old_np_state = np.random.get_state() if self._np else None # Set temporary seed random.seed(seed) if self._np: np.random.seed(seed) try: yield finally: # Restore states random.setstate(old_random_state) if self._np and old_np_state: np.random.set_state(old_np_state)
[docs] def get_sklearn_random_state(self, name: str) -> int: """ Get a random state for scikit-learn. Scikit-learn uses integers for random_state parameter. Parameters ---------- name : str Generator name Returns ------- int Random state integer for sklearn Examples -------- >>> rng_manager = RandomStateManager(42) >>> from sklearn.model_selection import train_test_split >>> X_train, X_test = train_test_split( ... X, test_size=0.2, ... random_state=rng_manager.get_sklearn_random_state("split") ... ) """ # Create deterministic seed from name name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16) seed = (self.seed + name_hash) % (2**32) return seed
[docs] def get_torch_generator(self, name: str): """ Get or create a named PyTorch generator. Parameters ---------- name : str Generator name Returns ------- torch.Generator PyTorch generator with deterministic seed Examples -------- >>> rng_manager = RandomStateManager(42) >>> gen = rng_manager.get_torch_generator("model") >>> torch.randn(5, 5, generator=gen) """ try: import torch except ImportError: raise ImportError("PyTorch not installed") if not hasattr(self, "_torch_generators"): self._torch_generators = {} if name not in self._torch_generators: # Create deterministic seed from name name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16) seed = (self.seed + name_hash) % (2**32) gen = torch.Generator() gen.manual_seed(seed) self._torch_generators[name] = gen return self._torch_generators[name]
[docs] def get_generator(self, name: str): """Alias for get_np_generator for compatibility.""" return self.get_np_generator(name)
[docs] def clear_cache(self, patterns: Union[str, List[str], None] = None) -> int: """ Clear verification cache files. Parameters ---------- patterns : str or list of str, optional Specific cache patterns to clear. If None, clears all. Can be: - Single name: "my_data" - List of names: ["data1", "data2"] - Glob pattern: "experiment_*" - None: clear all cache files Returns ------- int Number of cache files removed Examples -------- >>> rng_manager = RandomStateManager(42) >>> rng_manager.clear_cache() # Clear all >>> rng_manager.clear_cache("old_data") # Clear specific >>> rng_manager.clear_cache(["test1", "test2"]) # Clear multiple >>> rng_manager.clear_cache("experiment_*") # Clear pattern """ if not self._cache_dir.exists(): return 0 removed_count = 0 if patterns is None: # Clear all .json files cache_files = list(self._cache_dir.glob("*.json")) for cache_file in cache_files: cache_file.unlink() removed_count += 1 else: # Ensure patterns is a list if isinstance(patterns, str): patterns = [patterns] for pattern in patterns: # Handle glob patterns if "*" in pattern or "?" in pattern: cache_files = list(self._cache_dir.glob(f"{pattern}.json")) else: # Exact match cache_file = self._cache_dir / f"{pattern}.json" cache_files = [cache_file] if cache_file.exists() else [] for cache_file in cache_files: cache_file.unlink() removed_count += 1 return removed_count
[docs] def get(verbose: bool = False) -> RandomStateManager: """ Get or create the global RandomStateManager instance. Parameters ---------- verbose : bool, optional Whether to print status messages (default: False) Returns ------- RandomStateManager Global instance Examples -------- >>> from scitex_core.repro import get >>> rng_manager = get() >>> data = rng_manager("data").random(100) """ global _GLOBAL_INSTANCE if _GLOBAL_INSTANCE is None: _GLOBAL_INSTANCE = RandomStateManager(42, verbose=verbose) return _GLOBAL_INSTANCE
[docs] def reset(seed: int = 42, verbose: bool = False) -> RandomStateManager: """ Reset global RandomStateManager with new seed. Parameters ---------- seed : int New seed value verbose : bool, optional Whether to print status messages (default: False) Returns ------- RandomStateManager New global instance Examples -------- >>> from scitex_core.repro import reset >>> rng_manager = reset(seed=123) """ global _GLOBAL_INSTANCE _GLOBAL_INSTANCE = RandomStateManager(seed, verbose=verbose) return _GLOBAL_INSTANCE
__all__ = ["RandomStateManager", "get", "reset"] # EOF