docs for muutils v0.9.1
View Source on GitHub

muutils.math.bins


 1from __future__ import annotations
 2
 3from dataclasses import dataclass
 4from functools import cached_property
 5from typing import Literal
 6
 7import numpy as np
 8from jaxtyping import Float
 9
10
11@dataclass(frozen=True)
12class Bins:
13    n_bins: int = 32
14    start: float = 0
15    stop: float = 1.0
16    scale: Literal["lin", "log"] = "log"
17
18    _log_min: float = 1e-3
19    _zero_in_small_start_log: bool = True
20
21    @cached_property
22    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
23        if self.scale == "lin":
24            return np.linspace(self.start, self.stop, self.n_bins + 1)
25        elif self.scale == "log":
26            if self.start < 0:
27                raise ValueError(
28                    f"start must be positive for log scale, got {self.start}"
29                )
30            if self.start == 0:
31                return np.concatenate(
32                    [  # pyright: ignore[reportUnknownArgumentType]
33                        np.array([0]),
34                        np.logspace(
35                            np.log10(self._log_min),  # pyright: ignore[reportAny]
36                            np.log10(self.stop),  # pyright: ignore[reportAny]
37                            self.n_bins,
38                        ),
39                    ]
40                )
41            elif self.start < self._log_min and self._zero_in_small_start_log:
42                return np.concatenate(
43                    [  # pyright: ignore[reportUnknownArgumentType]
44                        np.array([0]),
45                        np.logspace(
46                            np.log10(self.start),  # pyright: ignore[reportAny]
47                            np.log10(self.stop),  # pyright: ignore[reportAny]
48                            self.n_bins,
49                        ),
50                    ]
51                )
52            else:
53                return np.logspace(  # pyright: ignore[reportUnknownVariableType]
54                    np.log10(self.start),  # pyright: ignore[reportAny]
55                    np.log10(self.stop),  # pyright: ignore[reportAny]
56                    self.n_bins + 1,
57                )
58        else:
59            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
60
61    @cached_property
62    def centers(self) -> Float[np.ndarray, "n_bins"]:
63        return (self.edges[:-1] + self.edges[1:]) / 2
64
65    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
66        return Bins(
67            n_bins=n_bins,
68            start=self.start,
69            stop=self.stop,
70            scale=self.scale,
71            _log_min=self._log_min,
72            _zero_in_small_start_log=self._zero_in_small_start_log,
73        )

@dataclass(frozen=True)
class Bins:
12@dataclass(frozen=True)
13class Bins:
14    n_bins: int = 32
15    start: float = 0
16    stop: float = 1.0
17    scale: Literal["lin", "log"] = "log"
18
19    _log_min: float = 1e-3
20    _zero_in_small_start_log: bool = True
21
22    @cached_property
23    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
24        if self.scale == "lin":
25            return np.linspace(self.start, self.stop, self.n_bins + 1)
26        elif self.scale == "log":
27            if self.start < 0:
28                raise ValueError(
29                    f"start must be positive for log scale, got {self.start}"
30                )
31            if self.start == 0:
32                return np.concatenate(
33                    [  # pyright: ignore[reportUnknownArgumentType]
34                        np.array([0]),
35                        np.logspace(
36                            np.log10(self._log_min),  # pyright: ignore[reportAny]
37                            np.log10(self.stop),  # pyright: ignore[reportAny]
38                            self.n_bins,
39                        ),
40                    ]
41                )
42            elif self.start < self._log_min and self._zero_in_small_start_log:
43                return np.concatenate(
44                    [  # pyright: ignore[reportUnknownArgumentType]
45                        np.array([0]),
46                        np.logspace(
47                            np.log10(self.start),  # pyright: ignore[reportAny]
48                            np.log10(self.stop),  # pyright: ignore[reportAny]
49                            self.n_bins,
50                        ),
51                    ]
52                )
53            else:
54                return np.logspace(  # pyright: ignore[reportUnknownVariableType]
55                    np.log10(self.start),  # pyright: ignore[reportAny]
56                    np.log10(self.stop),  # pyright: ignore[reportAny]
57                    self.n_bins + 1,
58                )
59        else:
60            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
61
62    @cached_property
63    def centers(self) -> Float[np.ndarray, "n_bins"]:
64        return (self.edges[:-1] + self.edges[1:]) / 2
65
66    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
67        return Bins(
68            n_bins=n_bins,
69            start=self.start,
70            stop=self.stop,
71            scale=self.scale,
72            _log_min=self._log_min,
73            _zero_in_small_start_log=self._zero_in_small_start_log,
74        )
Bins( n_bins: int = 32, start: float = 0, stop: float = 1.0, scale: Literal['lin', 'log'] = 'log', _log_min: float = 0.001, _zero_in_small_start_log: bool = True)
n_bins: int = 32
start: float = 0
stop: float = 1.0
scale: Literal['lin', 'log'] = 'log'
edges: jaxtyping.Float[ndarray, 'n_bins+1']
22    @cached_property
23    def edges(self) -> Float[np.ndarray, "n_bins+1"]:
24        if self.scale == "lin":
25            return np.linspace(self.start, self.stop, self.n_bins + 1)
26        elif self.scale == "log":
27            if self.start < 0:
28                raise ValueError(
29                    f"start must be positive for log scale, got {self.start}"
30                )
31            if self.start == 0:
32                return np.concatenate(
33                    [  # pyright: ignore[reportUnknownArgumentType]
34                        np.array([0]),
35                        np.logspace(
36                            np.log10(self._log_min),  # pyright: ignore[reportAny]
37                            np.log10(self.stop),  # pyright: ignore[reportAny]
38                            self.n_bins,
39                        ),
40                    ]
41                )
42            elif self.start < self._log_min and self._zero_in_small_start_log:
43                return np.concatenate(
44                    [  # pyright: ignore[reportUnknownArgumentType]
45                        np.array([0]),
46                        np.logspace(
47                            np.log10(self.start),  # pyright: ignore[reportAny]
48                            np.log10(self.stop),  # pyright: ignore[reportAny]
49                            self.n_bins,
50                        ),
51                    ]
52                )
53            else:
54                return np.logspace(  # pyright: ignore[reportUnknownVariableType]
55                    np.log10(self.start),  # pyright: ignore[reportAny]
56                    np.log10(self.stop),  # pyright: ignore[reportAny]
57                    self.n_bins + 1,
58                )
59        else:
60            raise ValueError(f"Invalid scale {self.scale}, expected lin or log")
centers: jaxtyping.Float[ndarray, 'n_bins']
62    @cached_property
63    def centers(self) -> Float[np.ndarray, "n_bins"]:
64        return (self.edges[:-1] + self.edges[1:]) / 2
def changed_n_bins_copy(self, n_bins: int) -> Bins:
66    def changed_n_bins_copy(self, n_bins: int) -> "Bins":
67        return Bins(
68            n_bins=n_bins,
69            start=self.start,
70            stop=self.stop,
71            scale=self.scale,
72            _log_min=self._log_min,
73            _zero_in_small_start_log=self._zero_in_small_start_log,
74        )