class TestGeneratePatches(unittest.TestCase):
def setUp(self):
# Create common test data
self.spatial_coords = np.array([[x, y] for x in range(0, 100, 20) for y in range(0, 100, 20)])
self.obs = pd.DataFrame({'library_key': ['lib1'] * len(self.spatial_coords)})
self.obs.index = self.obs.index.astype(str)
# AnnData object
self.adata = ad.AnnData(X=np.random.rand(len(self.spatial_coords), 5), obs=self.obs)
self.adata.obsm['spatial'] = self.spatial_coords
# DataFrame
self.df = pd.DataFrame({
'library_key': ['lib1'] * len(self.spatial_coords),
'spatial_x': self.spatial_coords[:, 0],
'spatial_y': self.spatial_coords[:, 1]
})
def test_generate_patches_with_AnnData(self):
# Test with AnnData input
patches = generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='lib1',
scaling_factor=2,
spatial_key='spatial'
)
expected_patches = 4 # Since scaling_factor=2, expect 2x2=4 patches
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_generate_patches_with_DataFrame(self):
# Test with DataFrame input
patches = generate_patches(
spatial_data=self.df,
library_key='library_key',
library_id='lib1',
scaling_factor=5,
spatial_key=['spatial_x', 'spatial_y']
)
expected_patches = 25 # Since scaling_factor=5, expect 5x5=25 patches
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_invalid_spatial_data_type(self):
# Test with invalid spatial_data type
with self.assertRaises(ValueError):
generate_patches(
spatial_data='invalid_type',
library_key='library_key',
library_id='lib1',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_library_key(self):
# Test with invalid library_key
with self.assertRaises(KeyError):
generate_patches(
spatial_data=self.adata,
library_key='invalid_key',
library_id='lib1',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_library_id(self):
# Test with invalid library_id
with self.assertRaises(ValueError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='invalid_id',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_spatial_key(self):
# Test with invalid spatial_key
with self.assertRaises(KeyError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='lib1',
scaling_factor=2,
spatial_key='invalid_spatial_key'
)
def test_zero_scaling_factor(self):
# Test with zero scaling_factor
with self.assertRaises(ValueError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='lib1',
scaling_factor=0,
spatial_key='spatial'
)
def test_non_numeric_scaling_factor(self):
# Test with non-numeric scaling_factor
with self.assertRaises(TypeError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='lib1',
scaling_factor='two',
spatial_key='spatial'
)
def test_empty_spatial_data_filtered(self):
# Test when spatial_data_filtered is empty
obs_empty = pd.DataFrame({'library_key': ['lib2'] * len(self.spatial_coords)})
obs_empty.index = obs_empty.index.astype(str)
adata_empty = ad.AnnData(X=np.random.rand(len(self.spatial_coords), 5), obs=obs_empty)
adata_empty.obsm['spatial'] = self.spatial_coords
with self.assertRaises(ValueError):
generate_patches(
spatial_data=adata_empty,
library_key='library_key',
library_id='lib1',
scaling_factor=2,
spatial_key='spatial'
)
def test_output_patch_coordinates(self):
# Test that patches have correct coordinates
patches = generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='lib1',
scaling_factor=2,
spatial_key='spatial'
)
# Extract min and max spatial values
min_coords = self.spatial_coords.min(axis=0)
max_coords = self.spatial_coords.max(axis=0)
expected_patch_width = (max_coords[0] - min_coords[0]) / 2
expected_patch_height = (max_coords[1] - min_coords[1]) / 2
# Check if patches cover the area correctly
for patch in patches:
x0, y0, x1, y1 = patch
self.assertAlmostEqual(x1 - x0, expected_patch_width)
self.assertAlmostEqual(y1 - y0, expected_patch_height)