
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/applications/plot_over_sampling_benchmark_lfw.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_applications_plot_over_sampling_benchmark_lfw.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_applications_plot_over_sampling_benchmark_lfw.py:


==========================================================
Benchmark over-sampling methods in a face recognition task
==========================================================

In this face recognition example two faces are used from the LFW
(Faces in the Wild) dataset. Several implemented over-sampling
methods are used in conjunction with a 3NN classifier in order
to examine the improvement of the classifier's output quality
by using an over-sampler.

.. GENERATED FROM PYTHON SOURCE LINES 12-17

.. code-block:: Python


    # Authors: Christos Aridas
    #          Guillaume Lemaitre <g.lemaitre58@gmail.com>
    # License: MIT








.. GENERATED FROM PYTHON SOURCE LINES 18-24

.. code-block:: Python

    print(__doc__)

    import seaborn as sns

    sns.set_context("poster")








.. GENERATED FROM PYTHON SOURCE LINES 25-31

Load the dataset
----------------

We will use a dataset containing image from know person where we will
build a model to recognize the person on the image. We will make this problem
a binary problem by taking picture of only George W. Bush and Bill Clinton.

.. GENERATED FROM PYTHON SOURCE LINES 33-42

.. code-block:: Python

    import numpy as np
    from sklearn.datasets import fetch_lfw_people

    data = fetch_lfw_people()
    george_bush_id = 1871  # Photos of George W. Bush
    bill_clinton_id = 531  # Photos of Bill Clinton
    classes = [george_bush_id, bill_clinton_id]
    classes_name = np.array(["B. Clinton", "G.W. Bush"], dtype=object)








.. GENERATED FROM PYTHON SOURCE LINES 43-48

.. code-block:: Python

    mask_photos = np.isin(data.target, classes)
    X, y = data.data[mask_photos], data.target[mask_photos]
    y = (y == george_bush_id).astype(np.int8)
    y = classes_name[y]








.. GENERATED FROM PYTHON SOURCE LINES 49-50

We can check the ratio between the two classes.

.. GENERATED FROM PYTHON SOURCE LINES 52-62

.. code-block:: Python

    import matplotlib.pyplot as plt
    import pandas as pd

    class_distribution = pd.Series(y).value_counts(normalize=True)
    ax = class_distribution.plot.barh()
    ax.set_title("Class distribution")
    pos_label = class_distribution.idxmin()
    plt.tight_layout()
    print(f"The positive label considered as the minority class is {pos_label}")




.. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_001.png
   :alt: Class distribution
   :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_001.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    The positive label considered as the minority class is B. Clinton




.. GENERATED FROM PYTHON SOURCE LINES 63-75

We see that we have an imbalanced classification problem with ~95% of the
data belonging to the class G.W. Bush.

Compare over-sampling approaches
--------------------------------

We will use different over-sampling approaches and use a kNN classifier
to check if we can recognize the 2 presidents. The evaluation will be
performed through cross-validation and we will plot the mean ROC curve
using `skore.evaluate`.

We will create different pipelines and evaluate them.

.. GENERATED FROM PYTHON SOURCE LINES 77-96

.. code-block:: Python

    import skore
    from sklearn.model_selection import StratifiedKFold
    from sklearn.neighbors import KNeighborsClassifier

    from imblearn import FunctionSampler
    from imblearn.over_sampling import ADASYN, SMOTE, RandomOverSampler
    from imblearn.pipeline import make_pipeline

    classifier = KNeighborsClassifier(n_neighbors=3)

    pipelines = {
        "No resampling": make_pipeline(FunctionSampler(), classifier),
        "Random Over-Sampler": make_pipeline(
            RandomOverSampler(random_state=42), classifier
        ),
        "ADASYN": make_pipeline(ADASYN(random_state=42), classifier),
        "SMOTE": make_pipeline(SMOTE(random_state=42), classifier),
    }








.. GENERATED FROM PYTHON SOURCE LINES 97-100

We use `skore.evaluate` to evaluate each pipeline using a
:class:`~sklearn.model_selection.StratifiedKFold` cross-validation and
compare their performance.

.. GENERATED FROM PYTHON SOURCE LINES 102-108

.. code-block:: Python

    cv = StratifiedKFold(n_splits=3)

    reports = {}
    for name, model in pipelines.items():
        reports[name] = skore.evaluate(model, X, y, splitter=cv, pos_label=pos_label)








.. GENERATED FROM PYTHON SOURCE LINES 109-114

.. code-block:: Python

    import pandas as pd

    results = {name: r.metrics.summarize().frame() for name, r in reports.items()}
    pd.concat(results)






.. raw:: html

    <div class="output_subarea output_html rendered_html output_result">
    <div>
    <style scoped>
        .dataframe tbody tr th:only-of-type {
            vertical-align: middle;
        }

        .dataframe tbody tr th {
            vertical-align: top;
        }

        .dataframe thead tr th {
            text-align: left;
        }

        .dataframe thead tr:last-of-type th {
            text-align: right;
        }
    </style>
    <table border="1" class="dataframe">
      <thead>
        <tr>
          <th></th>
          <th></th>
          <th colspan="2" halign="left">KNeighborsClassifier</th>
        </tr>
        <tr>
          <th></th>
          <th></th>
          <th>mean</th>
          <th>std</th>
        </tr>
        <tr>
          <th></th>
          <th>Metric</th>
          <th></th>
          <th></th>
        </tr>
      </thead>
      <tbody>
        <tr>
          <th rowspan="8" valign="top">No resampling</th>
          <th>Accuracy</th>
          <td>0.949926</td>
          <td>0.011066</td>
        </tr>
        <tr>
          <th>Precision</th>
          <td>0.488889</td>
          <td>0.269430</td>
        </tr>
        <tr>
          <th>Recall</th>
          <td>0.203704</td>
          <td>0.170088</td>
        </tr>
        <tr>
          <th>ROC AUC</th>
          <td>0.695430</td>
          <td>0.102237</td>
        </tr>
        <tr>
          <th>Log loss</th>
          <td>30.684276</td>
          <td>0.465169</td>
        </tr>
        <tr>
          <th>Brier score</th>
          <td>0.048489</td>
          <td>0.007454</td>
        </tr>
        <tr>
          <th>Fit time (s)</th>
          <td>0.001147</td>
          <td>0.000171</td>
        </tr>
        <tr>
          <th>Predict time (s)</th>
          <td>0.020389</td>
          <td>0.001263</td>
        </tr>
        <tr>
          <th rowspan="8" valign="top">Random Over-Sampler</th>
          <th>Accuracy</th>
          <td>0.905152</td>
          <td>0.017512</td>
        </tr>
        <tr>
          <th>Precision</th>
          <td>0.271384</td>
          <td>0.083465</td>
        </tr>
        <tr>
          <th>Recall</th>
          <td>0.477778</td>
          <td>0.195316</td>
        </tr>
        <tr>
          <th>ROC AUC</th>
          <td>0.700296</td>
          <td>0.094037</td>
        </tr>
        <tr>
          <th>Log loss</th>
          <td>31.329503</td>
          <td>0.281080</td>
        </tr>
        <tr>
          <th>Brier score</th>
          <td>0.077141</td>
          <td>0.009174</td>
        </tr>
        <tr>
          <th>Fit time (s)</th>
          <td>0.003459</td>
          <td>0.000389</td>
        </tr>
        <tr>
          <th>Predict time (s)</th>
          <td>0.017630</td>
          <td>0.001382</td>
        </tr>
        <tr>
          <th rowspan="8" valign="top">ADASYN</th>
          <th>Accuracy</th>
          <td>0.695839</td>
          <td>0.043894</td>
        </tr>
        <tr>
          <th>Precision</th>
          <td>0.121987</td>
          <td>0.027832</td>
        </tr>
        <tr>
          <th>Recall</th>
          <td>0.785185</td>
          <td>0.198865</td>
        </tr>
        <tr>
          <th>ROC AUC</th>
          <td>0.806093</td>
          <td>0.085753</td>
        </tr>
        <tr>
          <th>Log loss</th>
          <td>20.463838</td>
          <td>1.912165</td>
        </tr>
        <tr>
          <th>Brier score</th>
          <td>0.220684</td>
          <td>0.035568</td>
        </tr>
        <tr>
          <th>Fit time (s)</th>
          <td>0.042784</td>
          <td>0.024095</td>
        </tr>
        <tr>
          <th>Predict time (s)</th>
          <td>0.016614</td>
          <td>0.001087</td>
        </tr>
        <tr>
          <th rowspan="8" valign="top">SMOTE</th>
          <th>Accuracy</th>
          <td>0.713760</td>
          <td>0.033024</td>
        </tr>
        <tr>
          <th>Precision</th>
          <td>0.116084</td>
          <td>0.012309</td>
        </tr>
        <tr>
          <th>Recall</th>
          <td>0.685185</td>
          <td>0.122894</td>
        </tr>
        <tr>
          <th>ROC AUC</th>
          <td>0.800637</td>
          <td>0.080561</td>
        </tr>
        <tr>
          <th>Log loss</th>
          <td>20.476704</td>
          <td>2.230097</td>
        </tr>
        <tr>
          <th>Brier score</th>
          <td>0.212727</td>
          <td>0.027566</td>
        </tr>
        <tr>
          <th>Fit time (s)</th>
          <td>0.015151</td>
          <td>0.005906</td>
        </tr>
        <tr>
          <th>Predict time (s)</th>
          <td>0.016362</td>
          <td>0.001455</td>
        </tr>
      </tbody>
    </table>
    </div>
    </div>
    <br />
    <br />

.. GENERATED FROM PYTHON SOURCE LINES 115-116

We can also plot the ROC curves for each pipeline.

.. GENERATED FROM PYTHON SOURCE LINES 118-123

.. code-block:: Python

    fig, ax = plt.subplots(figsize=(9, 9))
    for name, report in reports.items():
        report.metrics.roc().plot()
    plt.show()




.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_002.png
         :alt: plot over sampling benchmark lfw
         :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_002.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_003.png
         :alt: ROC Curve for KNeighborsClassifier Positive label: B. Clinton Data source: Test set
         :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_003.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_004.png
         :alt: ROC Curve for KNeighborsClassifier Positive label: B. Clinton Data source: Test set
         :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_004.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_005.png
         :alt: ROC Curve for KNeighborsClassifier Positive label: B. Clinton Data source: Test set
         :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_005.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_006.png
         :alt: ROC Curve for KNeighborsClassifier Positive label: B. Clinton Data source: Test set
         :srcset: /auto_examples/applications/images/sphx_glr_plot_over_sampling_benchmark_lfw_006.png
         :class: sphx-glr-multi-img





.. GENERATED FROM PYTHON SOURCE LINES 124-127

We see that for this task, methods that are generating new samples with some
interpolation (i.e. ADASYN and SMOTE) perform better than random
over-sampling or no resampling.


.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 23.170 seconds)

**Estimated memory usage:**  794 MB


.. _sphx_glr_download_auto_examples_applications_plot_over_sampling_benchmark_lfw.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_over_sampling_benchmark_lfw.ipynb <plot_over_sampling_benchmark_lfw.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_over_sampling_benchmark_lfw.py <plot_over_sampling_benchmark_lfw.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_over_sampling_benchmark_lfw.zip <plot_over_sampling_benchmark_lfw.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
