Unit Test for generate_patches

[1]:
import unittest
import numpy as np
import pandas as pd
import anndata as ad

import os
os.sys.path.append('../../')
from mesa.ecospatial import generate_patches
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/geopandas/_compat.py:106: UserWarning: The Shapely GEOS version (3.8.0-CAPI-1.13.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.
  warnings.warn(
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
/opt/miniconda3/envs/mesa/lib/python3.11/site-packages/spaghetti/network.py:41: FutureWarning: The next major release of pysal/spaghetti (2.0.0) will drop support for all ``libpysal.cg`` geometries. This change is a first step in refactoring ``spaghetti`` that is expected to result in dramatically reduced runtimes for network instantiation and operations. Users currently requiring network and point pattern input as ``libpysal.cg`` geometries should prepare for this simply by converting to ``shapely`` geometries.
  warnings.warn(dep_msg, FutureWarning, stacklevel=1)
[2]:
class TestGeneratePatches(unittest.TestCase):
    def setUp(self):
        # Create common test data
        self.spatial_coords = np.array([[x, y] for x in range(0, 100, 20) for y in range(0, 100, 20)])
        self.obs = pd.DataFrame({'library_key': ['lib1'] * len(self.spatial_coords)})
        self.obs.index = self.obs.index.astype(str)

        # AnnData object
        self.adata = ad.AnnData(X=np.random.rand(len(self.spatial_coords), 5), obs=self.obs)
        self.adata.obsm['spatial'] = self.spatial_coords

        # DataFrame
        self.df = pd.DataFrame({
            'library_key': ['lib1'] * len(self.spatial_coords),
            'spatial_x': self.spatial_coords[:, 0],
            'spatial_y': self.spatial_coords[:, 1]
        })

    def test_generate_patches_with_AnnData(self):
        # Test with AnnData input
        patches = generate_patches(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='lib1',
            scaling_factor=2,
            spatial_key='spatial'
        )
        expected_patches = 4  # Since scaling_factor=2, expect 2x2=4 patches
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_generate_patches_with_DataFrame(self):
        # Test with DataFrame input
        patches = generate_patches(
            spatial_data=self.df,
            library_key='library_key',
            library_id='lib1',
            scaling_factor=5,
            spatial_key=['spatial_x', 'spatial_y']
        )
        expected_patches = 25  # Since scaling_factor=5, expect 5x5=25 patches
        self.assertEqual(len(patches), expected_patches)
        self.assertIsInstance(patches, list)

    def test_invalid_spatial_data_type(self):
        # Test with invalid spatial_data type
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data='invalid_type',
                library_key='library_key',
                library_id='lib1',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_library_key(self):
        # Test with invalid library_key
        with self.assertRaises(KeyError):
            generate_patches(
                spatial_data=self.adata,
                library_key='invalid_key',
                library_id='lib1',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_library_id(self):
        # Test with invalid library_id
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='invalid_id',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_invalid_spatial_key(self):
        # Test with invalid spatial_key
        with self.assertRaises(KeyError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='lib1',
                scaling_factor=2,
                spatial_key='invalid_spatial_key'
            )

    def test_zero_scaling_factor(self):
        # Test with zero scaling_factor
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='lib1',
                scaling_factor=0,
                spatial_key='spatial'
            )

    def test_non_numeric_scaling_factor(self):
        # Test with non-numeric scaling_factor
        with self.assertRaises(TypeError):
            generate_patches(
                spatial_data=self.adata,
                library_key='library_key',
                library_id='lib1',
                scaling_factor='two',
                spatial_key='spatial'
            )

    def test_empty_spatial_data_filtered(self):
        # Test when spatial_data_filtered is empty
        obs_empty = pd.DataFrame({'library_key': ['lib2'] * len(self.spatial_coords)})
        obs_empty.index = obs_empty.index.astype(str)
        adata_empty = ad.AnnData(X=np.random.rand(len(self.spatial_coords), 5), obs=obs_empty)
        adata_empty.obsm['spatial'] = self.spatial_coords
        with self.assertRaises(ValueError):
            generate_patches(
                spatial_data=adata_empty,
                library_key='library_key',
                library_id='lib1',
                scaling_factor=2,
                spatial_key='spatial'
            )

    def test_output_patch_coordinates(self):
        # Test that patches have correct coordinates
        patches = generate_patches(
            spatial_data=self.adata,
            library_key='library_key',
            library_id='lib1',
            scaling_factor=2,
            spatial_key='spatial'
        )
        # Extract min and max spatial values
        min_coords = self.spatial_coords.min(axis=0)
        max_coords = self.spatial_coords.max(axis=0)
        expected_patch_width = (max_coords[0] - min_coords[0]) / 2
        expected_patch_height = (max_coords[1] - min_coords[1]) / 2
        # Check if patches cover the area correctly
        for patch in patches:
            x0, y0, x1, y1 = patch
            self.assertAlmostEqual(x1 - x0, expected_patch_width)
            self.assertAlmostEqual(y1 - y0, expected_patch_height)

[3]:
# Run the tests in the notebook
unittest.main(argv=['first-arg-is-ignored'], exit=False)
..........
----------------------------------------------------------------------
Ran 10 tests in 0.083s

OK
[3]:
<unittest.main.TestProgram at 0x1515898d0>
[ ]: