Unit Test for generate_patches_randomly

[1]:
import unittest
import numpy as np
import pandas as pd
import anndata as ad

import os
os.sys.path.append('../../../')
from mesa.ecospatial import generate_patches_randomly
import mesa.ecospatial._utils as utils
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/geopandas/_compat.py:106: UserWarning: The Shapely GEOS version (3.8.0-CAPI-1.13.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.
  warnings.warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/spaghetti/network.py:41: FutureWarning: The next major release of pysal/spaghetti (2.0.0) will drop support for all ``libpysal.cg`` geometries. This change is a first step in refactoring ``spaghetti`` that is expected to result in dramatically reduced runtimes for network instantiation and operations. Users currently requiring network and point pattern input as ``libpysal.cg`` geometries should prepare for this simply by converting to ``shapely`` geometries.
  warnings.warn(dep_msg, FutureWarning, stacklevel=1)
[2]:
class TestGeneratePatchesRandomly(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        # Set up test data with more points
        np.random.seed(42)  # For reproducibility in data generation
        total_points = 5000
        cls.spatial_data = pd.DataFrame({
            'x': 1000 * np.random.rand(total_points),
            'y': 1000 * np.random.rand(total_points),
            'library_key': ['sample_1'] * (total_points // 2) + ['sample_2'] * (total_points // 2),
            'cluster_key': np.random.randint(0, 10, size=total_points)
        })

        # Create AnnData object from the DataFrame
        cls.obs = cls.spatial_data[['library_key']]
        cls.obs.index = cls.spatial_data.index.astype(str)
        cls.adata = ad.AnnData(X=np.random.rand(len(cls.spatial_data), 50), obs=cls.obs)
        cls.adata.obsm['spatial'] = cls.spatial_data[['x', 'y']].values

        # Subset DataFrame for 'sample_1'
        cls.df_sample1 = cls.spatial_data[cls.spatial_data['library_key'] == 'sample_1'].copy()
        cls.df_sample1.reset_index(drop=True, inplace=True)

    def test_generate_patches_randomly_with_AnnData(self):
        # Test with AnnData input
        patches = generate_patches_randomly(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=2,
            spatial_key='spatial',
            max_overlap=0.5,
            random_seed=42
        )
        expected_patches = 4
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_generate_patches_randomly_with_DataFrame(self):
        # Test with DataFrame input
        patches = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            max_overlap=0.5,
            random_seed=42
        )
        expected_patches = 16  # Since scaling_factor=5, expect 5x5=25 patches
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_invalid_spatial_data_type(self):
        # Test with invalid spatial_data type
        with self.assertRaises(ValueError):
            generate_patches_randomly(
                spatial_data='invalid_type',
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_invalid_library_key(self):
        # Test with invalid library_key
        with self.assertRaises(KeyError):
            generate_patches_randomly(
                spatial_data=self.adata,
                library_key='invalid_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_invalid_library_id(self):
        # Test with invalid library_id
        with self.assertRaises(ValueError):
            generate_patches_randomly(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='invalid_id',
                scaling_factor=2,
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_invalid_spatial_key(self):
        # Test with invalid spatial_key
        with self.assertRaises(KeyError):
            generate_patches_randomly(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=2,
                spatial_key='invalid_spatial_key',
                max_overlap=0.5,
                random_seed=42
            )

    def test_zero_scaling_factor(self):
        # Test with zero scaling_factor
        with self.assertRaises(ValueError):
            generate_patches_randomly(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor=0,
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_non_numeric_scaling_factor(self):
        # Test with non-numeric scaling_factor
        with self.assertRaises(TypeError):
            generate_patches_randomly(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='sample_1',
                scaling_factor='two',
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_empty_spatial_data_filtered(self):
        # Test when spatial_data_filtered is empty
        with self.assertRaises(ValueError):
            generate_patches_randomly(
                spatial_data=self.adata[self.adata.obs['library_key'] == 'non_existent'],
                library_key='library_key',
                library_id='non_existent',
                scaling_factor=2,
                spatial_key='spatial',
                max_overlap=0.5,
                random_seed=42
            )

    def test_output_patch_sizes(self):
        # Test that patches have correct sizes
        patches = generate_patches_randomly(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=5,
            spatial_key='spatial',
            max_overlap=0.5,
            random_seed=42
        )
        # Extract min and max spatial values for 'sample_1'
        spatial_values = self.adata[self.adata.obs['library_key'] == 'sample_1'].obsm['spatial']
        min_coords = spatial_values.min(axis=0)
        max_coords = spatial_values.max(axis=0)
        expected_patch_width = (max_coords[0] - min_coords[0]) / 5
        expected_patch_height = (max_coords[1] - min_coords[1]) / 5
        # Check if patches have correct sizes
        for patch in patches:
            x0, y0, x1, y1 = patch
            self.assertAlmostEqual(x1 - x0, expected_patch_width, places=5)
            self.assertAlmostEqual(y1 - y0, expected_patch_height, places=5)

    def test_max_overlap(self):
        # Test with max_overlap parameter
        patches = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            max_overlap=0.5,
            random_seed=42
        )
        # Check for overlaps
        for i, patch1 in enumerate(patches):
            for patch2 in patches[i+1:]:
                # Use overlap_check function from mesa.ecospatial
                overlap_allowed = utils._overlap_check(
                    new_patch=patch1,
                    existing_patches=[patch2],
                    max_overlap_ratio=0.5
                )
                self.assertTrue(overlap_allowed)

    def test_min_points(self):
        # Test with min_points set too high, expecting zero patches
        patches = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            min_points=5000,  # More than total points in 'sample_1'
            max_overlap=0.5,
            random_seed=42
        )
        self.assertEqual(len(patches), 0)

    def test_random_seed_consistency(self):
        # Generate patches with the same random_seed
        patches1 = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            max_overlap=0.5,
            random_seed=42
        )
        patches2 = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            max_overlap=0.5,
            random_seed=42
        )
        # Check that the patches are the same
        self.assertEqual(patches1, patches2)

    def test_contains_min_points(self):
        # Test that each patch contains at least min_points
        min_points = 20
        patches = generate_patches_randomly(
            spatial_data=self.df_sample1,
            library_key='library_key',
            library_id='sample_1',
            scaling_factor=4,
            spatial_key=['x', 'y'],
            max_overlap=0.5,
            random_seed=42,
            min_points=min_points
        )
        spatial_values = self.df_sample1[['x', 'y']].values
        for patch in patches:
            contains = utils._contains_points(
                patch=patch,
                spatial_values=spatial_values,
                min_points=min_points
            )
            self.assertTrue(contains)


[3]:
# Run the tests in the notebook
unittest.main(argv=['first-arg-is-ignored'], exit=False)
..............
----------------------------------------------------------------------
Ran 14 tests in 5.393s

OK
Warning: Could not generate a new patch within 5 seconds. Returning 0 out of 16 patches
[3]:
<unittest.main.TestProgram at 0x15ba4a150>