Source code for pipolars.cache.strategies

"""Caching strategies for PI data.

This module provides different caching strategies for managing
how PI data is cached and retrieved.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Callable
from datetime import datetime, timedelta
from typing import Any

import polars as pl

from pipolars.cache.storage import CacheBackendBase


class CacheStrategy(ABC):
    """Abstract base class for cache strategies.

    Cache strategies determine how data is cached and when
    cached data should be used vs. fetched from the server.
    """

    def __init__(self, backend: CacheBackendBase) -> None:
        """Initialize the strategy.

        Args:
            backend: Cache backend to use
        """
        self._backend = backend

    @property
    def backend(self) -> CacheBackendBase:
        """Get the cache backend."""
        return self._backend

    @abstractmethod
    def get_or_fetch(
        self,
        key: str,
        fetch_func: Callable[[], pl.DataFrame],
    ) -> pl.DataFrame:
        """Get data from cache or fetch from source.

        Args:
            key: Cache key
            fetch_func: Function to fetch data if not cached

        Returns:
            DataFrame from cache or freshly fetched
        """
        pass

    def invalidate(self, key: str) -> bool:
        """Invalidate a cache entry.

        Args:
            key: Cache key

        Returns:
            True if invalidated
        """
        return self._backend.delete(key)

    def clear_all(self) -> None:
        """Clear all cached data."""
        self._backend.clear()


[docs] class TTLStrategy(CacheStrategy): """Time-to-live caching strategy. Cached data expires after a specified duration. Once expired, data is re-fetched from the source. """
[docs] def __init__( self, backend: CacheBackendBase, ttl: timedelta = timedelta(hours=24), ) -> None: """Initialize the TTL strategy. Args: backend: Cache backend ttl: Time-to-live for cached data """ super().__init__(backend) self._ttl = ttl
@property def ttl(self) -> timedelta: """Get the TTL duration.""" return self._ttl
[docs] def get_or_fetch( self, key: str, fetch_func: Callable[[], pl.DataFrame], ) -> pl.DataFrame: """Get data from cache or fetch with TTL.""" cached = self._backend.get(key) if cached is not None: return cached data = fetch_func() self._backend.set(key, data, self._ttl) return data
[docs] def set_with_ttl( self, key: str, data: pl.DataFrame, ttl: timedelta | None = None, ) -> None: """Set data with custom TTL. Args: key: Cache key data: Data to cache ttl: Custom TTL (uses default if None) """ self._backend.set(key, data, ttl or self._ttl)
class SlidingWindowStrategy(CacheStrategy): """Sliding window caching strategy for time-series data. This strategy is optimized for time-series queries where the time range slides forward over time. It maintains overlapping cached data to minimize re-fetching. """ def __init__( self, backend: CacheBackendBase, window_size: timedelta = timedelta(hours=24), overlap: timedelta = timedelta(hours=1), ) -> None: """Initialize the sliding window strategy. Args: backend: Cache backend window_size: Size of each cache window overlap: Overlap between windows """ super().__init__(backend) self._window_size = window_size self._overlap = overlap def get_or_fetch( self, key: str, fetch_func: Callable[[], pl.DataFrame], ) -> pl.DataFrame: """Get data using sliding window logic.""" cached = self._backend.get(key) if cached is not None: return cached data = fetch_func() self._backend.set(key, data, self._window_size) return data def get_time_range_data( self, tag: str, start: datetime, end: datetime, fetch_func: Callable[[datetime, datetime], pl.DataFrame], timestamp_col: str = "timestamp", ) -> pl.DataFrame: """Get data for a time range, using cached windows. Args: tag: Tag name start: Start time end: End time fetch_func: Function to fetch data for a time range timestamp_col: Name of the timestamp column Returns: Combined DataFrame for the full time range """ # Calculate window boundaries windows = self._calculate_windows(start, end) all_data = [] for window_start, window_end in windows: key = self._backend.generate_key( tag, window_start.isoformat(), window_end.isoformat(), "window", ) cached = self._backend.get(key) if cached is not None: # Filter to requested range filtered = cached.filter( (pl.col(timestamp_col) >= start) & (pl.col(timestamp_col) <= end) ) all_data.append(filtered) else: # Fetch and cache the full window data = fetch_func(window_start, window_end) self._backend.set(key, data, self._window_size) # Filter to requested range filtered = data.filter( (pl.col(timestamp_col) >= start) & (pl.col(timestamp_col) <= end) ) all_data.append(filtered) if not all_data: return pl.DataFrame() # Combine and deduplicate combined = pl.concat(all_data) return combined.unique(subset=[timestamp_col]).sort(timestamp_col) def _calculate_windows( self, start: datetime, end: datetime, ) -> list[tuple[datetime, datetime]]: """Calculate cache window boundaries. Args: start: Query start time end: Query end time Returns: List of (window_start, window_end) tuples """ windows = [] current_start = self._align_to_window(start) while current_start < end: window_end = current_start + self._window_size windows.append((current_start, window_end)) current_start = window_end - self._overlap return windows def _align_to_window(self, dt: datetime) -> datetime: """Align a datetime to a window boundary. Args: dt: Datetime to align Returns: Aligned datetime """ # Round down to the nearest window boundary window_seconds = self._window_size.total_seconds() timestamp = dt.timestamp() aligned_timestamp = (timestamp // window_seconds) * window_seconds return datetime.fromtimestamp(aligned_timestamp, tz=dt.tzinfo) class SmartCacheStrategy(CacheStrategy): """Smart caching strategy that adapts based on query patterns. This strategy analyzes query patterns and adjusts caching behavior for optimal performance. """ def __init__( self, backend: CacheBackendBase, default_ttl: timedelta = timedelta(hours=24), hot_data_ttl: timedelta = timedelta(minutes=5), historical_ttl: timedelta = timedelta(days=30), hot_data_threshold: timedelta = timedelta(hours=1), ) -> None: """Initialize the smart cache strategy. Args: backend: Cache backend default_ttl: Default TTL for normal data hot_data_ttl: TTL for recent/hot data historical_ttl: TTL for historical data hot_data_threshold: Threshold for considering data "hot" """ super().__init__(backend) self._default_ttl = default_ttl self._hot_data_ttl = hot_data_ttl self._historical_ttl = historical_ttl self._hot_data_threshold = hot_data_threshold self._query_history: list[dict[str, Any]] = [] def get_or_fetch( self, key: str, fetch_func: Callable[[], pl.DataFrame], ) -> pl.DataFrame: """Get data with smart TTL selection.""" cached = self._backend.get(key) if cached is not None: return cached data = fetch_func() ttl = self._select_ttl(key) self._backend.set(key, data, ttl) return data def get_with_time_range( self, key: str, start: datetime, end: datetime, fetch_func: Callable[[], pl.DataFrame], ) -> pl.DataFrame: """Get data with time-range-aware TTL. Args: key: Cache key start: Query start time end: Query end time fetch_func: Function to fetch data Returns: DataFrame """ cached = self._backend.get(key) if cached is not None: return cached data = fetch_func() ttl = self._select_ttl_for_range(start, end) self._backend.set(key, data, ttl) # Record query for pattern analysis self._record_query(key, start, end) return data def _select_ttl(self, _key: str) -> timedelta: """Select appropriate TTL based on key patterns.""" return self._default_ttl def _select_ttl_for_range( self, start: datetime, end: datetime, ) -> timedelta: """Select TTL based on time range. Recent data gets shorter TTL as it may still be updating. Historical data gets longer TTL as it's unlikely to change. """ now = datetime.now(tz=start.tzinfo) # If end time is recent (hot data) if now - end < self._hot_data_threshold: return self._hot_data_ttl # If data is old (historical) if now - end > timedelta(days=7): return self._historical_ttl return self._default_ttl def _record_query( self, key: str, start: datetime, end: datetime, ) -> None: """Record a query for pattern analysis.""" self._query_history.append({ "key": key, "start": start, "end": end, "timestamp": datetime.now(), }) # Keep only recent history if len(self._query_history) > 1000: self._query_history = self._query_history[-500:] def get_popular_patterns(self) -> list[dict[str, Any]]: """Analyze query history for popular patterns. Returns: List of popular query patterns """ # Simple frequency analysis from collections import Counter key_counts = Counter(q["key"] for q in self._query_history) return [ {"key": key, "count": count} for key, count in key_counts.most_common(10) ] def prefetch_popular( self, fetch_func: Callable[[str], pl.DataFrame], ) -> int: """Prefetch data for popular query patterns. Args: fetch_func: Function to fetch data by key Returns: Number of patterns prefetched """ patterns = self.get_popular_patterns() prefetched = 0 for pattern in patterns: key = pattern["key"] if not self._backend.exists(key): try: data = fetch_func(key) self._backend.set(key, data, self._default_ttl) prefetched += 1 except Exception: pass return prefetched