mgplot.postcovid_plot

Plot the linear pre-COVID trajectory against the current data.

  1"""Plot the linear pre-COVID trajectory against the current data."""
  2
  3from typing import Literal, NotRequired, Unpack, cast
  4
  5from matplotlib.axes import Axes
  6from numpy import array, polyfit
  7from pandas import DataFrame, Period, PeriodIndex, Series, period_range
  8
  9from mgplot.keyword_checking import (
 10    report_kwargs,
 11    validate_kwargs,
 12)
 13from mgplot.line_plot import LineKwargs, line_plot
 14from mgplot.settings import DataT, get_setting
 15from mgplot.utilities import check_clean_timeseries
 16
 17# --- constants
 18ME = "postcovid_plot"
 19MIN_REGRESSION_POINTS = 10  # minimum number of points for a useful linear regression
 20
 21# Default regression periods by frequency
 22DEFAULT_PERIODS = {
 23    "Q": {"start": "2014Q4", "end": "2019Q4"},
 24    "M": {"start": "2015-01", "end": "2020-01"},
 25    "D": {"start": "2015-01-01", "end": "2020-01-01"},
 26}
 27
 28
 29class PostcovidKwargs(LineKwargs):
 30    """Keyword arguments for the post-COVID plot."""
 31
 32    start_r: NotRequired[Period]  # start of regression period
 33    end_r: NotRequired[Period]  # end of regression period
 34
 35
 36# --- functions
 37def get_projection(source: Series, to_period: Period) -> Series:
 38    """Create a linear projection based on pre-COVID data.
 39
 40    Args:
 41        source: Series - the original series with a PeriodIndex
 42            Assume the index is a PeriodIndex, that is unique and monotonic increasing.
 43            Assume there may be gaps in the source series (either missing or NaNs)
 44            And that it starts from when the regression should start.
 45        to_period: Period - the period to which the projection should extend.
 46
 47    Returns:
 48        Series: A pandas Series with linear projection values using the same index as original.
 49            Returns an empty Series if it fails to create a projection.
 50
 51    Raises:
 52        ValueError: If to_period is not within the original series index range.
 53
 54    """
 55    # --- initial validation
 56    if not isinstance(source.index, PeriodIndex):
 57        raise TypeError("Source index must be a PeriodIndex")
 58    if source.empty or not source.index.is_monotonic_increasing or not source.index.is_unique:
 59        print("Source series must be non-empty, uniquely indexed, and a monotonic increasing index.")
 60        return Series(dtype=float)  # return empty series if validation fails
 61
 62    # --- Drop any missing data and establish the input data for regression
 63    source_no_nan = source.dropna()
 64    input_series = source_no_nan[source_no_nan.index <= to_period]
 65
 66    # --- further validation
 67    if input_series.empty or len(input_series) < MIN_REGRESSION_POINTS:
 68        print("Insufficient data points for regression.")
 69        return Series(dtype=float)  # return empty series if no data for regression
 70
 71    # --- Establish the simple linear regression model
 72    input_index = input_series.index
 73    x_cause = array([p.ordinal for p in input_index if p <= to_period])
 74    y_effect = input_series.to_numpy()
 75    slope, intercept = polyfit(x_cause, y_effect, 1)
 76
 77    # --- use the regression model to create an out-of-sample projection
 78    x_complete = array([p.ordinal for p in source.index])
 79    projection = Series((x_complete * slope) + intercept, index=source.index)
 80
 81    # --- ensure the projection covers any date gaps in the PeriodIndex
 82    source_index = source.index
 83    return projection.reindex(period_range(start=source_index[0], end=source_index[-1])).interpolate(
 84        method="linear"
 85    )
 86
 87
 88def regression_period(data: Series, **kwargs: Unpack[PostcovidKwargs]) -> tuple[Period, Period, bool]:
 89    """Establish the regression period.
 90
 91    Args:
 92        data: Series - the original time series data.
 93        **kwargs: Additional keyword arguments.
 94
 95    Returns:
 96        A tuple containing the start and end periods for regression,
 97        and a boolean indicating if the period is robust.
 98
 99    Raises:
100        TypeError: If the series index is not a PeriodIndex.
101        ValueError: If the series index does not have a D, M, or Q frequency
102
103    """
104    # --- check that the series index is a PeriodIndex with a valid frequency
105    if not isinstance(data.index, PeriodIndex):
106        raise TypeError("The series index must be a PeriodIndex")
107    freq_str = data.index.freqstr
108    freq_key = freq_str[0]
109    if not freq_str or freq_key not in ("Q", "M", "D"):
110        raise ValueError("The series index must have a D, M or Q frequency")
111
112    # --- set the default regression period, use user provided periods if specified
113    default_periods = DEFAULT_PERIODS[freq_key]
114    start_regression = Period(default_periods["start"], freq=freq_str)
115    end_regression = Period(default_periods["end"], freq=freq_str)
116
117    user_start = kwargs.pop("start_r", None)
118    user_end = kwargs.pop("end_r", None)
119    start_r = Period(user_start, freq=freq_str) if user_start else start_regression
120    end_r = Period(user_end, freq=freq_str) if user_end else end_regression
121
122    # --- Validate the regression period
123    robust = True
124    if start_r >= end_r:
125        print(f"Invalid regression period: {start_r=}, {end_r=}")
126        robust = False
127
128    return start_r, end_r, robust
129
130
131def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes:
132    """Plot a series with a PeriodIndex, including a post-COVID projection.
133
134    Args:
135        data: Series - the series to be plotted.
136        kwargs: PostcovidKwargs - plotting arguments.
137
138    Raises:
139        TypeError if series is not a pandas Series
140        TypeError if series does not have a PeriodIndex
141        ValueError if series does not have a D, M or Q frequency
142        ValueError if regression start is after regression end
143
144    """
145
146    # --- failure
147    def failure() -> Axes:
148        print("postcovid_plot(): plotting the raw data only.")
149        remove: list[Literal["plot_from", "start_r", "end_r"]] = ["plot_from", "start_r", "end_r"]
150        for key in remove:
151            kwargs.pop(key, None)
152        return line_plot(
153            data,
154            **cast("LineKwargs", kwargs),
155        )
156
157    # --- check the kwargs
158    report_kwargs(caller=ME, **kwargs)
159    validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs)
160
161    # --- check the data
162    data = check_clean_timeseries(data, ME)
163    if not isinstance(data, Series):
164        raise TypeError("The series argument must be a pandas Series")
165
166    # --- rely on line_plot() to validate kwargs, but remove any that are not relevant
167    if "plot_from" in kwargs:
168        print("Warning: the 'plot_from' argument is ignored in postcovid_plot().")
169        kwargs.pop("plot_from", None)
170
171    # --- set the regression period
172    start_r, end_r, robust = regression_period(data, **kwargs)
173    kwargs.pop("start_r", None)  # remove from kwargs to avoid confusion
174    kwargs.pop("end_r", None)  # remove from kwargs to avoid confusion
175    if not robust:
176        return failure()
177
178    # --- combine data and projection
179    if start_r < data.dropna().index.min():
180        print(f"Caution: Regression start period pre-dates the series index: {start_r=}")
181    recent_data = data[data.index >= start_r].copy()
182    recent_data.name = "Series"
183    projection_data = get_projection(recent_data, end_r)
184    if projection_data.empty:
185        return failure()
186    projection_data.name = "Pre-COVID projection"
187
188    # --- Create DataFrame with proper column alignment
189    combined_data = DataFrame(
190        {
191            projection_data.name: projection_data,
192            recent_data.name: recent_data,
193        }
194    )
195
196    # --- activate plot settings
197    kwargs["width"] = kwargs.pop(
198        "width",
199        (get_setting("line_normal"), get_setting("line_wide")),
200    )  # series line is thicker than projection
201    kwargs["style"] = kwargs.pop("style", ("--", "-"))  # dashed regression line
202    kwargs["label_series"] = kwargs.pop("label_series", True)
203    kwargs["annotate"] = kwargs.pop("annotate", (False, True))  # annotate series only
204    kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000"))
205    kwargs["dropna"] = kwargs.pop("dropna", False)  # drop NaN values
206
207    return line_plot(
208        combined_data,
209        **cast("LineKwargs", kwargs),
210    )
211
212
213if __name__ == "__main__":
214
215    def test_make_projection() -> None:
216        """Test the get_projection function."""
217        n = 30
218        periods = period_range(start="2015-Q1", periods=n, freq="Q")
219        series = Series(
220            [i + (i % 3) for i in range(n)],  # simple increasing series with some noise
221            index=periods,
222        )
223        proj = get_projection(series, Period("2019-Q4", freq="Q"))
224        print(
225            DataFrame(
226                {
227                    "Input": series,
228                    "Projection": proj,
229                }
230            )
231        )
232
233    test_make_projection()
ME = 'postcovid_plot'
MIN_REGRESSION_POINTS = 10
DEFAULT_PERIODS = {'Q': {'start': '2014Q4', 'end': '2019Q4'}, 'M': {'start': '2015-01', 'end': '2020-01'}, 'D': {'start': '2015-01-01', 'end': '2020-01-01'}}
class PostcovidKwargs(mgplot.line_plot.LineKwargs):
30class PostcovidKwargs(LineKwargs):
31    """Keyword arguments for the post-COVID plot."""
32
33    start_r: NotRequired[Period]  # start of regression period
34    end_r: NotRequired[Period]  # end of regression period

Keyword arguments for the post-COVID plot.

start_r: NotRequired[pandas._libs.tslibs.period.Period]
end_r: NotRequired[pandas._libs.tslibs.period.Period]
def get_projection( source: pandas.core.series.Series, to_period: pandas._libs.tslibs.period.Period) -> pandas.core.series.Series:
38def get_projection(source: Series, to_period: Period) -> Series:
39    """Create a linear projection based on pre-COVID data.
40
41    Args:
42        source: Series - the original series with a PeriodIndex
43            Assume the index is a PeriodIndex, that is unique and monotonic increasing.
44            Assume there may be gaps in the source series (either missing or NaNs)
45            And that it starts from when the regression should start.
46        to_period: Period - the period to which the projection should extend.
47
48    Returns:
49        Series: A pandas Series with linear projection values using the same index as original.
50            Returns an empty Series if it fails to create a projection.
51
52    Raises:
53        ValueError: If to_period is not within the original series index range.
54
55    """
56    # --- initial validation
57    if not isinstance(source.index, PeriodIndex):
58        raise TypeError("Source index must be a PeriodIndex")
59    if source.empty or not source.index.is_monotonic_increasing or not source.index.is_unique:
60        print("Source series must be non-empty, uniquely indexed, and a monotonic increasing index.")
61        return Series(dtype=float)  # return empty series if validation fails
62
63    # --- Drop any missing data and establish the input data for regression
64    source_no_nan = source.dropna()
65    input_series = source_no_nan[source_no_nan.index <= to_period]
66
67    # --- further validation
68    if input_series.empty or len(input_series) < MIN_REGRESSION_POINTS:
69        print("Insufficient data points for regression.")
70        return Series(dtype=float)  # return empty series if no data for regression
71
72    # --- Establish the simple linear regression model
73    input_index = input_series.index
74    x_cause = array([p.ordinal for p in input_index if p <= to_period])
75    y_effect = input_series.to_numpy()
76    slope, intercept = polyfit(x_cause, y_effect, 1)
77
78    # --- use the regression model to create an out-of-sample projection
79    x_complete = array([p.ordinal for p in source.index])
80    projection = Series((x_complete * slope) + intercept, index=source.index)
81
82    # --- ensure the projection covers any date gaps in the PeriodIndex
83    source_index = source.index
84    return projection.reindex(period_range(start=source_index[0], end=source_index[-1])).interpolate(
85        method="linear"
86    )

Create a linear projection based on pre-COVID data.

Args: source: Series - the original series with a PeriodIndex Assume the index is a PeriodIndex, that is unique and monotonic increasing. Assume there may be gaps in the source series (either missing or NaNs) And that it starts from when the regression should start. to_period: Period - the period to which the projection should extend.

Returns: Series: A pandas Series with linear projection values using the same index as original. Returns an empty Series if it fails to create a projection.

Raises: ValueError: If to_period is not within the original series index range.

def regression_period( data: pandas.core.series.Series, **kwargs: Unpack[PostcovidKwargs]) -> tuple[pandas._libs.tslibs.period.Period, pandas._libs.tslibs.period.Period, bool]:
 89def regression_period(data: Series, **kwargs: Unpack[PostcovidKwargs]) -> tuple[Period, Period, bool]:
 90    """Establish the regression period.
 91
 92    Args:
 93        data: Series - the original time series data.
 94        **kwargs: Additional keyword arguments.
 95
 96    Returns:
 97        A tuple containing the start and end periods for regression,
 98        and a boolean indicating if the period is robust.
 99
100    Raises:
101        TypeError: If the series index is not a PeriodIndex.
102        ValueError: If the series index does not have a D, M, or Q frequency
103
104    """
105    # --- check that the series index is a PeriodIndex with a valid frequency
106    if not isinstance(data.index, PeriodIndex):
107        raise TypeError("The series index must be a PeriodIndex")
108    freq_str = data.index.freqstr
109    freq_key = freq_str[0]
110    if not freq_str or freq_key not in ("Q", "M", "D"):
111        raise ValueError("The series index must have a D, M or Q frequency")
112
113    # --- set the default regression period, use user provided periods if specified
114    default_periods = DEFAULT_PERIODS[freq_key]
115    start_regression = Period(default_periods["start"], freq=freq_str)
116    end_regression = Period(default_periods["end"], freq=freq_str)
117
118    user_start = kwargs.pop("start_r", None)
119    user_end = kwargs.pop("end_r", None)
120    start_r = Period(user_start, freq=freq_str) if user_start else start_regression
121    end_r = Period(user_end, freq=freq_str) if user_end else end_regression
122
123    # --- Validate the regression period
124    robust = True
125    if start_r >= end_r:
126        print(f"Invalid regression period: {start_r=}, {end_r=}")
127        robust = False
128
129    return start_r, end_r, robust

Establish the regression period.

Args: data: Series - the original time series data. **kwargs: Additional keyword arguments.

Returns: A tuple containing the start and end periods for regression, and a boolean indicating if the period is robust.

Raises: TypeError: If the series index is not a PeriodIndex. ValueError: If the series index does not have a D, M, or Q frequency

def postcovid_plot( data: ~DataT, **kwargs: Unpack[PostcovidKwargs]) -> matplotlib.axes._axes.Axes:
132def postcovid_plot(data: DataT, **kwargs: Unpack[PostcovidKwargs]) -> Axes:
133    """Plot a series with a PeriodIndex, including a post-COVID projection.
134
135    Args:
136        data: Series - the series to be plotted.
137        kwargs: PostcovidKwargs - plotting arguments.
138
139    Raises:
140        TypeError if series is not a pandas Series
141        TypeError if series does not have a PeriodIndex
142        ValueError if series does not have a D, M or Q frequency
143        ValueError if regression start is after regression end
144
145    """
146
147    # --- failure
148    def failure() -> Axes:
149        print("postcovid_plot(): plotting the raw data only.")
150        remove: list[Literal["plot_from", "start_r", "end_r"]] = ["plot_from", "start_r", "end_r"]
151        for key in remove:
152            kwargs.pop(key, None)
153        return line_plot(
154            data,
155            **cast("LineKwargs", kwargs),
156        )
157
158    # --- check the kwargs
159    report_kwargs(caller=ME, **kwargs)
160    validate_kwargs(schema=PostcovidKwargs, caller=ME, **kwargs)
161
162    # --- check the data
163    data = check_clean_timeseries(data, ME)
164    if not isinstance(data, Series):
165        raise TypeError("The series argument must be a pandas Series")
166
167    # --- rely on line_plot() to validate kwargs, but remove any that are not relevant
168    if "plot_from" in kwargs:
169        print("Warning: the 'plot_from' argument is ignored in postcovid_plot().")
170        kwargs.pop("plot_from", None)
171
172    # --- set the regression period
173    start_r, end_r, robust = regression_period(data, **kwargs)
174    kwargs.pop("start_r", None)  # remove from kwargs to avoid confusion
175    kwargs.pop("end_r", None)  # remove from kwargs to avoid confusion
176    if not robust:
177        return failure()
178
179    # --- combine data and projection
180    if start_r < data.dropna().index.min():
181        print(f"Caution: Regression start period pre-dates the series index: {start_r=}")
182    recent_data = data[data.index >= start_r].copy()
183    recent_data.name = "Series"
184    projection_data = get_projection(recent_data, end_r)
185    if projection_data.empty:
186        return failure()
187    projection_data.name = "Pre-COVID projection"
188
189    # --- Create DataFrame with proper column alignment
190    combined_data = DataFrame(
191        {
192            projection_data.name: projection_data,
193            recent_data.name: recent_data,
194        }
195    )
196
197    # --- activate plot settings
198    kwargs["width"] = kwargs.pop(
199        "width",
200        (get_setting("line_normal"), get_setting("line_wide")),
201    )  # series line is thicker than projection
202    kwargs["style"] = kwargs.pop("style", ("--", "-"))  # dashed regression line
203    kwargs["label_series"] = kwargs.pop("label_series", True)
204    kwargs["annotate"] = kwargs.pop("annotate", (False, True))  # annotate series only
205    kwargs["color"] = kwargs.pop("color", ("darkblue", "#dd0000"))
206    kwargs["dropna"] = kwargs.pop("dropna", False)  # drop NaN values
207
208    return line_plot(
209        combined_data,
210        **cast("LineKwargs", kwargs),
211    )

Plot a series with a PeriodIndex, including a post-COVID projection.

Args: data: Series - the series to be plotted. kwargs: PostcovidKwargs - plotting arguments.

Raises: TypeError if series is not a pandas Series TypeError if series does not have a PeriodIndex ValueError if series does not have a D, M or Q frequency ValueError if regression start is after regression end