mgplot.bar_plot

Create bar plots using Matplotlib.

Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.

  1"""Create bar plots using Matplotlib.
  2
  3Note: bar plots in Matplotlib are not the same as bar charts in other
  4libraries. Bar plots are used to represent categorical data with
  5rectangular bars. As a result, bar plots and line plots typically
  6cannot be plotted on the same axes.
  7"""
  8
  9from collections.abc import Sequence
 10from typing import Any, Final, NotRequired, TypedDict, Unpack
 11
 12import matplotlib.patheffects as pe
 13import matplotlib.pyplot as plt
 14import numpy as np
 15from matplotlib.axes import Axes
 16from pandas import DataFrame, Period, Series
 17
 18from mgplot.axis_utils import map_periodindex, set_labels
 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs
 20from mgplot.settings import DataT, get_setting
 21from mgplot.utilities import (
 22    apply_defaults,
 23    constrain_data,
 24    default_rounding,
 25    get_axes,
 26    get_color_list,
 27)
 28
 29# --- constants
 30ME: Final[str] = "bar_plot"
 31MAX_ANNOTATIONS: Final[int] = 30
 32ADJUSTMENT_FACTOR: Final[float] = 0.02
 33MIN_BAR_WIDTH: Final[float] = 0.0
 34MAX_BAR_WIDTH: Final[float] = 1.0
 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
 36DEFAULT_BAR_OFFSET: Final[float] = 0.5
 37DEFAULT_MAX_TICKS: Final[int] = 10
 38
 39
 40class BarKwargs(BaseKwargs):
 41    """Keyword arguments for the bar_plot function."""
 42
 43    # --- options for the entire bar plot
 44    ax: NotRequired[Axes | None]
 45    stacked: NotRequired[bool]
 46    max_ticks: NotRequired[int]
 47    plot_from: NotRequired[int | Period]
 48    label_rotation: NotRequired[int | float]
 49    # --- options for each bar ...
 50    color: NotRequired[str | Sequence[str]]
 51    label_series: NotRequired[bool | Sequence[bool]]
 52    width: NotRequired[float | int | Sequence[float | int]]
 53    zorder: NotRequired[int | float | Sequence[int | float]]
 54    # --- options for bar annotations
 55    annotate: NotRequired[bool]
 56    fontsize: NotRequired[int | float | str]
 57    fontname: NotRequired[str]
 58    rounding: NotRequired[int]
 59    rotation: NotRequired[int | float]
 60    annotate_color: NotRequired[str]
 61    above: NotRequired[bool]
 62
 63
 64# --- functions
 65class AnnoKwargs(TypedDict, total=False):
 66    """TypedDict for the kwargs used in annotate_bars."""
 67
 68    annotate: bool
 69    fontsize: int | float | str
 70    fontname: str
 71    color: str
 72    rotation: int | float
 73    foreground: str  # used for stroke effect on text
 74    above: bool
 75    rounding: bool | int  # if True, uses default rounding; if int, uses that value
 76
 77
 78def annotate_bars(
 79    series: Series,
 80    offset: float,
 81    base: np.ndarray,
 82    axes: Axes,
 83    **anno_kwargs: Unpack[AnnoKwargs],
 84) -> None:
 85    """Bar plot annotations.
 86
 87    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 88    """
 89    # --- only annotate in limited circumstances
 90    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 91        return
 92    max_annotations = MAX_ANNOTATIONS
 93    if len(series) > max_annotations:
 94        return
 95
 96    # --- internal logic check
 97    if len(base) != len(series):
 98        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 99        return
100
101    # --- assemble the annotation parameters
102    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
103    annotate_style: dict[str, Any] = {
104        "fontsize": anno_kwargs.get("fontsize"),
105        "fontname": anno_kwargs.get("fontname"),
106        "color": anno_kwargs.get("color"),
107        "rotation": anno_kwargs.get("rotation"),
108    }
109    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
110    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
111    zero_correction = series.index.min()
112
113    # --- annotate each bar
114    for index, value in zip(series.index.astype(int), series, strict=True):
115        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
116        if above:
117            position += value
118        text = axes.text(
119            x=index + offset,
120            y=position,
121            s=f"{value:.{rounding}f}",
122            ha="center",
123            va="bottom" if value >= 0 else "top",
124            **annotate_style,
125        )
126        if not above and "foreground" in anno_kwargs:
127            # apply a stroke-effect to within bar annotations
128            # to make them more readable with very small bars.
129            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
130
131
132class GroupedKwargs(TypedDict):
133    """TypedDict for the kwargs used in grouped."""
134
135    color: Sequence[str]
136    width: Sequence[float | int]
137    label_series: Sequence[bool]
138    zorder: Sequence[int | float | None]
139
140
141def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
142    """Plot a grouped bar plot."""
143    series_count = len(df.columns)
144
145    for i, col in enumerate(df.columns):
146        series = df[col]
147        if series.isna().all():
148            continue
149        width = kwargs["width"][i]
150        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
151            width = DEFAULT_GROUPED_WIDTH
152        adjusted_width = width / series_count
153        # far-left + margin + halfway through one grouped column
154        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
155        offset = left + (i * adjusted_width)
156        foreground = kwargs["color"][i]
157        axes.bar(
158            x=series.index + offset,
159            height=series,
160            color=foreground,
161            width=adjusted_width,
162            zorder=kwargs["zorder"][i],
163            label=col if kwargs["label_series"][i] else f"_{col}_",
164        )
165        anno_args["foreground"] = foreground
166        annotate_bars(
167            series=series,
168            offset=offset,
169            base=np.zeros(len(series)),
170            axes=axes,
171            **anno_args,
172        )
173
174
175class StackedKwargs(TypedDict):
176    """TypedDict for the kwargs used in stacked."""
177
178    color: Sequence[str]
179    width: Sequence[float | int]
180    label_series: Sequence[bool]
181    zorder: Sequence[int | float | None]
182
183
184def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
185    """Plot a stacked bar plot."""
186    row_count = len(df)
187    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
188    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
189    for i, col in enumerate(df.columns):
190        series = df[col]
191        base = np.where(series >= 0, base_plus, base_minus)
192        foreground = kwargs["color"][i]
193        axes.bar(
194            x=series.index,
195            height=series,
196            bottom=base,
197            color=foreground,
198            width=kwargs["width"][i],
199            zorder=kwargs["zorder"][i],
200            label=col if kwargs["label_series"][i] else f"_{col}_",
201        )
202        anno_args["foreground"] = foreground
203        annotate_bars(
204            series=series,
205            offset=0,
206            base=base,
207            axes=axes,
208            **anno_args,
209        )
210        base_plus += np.where(series >= 0, series, 0)
211        base_minus += np.where(series < 0, series, 0)
212
213
214def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
215    """Create a bar plot from the given data.
216
217    Each column in the DataFrame will be stacked on top of each other,
218    with positive values above zero and negative values below zero.
219
220    Args:
221        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
222        **kwargs: BarKwargs - Additional keyword arguments for customization.
223        (see BarKwargs for details)
224
225    Note: This function does not assume all data is timeseries with a PeriodIndex.
226
227    Returns:
228        axes: Axes - The axes for the plot.
229
230    """
231    # --- check the kwargs
232    report_kwargs(caller=ME, **kwargs)
233    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
234
235    # --- get the data
236    # no call to check_clean_timeseries here, as bar plots are not
237    # necessarily timeseries data. If the data is a Series, it will be
238    # converted to a DataFrame with a single column.
239    df = DataFrame(data)  # really we are only plotting DataFrames
240    df, kwargs_d = constrain_data(df, **kwargs)
241    item_count = len(df.columns)
242
243    # --- deal with complete PeriodIndex indices
244    saved_pi = map_periodindex(df)
245    if saved_pi is not None:
246        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
247
248    # --- set up the default arguments
249    chart_defaults: dict[str, bool | int] = {
250        "stacked": False,
251        "max_ticks": DEFAULT_MAX_TICKS,
252        "label_series": item_count > 1,
253        "xlabel_rotation": 0,
254    }
255    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
256
257    bar_defaults = {
258        "color": get_color_list(item_count),
259        "width": get_setting("bar_width"),
260        "label_series": item_count > 1,
261        "zorder": None,
262    }
263    above = kwargs_d.get("above", False)
264    anno_args: AnnoKwargs = {
265        "annotate": kwargs_d.get("annotate", False),
266        "fontsize": kwargs_d.get("fontsize", "small"),
267        "fontname": kwargs_d.get("fontname", "Helvetica"),
268        "rotation": kwargs_d.get("rotation", 0),
269        "rounding": kwargs_d.get("rounding", True),
270        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
271        "above": above,
272    }
273    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
274
275    # --- plot the data
276    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
277    if chart_args["stacked"]:
278        stacked(axes, df, anno_args, **bar_args)
279    else:
280        grouped(axes, df, anno_args, **bar_args)
281
282    # --- handle complete periodIndex data and label rotation
283    if saved_pi is not None:
284        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
285    plt.xticks(rotation=chart_args["xlabel_rotation"])
286
287    return axes
ME: Final[str] = 'bar_plot'
MAX_ANNOTATIONS: Final[int] = 30
ADJUSTMENT_FACTOR: Final[float] = 0.02
MIN_BAR_WIDTH: Final[float] = 0.0
MAX_BAR_WIDTH: Final[float] = 1.0
DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
DEFAULT_BAR_OFFSET: Final[float] = 0.5
DEFAULT_MAX_TICKS: Final[int] = 10
class BarKwargs(mgplot.keyword_checking.BaseKwargs):
41class BarKwargs(BaseKwargs):
42    """Keyword arguments for the bar_plot function."""
43
44    # --- options for the entire bar plot
45    ax: NotRequired[Axes | None]
46    stacked: NotRequired[bool]
47    max_ticks: NotRequired[int]
48    plot_from: NotRequired[int | Period]
49    label_rotation: NotRequired[int | float]
50    # --- options for each bar ...
51    color: NotRequired[str | Sequence[str]]
52    label_series: NotRequired[bool | Sequence[bool]]
53    width: NotRequired[float | int | Sequence[float | int]]
54    zorder: NotRequired[int | float | Sequence[int | float]]
55    # --- options for bar annotations
56    annotate: NotRequired[bool]
57    fontsize: NotRequired[int | float | str]
58    fontname: NotRequired[str]
59    rounding: NotRequired[int]
60    rotation: NotRequired[int | float]
61    annotate_color: NotRequired[str]
62    above: NotRequired[bool]

Keyword arguments for the bar_plot function.

ax: NotRequired[matplotlib.axes._axes.Axes | None]
stacked: NotRequired[bool]
max_ticks: NotRequired[int]
plot_from: NotRequired[int | pandas._libs.tslibs.period.Period]
label_rotation: NotRequired[int | float]
color: NotRequired[str | Sequence[str]]
label_series: NotRequired[bool | Sequence[bool]]
width: NotRequired[float | int | Sequence[float | int]]
zorder: NotRequired[float | int | Sequence[float | int]]
annotate: NotRequired[bool]
fontsize: NotRequired[int | float | str]
fontname: NotRequired[str]
rounding: NotRequired[int]
rotation: NotRequired[int | float]
annotate_color: NotRequired[str]
above: NotRequired[bool]
class AnnoKwargs(typing.TypedDict):
66class AnnoKwargs(TypedDict, total=False):
67    """TypedDict for the kwargs used in annotate_bars."""
68
69    annotate: bool
70    fontsize: int | float | str
71    fontname: str
72    color: str
73    rotation: int | float
74    foreground: str  # used for stroke effect on text
75    above: bool
76    rounding: bool | int  # if True, uses default rounding; if int, uses that value

TypedDict for the kwargs used in annotate_bars.

annotate: bool
fontsize: int | float | str
fontname: str
color: str
rotation: int | float
foreground: str
above: bool
rounding: bool | int
def annotate_bars( series: pandas.core.series.Series, offset: float, base: numpy.ndarray, axes: matplotlib.axes._axes.Axes, **anno_kwargs: Unpack[AnnoKwargs]) -> None:
 79def annotate_bars(
 80    series: Series,
 81    offset: float,
 82    base: np.ndarray,
 83    axes: Axes,
 84    **anno_kwargs: Unpack[AnnoKwargs],
 85) -> None:
 86    """Bar plot annotations.
 87
 88    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 89    """
 90    # --- only annotate in limited circumstances
 91    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 92        return
 93    max_annotations = MAX_ANNOTATIONS
 94    if len(series) > max_annotations:
 95        return
 96
 97    # --- internal logic check
 98    if len(base) != len(series):
 99        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
100        return
101
102    # --- assemble the annotation parameters
103    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
104    annotate_style: dict[str, Any] = {
105        "fontsize": anno_kwargs.get("fontsize"),
106        "fontname": anno_kwargs.get("fontname"),
107        "color": anno_kwargs.get("color"),
108        "rotation": anno_kwargs.get("rotation"),
109    }
110    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
111    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
112    zero_correction = series.index.min()
113
114    # --- annotate each bar
115    for index, value in zip(series.index.astype(int), series, strict=True):
116        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
117        if above:
118            position += value
119        text = axes.text(
120            x=index + offset,
121            y=position,
122            s=f"{value:.{rounding}f}",
123            ha="center",
124            va="bottom" if value >= 0 else "top",
125            **annotate_style,
126        )
127        if not above and "foreground" in anno_kwargs:
128            # apply a stroke-effect to within bar annotations
129            # to make them more readable with very small bars.
130            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])

Bar plot annotations.

Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.

class GroupedKwargs(typing.TypedDict):
133class GroupedKwargs(TypedDict):
134    """TypedDict for the kwargs used in grouped."""
135
136    color: Sequence[str]
137    width: Sequence[float | int]
138    label_series: Sequence[bool]
139    zorder: Sequence[int | float | None]

TypedDict for the kwargs used in grouped.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
zorder: Sequence[int | float | None]
def grouped( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
142def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
143    """Plot a grouped bar plot."""
144    series_count = len(df.columns)
145
146    for i, col in enumerate(df.columns):
147        series = df[col]
148        if series.isna().all():
149            continue
150        width = kwargs["width"][i]
151        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
152            width = DEFAULT_GROUPED_WIDTH
153        adjusted_width = width / series_count
154        # far-left + margin + halfway through one grouped column
155        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
156        offset = left + (i * adjusted_width)
157        foreground = kwargs["color"][i]
158        axes.bar(
159            x=series.index + offset,
160            height=series,
161            color=foreground,
162            width=adjusted_width,
163            zorder=kwargs["zorder"][i],
164            label=col if kwargs["label_series"][i] else f"_{col}_",
165        )
166        anno_args["foreground"] = foreground
167        annotate_bars(
168            series=series,
169            offset=offset,
170            base=np.zeros(len(series)),
171            axes=axes,
172            **anno_args,
173        )

Plot a grouped bar plot.

class StackedKwargs(typing.TypedDict):
176class StackedKwargs(TypedDict):
177    """TypedDict for the kwargs used in stacked."""
178
179    color: Sequence[str]
180    width: Sequence[float | int]
181    label_series: Sequence[bool]
182    zorder: Sequence[int | float | None]

TypedDict for the kwargs used in stacked.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
zorder: Sequence[int | float | None]
def stacked( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
185def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
186    """Plot a stacked bar plot."""
187    row_count = len(df)
188    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
189    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
190    for i, col in enumerate(df.columns):
191        series = df[col]
192        base = np.where(series >= 0, base_plus, base_minus)
193        foreground = kwargs["color"][i]
194        axes.bar(
195            x=series.index,
196            height=series,
197            bottom=base,
198            color=foreground,
199            width=kwargs["width"][i],
200            zorder=kwargs["zorder"][i],
201            label=col if kwargs["label_series"][i] else f"_{col}_",
202        )
203        anno_args["foreground"] = foreground
204        annotate_bars(
205            series=series,
206            offset=0,
207            base=base,
208            axes=axes,
209            **anno_args,
210        )
211        base_plus += np.where(series >= 0, series, 0)
212        base_minus += np.where(series < 0, series, 0)

Plot a stacked bar plot.

def bar_plot( data: ~DataT, **kwargs: Unpack[BarKwargs]) -> matplotlib.axes._axes.Axes:
215def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
216    """Create a bar plot from the given data.
217
218    Each column in the DataFrame will be stacked on top of each other,
219    with positive values above zero and negative values below zero.
220
221    Args:
222        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
223        **kwargs: BarKwargs - Additional keyword arguments for customization.
224        (see BarKwargs for details)
225
226    Note: This function does not assume all data is timeseries with a PeriodIndex.
227
228    Returns:
229        axes: Axes - The axes for the plot.
230
231    """
232    # --- check the kwargs
233    report_kwargs(caller=ME, **kwargs)
234    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
235
236    # --- get the data
237    # no call to check_clean_timeseries here, as bar plots are not
238    # necessarily timeseries data. If the data is a Series, it will be
239    # converted to a DataFrame with a single column.
240    df = DataFrame(data)  # really we are only plotting DataFrames
241    df, kwargs_d = constrain_data(df, **kwargs)
242    item_count = len(df.columns)
243
244    # --- deal with complete PeriodIndex indices
245    saved_pi = map_periodindex(df)
246    if saved_pi is not None:
247        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
248
249    # --- set up the default arguments
250    chart_defaults: dict[str, bool | int] = {
251        "stacked": False,
252        "max_ticks": DEFAULT_MAX_TICKS,
253        "label_series": item_count > 1,
254        "xlabel_rotation": 0,
255    }
256    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
257
258    bar_defaults = {
259        "color": get_color_list(item_count),
260        "width": get_setting("bar_width"),
261        "label_series": item_count > 1,
262        "zorder": None,
263    }
264    above = kwargs_d.get("above", False)
265    anno_args: AnnoKwargs = {
266        "annotate": kwargs_d.get("annotate", False),
267        "fontsize": kwargs_d.get("fontsize", "small"),
268        "fontname": kwargs_d.get("fontname", "Helvetica"),
269        "rotation": kwargs_d.get("rotation", 0),
270        "rounding": kwargs_d.get("rounding", True),
271        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
272        "above": above,
273    }
274    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
275
276    # --- plot the data
277    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
278    if chart_args["stacked"]:
279        stacked(axes, df, anno_args, **bar_args)
280    else:
281        grouped(axes, df, anno_args, **bar_args)
282
283    # --- handle complete periodIndex data and label rotation
284    if saved_pi is not None:
285        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
286    plt.xticks(rotation=chart_args["xlabel_rotation"])
287
288    return axes

Create a bar plot from the given data.

Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.

Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)

Note: This function does not assume all data is timeseries with a PeriodIndex.

Returns: axes: Axes - The axes for the plot.