import logging

import autoarray as aa
from autoarray.dataset import abstract_dataset
import numpy as np

logger = logging.getLogger(__name__)


class TestProperties:
    def test__inverse_noise_is_one_over_noise(self):
        array = aa.Array2D.manual_native([[1.0, 2.0], [3.0, 4.0]], pixel_scales=1.0)
        noise_map = aa.Array2D.manual_native([[1.0, 2.0], [4.0, 8.0]], pixel_scales=1.0)

        dataset = abstract_dataset.AbstractDataset(data=array, noise_map=noise_map)

        assert (
            dataset.inverse_noise_map.native == np.array([[1.0, 0.5], [0.25, 0.125]])
        ).all()

    def test__signal_to_noise_map__image_and_noise_are_values__signal_to_noise_is_ratio_of_each(
        self,
    ):
        array = aa.Array2D.manual_native([[1.0, 2.0], [3.0, 4.0]], pixel_scales=1.0)
        noise_map = aa.Array2D.manual_native(
            [[10.0, 10.0], [30.0, 4.0]], pixel_scales=1.0
        )

        dataset = abstract_dataset.AbstractDataset(data=array, noise_map=noise_map)

        assert (
            dataset.signal_to_noise_map.native == np.array([[0.1, 0.2], [0.1, 1.0]])
        ).all()
        assert dataset.signal_to_noise_max == 1.0

    def test__signal_to_noise_map__same_as_above__but_image_has_negative_values__replaced_with_zeros(
        self,
    ):
        array = aa.Array2D.manual_native([[-1.0, 2.0], [3.0, -4.0]], pixel_scales=1.0)

        noise_map = aa.Array2D.manual_native(
            [[10.0, 10.0], [30.0, 4.0]], pixel_scales=1.0
        )

        dataset = abstract_dataset.AbstractDataset(data=array, noise_map=noise_map)

        assert (
            dataset.signal_to_noise_map.native == np.array([[0.0, 0.2], [0.1, 0.0]])
        ).all()
        assert dataset.signal_to_noise_max == 0.2

    def test__absolute_signal_to_noise_map__image_and_noise_are_values__signal_to_noise_is_absolute_image_value_over_noise(
        self,
    ):
        array = aa.Array2D.manual_native([[-1.0, 2.0], [3.0, -4.0]], pixel_scales=1.0)

        noise_map = aa.Array2D.manual_native(
            [[10.0, 10.0], [30.0, 4.0]], pixel_scales=1.0
        )

        dataset = abstract_dataset.AbstractDataset(data=array, noise_map=noise_map)

        assert (
            dataset.absolute_signal_to_noise_map.native
            == np.array([[0.1, 0.2], [0.1, 1.0]])
        ).all()
        assert dataset.absolute_signal_to_noise_max == 1.0

    def test__potential_chi_squared_map__image_and_noise_are_values__signal_to_noise_is_absolute_image_value_over_noise(
        self,
    ):
        array = aa.Array2D.manual_native([[-1.0, 2.0], [3.0, -4.0]], pixel_scales=1.0)
        noise_map = aa.Array2D.manual_native(
            [[10.0, 10.0], [30.0, 4.0]], pixel_scales=1.0
        )

        dataset = abstract_dataset.AbstractDataset(data=array, noise_map=noise_map)

        assert (
            dataset.potential_chi_squared_map.native
            == np.array([[0.1 ** 2.0, 0.2 ** 2.0], [0.1 ** 2.0, 1.0 ** 2.0]])
        ).all()
        assert dataset.potential_chi_squared_max == 1.0


class TestMethods:
    def test__new_imaging_with_arrays_trimmed_via_kernel_shape(self):
        data = aa.Array2D.full(fill_value=20.0, shape_native=(3, 3), pixel_scales=1.0)
        data[4] = 5.0

        noise_map_array = aa.Array2D.full(
            fill_value=5.0, shape_native=(3, 3), pixel_scales=1.0
        )
        noise_map_array[4] = 2.0

        dataset = abstract_dataset.AbstractDataset(data=data, noise_map=noise_map_array)

        dataset_trimmed = dataset.trimmed_after_convolution_from(kernel_shape=(3, 3))

        assert (dataset_trimmed.data.native == np.array([[5.0]])).all()

        assert (dataset_trimmed.noise_map.native == np.array([[2.0]])).all()


class TestAbstractMaskedDatasetTags:
    def test__grids__sub_size_tags(self):

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DIterate, sub_size=1
        )
        assert settings.grid_sub_size_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D, sub_size=1
        )
        assert settings.grid_sub_size_tag == "sub_1"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D, sub_size=2
        )
        assert settings.grid_sub_size_tag == "sub_2"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D, sub_size=4
        )
        assert settings.grid_sub_size_tag == "sub_4"

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DIterate, sub_size=1
        )
        assert settings.grid_inversion_sub_size_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2D, sub_size=1
        )
        assert settings.grid_inversion_sub_size_tag == "sub_1"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2D, sub_size=2
        )
        assert settings.grid_inversion_sub_size_tag == "sub_2"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2D, sub_size=4
        )
        assert settings.grid_inversion_sub_size_tag == "sub_4"

    def test__grids__fractional_accuracy_tags(self):

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D, fractional_accuracy=1
        )
        assert settings.grid_fractional_accuracy_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DIterate, fractional_accuracy=0.5
        )
        assert settings.grid_fractional_accuracy_tag == "facc_0.5"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DIterate, fractional_accuracy=0.71
        )
        assert settings.grid_fractional_accuracy_tag == "facc_0.71"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DIterate, fractional_accuracy=0.999999
        )
        assert settings.grid_fractional_accuracy_tag == "facc_0.999999"

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2D, fractional_accuracy=1
        )
        assert settings.grid_inversion_fractional_accuracy_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DIterate, fractional_accuracy=0.5
        )
        assert settings.grid_inversion_fractional_accuracy_tag == "facc_0.5"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DIterate, fractional_accuracy=0.71
        )
        assert settings.grid_inversion_fractional_accuracy_tag == "facc_0.71"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DIterate, fractional_accuracy=0.999999
        )
        assert settings.grid_inversion_fractional_accuracy_tag == "facc_0.999999"

    def test__grid__pixel_scales_interp_tag(self):

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D, pixel_scales_interp=0.1
        )
        assert settings.grid_pixel_scales_interp_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DInterpolate, pixel_scales_interp=None
        )
        assert settings.grid_pixel_scales_interp_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DInterpolate, pixel_scales_interp=0.5
        )
        assert settings.grid_pixel_scales_interp_tag == "interp_0.500"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DInterpolate, pixel_scales_interp=0.25
        )
        assert settings.grid_pixel_scales_interp_tag == "interp_0.250"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DInterpolate, pixel_scales_interp=0.234
        )
        assert settings.grid_pixel_scales_interp_tag == "interp_0.234"

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2D, pixel_scales_interp=0.1
        )
        assert settings.grid_inversion_pixel_scales_interp_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DInterpolate, pixel_scales_interp=None
        )
        assert settings.grid_inversion_pixel_scales_interp_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DInterpolate, pixel_scales_interp=0.5
        )
        assert settings.grid_inversion_pixel_scales_interp_tag == "interp_0.500"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DInterpolate, pixel_scales_interp=0.25
        )
        assert settings.grid_inversion_pixel_scales_interp_tag == "interp_0.250"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_inversion_class=aa.Grid2DInterpolate, pixel_scales_interp=0.234
        )
        assert settings.grid_inversion_pixel_scales_interp_tag == "interp_0.234"

    def test__grid_tags(self):

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2D,
            sub_size=1,
            grid_inversion_class=aa.Grid2DIterate,
            fractional_accuracy=0.5,
        )
        assert settings.grid_tag_no_inversion == "grid_sub_1"
        assert settings.grid_tag_with_inversion == "grid_sub_1_inv_facc_0.5"

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DInterpolate,
            grid_inversion_class=aa.Grid2DInterpolate,
            pixel_scales_interp=0.5,
        )
        assert settings.grid_tag_no_inversion == "grid_interp_0.500"
        assert settings.grid_tag_with_inversion == "grid_interp_0.500_inv_interp_0.500"

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            grid_class=aa.Grid2DIterate,
            fractional_accuracy=0.8,
            grid_inversion_class=aa.Grid2D,
            sub_size=2,
        )
        assert settings.grid_tag_no_inversion == "grid_facc_0.8"
        assert settings.grid_tag_with_inversion == "grid_facc_0.8_inv_sub_2"

    def test__signal_to_noise_limit_tag(self):

        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            signal_to_noise_limit=None
        )
        assert settings.signal_to_noise_limit_tag == ""
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            signal_to_noise_limit=1
        )
        assert settings.signal_to_noise_limit_tag == "__snr_1"
        settings = abstract_dataset.AbstractSettingsMaskedDataset(
            signal_to_noise_limit=2
        )
        assert settings.signal_to_noise_limit_tag == "__snr_2"


class TestAbstractMaskedData:
    def test__grid(
        self,
        imaging_7x7,
        sub_mask_7x7,
        grid_7x7,
        sub_grid_7x7,
        blurring_grid_7x7,
        grid_iterate_7x7,
    ):
        masked_imaging_7x7 = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=sub_mask_7x7,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(
                grid_class=aa.Grid2D
            ),
        )

        assert isinstance(masked_imaging_7x7.grid, aa.Grid2D)
        assert (masked_imaging_7x7.grid.slim_binned == grid_7x7).all()
        assert (masked_imaging_7x7.grid.slim == sub_grid_7x7).all()

        masked_imaging_7x7 = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=sub_mask_7x7,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(
                grid_class=aa.Grid2DIterate
            ),
        )

        assert isinstance(masked_imaging_7x7.grid, aa.Grid2DIterate)
        assert (masked_imaging_7x7.grid.slim_binned == grid_iterate_7x7).all()

        masked_imaging_7x7 = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=sub_mask_7x7,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(
                grid_class=aa.Grid2DInterpolate, pixel_scales_interp=1.0
            ),
        )

        grid = aa.Grid2DInterpolate.from_mask(
            mask=sub_mask_7x7, pixel_scales_interp=1.0
        )

        assert isinstance(masked_imaging_7x7.grid, aa.Grid2DInterpolate)
        assert (masked_imaging_7x7.grid == grid).all()
        assert (masked_imaging_7x7.grid.grid_interp == grid.grid_interp).all()
        assert (masked_imaging_7x7.grid.vtx == grid.vtx).all()
        assert (masked_imaging_7x7.grid.wts == grid.wts).all()

    def test__grid_inversion(
        self, imaging_7x7, sub_mask_7x7, grid_7x7, sub_grid_7x7, blurring_grid_7x7
    ):
        masked_imaging_7x7 = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=sub_mask_7x7,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(
                grid_inversion_class=aa.Grid2D
            ),
        )

        assert isinstance(masked_imaging_7x7.grid_inversion, aa.Grid2D)
        assert (masked_imaging_7x7.grid_inversion.slim_binned == grid_7x7).all()
        assert (masked_imaging_7x7.grid_inversion.slim == sub_grid_7x7).all()

        masked_imaging_7x7 = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=sub_mask_7x7,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(
                grid_inversion_class=aa.Grid2DInterpolate, pixel_scales_interp=1.0
            ),
        )

        grid = aa.Grid2DInterpolate.from_mask(
            mask=sub_mask_7x7, pixel_scales_interp=1.0
        )

        assert isinstance(masked_imaging_7x7.grid_inversion, aa.Grid2DInterpolate)
        assert (masked_imaging_7x7.grid_inversion == grid).all()
        assert (masked_imaging_7x7.grid_inversion.vtx == grid.vtx).all()
        assert (masked_imaging_7x7.grid_inversion.wts == grid.wts).all()

    def test__mask_changes_sub_size_using_settings(self, imaging_7x7):
        # If an input mask is supplied we use mask input.

        mask_input = aa.Mask2D.circular(
            shape_native=imaging_7x7.shape_native,
            pixel_scales=1,
            sub_size=1,
            radius=1.5,
        )

        masked_dataset = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=mask_input,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(sub_size=1),
        )

        assert (masked_dataset.mask == mask_input).all()
        assert masked_dataset.mask.sub_size == 1
        assert masked_dataset.mask.pixel_scales == mask_input.pixel_scales

        masked_dataset = abstract_dataset.AbstractMaskedDataset(
            dataset=imaging_7x7,
            mask=mask_input,
            settings=abstract_dataset.AbstractSettingsMaskedDataset(sub_size=2),
        )

        assert (masked_dataset.mask == mask_input).all()
        assert masked_dataset.mask.sub_size == 2
        assert masked_dataset.mask.pixel_scales == mask_input.pixel_scales
