#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# File: ./src/scitex_core/repro/_RandomStateManager.py
"""
Clean, simple RandomStateManager for scientific reproducibility.
Manages random number generators across multiple libraries (numpy, torch,
tensorflow, jax) with deterministic seeding and verification capabilities.
Main API:
rng_manager = RandomStateManager(seed=42) # Create instance
gen = rng_manager("name") # Get named generator
rng_manager.verify(obj, "name") # Verify reproducibility
"""
import hashlib
import json
import os
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
# Global singleton instance
_GLOBAL_INSTANCE = None
[docs]
class RandomStateManager:
"""
Simple, robust random state manager for scientific computing.
Provides centralized management of random number generators with
deterministic seeding across multiple ML/scientific libraries.
Parameters
----------
seed : int, optional
Master seed for all random number generators (default: 42)
verbose : bool, optional
Print status messages (default: False)
Examples
--------
>>> from scitex_core.repro import RandomStateManager
>>>
>>> # Direct usage
>>> rng_manager = RandomStateManager(seed=42)
>>> gen = rng_manager("data")
>>> data = gen.random(100)
>>>
>>> # Verify reproducibility
>>> rng_manager.verify(data, "my_data")
>>>
>>> # Named generators for different purposes
>>> data_gen = rng_manager("data")
>>> model_gen = rng_manager("model")
>>> augment_gen = rng_manager("augment")
Notes
-----
- Automatically detects and seeds available libraries (numpy, torch, tf, jax)
- Creates independent named generators for different experiment components
- Verification cache stored in ~/.scitex/rng/
"""
[docs]
def __init__(self, seed: int = 42, verbose: bool = False):
"""Initialize with automatic module detection."""
self.seed = seed
self.verbose = verbose
self._generators = {}
self._cache_dir = Path.home() / ".scitex" / "rng"
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._jax_key = None
self._np = None
if verbose:
print(f"RandomStateManager initialized with seed {seed}")
# Auto-fix all available seeds
self._auto_fix_seeds(verbose=verbose)
[docs]
def _auto_fix_seeds(self, verbose: Optional[bool] = None):
"""Automatically detect and fix ALL available random modules."""
# Use instance verbose if not specified
if verbose is None:
verbose = self.verbose
# OS environment
os.environ["PYTHONHASHSEED"] = str(self.seed)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
fixed_modules = []
# Python random
try:
import random
random.seed(self.seed)
fixed_modules.append("random")
except ImportError:
pass
# NumPy
try:
import numpy as np
np.random.seed(self.seed)
self._np = np
self._np_default_rng_manager = np.random.default_rng(self.seed)
fixed_modules.append("numpy")
except ImportError:
self._np = None
# PyTorch
try:
import torch
torch.manual_seed(self.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
fixed_modules.append("torch+cuda")
else:
fixed_modules.append("torch")
except ImportError:
pass
# TensorFlow
try:
import tensorflow as tf
tf.random.set_seed(self.seed)
fixed_modules.append("tensorflow")
except ImportError:
pass
# JAX
try:
import jax
self._jax_key = jax.random.PRNGKey(self.seed)
fixed_modules.append("jax")
except (ImportError, AttributeError, RuntimeError):
self._jax_key = None
if verbose and fixed_modules:
print(f"Fixed random seeds for: {', '.join(fixed_modules)}")
[docs]
def get_np_generator(self, name: str):
"""
Get or create a named NumPy random generator.
Parameters
----------
name : str
Generator name (e.g., "data", "model", "augment")
Returns
-------
numpy.random.Generator
Independent NumPy random generator
Examples
--------
>>> rng_manager = RandomStateManager(42)
>>> gen = rng_manager.get_np_generator("data")
>>> values = gen.random(100)
>>> perm = gen.permutation(100)
"""
if self._np is None:
raise ImportError("NumPy required for random generators")
if name not in self._generators:
# Create deterministic seed from name
name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16)
seed = (self.seed + name_hash) % (2**32)
self._generators[name] = self._np.random.default_rng(seed)
return self._generators[name]
[docs]
def __call__(self, name: str, verbose: bool = None):
"""
Get or create a named NumPy random generator.
This is a convenience wrapper for get_np_generator().
Parameters
----------
name : str
Generator name
verbose : bool, optional
Whether to show deprecation warning
Returns
-------
numpy.random.Generator
NumPy random generator with deterministic seed
"""
if verbose:
print(
f"Note: rng('{name}') is deprecated. "
f"Use rng.get_np_generator('{name}') instead."
)
return self.get_np_generator(name)
[docs]
def verify(self, obj: Any, name: Optional[str] = None, verbose: bool = True) -> bool:
"""
Verify object matches cached hash (detects broken reproducibility).
First call: caches the object's hash
Later calls: verifies object matches cached hash
Parameters
----------
obj : Any
Object to verify (array, tensor, data, model weights, etc.)
Supports: numpy arrays, torch tensors, tf tensors, jax arrays,
lists, dicts, pandas dataframes, and basic types
name : str, optional
Cache name. Auto-generated from caller location if not provided.
verbose : bool, optional
Print verification results (default: True)
Returns
-------
bool
True if matches cache (or first call), False if different
Raises
------
ValueError
If verification fails (object doesn't match cached hash)
Examples
--------
>>> data = generate_data()
>>> rng_manager.verify(data, "train_data") # First run: caches
>>> # Next run:
>>> rng_manager.verify(data, "train_data") # Verifies match
"""
import numpy as np
# Auto-generate name if needed
if name is None:
import inspect
frame = inspect.currentframe().f_back
filename = Path(frame.f_code.co_filename).stem
lineno = frame.f_lineno
name = f"{filename}_L{lineno}"
# Sanitize name
safe_name = "".join(c if c.isalnum() or c in "-_" else "_" for c in name)
cache_file = self._cache_dir / f"{safe_name}.json"
# Compute hash based on object type
obj_hash = self._compute_hash(obj)
# Use instance verbose if not specified
if verbose is None:
verbose = self.verbose
# Check cache
if cache_file.exists():
with open(cache_file, "r") as f:
cached = json.load(f)
matches = cached["hash"] == obj_hash
if not matches and verbose:
print(f"⚠️ Reproducibility broken for '{name}'!")
print(f" Expected: {cached['hash'][:16]}...")
print(f" Got: {obj_hash[:16]}...")
raise ValueError(f"Reproducibility verification failed for '{name}'")
elif matches and verbose:
print(f"✓ Reproducibility verified for '{name}'")
return matches
else:
# First call - cache it
with open(cache_file, "w") as f:
json.dump({"name": name, "hash": obj_hash, "seed": self.seed}, f)
return True
[docs]
def _compute_hash(self, obj: Any) -> str:
"""
Compute hash for various object types.
Supports:
- NumPy arrays
- PyTorch tensors
- TensorFlow tensors
- JAX arrays
- Pandas DataFrames/Series
- Lists, tuples, dicts
- Basic types (int, float, str, bool)
"""
import numpy as np
# NumPy array
if isinstance(obj, np.ndarray):
return hashlib.sha256(obj.tobytes()).hexdigest()[:32]
# PyTorch tensor
try:
import torch
if isinstance(obj, torch.Tensor):
obj_np = obj.detach().cpu().numpy()
return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32]
except ImportError:
pass
# TensorFlow tensor
try:
import tensorflow as tf
if isinstance(obj, (tf.Tensor, tf.Variable)):
obj_np = obj.numpy()
return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32]
except ImportError:
pass
# JAX array
try:
import jax.numpy as jnp
if isinstance(obj, jnp.ndarray):
obj_np = np.array(obj)
return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32]
except ImportError:
pass
# Pandas DataFrame/Series
try:
import pandas as pd
if isinstance(obj, (pd.DataFrame, pd.Series)):
obj_str = obj.to_json(orient="split", date_format="iso")
return hashlib.sha256(obj_str.encode()).hexdigest()[:32]
except ImportError:
pass
# Lists and tuples - convert to numpy array if numeric
if isinstance(obj, (list, tuple)):
try:
obj_np = np.array(obj)
if obj_np.dtype != object: # Numeric array
return hashlib.sha256(obj_np.tobytes()).hexdigest()[:32]
except:
pass
# Dictionaries - serialize to JSON
if isinstance(obj, dict):
try:
obj_str = json.dumps(obj, sort_keys=True, default=str)
return hashlib.sha256(obj_str.encode()).hexdigest()[:32]
except:
pass
# Default: convert to string
obj_str = str(obj)
return hashlib.sha256(obj_str.encode()).hexdigest()[:32]
[docs]
def checkpoint(self, name: str = "checkpoint"):
"""
Save current state of all generators.
Parameters
----------
name : str, optional
Checkpoint name (default: "checkpoint")
Returns
-------
Path
Path to checkpoint file
"""
checkpoint_file = self._cache_dir / f"{name}.pkl"
state = {
"seed": self.seed,
"generators": {
k: v.bit_generator.state for k, v in self._generators.items()
},
}
with open(checkpoint_file, "wb") as f:
pickle.dump(state, f)
return checkpoint_file
[docs]
def restore(self, checkpoint: Union[str, Path]):
"""
Restore from checkpoint.
Parameters
----------
checkpoint : str or Path
Path to checkpoint file
"""
if isinstance(checkpoint, str):
checkpoint = Path(checkpoint)
with open(checkpoint, "rb") as f:
state = pickle.load(f)
self.seed = state["seed"]
self._auto_fix_seeds()
# Restore generator states
for name, gen_state in state["generators"].items():
gen = self(name)
gen.bit_generator.state = gen_state
[docs]
@contextmanager
def temporary_seed(self, seed: int):
"""
Context manager for temporary seed change.
Parameters
----------
seed : int
Temporary seed value
Examples
--------
>>> rng_manager = RandomStateManager(42)
>>> with rng_manager.temporary_seed(123):
... data = np.random.random(10)
"""
import random
import numpy as np
# Save current states
old_random_state = random.getstate()
old_np_state = np.random.get_state() if self._np else None
# Set temporary seed
random.seed(seed)
if self._np:
np.random.seed(seed)
try:
yield
finally:
# Restore states
random.setstate(old_random_state)
if self._np and old_np_state:
np.random.set_state(old_np_state)
[docs]
def get_sklearn_random_state(self, name: str) -> int:
"""
Get a random state for scikit-learn.
Scikit-learn uses integers for random_state parameter.
Parameters
----------
name : str
Generator name
Returns
-------
int
Random state integer for sklearn
Examples
--------
>>> rng_manager = RandomStateManager(42)
>>> from sklearn.model_selection import train_test_split
>>> X_train, X_test = train_test_split(
... X, test_size=0.2,
... random_state=rng_manager.get_sklearn_random_state("split")
... )
"""
# Create deterministic seed from name
name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16)
seed = (self.seed + name_hash) % (2**32)
return seed
[docs]
def get_torch_generator(self, name: str):
"""
Get or create a named PyTorch generator.
Parameters
----------
name : str
Generator name
Returns
-------
torch.Generator
PyTorch generator with deterministic seed
Examples
--------
>>> rng_manager = RandomStateManager(42)
>>> gen = rng_manager.get_torch_generator("model")
>>> torch.randn(5, 5, generator=gen)
"""
try:
import torch
except ImportError:
raise ImportError("PyTorch not installed")
if not hasattr(self, "_torch_generators"):
self._torch_generators = {}
if name not in self._torch_generators:
# Create deterministic seed from name
name_hash = int(hashlib.md5(name.encode()).hexdigest()[:8], 16)
seed = (self.seed + name_hash) % (2**32)
gen = torch.Generator()
gen.manual_seed(seed)
self._torch_generators[name] = gen
return self._torch_generators[name]
[docs]
def get_generator(self, name: str):
"""Alias for get_np_generator for compatibility."""
return self.get_np_generator(name)
[docs]
def clear_cache(self, patterns: Union[str, List[str], None] = None) -> int:
"""
Clear verification cache files.
Parameters
----------
patterns : str or list of str, optional
Specific cache patterns to clear. If None, clears all.
Can be:
- Single name: "my_data"
- List of names: ["data1", "data2"]
- Glob pattern: "experiment_*"
- None: clear all cache files
Returns
-------
int
Number of cache files removed
Examples
--------
>>> rng_manager = RandomStateManager(42)
>>> rng_manager.clear_cache() # Clear all
>>> rng_manager.clear_cache("old_data") # Clear specific
>>> rng_manager.clear_cache(["test1", "test2"]) # Clear multiple
>>> rng_manager.clear_cache("experiment_*") # Clear pattern
"""
if not self._cache_dir.exists():
return 0
removed_count = 0
if patterns is None:
# Clear all .json files
cache_files = list(self._cache_dir.glob("*.json"))
for cache_file in cache_files:
cache_file.unlink()
removed_count += 1
else:
# Ensure patterns is a list
if isinstance(patterns, str):
patterns = [patterns]
for pattern in patterns:
# Handle glob patterns
if "*" in pattern or "?" in pattern:
cache_files = list(self._cache_dir.glob(f"{pattern}.json"))
else:
# Exact match
cache_file = self._cache_dir / f"{pattern}.json"
cache_files = [cache_file] if cache_file.exists() else []
for cache_file in cache_files:
cache_file.unlink()
removed_count += 1
return removed_count
[docs]
def get(verbose: bool = False) -> RandomStateManager:
"""
Get or create the global RandomStateManager instance.
Parameters
----------
verbose : bool, optional
Whether to print status messages (default: False)
Returns
-------
RandomStateManager
Global instance
Examples
--------
>>> from scitex_core.repro import get
>>> rng_manager = get()
>>> data = rng_manager("data").random(100)
"""
global _GLOBAL_INSTANCE
if _GLOBAL_INSTANCE is None:
_GLOBAL_INSTANCE = RandomStateManager(42, verbose=verbose)
return _GLOBAL_INSTANCE
[docs]
def reset(seed: int = 42, verbose: bool = False) -> RandomStateManager:
"""
Reset global RandomStateManager with new seed.
Parameters
----------
seed : int
New seed value
verbose : bool, optional
Whether to print status messages (default: False)
Returns
-------
RandomStateManager
New global instance
Examples
--------
>>> from scitex_core.repro import reset
>>> rng_manager = reset(seed=123)
"""
global _GLOBAL_INSTANCE
_GLOBAL_INSTANCE = RandomStateManager(seed, verbose=verbose)
return _GLOBAL_INSTANCE
__all__ = ["RandomStateManager", "get", "reset"]
# EOF