import warnings
import inspect
import shutil
import locale
import types
from copy import deepcopy
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.testing.compare import compare_images

from plotnine import ggplot, theme


TOLERANCE = 2           # Default tolerance for the tests
DPI = 72                # Default DPI for the tests

# This partial theme modifies all themes that are used in
# the test. It is limited to setting the size of the test
# images Should a test require a larger or smaller figure
# size, the dpi or aspect_ratio should be modified.
test_theme = theme(figure_size=(640/DPI, 480/DPI), dpi=DPI)

tests_dir = Path(__file__).parent
baseline_images_dir = tests_dir / 'baseline_images'
result_images_dir = tests_dir / 'result_images'

if not baseline_images_dir.exists():
    raise OSError(
        "The baseline image directory does not exist. "
        "This is most likely because the test data is not installed. "
        "You may need to install plotnine from source to get the "
        "test data."
    )


def raise_no_baseline_image(filename):
    raise Exception(f"Baseline image {filename} is missing")


def ggplot_equals(gg, name):
    """
    Compare ggplot object to image determined by `right`

    Parameters
    ----------
    gg : ggplot
        ggplot object
    name : str
        Identifier for the test image

    This function is meant to monkey patch ggplot.__eq__
    so that tests can use the `assert` statement.
    """
    test_file = inspect.stack()[1][1]
    filenames = make_test_image_filenames(name, test_file)
    bbox_inches = 'tight' if 'caption' in gg.labels else None
    # Save the figure before testing whether the original image
    # actually exists. This makes creating new tests much easier,
    # as the result image can afterwards just be copied.
    gg += test_theme
    with _test_cleanup():
        gg.save(filenames.result, verbose=False, bbox_inches=bbox_inches)

    if filenames.baseline.exists():
        shutil.copyfile(filenames.baseline, filenames.expected)
    else:
        # Putting the exception in short function makes for
        #  short pytest error messages
        raise_no_baseline_image(filenames.baseline)

    err = compare_images(
        filenames.expected,
        filenames.result,
        TOLERANCE,
        in_decorator=True
    )
    gg._err = err  # For the pytest error message
    return False if err else True


ggplot.__eq__ = ggplot_equals


def draw_test(self):
    """
    Try drawing the ggplot object

    Parameters
    ----------
    self : ggplot
        ggplot object

    This function is meant to monkey patch ggplot.draw_test
    so that tests can draw and not care about cleaning up
    the MPL figure.
    """
    with _test_cleanup():
        self.draw()


ggplot.draw_test = draw_test


def build_test(self):
    """
    Try building the ggplot object

    Parameters
    ----------
    self : ggplot
        ggplot object

    This function is meant to monkey patch ggplot.build_test
    so that tests can build a plot and inspect the side effects
    on the plot object.
    """
    self = deepcopy(self)
    self._build()
    return self


ggplot.build_test = build_test


def pytest_assertrepr_compare(op, left, right):
    if (isinstance(left, ggplot) and
            isinstance(right, str) and
            op == "=="):
        msg = ("images not close: {actual:s} vs. {expected:s} "
               "(RMS {rms:.2f})".format(**left._err))
        return [msg]


def make_test_image_filenames(name, test_file):
    """
    Create filenames for testing

    Parameters
    ----------
    name : str
        An identifier for the specific test. This will make-up
        part of the filenames.
    test_file : str
        Full path of the test file. This will determine the
        directory structure

    Returns
    -------
    out : types.SimpleNamespace
        Object with 3 attributes to store the generated filenames

            - result
            - baseline
            - expected

        `result`, is the filename for the image generated by the test.
        `baseline`, is the filename for the baseline image to which
        the result will be compared.
        `expected`, is the filename to the copy of the baseline that
        will be stored in the same directory as the result image.
        Creating a copy make comparison easier.
    """
    name = Path(name).with_suffix('.png')
    expected_name = f'{name.stem}-expected{name.suffix}'
    subdir = Path(test_file).stem
    filenames = types.SimpleNamespace(
        baseline=baseline_images_dir / subdir / name,
        result=result_images_dir / subdir / name,
        expected=result_images_dir / subdir / expected_name,
    )
    filenames.result.parent.mkdir(parents=True, exist_ok=True)
    return filenames


class _test_cleanup:
    def __enter__(self):
        # The baseline images are created in this locale, so we should use
        # it during all of the tests.
        try:
            locale.setlocale(locale.LC_ALL, 'en_US.UTF-8')
        except locale.Error:
            try:
                locale.setlocale(locale.LC_ALL, 'English_United States.1252')
            except locale.Error:
                warnings.warn(
                    "Could not set locale to English/United States. "
                    "Some date-related tests may fail"
                )

        # make sure we don't carry over bad plots from former tests
        plt.close('all')
        n_figs = len(plt.get_fignums())
        msg = (f"No. of open figs: {n_figs}. Make sure the "
               "figures from the previous tests are cleaned up."
               )
        assert n_figs == 0, msg

        mpl.use('Agg')
        # These settings *must* be hardcoded for running the comparison
        # tests
        mpl.rcdefaults()  # Start with all defaults
        mpl.rcParams['text.hinting'] = 'auto'
        mpl.rcParams['text.antialiased'] = True
        mpl.rcParams['text.hinting_factor'] = 8
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        plt.close('all')
        warnings.resetwarnings()


def layer_data(p, i=0):
    """
    Return layer information used to draw the plot

    Parameters
    ----------
    p : ggplot
        ggplot object
    i : int
        Layer number

    Returns
    -------
    out : dataframe
        Layer information
    """
    p = deepcopy(p)
    p._build()
    return p.layers.data[i]
