class TestGeneratePatchesRandomly(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set up test data with more points
np.random.seed(42) # For reproducibility in data generation
total_points = 5000
cls.spatial_data = pd.DataFrame({
'x': 1000 * np.random.rand(total_points),
'y': 1000 * np.random.rand(total_points),
'library_key': ['sample_1'] * (total_points // 2) + ['sample_2'] * (total_points // 2),
'cluster_key': np.random.randint(0, 10, size=total_points)
})
# Create AnnData object from the DataFrame
cls.obs = cls.spatial_data[['library_key']]
cls.obs.index = cls.spatial_data.index.astype(str)
cls.adata = ad.AnnData(X=np.random.rand(len(cls.spatial_data), 50), obs=cls.obs)
cls.adata.obsm['spatial'] = cls.spatial_data[['x', 'y']].values
# Subset DataFrame for 'sample_1'
cls.df_sample1 = cls.spatial_data[cls.spatial_data['library_key'] == 'sample_1'].copy()
cls.df_sample1.reset_index(drop=True, inplace=True)
def test_generate_patches_randomly_with_AnnData(self):
# Test with AnnData input
patches = generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
expected_patches = 4
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_generate_patches_randomly_with_DataFrame(self):
# Test with DataFrame input
patches = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
max_overlap=0.5,
random_seed=42
)
expected_patches = 16 # Since scaling_factor=5, expect 5x5=25 patches
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_invalid_spatial_data_type(self):
# Test with invalid spatial_data type
with self.assertRaises(ValueError):
generate_patches_randomly(
spatial_data='invalid_type',
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_invalid_library_key(self):
# Test with invalid library_key
with self.assertRaises(KeyError):
generate_patches_randomly(
spatial_data=self.adata,
library_key='invalid_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_invalid_library_id(self):
# Test with invalid library_id
with self.assertRaises(ValueError):
generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='invalid_id',
scaling_factor=2,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_invalid_spatial_key(self):
# Test with invalid spatial_key
with self.assertRaises(KeyError):
generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='invalid_spatial_key',
max_overlap=0.5,
random_seed=42
)
def test_zero_scaling_factor(self):
# Test with zero scaling_factor
with self.assertRaises(ValueError):
generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=0,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_non_numeric_scaling_factor(self):
# Test with non-numeric scaling_factor
with self.assertRaises(TypeError):
generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor='two',
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_empty_spatial_data_filtered(self):
# Test when spatial_data_filtered is empty
with self.assertRaises(ValueError):
generate_patches_randomly(
spatial_data=self.adata[self.adata.obs['library_key'] == 'non_existent'],
library_key='library_key',
library_id='non_existent',
scaling_factor=2,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
def test_output_patch_sizes(self):
# Test that patches have correct sizes
patches = generate_patches_randomly(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=5,
spatial_key='spatial',
max_overlap=0.5,
random_seed=42
)
# Extract min and max spatial values for 'sample_1'
spatial_values = self.adata[self.adata.obs['library_key'] == 'sample_1'].obsm['spatial']
min_coords = spatial_values.min(axis=0)
max_coords = spatial_values.max(axis=0)
expected_patch_width = (max_coords[0] - min_coords[0]) / 5
expected_patch_height = (max_coords[1] - min_coords[1]) / 5
# Check if patches have correct sizes
for patch in patches:
x0, y0, x1, y1 = patch
self.assertAlmostEqual(x1 - x0, expected_patch_width, places=5)
self.assertAlmostEqual(y1 - y0, expected_patch_height, places=5)
def test_max_overlap(self):
# Test with max_overlap parameter
patches = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
max_overlap=0.5,
random_seed=42
)
# Check for overlaps
for i, patch1 in enumerate(patches):
for patch2 in patches[i+1:]:
# Use overlap_check function from mesa.ecospatial
overlap_allowed = utils._overlap_check(
new_patch=patch1,
existing_patches=[patch2],
max_overlap_ratio=0.5
)
self.assertTrue(overlap_allowed)
def test_min_points(self):
# Test with min_points set too high, expecting zero patches
patches = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
min_points=5000, # More than total points in 'sample_1'
max_overlap=0.5,
random_seed=42
)
self.assertEqual(len(patches), 0)
def test_random_seed_consistency(self):
# Generate patches with the same random_seed
patches1 = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
max_overlap=0.5,
random_seed=42
)
patches2 = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
max_overlap=0.5,
random_seed=42
)
# Check that the patches are the same
self.assertEqual(patches1, patches2)
def test_contains_min_points(self):
# Test that each patch contains at least min_points
min_points = 20
patches = generate_patches_randomly(
spatial_data=self.df_sample1,
library_key='library_key',
library_id='sample_1',
scaling_factor=4,
spatial_key=['x', 'y'],
max_overlap=0.5,
random_seed=42,
min_points=min_points
)
spatial_values = self.df_sample1[['x', 'y']].values
for patch in patches:
contains = utils._contains_points(
patch=patch,
spatial_values=spatial_values,
min_points=min_points
)
self.assertTrue(contains)