mgplot.bar_plot

Create bar plots using Matplotlib.

Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.

  1"""Create bar plots using Matplotlib.
  2
  3Note: bar plots in Matplotlib are not the same as bar charts in other
  4libraries. Bar plots are used to represent categorical data with
  5rectangular bars. As a result, bar plots and line plots typically
  6cannot be plotted on the same axes.
  7"""
  8
  9from collections.abc import Sequence
 10from typing import Any, Final, NotRequired, TypedDict, Unpack
 11
 12import matplotlib.patheffects as pe
 13import matplotlib.pyplot as plt
 14import numpy as np
 15from matplotlib.axes import Axes
 16from pandas import DataFrame, Period, Series
 17
 18from mgplot.axis_utils import map_periodindex, set_labels
 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs
 20from mgplot.settings import DataT, get_setting
 21from mgplot.utilities import (
 22    apply_defaults,
 23    constrain_data,
 24    default_rounding,
 25    get_axes,
 26    get_color_list,
 27)
 28
 29# --- constants
 30ME: Final[str] = "bar_plot"
 31MAX_ANNOTATIONS: Final[int] = 30
 32ADJUSTMENT_FACTOR: Final[float] = 0.02
 33MIN_BAR_WIDTH: Final[float] = 0.0
 34MAX_BAR_WIDTH: Final[float] = 1.0
 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
 36DEFAULT_BAR_OFFSET: Final[float] = 0.5
 37DEFAULT_MAX_TICKS: Final[int] = 10
 38
 39
 40class BarKwargs(BaseKwargs):
 41    """Keyword arguments for the bar_plot function."""
 42
 43    # --- options for the entire bar plot
 44    ax: NotRequired[Axes | None]
 45    stacked: NotRequired[bool]
 46    max_ticks: NotRequired[int]
 47    plot_from: NotRequired[int | Period]
 48    label_rotation: NotRequired[int | float]
 49    # --- options for each bar ...
 50    color: NotRequired[str | Sequence[str]]
 51    label_series: NotRequired[bool | Sequence[bool]]
 52    width: NotRequired[float | int | Sequence[float | int]]
 53    # --- options for bar annotations
 54    annotate: NotRequired[bool]
 55    fontsize: NotRequired[int | float | str]
 56    fontname: NotRequired[str]
 57    rounding: NotRequired[int]
 58    rotation: NotRequired[int | float]
 59    annotate_color: NotRequired[str]
 60    above: NotRequired[bool]
 61
 62
 63# --- functions
 64class AnnoKwargs(TypedDict, total=False):
 65    """TypedDict for the kwargs used in annotate_bars."""
 66
 67    annotate: bool
 68    fontsize: int | float | str
 69    fontname: str
 70    color: str
 71    rotation: int | float
 72    foreground: str  # used for stroke effect on text
 73    above: bool
 74    rounding: bool | int  # if True, uses default rounding; if int, uses that value
 75
 76
 77def annotate_bars(
 78    series: Series,
 79    offset: float,
 80    base: np.ndarray,
 81    axes: Axes,
 82    **anno_kwargs: Unpack[AnnoKwargs],
 83) -> None:
 84    """Bar plot annotations.
 85
 86    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 87    """
 88    # --- only annotate in limited circumstances
 89    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 90        return
 91    max_annotations = MAX_ANNOTATIONS
 92    if len(series) > max_annotations:
 93        return
 94
 95    # --- internal logic check
 96    if len(base) != len(series):
 97        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 98        return
 99
100    # --- assemble the annotation parameters
101    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
102    annotate_style: dict[str, Any] = {
103        "fontsize": anno_kwargs.get("fontsize"),
104        "fontname": anno_kwargs.get("fontname"),
105        "color": anno_kwargs.get("color"),
106        "rotation": anno_kwargs.get("rotation"),
107    }
108    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
109    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
110    zero_correction = series.index.min()
111
112    # --- annotate each bar
113    for index, value in zip(series.index.astype(int), series, strict=True):
114        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
115        if above:
116            position += value
117        text = axes.text(
118            x=index + offset,
119            y=position,
120            s=f"{value:.{rounding}f}",
121            ha="center",
122            va="bottom" if value >= 0 else "top",
123            **annotate_style,
124        )
125        if not above and "foreground" in anno_kwargs:
126            # apply a stroke-effect to within bar annotations
127            # to make them more readable with very small bars.
128            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
129
130
131class GroupedKwargs(TypedDict):
132    """TypedDict for the kwargs used in grouped."""
133
134    color: Sequence[str]
135    width: Sequence[float | int]
136    label_series: Sequence[bool]
137
138
139def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
140    """Plot a grouped bar plot."""
141    series_count = len(df.columns)
142
143    for i, col in enumerate(df.columns):
144        series = df[col]
145        if series.isna().all():
146            continue
147        width = kwargs["width"][i]
148        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
149            width = DEFAULT_GROUPED_WIDTH
150        adjusted_width = width / series_count
151        # far-left + margin + halfway through one grouped column
152        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
153        offset = left + (i * adjusted_width)
154        foreground = kwargs["color"][i]
155        axes.bar(
156            x=series.index + offset,
157            height=series,
158            color=foreground,
159            width=adjusted_width,
160            label=col if kwargs["label_series"][i] else f"_{col}_",
161        )
162        anno_args["foreground"] = foreground
163        annotate_bars(
164            series=series,
165            offset=offset,
166            base=np.zeros(len(series)),
167            axes=axes,
168            **anno_args,
169        )
170
171
172class StackedKwargs(TypedDict):
173    """TypedDict for the kwargs used in stacked."""
174
175    color: Sequence[str]
176    width: Sequence[float | int]
177    label_series: Sequence[bool]
178
179
180def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
181    """Plot a stacked bar plot."""
182    row_count = len(df)
183    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
184    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
185    for i, col in enumerate(df.columns):
186        series = df[col]
187        base = np.where(series >= 0, base_plus, base_minus)
188        foreground = kwargs["color"][i]
189        axes.bar(
190            x=series.index,
191            height=series,
192            bottom=base,
193            color=foreground,
194            width=kwargs["width"][i],
195            label=col if kwargs["label_series"][i] else f"_{col}_",
196        )
197        anno_args["foreground"] = foreground
198        annotate_bars(
199            series=series,
200            offset=0,
201            base=base,
202            axes=axes,
203            **anno_args,
204        )
205        base_plus += np.where(series >= 0, series, 0)
206        base_minus += np.where(series < 0, series, 0)
207
208
209def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
210    """Create a bar plot from the given data.
211
212    Each column in the DataFrame will be stacked on top of each other,
213    with positive values above zero and negative values below zero.
214
215    Args:
216        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
217        **kwargs: BarKwargs - Additional keyword arguments for customization.
218        (see BarKwargs for details)
219
220    Note: This function does not assume all data is timeseries with a PeriodIndex.
221
222    Returns:
223        axes: Axes - The axes for the plot.
224
225    """
226    # --- check the kwargs
227    report_kwargs(caller=ME, **kwargs)
228    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
229
230    # --- get the data
231    # no call to check_clean_timeseries here, as bar plots are not
232    # necessarily timeseries data. If the data is a Series, it will be
233    # converted to a DataFrame with a single column.
234    df = DataFrame(data)  # really we are only plotting DataFrames
235    df, kwargs_d = constrain_data(df, **kwargs)
236    item_count = len(df.columns)
237
238    # --- deal with complete PeriodIndex indices
239    saved_pi = map_periodindex(df)
240    if saved_pi is not None:
241        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
242
243    # --- set up the default arguments
244    chart_defaults: dict[str, bool | int] = {
245        "stacked": False,
246        "max_ticks": DEFAULT_MAX_TICKS,
247        "label_series": item_count > 1,
248        "xlabel_rotation": 0,
249    }
250    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
251
252    bar_defaults = {
253        "color": get_color_list(item_count),
254        "width": get_setting("bar_width"),
255        "label_series": item_count > 1,
256    }
257    above = kwargs_d.get("above", False)
258    anno_args: AnnoKwargs = {
259        "annotate": kwargs_d.get("annotate", False),
260        "fontsize": kwargs_d.get("fontsize", "small"),
261        "fontname": kwargs_d.get("fontname", "Helvetica"),
262        "rotation": kwargs_d.get("rotation", 0),
263        "rounding": kwargs_d.get("rounding", True),
264        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
265        "above": above,
266    }
267    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
268
269    # --- plot the data
270    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
271    if chart_args["stacked"]:
272        stacked(axes, df, anno_args, **bar_args)
273    else:
274        grouped(axes, df, anno_args, **bar_args)
275
276    # --- handle complete periodIndex data and label rotation
277    if saved_pi is not None:
278        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
279    plt.xticks(rotation=chart_args["xlabel_rotation"])
280
281    return axes
ME: Final[str] = 'bar_plot'
MAX_ANNOTATIONS: Final[int] = 30
ADJUSTMENT_FACTOR: Final[float] = 0.02
MIN_BAR_WIDTH: Final[float] = 0.0
MAX_BAR_WIDTH: Final[float] = 1.0
DEFAULT_GROUPED_WIDTH: Final[float] = 0.8
DEFAULT_BAR_OFFSET: Final[float] = 0.5
DEFAULT_MAX_TICKS: Final[int] = 10
class BarKwargs(mgplot.keyword_checking.BaseKwargs):
41class BarKwargs(BaseKwargs):
42    """Keyword arguments for the bar_plot function."""
43
44    # --- options for the entire bar plot
45    ax: NotRequired[Axes | None]
46    stacked: NotRequired[bool]
47    max_ticks: NotRequired[int]
48    plot_from: NotRequired[int | Period]
49    label_rotation: NotRequired[int | float]
50    # --- options for each bar ...
51    color: NotRequired[str | Sequence[str]]
52    label_series: NotRequired[bool | Sequence[bool]]
53    width: NotRequired[float | int | Sequence[float | int]]
54    # --- options for bar annotations
55    annotate: NotRequired[bool]
56    fontsize: NotRequired[int | float | str]
57    fontname: NotRequired[str]
58    rounding: NotRequired[int]
59    rotation: NotRequired[int | float]
60    annotate_color: NotRequired[str]
61    above: NotRequired[bool]

Keyword arguments for the bar_plot function.

ax: NotRequired[matplotlib.axes._axes.Axes | None]
stacked: NotRequired[bool]
max_ticks: NotRequired[int]
plot_from: NotRequired[int | pandas._libs.tslibs.period.Period]
label_rotation: NotRequired[int | float]
color: NotRequired[str | Sequence[str]]
label_series: NotRequired[bool | Sequence[bool]]
width: NotRequired[float | int | Sequence[float | int]]
annotate: NotRequired[bool]
fontsize: NotRequired[int | float | str]
fontname: NotRequired[str]
rounding: NotRequired[int]
rotation: NotRequired[int | float]
annotate_color: NotRequired[str]
above: NotRequired[bool]
class AnnoKwargs(typing.TypedDict):
65class AnnoKwargs(TypedDict, total=False):
66    """TypedDict for the kwargs used in annotate_bars."""
67
68    annotate: bool
69    fontsize: int | float | str
70    fontname: str
71    color: str
72    rotation: int | float
73    foreground: str  # used for stroke effect on text
74    above: bool
75    rounding: bool | int  # if True, uses default rounding; if int, uses that value

TypedDict for the kwargs used in annotate_bars.

annotate: bool
fontsize: int | float | str
fontname: str
color: str
rotation: int | float
foreground: str
above: bool
rounding: bool | int
def annotate_bars( series: pandas.core.series.Series, offset: float, base: numpy.ndarray, axes: matplotlib.axes._axes.Axes, **anno_kwargs: Unpack[AnnoKwargs]) -> None:
 78def annotate_bars(
 79    series: Series,
 80    offset: float,
 81    base: np.ndarray,
 82    axes: Axes,
 83    **anno_kwargs: Unpack[AnnoKwargs],
 84) -> None:
 85    """Bar plot annotations.
 86
 87    Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
 88    """
 89    # --- only annotate in limited circumstances
 90    if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]:
 91        return
 92    max_annotations = MAX_ANNOTATIONS
 93    if len(series) > max_annotations:
 94        return
 95
 96    # --- internal logic check
 97    if len(base) != len(series):
 98        print(f"Warning: base array length {len(base)} does not match series length {len(series)}.")
 99        return
100
101    # --- assemble the annotation parameters
102    above: Final[bool | None] = anno_kwargs.get("above", False)  # None is also False-ish
103    annotate_style: dict[str, Any] = {
104        "fontsize": anno_kwargs.get("fontsize"),
105        "fontname": anno_kwargs.get("fontname"),
106        "color": anno_kwargs.get("color"),
107        "rotation": anno_kwargs.get("rotation"),
108    }
109    rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding"))
110    adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR
111    zero_correction = series.index.min()
112
113    # --- annotate each bar
114    for index, value in zip(series.index.astype(int), series, strict=True):
115        position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment)
116        if above:
117            position += value
118        text = axes.text(
119            x=index + offset,
120            y=position,
121            s=f"{value:.{rounding}f}",
122            ha="center",
123            va="bottom" if value >= 0 else "top",
124            **annotate_style,
125        )
126        if not above and "foreground" in anno_kwargs:
127            # apply a stroke-effect to within bar annotations
128            # to make them more readable with very small bars.
129            text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])

Bar plot annotations.

Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.

class GroupedKwargs(typing.TypedDict):
132class GroupedKwargs(TypedDict):
133    """TypedDict for the kwargs used in grouped."""
134
135    color: Sequence[str]
136    width: Sequence[float | int]
137    label_series: Sequence[bool]

TypedDict for the kwargs used in grouped.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
def grouped( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
140def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None:
141    """Plot a grouped bar plot."""
142    series_count = len(df.columns)
143
144    for i, col in enumerate(df.columns):
145        series = df[col]
146        if series.isna().all():
147            continue
148        width = kwargs["width"][i]
149        if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH:
150            width = DEFAULT_GROUPED_WIDTH
151        adjusted_width = width / series_count
152        # far-left + margin + halfway through one grouped column
153        left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0)
154        offset = left + (i * adjusted_width)
155        foreground = kwargs["color"][i]
156        axes.bar(
157            x=series.index + offset,
158            height=series,
159            color=foreground,
160            width=adjusted_width,
161            label=col if kwargs["label_series"][i] else f"_{col}_",
162        )
163        anno_args["foreground"] = foreground
164        annotate_bars(
165            series=series,
166            offset=offset,
167            base=np.zeros(len(series)),
168            axes=axes,
169            **anno_args,
170        )

Plot a grouped bar plot.

class StackedKwargs(typing.TypedDict):
173class StackedKwargs(TypedDict):
174    """TypedDict for the kwargs used in stacked."""
175
176    color: Sequence[str]
177    width: Sequence[float | int]
178    label_series: Sequence[bool]

TypedDict for the kwargs used in stacked.

color: Sequence[str]
width: Sequence[float | int]
label_series: Sequence[bool]
def stacked( axes: matplotlib.axes._axes.Axes, df: pandas.core.frame.DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
181def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None:
182    """Plot a stacked bar plot."""
183    row_count = len(df)
184    base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
185    base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64)
186    for i, col in enumerate(df.columns):
187        series = df[col]
188        base = np.where(series >= 0, base_plus, base_minus)
189        foreground = kwargs["color"][i]
190        axes.bar(
191            x=series.index,
192            height=series,
193            bottom=base,
194            color=foreground,
195            width=kwargs["width"][i],
196            label=col if kwargs["label_series"][i] else f"_{col}_",
197        )
198        anno_args["foreground"] = foreground
199        annotate_bars(
200            series=series,
201            offset=0,
202            base=base,
203            axes=axes,
204            **anno_args,
205        )
206        base_plus += np.where(series >= 0, series, 0)
207        base_minus += np.where(series < 0, series, 0)

Plot a stacked bar plot.

def bar_plot( data: ~DataT, **kwargs: Unpack[BarKwargs]) -> matplotlib.axes._axes.Axes:
210def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes:
211    """Create a bar plot from the given data.
212
213    Each column in the DataFrame will be stacked on top of each other,
214    with positive values above zero and negative values below zero.
215
216    Args:
217        data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series.
218        **kwargs: BarKwargs - Additional keyword arguments for customization.
219        (see BarKwargs for details)
220
221    Note: This function does not assume all data is timeseries with a PeriodIndex.
222
223    Returns:
224        axes: Axes - The axes for the plot.
225
226    """
227    # --- check the kwargs
228    report_kwargs(caller=ME, **kwargs)
229    validate_kwargs(schema=BarKwargs, caller=ME, **kwargs)
230
231    # --- get the data
232    # no call to check_clean_timeseries here, as bar plots are not
233    # necessarily timeseries data. If the data is a Series, it will be
234    # converted to a DataFrame with a single column.
235    df = DataFrame(data)  # really we are only plotting DataFrames
236    df, kwargs_d = constrain_data(df, **kwargs)
237    item_count = len(df.columns)
238
239    # --- deal with complete PeriodIndex indices
240    saved_pi = map_periodindex(df)
241    if saved_pi is not None:
242        df = saved_pi[0]  # extract the reindexed DataFrame from the PeriodIndex
243
244    # --- set up the default arguments
245    chart_defaults: dict[str, bool | int] = {
246        "stacked": False,
247        "max_ticks": DEFAULT_MAX_TICKS,
248        "label_series": item_count > 1,
249        "xlabel_rotation": 0,
250    }
251    chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()}
252
253    bar_defaults = {
254        "color": get_color_list(item_count),
255        "width": get_setting("bar_width"),
256        "label_series": item_count > 1,
257    }
258    above = kwargs_d.get("above", False)
259    anno_args: AnnoKwargs = {
260        "annotate": kwargs_d.get("annotate", False),
261        "fontsize": kwargs_d.get("fontsize", "small"),
262        "fontname": kwargs_d.get("fontname", "Helvetica"),
263        "rotation": kwargs_d.get("rotation", 0),
264        "rounding": kwargs_d.get("rounding", True),
265        "color": kwargs_d.get("annotate_color", "black" if above else "white"),
266        "above": above,
267    }
268    bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d)
269
270    # --- plot the data
271    axes, remaining_kwargs = get_axes(**dict(remaining_kwargs))
272    if chart_args["stacked"]:
273        stacked(axes, df, anno_args, **bar_args)
274    else:
275        grouped(axes, df, anno_args, **bar_args)
276
277    # --- handle complete periodIndex data and label rotation
278    if saved_pi is not None:
279        set_labels(axes, saved_pi[1], chart_args["max_ticks"])
280    plt.xticks(rotation=chart_args["xlabel_rotation"])
281
282    return axes

Create a bar plot from the given data.

Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.

Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)

Note: This function does not assume all data is timeseries with a PeriodIndex.

Returns: axes: Axes - The axes for the plot.