mgplot.bar_plot
Create bar plots using Matplotlib.
Note: bar plots in Matplotlib are not the same as bar charts in other libraries. Bar plots are used to represent categorical data with rectangular bars. As a result, bar plots and line plots typically cannot be plotted on the same axes.
1"""Create bar plots using Matplotlib. 2 3Note: bar plots in Matplotlib are not the same as bar charts in other 4libraries. Bar plots are used to represent categorical data with 5rectangular bars. As a result, bar plots and line plots typically 6cannot be plotted on the same axes. 7""" 8 9from collections.abc import Sequence 10from typing import Any, Final, NotRequired, TypedDict, Unpack 11 12import matplotlib.patheffects as pe 13import matplotlib.pyplot as plt 14import numpy as np 15from matplotlib.axes import Axes 16from pandas import DataFrame, Period, Series 17 18from mgplot.axis_utils import map_periodindex, set_labels 19from mgplot.keyword_checking import BaseKwargs, report_kwargs, validate_kwargs 20from mgplot.settings import DataT, get_setting 21from mgplot.utilities import ( 22 apply_defaults, 23 constrain_data, 24 default_rounding, 25 get_axes, 26 get_color_list, 27) 28 29# --- constants 30ME: Final[str] = "bar_plot" 31MAX_ANNOTATIONS: Final[int] = 30 32ADJUSTMENT_FACTOR: Final[float] = 0.02 33MIN_BAR_WIDTH: Final[float] = 0.0 34MAX_BAR_WIDTH: Final[float] = 1.0 35DEFAULT_GROUPED_WIDTH: Final[float] = 0.8 36DEFAULT_BAR_OFFSET: Final[float] = 0.5 37DEFAULT_MAX_TICKS: Final[int] = 10 38 39 40class BarKwargs(BaseKwargs): 41 """Keyword arguments for the bar_plot function.""" 42 43 # --- options for the entire bar plot 44 ax: NotRequired[Axes | None] 45 stacked: NotRequired[bool] 46 max_ticks: NotRequired[int] 47 plot_from: NotRequired[int | Period] 48 label_rotation: NotRequired[int | float] 49 # --- options for each bar ... 50 color: NotRequired[str | Sequence[str]] 51 label_series: NotRequired[bool | Sequence[bool]] 52 width: NotRequired[float | int | Sequence[float | int]] 53 # --- options for bar annotations 54 annotate: NotRequired[bool] 55 fontsize: NotRequired[int | float | str] 56 fontname: NotRequired[str] 57 rounding: NotRequired[int] 58 rotation: NotRequired[int | float] 59 annotate_color: NotRequired[str] 60 above: NotRequired[bool] 61 62 63# --- functions 64class AnnoKwargs(TypedDict, total=False): 65 """TypedDict for the kwargs used in annotate_bars.""" 66 67 annotate: bool 68 fontsize: int | float | str 69 fontname: str 70 color: str 71 rotation: int | float 72 foreground: str # used for stroke effect on text 73 above: bool 74 rounding: bool | int # if True, uses default rounding; if int, uses that value 75 76 77def annotate_bars( 78 series: Series, 79 offset: float, 80 base: np.ndarray, 81 axes: Axes, 82 **anno_kwargs: Unpack[AnnoKwargs], 83) -> None: 84 """Bar plot annotations. 85 86 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 87 """ 88 # --- only annotate in limited circumstances 89 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 90 return 91 max_annotations = MAX_ANNOTATIONS 92 if len(series) > max_annotations: 93 return 94 95 # --- internal logic check 96 if len(base) != len(series): 97 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 98 return 99 100 # --- assemble the annotation parameters 101 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 102 annotate_style: dict[str, Any] = { 103 "fontsize": anno_kwargs.get("fontsize"), 104 "fontname": anno_kwargs.get("fontname"), 105 "color": anno_kwargs.get("color"), 106 "rotation": anno_kwargs.get("rotation"), 107 } 108 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 109 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 110 zero_correction = series.index.min() 111 112 # --- annotate each bar 113 for index, value in zip(series.index.astype(int), series, strict=True): 114 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 115 if above: 116 position += value 117 text = axes.text( 118 x=index + offset, 119 y=position, 120 s=f"{value:.{rounding}f}", 121 ha="center", 122 va="bottom" if value >= 0 else "top", 123 **annotate_style, 124 ) 125 if not above and "foreground" in anno_kwargs: 126 # apply a stroke-effect to within bar annotations 127 # to make them more readable with very small bars. 128 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))]) 129 130 131class GroupedKwargs(TypedDict): 132 """TypedDict for the kwargs used in grouped.""" 133 134 color: Sequence[str] 135 width: Sequence[float | int] 136 label_series: Sequence[bool] 137 138 139def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 140 """Plot a grouped bar plot.""" 141 series_count = len(df.columns) 142 143 for i, col in enumerate(df.columns): 144 series = df[col] 145 if series.isna().all(): 146 continue 147 width = kwargs["width"][i] 148 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 149 width = DEFAULT_GROUPED_WIDTH 150 adjusted_width = width / series_count 151 # far-left + margin + halfway through one grouped column 152 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 153 offset = left + (i * adjusted_width) 154 foreground = kwargs["color"][i] 155 axes.bar( 156 x=series.index + offset, 157 height=series, 158 color=foreground, 159 width=adjusted_width, 160 label=col if kwargs["label_series"][i] else f"_{col}_", 161 ) 162 anno_args["foreground"] = foreground 163 annotate_bars( 164 series=series, 165 offset=offset, 166 base=np.zeros(len(series)), 167 axes=axes, 168 **anno_args, 169 ) 170 171 172class StackedKwargs(TypedDict): 173 """TypedDict for the kwargs used in stacked.""" 174 175 color: Sequence[str] 176 width: Sequence[float | int] 177 label_series: Sequence[bool] 178 179 180def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 181 """Plot a stacked bar plot.""" 182 row_count = len(df) 183 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 184 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 185 for i, col in enumerate(df.columns): 186 series = df[col] 187 base = np.where(series >= 0, base_plus, base_minus) 188 foreground = kwargs["color"][i] 189 axes.bar( 190 x=series.index, 191 height=series, 192 bottom=base, 193 color=foreground, 194 width=kwargs["width"][i], 195 label=col if kwargs["label_series"][i] else f"_{col}_", 196 ) 197 anno_args["foreground"] = foreground 198 annotate_bars( 199 series=series, 200 offset=0, 201 base=base, 202 axes=axes, 203 **anno_args, 204 ) 205 base_plus += np.where(series >= 0, series, 0) 206 base_minus += np.where(series < 0, series, 0) 207 208 209def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 210 """Create a bar plot from the given data. 211 212 Each column in the DataFrame will be stacked on top of each other, 213 with positive values above zero and negative values below zero. 214 215 Args: 216 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 217 **kwargs: BarKwargs - Additional keyword arguments for customization. 218 (see BarKwargs for details) 219 220 Note: This function does not assume all data is timeseries with a PeriodIndex. 221 222 Returns: 223 axes: Axes - The axes for the plot. 224 225 """ 226 # --- check the kwargs 227 report_kwargs(caller=ME, **kwargs) 228 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 229 230 # --- get the data 231 # no call to check_clean_timeseries here, as bar plots are not 232 # necessarily timeseries data. If the data is a Series, it will be 233 # converted to a DataFrame with a single column. 234 df = DataFrame(data) # really we are only plotting DataFrames 235 df, kwargs_d = constrain_data(df, **kwargs) 236 item_count = len(df.columns) 237 238 # --- deal with complete PeriodIndex indices 239 saved_pi = map_periodindex(df) 240 if saved_pi is not None: 241 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 242 243 # --- set up the default arguments 244 chart_defaults: dict[str, bool | int] = { 245 "stacked": False, 246 "max_ticks": DEFAULT_MAX_TICKS, 247 "label_series": item_count > 1, 248 "xlabel_rotation": 0, 249 } 250 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 251 252 bar_defaults = { 253 "color": get_color_list(item_count), 254 "width": get_setting("bar_width"), 255 "label_series": item_count > 1, 256 } 257 above = kwargs_d.get("above", False) 258 anno_args: AnnoKwargs = { 259 "annotate": kwargs_d.get("annotate", False), 260 "fontsize": kwargs_d.get("fontsize", "small"), 261 "fontname": kwargs_d.get("fontname", "Helvetica"), 262 "rotation": kwargs_d.get("rotation", 0), 263 "rounding": kwargs_d.get("rounding", True), 264 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 265 "above": above, 266 } 267 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 268 269 # --- plot the data 270 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 271 if chart_args["stacked"]: 272 stacked(axes, df, anno_args, **bar_args) 273 else: 274 grouped(axes, df, anno_args, **bar_args) 275 276 # --- handle complete periodIndex data and label rotation 277 if saved_pi is not None: 278 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 279 plt.xticks(rotation=chart_args["xlabel_rotation"]) 280 281 return axes
41class BarKwargs(BaseKwargs): 42 """Keyword arguments for the bar_plot function.""" 43 44 # --- options for the entire bar plot 45 ax: NotRequired[Axes | None] 46 stacked: NotRequired[bool] 47 max_ticks: NotRequired[int] 48 plot_from: NotRequired[int | Period] 49 label_rotation: NotRequired[int | float] 50 # --- options for each bar ... 51 color: NotRequired[str | Sequence[str]] 52 label_series: NotRequired[bool | Sequence[bool]] 53 width: NotRequired[float | int | Sequence[float | int]] 54 # --- options for bar annotations 55 annotate: NotRequired[bool] 56 fontsize: NotRequired[int | float | str] 57 fontname: NotRequired[str] 58 rounding: NotRequired[int] 59 rotation: NotRequired[int | float] 60 annotate_color: NotRequired[str] 61 above: NotRequired[bool]
Keyword arguments for the bar_plot function.
65class AnnoKwargs(TypedDict, total=False): 66 """TypedDict for the kwargs used in annotate_bars.""" 67 68 annotate: bool 69 fontsize: int | float | str 70 fontname: str 71 color: str 72 rotation: int | float 73 foreground: str # used for stroke effect on text 74 above: bool 75 rounding: bool | int # if True, uses default rounding; if int, uses that value
TypedDict for the kwargs used in annotate_bars.
78def annotate_bars( 79 series: Series, 80 offset: float, 81 base: np.ndarray, 82 axes: Axes, 83 **anno_kwargs: Unpack[AnnoKwargs], 84) -> None: 85 """Bar plot annotations. 86 87 Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs. 88 """ 89 # --- only annotate in limited circumstances 90 if "annotate" not in anno_kwargs or not anno_kwargs["annotate"]: 91 return 92 max_annotations = MAX_ANNOTATIONS 93 if len(series) > max_annotations: 94 return 95 96 # --- internal logic check 97 if len(base) != len(series): 98 print(f"Warning: base array length {len(base)} does not match series length {len(series)}.") 99 return 100 101 # --- assemble the annotation parameters 102 above: Final[bool | None] = anno_kwargs.get("above", False) # None is also False-ish 103 annotate_style: dict[str, Any] = { 104 "fontsize": anno_kwargs.get("fontsize"), 105 "fontname": anno_kwargs.get("fontname"), 106 "color": anno_kwargs.get("color"), 107 "rotation": anno_kwargs.get("rotation"), 108 } 109 rounding = default_rounding(series=series, provided=anno_kwargs.get("rounding")) 110 adjustment = (series.max() - series.min()) * ADJUSTMENT_FACTOR 111 zero_correction = series.index.min() 112 113 # --- annotate each bar 114 for index, value in zip(series.index.astype(int), series, strict=True): 115 position = base[index - zero_correction] + (adjustment if value >= 0 else -adjustment) 116 if above: 117 position += value 118 text = axes.text( 119 x=index + offset, 120 y=position, 121 s=f"{value:.{rounding}f}", 122 ha="center", 123 va="bottom" if value >= 0 else "top", 124 **annotate_style, 125 ) 126 if not above and "foreground" in anno_kwargs: 127 # apply a stroke-effect to within bar annotations 128 # to make them more readable with very small bars. 129 text.set_path_effects([pe.withStroke(linewidth=2, foreground=anno_kwargs.get("foreground"))])
Bar plot annotations.
Note: "annotate", "fontsize", "fontname", "color", and "rotation" are expected in anno_kwargs.
132class GroupedKwargs(TypedDict): 133 """TypedDict for the kwargs used in grouped.""" 134 135 color: Sequence[str] 136 width: Sequence[float | int] 137 label_series: Sequence[bool]
TypedDict for the kwargs used in grouped.
140def grouped(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[GroupedKwargs]) -> None: 141 """Plot a grouped bar plot.""" 142 series_count = len(df.columns) 143 144 for i, col in enumerate(df.columns): 145 series = df[col] 146 if series.isna().all(): 147 continue 148 width = kwargs["width"][i] 149 if width < MIN_BAR_WIDTH or width > MAX_BAR_WIDTH: 150 width = DEFAULT_GROUPED_WIDTH 151 adjusted_width = width / series_count 152 # far-left + margin + halfway through one grouped column 153 left = -DEFAULT_BAR_OFFSET + ((1 - width) / 2.0) + (adjusted_width / 2.0) 154 offset = left + (i * adjusted_width) 155 foreground = kwargs["color"][i] 156 axes.bar( 157 x=series.index + offset, 158 height=series, 159 color=foreground, 160 width=adjusted_width, 161 label=col if kwargs["label_series"][i] else f"_{col}_", 162 ) 163 anno_args["foreground"] = foreground 164 annotate_bars( 165 series=series, 166 offset=offset, 167 base=np.zeros(len(series)), 168 axes=axes, 169 **anno_args, 170 )
Plot a grouped bar plot.
173class StackedKwargs(TypedDict): 174 """TypedDict for the kwargs used in stacked.""" 175 176 color: Sequence[str] 177 width: Sequence[float | int] 178 label_series: Sequence[bool]
TypedDict for the kwargs used in stacked.
181def stacked(axes: Axes, df: DataFrame, anno_args: AnnoKwargs, **kwargs: Unpack[StackedKwargs]) -> None: 182 """Plot a stacked bar plot.""" 183 row_count = len(df) 184 base_plus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 185 base_minus: np.ndarray = np.zeros(shape=row_count, dtype=np.float64) 186 for i, col in enumerate(df.columns): 187 series = df[col] 188 base = np.where(series >= 0, base_plus, base_minus) 189 foreground = kwargs["color"][i] 190 axes.bar( 191 x=series.index, 192 height=series, 193 bottom=base, 194 color=foreground, 195 width=kwargs["width"][i], 196 label=col if kwargs["label_series"][i] else f"_{col}_", 197 ) 198 anno_args["foreground"] = foreground 199 annotate_bars( 200 series=series, 201 offset=0, 202 base=base, 203 axes=axes, 204 **anno_args, 205 ) 206 base_plus += np.where(series >= 0, series, 0) 207 base_minus += np.where(series < 0, series, 0)
Plot a stacked bar plot.
210def bar_plot(data: DataT, **kwargs: Unpack[BarKwargs]) -> Axes: 211 """Create a bar plot from the given data. 212 213 Each column in the DataFrame will be stacked on top of each other, 214 with positive values above zero and negative values below zero. 215 216 Args: 217 data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. 218 **kwargs: BarKwargs - Additional keyword arguments for customization. 219 (see BarKwargs for details) 220 221 Note: This function does not assume all data is timeseries with a PeriodIndex. 222 223 Returns: 224 axes: Axes - The axes for the plot. 225 226 """ 227 # --- check the kwargs 228 report_kwargs(caller=ME, **kwargs) 229 validate_kwargs(schema=BarKwargs, caller=ME, **kwargs) 230 231 # --- get the data 232 # no call to check_clean_timeseries here, as bar plots are not 233 # necessarily timeseries data. If the data is a Series, it will be 234 # converted to a DataFrame with a single column. 235 df = DataFrame(data) # really we are only plotting DataFrames 236 df, kwargs_d = constrain_data(df, **kwargs) 237 item_count = len(df.columns) 238 239 # --- deal with complete PeriodIndex indices 240 saved_pi = map_periodindex(df) 241 if saved_pi is not None: 242 df = saved_pi[0] # extract the reindexed DataFrame from the PeriodIndex 243 244 # --- set up the default arguments 245 chart_defaults: dict[str, bool | int] = { 246 "stacked": False, 247 "max_ticks": DEFAULT_MAX_TICKS, 248 "label_series": item_count > 1, 249 "xlabel_rotation": 0, 250 } 251 chart_args = {k: kwargs_d.get(k, v) for k, v in chart_defaults.items()} 252 253 bar_defaults = { 254 "color": get_color_list(item_count), 255 "width": get_setting("bar_width"), 256 "label_series": item_count > 1, 257 } 258 above = kwargs_d.get("above", False) 259 anno_args: AnnoKwargs = { 260 "annotate": kwargs_d.get("annotate", False), 261 "fontsize": kwargs_d.get("fontsize", "small"), 262 "fontname": kwargs_d.get("fontname", "Helvetica"), 263 "rotation": kwargs_d.get("rotation", 0), 264 "rounding": kwargs_d.get("rounding", True), 265 "color": kwargs_d.get("annotate_color", "black" if above else "white"), 266 "above": above, 267 } 268 bar_args, remaining_kwargs = apply_defaults(item_count, bar_defaults, kwargs_d) 269 270 # --- plot the data 271 axes, remaining_kwargs = get_axes(**dict(remaining_kwargs)) 272 if chart_args["stacked"]: 273 stacked(axes, df, anno_args, **bar_args) 274 else: 275 grouped(axes, df, anno_args, **bar_args) 276 277 # --- handle complete periodIndex data and label rotation 278 if saved_pi is not None: 279 set_labels(axes, saved_pi[1], chart_args["max_ticks"]) 280 plt.xticks(rotation=chart_args["xlabel_rotation"]) 281 282 return axes
Create a bar plot from the given data.
Each column in the DataFrame will be stacked on top of each other, with positive values above zero and negative values below zero.
Args: data: Series | DataFrame - The data to plot. Can be a DataFrame or a Series. **kwargs: BarKwargs - Additional keyword arguments for customization. (see BarKwargs for details)
Note: This function does not assume all data is timeseries with a PeriodIndex.
Returns: axes: Axes - The axes for the plot.