muutils.math.bins
1from __future__ import annotations 2 3from dataclasses import dataclass 4from functools import cached_property 5from typing import Literal 6 7import numpy as np 8from jaxtyping import Float 9 10 11@dataclass(frozen=True) 12class Bins: 13 n_bins: int = 32 14 start: float = 0 15 stop: float = 1.0 16 scale: Literal["lin", "log"] = "log" 17 18 _log_min: float = 1e-3 19 _zero_in_small_start_log: bool = True 20 21 @cached_property 22 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 23 if self.scale == "lin": 24 return np.linspace(self.start, self.stop, self.n_bins + 1) 25 elif self.scale == "log": 26 if self.start < 0: 27 raise ValueError( 28 f"start must be positive for log scale, got {self.start}" 29 ) 30 if self.start == 0: 31 return np.concatenate( 32 [ # pyright: ignore[reportUnknownArgumentType] 33 np.array([0]), 34 np.logspace( 35 np.log10(self._log_min), # pyright: ignore[reportAny] 36 np.log10(self.stop), # pyright: ignore[reportAny] 37 self.n_bins, 38 ), 39 ] 40 ) 41 elif self.start < self._log_min and self._zero_in_small_start_log: 42 return np.concatenate( 43 [ # pyright: ignore[reportUnknownArgumentType] 44 np.array([0]), 45 np.logspace( 46 np.log10(self.start), # pyright: ignore[reportAny] 47 np.log10(self.stop), # pyright: ignore[reportAny] 48 self.n_bins, 49 ), 50 ] 51 ) 52 else: 53 return np.logspace( # pyright: ignore[reportUnknownVariableType] 54 np.log10(self.start), # pyright: ignore[reportAny] 55 np.log10(self.stop), # pyright: ignore[reportAny] 56 self.n_bins + 1, 57 ) 58 else: 59 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 60 61 @cached_property 62 def centers(self) -> Float[np.ndarray, "n_bins"]: 63 return (self.edges[:-1] + self.edges[1:]) / 2 64 65 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 66 return Bins( 67 n_bins=n_bins, 68 start=self.start, 69 stop=self.stop, 70 scale=self.scale, 71 _log_min=self._log_min, 72 _zero_in_small_start_log=self._zero_in_small_start_log, 73 )
@dataclass(frozen=True)
class
Bins:
12@dataclass(frozen=True) 13class Bins: 14 n_bins: int = 32 15 start: float = 0 16 stop: float = 1.0 17 scale: Literal["lin", "log"] = "log" 18 19 _log_min: float = 1e-3 20 _zero_in_small_start_log: bool = True 21 22 @cached_property 23 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 24 if self.scale == "lin": 25 return np.linspace(self.start, self.stop, self.n_bins + 1) 26 elif self.scale == "log": 27 if self.start < 0: 28 raise ValueError( 29 f"start must be positive for log scale, got {self.start}" 30 ) 31 if self.start == 0: 32 return np.concatenate( 33 [ # pyright: ignore[reportUnknownArgumentType] 34 np.array([0]), 35 np.logspace( 36 np.log10(self._log_min), # pyright: ignore[reportAny] 37 np.log10(self.stop), # pyright: ignore[reportAny] 38 self.n_bins, 39 ), 40 ] 41 ) 42 elif self.start < self._log_min and self._zero_in_small_start_log: 43 return np.concatenate( 44 [ # pyright: ignore[reportUnknownArgumentType] 45 np.array([0]), 46 np.logspace( 47 np.log10(self.start), # pyright: ignore[reportAny] 48 np.log10(self.stop), # pyright: ignore[reportAny] 49 self.n_bins, 50 ), 51 ] 52 ) 53 else: 54 return np.logspace( # pyright: ignore[reportUnknownVariableType] 55 np.log10(self.start), # pyright: ignore[reportAny] 56 np.log10(self.stop), # pyright: ignore[reportAny] 57 self.n_bins + 1, 58 ) 59 else: 60 raise ValueError(f"Invalid scale {self.scale}, expected lin or log") 61 62 @cached_property 63 def centers(self) -> Float[np.ndarray, "n_bins"]: 64 return (self.edges[:-1] + self.edges[1:]) / 2 65 66 def changed_n_bins_copy(self, n_bins: int) -> "Bins": 67 return Bins( 68 n_bins=n_bins, 69 start=self.start, 70 stop=self.stop, 71 scale=self.scale, 72 _log_min=self._log_min, 73 _zero_in_small_start_log=self._zero_in_small_start_log, 74 )
Bins( n_bins: int = 32, start: float = 0, stop: float = 1.0, scale: Literal['lin', 'log'] = 'log', _log_min: float = 0.001, _zero_in_small_start_log: bool = True)
edges: jaxtyping.Float[ndarray, 'n_bins+1']
22 @cached_property 23 def edges(self) -> Float[np.ndarray, "n_bins+1"]: 24 if self.scale == "lin": 25 return np.linspace(self.start, self.stop, self.n_bins + 1) 26 elif self.scale == "log": 27 if self.start < 0: 28 raise ValueError( 29 f"start must be positive for log scale, got {self.start}" 30 ) 31 if self.start == 0: 32 return np.concatenate( 33 [ # pyright: ignore[reportUnknownArgumentType] 34 np.array([0]), 35 np.logspace( 36 np.log10(self._log_min), # pyright: ignore[reportAny] 37 np.log10(self.stop), # pyright: ignore[reportAny] 38 self.n_bins, 39 ), 40 ] 41 ) 42 elif self.start < self._log_min and self._zero_in_small_start_log: 43 return np.concatenate( 44 [ # pyright: ignore[reportUnknownArgumentType] 45 np.array([0]), 46 np.logspace( 47 np.log10(self.start), # pyright: ignore[reportAny] 48 np.log10(self.stop), # pyright: ignore[reportAny] 49 self.n_bins, 50 ), 51 ] 52 ) 53 else: 54 return np.logspace( # pyright: ignore[reportUnknownVariableType] 55 np.log10(self.start), # pyright: ignore[reportAny] 56 np.log10(self.stop), # pyright: ignore[reportAny] 57 self.n_bins + 1, 58 ) 59 else: 60 raise ValueError(f"Invalid scale {self.scale}, expected lin or log")