Unit Test for calculate_diversity_index

[1]:
import unittest
import numpy as np
import pandas as pd
import anndata as ad
from unittest.mock import patch
import os
os.sys.path.append('../../../')

from mesa import ecospatial as eco
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/geopandas/_compat.py:106: UserWarning: The Shapely GEOS version (3.8.0-CAPI-1.13.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.
  warnings.warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/spaghetti/network.py:41: FutureWarning: The next major release of pysal/spaghetti (2.0.0) will drop support for all ``libpysal.cg`` geometries. This change is a first step in refactoring ``spaghetti`` that is expected to result in dramatically reduced runtimes for network instantiation and operations. Users currently requiring network and point pattern input as ``libpysal.cg`` geometries should prepare for this simply by converting to ``shapely`` geometries.
  warnings.warn(dep_msg, FutureWarning, stacklevel=1)
[2]:
class TestCalculateDiversityIndex(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        # Create sample data for testing, shared among all tests
        print("Setting up test data...")
        # For AnnData
        obs = pd.DataFrame({
            'library_key': ['sample_1', 'sample_1', 'sample_2', 'sample_2', 'sample_1'],
            'cluster_key': ['A', 'B', 'A', 'B', 'A']
        }, index=['cell1', 'cell2', 'cell3', 'cell4', 'cell5'])
        obsm = {'spatial_key': np.array([[0, 0], [1, 1], [2, 2], [3, 3], [0.5, 0.5]])}
        cls.adata = ad.AnnData(obs=obs)
        cls.adata.obsm = obsm

        # For DataFrame
        cls.df = pd.DataFrame({
            'library_key': ['sample_1', 'sample_1', 'sample_2', 'sample_2', 'sample_1'],
            'cluster_key': ['A', 'B', 'A', 'B', 'A'],
            'x_coord': [0, 1, 2, 3, 0.5],
            'y_coord': [0, 1, 2, 3, 0.5]
        }, index=['cell1', 'cell2', 'cell3', 'cell4', 'cell5'])

        # Patches: [(x0, y0, x1, y1), ...]
        cls.patches = [(0, 0, 1, 1), (2, 2, 3, 3), (0, 0, 0.6, 0.6)]

    def test_calculate_diversity_index_with_adata(self):
        # Test with AnnData input
        result = eco.calculate_diversity_index(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            spatial_key='spatial_key',
            patches=self.patches,
            cluster_key='cluster_key',
            metric='Shannon Diversity'
        )

        # Assert that result is a pandas Series
        self.assertIsInstance(result, pd.Series)
        # Expected indices
        expected_entropy_patch0 = eco.calculate_shannon_entropy([2, 1])
        self.assertAlmostEqual(result[0], expected_entropy_patch0)
        self.assertNotIn(1, result.index)
        expected_entropy_patch2 = eco.calculate_shannon_entropy([2])
        self.assertAlmostEqual(result[2], expected_entropy_patch2)
        self.assertEqual(len(result), 2)

    def test_calculate_diversity_index_with_dataframe(self):
        # Test with DataFrame input
        result = eco.calculate_diversity_index(
            spatial_data=self.df,
            library_key='library_key',
            library_id='sample_1',
            spatial_key=['x_coord', 'y_coord'],
            patches=self.patches,
            cluster_key='cluster_key',
            metric='Shannon Diversity'
        )
        self.assertIsInstance(result, pd.Series)
        expected_entropy_patch0 = eco.calculate_shannon_entropy([2, 1])
        self.assertAlmostEqual(result[0], expected_entropy_patch0)
        expected_entropy_patch2 = eco.calculate_shannon_entropy([2])
        self.assertAlmostEqual(result[2], expected_entropy_patch2)
        self.assertEqual(len(result), 2)

    def test_invalid_metric(self):
        # Test with invalid metric
        with self.assertRaises(ValueError) as context:
            eco.calculate_diversity_index(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                spatial_key='spatial_key',
                patches=self.patches,
                cluster_key='cluster_key',
                metric='Invalid Metric'
            )
        self.assertIn("Unknown metric", str(context.exception))

    def test_invalid_spatial_data(self):
        # Test with invalid spatial_data type
        with self.assertRaises(ValueError) as context:
            eco.calculate_diversity_index(
                spatial_data='invalid_data',
                library_key='library_key',
                library_id='sample_1',
                spatial_key='spatial_key',
                patches=self.patches,
                cluster_key='cluster_key',
                metric='Shannon Diversity'
            )
        self.assertIn("spatial_data should be either an AnnData object or a pandas DataFrame", str(context.exception))

    def test_missing_cluster_key(self):
        # Test with missing cluster_key in obs
        adata_missing_cluster = self.adata.copy()
        adata_missing_cluster.obs.drop('cluster_key', axis=1, inplace=True)
        with self.assertRaises(ValueError) as context:
            eco.calculate_diversity_index(
                spatial_data=adata_missing_cluster,
                library_key='library_key',
                library_id='sample_1',
                spatial_key='spatial_key',
                patches=self.patches,
                cluster_key='cluster_key',
                metric='Shannon Diversity'
            )
        self.assertIn("cluster_key 'cluster_key' not found", str(context.exception))

    def test_empty_patches(self):
        # Test patches that are empty
        empty_patches = [(10, 10, 11, 11)]
        result = eco.calculate_diversity_index(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            spatial_key='spatial_key',
            patches=empty_patches,
            cluster_key='cluster_key',
            metric='Shannon Diversity'
        )
        # Result should be empty
        self.assertEqual(len(result), 0)

    def test_return_comp_true(self):
        # Test with return_comp=True
        result_series, patches_comp = eco.calculate_diversity_index(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            spatial_key='spatial_key',
            patches=self.patches,
            cluster_key='cluster_key',
            metric='Shannon Diversity',
            return_comp=True
        )
        self.assertIsInstance(result_series, pd.Series)
        self.assertIsInstance(patches_comp, list)
        self.assertEqual(patches_comp[0].to_dict(), {'A': 2, 'B': 1})
        self.assertIsNone(patches_comp[1])
        self.assertEqual(patches_comp[2].to_dict(), {'A': 2})

    def test_metric_simpson(self):
        # Test with metric 'Simpson'
        result = eco.calculate_diversity_index(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            spatial_key='spatial_key',
            patches=self.patches,
            cluster_key='cluster_key',
            metric='Simpson'
        )
        expected_simpson_patch0 = eco.calculate_simpson_index([2, 1])
        self.assertAlmostEqual(result[0], expected_simpson_patch0)
        expected_simpson_patch2 = eco.calculate_simpson_index([2])
        self.assertAlmostEqual(result[2], expected_simpson_patch2)

    def test_metric_simpson_diversity(self):
        # Test with metric 'Simpson Diversity'
        result = eco.calculate_diversity_index(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            spatial_key='spatial_key',
            patches=self.patches,
            cluster_key='cluster_key',
            metric='Simpson Diversity'
        )
        expected_simpson_diversity_patch0 = eco.calculate_simpsonDiversity_index([2, 1])
        self.assertAlmostEqual(result[0], expected_simpson_diversity_patch0)
        expected_simpson_diversity_patch2 = eco.calculate_simpsonDiversity_index([2])
        self.assertAlmostEqual(result[2], expected_simpson_diversity_patch2)
[3]:
# Run the tests in the notebook
unittest.main(argv=['first-arg-is-ignored'], exit=False)
......./opt/miniconda3/envs/mesa/lib/python3.11/site-packages/anndata/_core/anndata.py:183: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
..
----------------------------------------------------------------------
Ran 9 tests in 0.082s

OK
Setting up test data...
33.333 per cent patches are empty
33.333 per cent patches are empty
100.000 per cent patches are empty
33.333 per cent patches are empty
33.333 per cent patches are empty
33.333 per cent patches are empty
[3]:
<unittest.main.TestProgram at 0x151d4a150>