Unit Test for generate_patches

[1]:
import unittest
import numpy as np
import pandas as pd
import anndata as ad

import os
os.sys.path.append('../../../')
from mesa.ecospatial import generate_patches
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/geopandas/_compat.py:106: UserWarning: The Shapely GEOS version (3.8.0-CAPI-1.13.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.
  warnings.warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/spaghetti/network.py:41: FutureWarning: The next major release of pysal/spaghetti (2.0.0) will drop support for all ``libpysal.cg`` geometries. This change is a first step in refactoring ``spaghetti`` that is expected to result in dramatically reduced runtimes for network instantiation and operations. Users currently requiring network and point pattern input as ``libpysal.cg`` geometries should prepare for this simply by converting to ``shapely`` geometries.
  warnings.warn(dep_msg, FutureWarning, stacklevel=1)
[2]:
class TestGeneratePatches(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # Set up test data with more points
        np.random.seed(42)  # For reproducibility in data generation
        total_points = 2500
        cls.spatial_data = pd.DataFrame({
            'x': 1000 * np.random.rand(total_points),
            'y': 1000 * np.random.rand(total_points),
            'library_key': ['sample_1'] * total_points,
            'cluster_key': np.random.randint(0, 10, size=total_points)
        })
        cls.spatial_coords = cls.spatial_data[['x','y']]

        # Create AnnData object from the DataFrame
        cls.obs = cls.spatial_data[['library_key']]
        cls.obs.index = cls.spatial_data.index.astype(str)
        cls.adata = ad.AnnData(X=np.random.rand(len(cls.spatial_data), 50), obs=cls.obs)
        cls.adata.obsm['spatial'] = cls.spatial_data[['x', 'y']].values

        # Subset DataFrame for 'sample_1'
        cls.df = cls.spatial_data[cls.spatial_data['library_key'] == 'sample_1'].copy()
        cls.df.reset_index(drop=True, inplace=True)


    def test_generate_patches_with_AnnData(self):
        # Test with AnnData input
        patches = generate_patches(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=2,
            spatial_key='spatial'
        )
        expected_patches = 4  # Since scaling_factor=2, expect 2x2=4 patches
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_generate_patches_with_DataFrame(self):
        # Test with DataFrame input
        patches = generate_patches(
            spatial_data=self.df,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=5,
            spatial_key=['x', 'y']
        )
        expected_patches = 25  # Since scaling_factor=5, expect 5x5=25 patches
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_invalid_spatial_data_type(self):
        # Test with invalid spatial_data type
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data='invalid_type',
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_library_key(self):
        # Test with invalid library_key
        with self.assertRaises(KeyError):
            generate_patches(
                spatial_data=self.adata,
                library_key='invalid_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_library_id(self):
        # Test with invalid library_id
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='invalid_id',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_spatial_key(self):
        # Test with invalid spatial_key
        with self.assertRaises(KeyError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='invalid_spatial_key'
            )

    def test_zero_scaling_factor(self):
        # Test with zero scaling_factor
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=0,
                spatial_key='spatial'
            )

    def test_non_numeric_scaling_factor(self):
        # Test with non-numeric scaling_factor
        with self.assertRaises(TypeError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor='two',
                spatial_key='spatial'
            )

    def test_empty_spatial_data_filtered(self):
        # Test when spatial_data_filtered is empty
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=self.adata[self.adata.obs['library_key'] == 'non_existent'],
                library_key='library_key',
                library_id='non_existent',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_output_patch_coordinates(self):
        # Test that patches have correct coordinates
        patches = generate_patches(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=2,
            spatial_key='spatial'
        )
        # Extract min and max spatial values
        min_coords = self.spatial_coords.min(axis=0)
        max_coords = self.spatial_coords.max(axis=0)
        expected_patch_width = (max_coords[0] - min_coords[0]) / 2
        expected_patch_height = (max_coords[1] - min_coords[1]) / 2
        # Check if patches cover the area correctly
        for patch in patches:
            x0, y0, x1, y1 = patch
            self.assertAlmostEqual(x1 - x0, expected_patch_width)
            self.assertAlmostEqual(y1 - y0, expected_patch_height)

[3]:
# Run the tests in the notebook
unittest.main(argv=['first-arg-is-ignored'], exit=False)
......../var/folders/7g/phdhh_ld3dlbnrst0t60bwzr0000gn/T/ipykernel_85830/3363036317.py:141: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  expected_patch_width = (max_coords[0] - min_coords[0]) / 2
/var/folders/7g/phdhh_ld3dlbnrst0t60bwzr0000gn/T/ipykernel_85830/3363036317.py:142: FutureWarning: Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`
  expected_patch_height = (max_coords[1] - min_coords[1]) / 2
..
----------------------------------------------------------------------
Ran 10 tests in 0.055s

OK
[3]:
<unittest.main.TestProgram at 0x158b78950>
[ ]: