Metadata-Version: 2.4
Name: adhs
Version: 0.1.4
Summary: Adaptive Hierarchical Shrinkage
Author: Markus Loecher, Bastian Pfeifer
Author-email: Arne Gevaert <arne.gevaert@ugent.be>
License-Expression: MIT
Project-URL: Homepage, https://github.com/arnegevaert/adaptive-hierarchical-shrinkage
Classifier: Development Status :: 3 - Alpha
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3.13
Requires-Python: >=3.13
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: numpy<2.0.0,>=1.23.0
Requires-Dist: scipy>=1.10.0
Requires-Dist: pandas>=2.3.3
Requires-Dist: scikit-learn>=1.2.1
Provides-Extra: dev
Requires-Dist: bumpver; extra == "dev"
Provides-Extra: experiments
Requires-Dist: matplotlib; extra == "experiments"
Requires-Dist: seaborn; extra == "experiments"
Requires-Dist: tqdm; extra == "experiments"
Requires-Dist: ipykernel; extra == "experiments"
Requires-Dist: imodels; extra == "experiments"
Requires-Dist: pmlb; extra == "experiments"
Requires-Dist: shap>=0.42.0; extra == "experiments"
Requires-Dist: adhs; extra == "experiments"
Dynamic: license-file

# Scikit-Learn-compatible implementation of Adaptive Hierarchical Shrinkage
This directory contains an implementation of Adaptive Hierarchical Shrinkage that is compatible with Scikit-Learn. It exports 2 classes:
- `ShrinkageClassifier`
- `ShrinkageRegressor`

## Installation
### `adhs` Package
The `adhs` package, which contains the implementations of
Adaptive Hierarchical Shrinkage, can be installed from [PyPI](https://pypi.org/project/adhs/):

```bash
$ pip install adhs
```

Alternatively, the package can be installed by cloning:
```
$ git clone git@github.com:arnegevaert/adaptive-hierarchical-shrinkage.git
$ cd adaptive-hierarchical-shrinkage
$ pip install .
```
Note that this will install the version corresponding to the latest commit on the
master branch, which may or may not be stable.

### Experiments
To be able to run the scripts in the `experiments` directory, some extra
requirements are needed. These can be installed in a new conda
environment as follows:
```
$ git clone git@github.com:arnegevaert/adaptive-hierarchical-shrinkage.git
$ cd adaptive-hierarchical-shrinkage
$ conda create -n shrinkage python=3.10
$ conda activate shrinkage
$ pip install .[experiments]
```

For more info on reproducing the experiments, see [experiments/README.md](experiments/README.md).

## Basic API
This package exports 2 classes and 1 method:
- `ShrinkageClassifier`
- `ShrinkageRegressor`
- `cross_val_shrinkage`

### `ShrinkageClassifier` and `ShrinkageRegressor`
Both classes inherit from `ShrinkageEstimator`, which extends `sklearn.base.BaseEstimator`. Adaptive hierarchical shrinkage can be summarized as follows:

```math
\hat{f}(\mathbf{x}) = \mathbb{E}_{t_0}[y] + \sum_{l=1}^L\frac{\mathbb{E}_{t_l}[y] - \mathbb{E}_{t_{l-1}}[y]}{1 + \frac{g(t_{l-1})}{N(t_{l-1})}}
```

where $g(t_{l-1})$ is some function of the node $t_{l-1}$. Classical hierarchical shrinkage (Agarwal et al. 2022) corresponds to $g(t_{l-1}) = \lambda$, where $\lambda$ is a chosen constant.

- `__init__()` parameters:
    - `base_estimator`: the estimator around which we "wrap" hierarchical shrinkage. This should be a tree-based estimator: `DecisionTreeClassifier`, `RandomForestClassifier`, ... (analogous for `Regressor`s)
    - `lmb`: $\lambda$ hyperparameter
    - `random_state`: random state for reproducibility
    - `shrink_mode`: 6 options:
        - `"no_shrinkage"`: dummy value. This setting will not influence the `base_estimator` in any way, and is equivalent to just using the `base_estimator` by itself. Added for easy comparison between different modes of shrinkage and no shrinkage at all.
        - `"hs"`: classical Hierarchical Shrinkage (from Agarwal et al. 2022): $g(t_{l-1}) = \lambda$.
        - `"hs_entropy"`: Adaptive Hierarchical Shrinkage with added entropy term: $g(t_{l-1}) = \lambda H(t_{l-1})$.
        - `"hs_log_cardinality"`: Adaptive Hierarchical Shrinkage with log of cardinality term: $g(t_{l-1}) = \lambda \log C(t_{l-1})$ where $C(t)$ is the number of unique values in $t$.
        - `"hs_global_permutation"`: Same as `"hs_permutation"`, but the data is permuted only once for the full dataset rather than once in each node.
        - `"hs_permutation"`: Adaptive Hierarchical Shrinkage with:
```math
\begin{aligned}
g(t_{l-1}) &= \frac{1}{\alpha(t_{l-1})}\\
\alpha(t_{l-1}) &= 1 - \frac{\Delta_\mathcal{I}(t_{l-1}, { }_\pi x(t_{l-1})) + \epsilon}{\Delta_\mathcal{I}(t_{l-1}, x(t_{l-1}))+ \epsilon}
\end{aligned}
```
- `reshrink(shrink_mode, lmb, X)`: changes the shrinkage mode and/or lambda value in the shrinkage process. Calling `reshrink` with a given value of `shrink_mode` and/or `lmb` on an existing model is equivalent to fitting a new model with the same base estimator but the new, given values for `shrink_mode` and/or `lmb`. This method can avoid redundant computations in the shrinkage process, so can be more efficient than re-fitting a new `ShrinkageClassifier` or `ShrinkageRegressor`.
- Other functions: `fit(X, y)`, `predict(X)`, `predict_proba(X)`, `score(X, y)` work just like with any other `sklearn` estimator.

### `cross_val_shrinkage`
This method can be used to efficiently run cross-validation for the `shrink_mode` and/or `lmb` hyperparameters. As adaptive hierarchical shrinkage is a fully post-hoc procedure, cross-validation requires no retraining of the base model. This function exploits this property.

## Tutorials

- [General usage](notebooks/tutorial_general_usage.ipynb): Shows how to apply
hierarchical shrinkage on a simple dataset and access feature importances.
- [Cross-validating shrinkage parameters](notebooks/tutorial_shrinkage_cf.ipynb):
Hyperparameters for (augmented) hierarchical shrinkage (i.e. `shrink_mode` and
`lmb`) can be tuned using cross-validation, without having to retrain the
underlying model. This is because (augmented) hierarchical shrinkage is a
**fully post-hoc** procedure. As the `ShrinkageClassifier` and
`ShrinkageRegressor` are valid scikit-learn estimators, you could simply tune
these hyperparameters using [`GridSearchCV`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html) as you would do with any other scikit-learn
model. However, this **will** retrain the decision tree or random forest, which
leads to unnecessary performance loss. This notebook shows how you can use our
cross-validation function to cross-validate `shrink_mode` and `lmb` without
this performance loss.

## Experiments
To reproduce the experiments from the paper, see [experiments/README.md](experiments/README.md).
