class TestGeneratePatches(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set up test data with more points
np.random.seed(42) # For reproducibility in data generation
total_points = 2500
cls.spatial_data = pd.DataFrame({
'x': 1000 * np.random.rand(total_points),
'y': 1000 * np.random.rand(total_points),
'library_key': ['sample_1'] * total_points,
'cluster_key': np.random.randint(0, 10, size=total_points)
})
cls.spatial_coords = cls.spatial_data[['x','y']]
# Create AnnData object from the DataFrame
cls.obs = cls.spatial_data[['library_key']]
cls.obs.index = cls.spatial_data.index.astype(str)
cls.adata = ad.AnnData(X=np.random.rand(len(cls.spatial_data), 50), obs=cls.obs)
cls.adata.obsm['spatial'] = cls.spatial_data[['x', 'y']].values
# Subset DataFrame for 'sample_1'
cls.df = cls.spatial_data[cls.spatial_data['library_key'] == 'sample_1'].copy()
cls.df.reset_index(drop=True, inplace=True)
def test_generate_patches_with_AnnData(self):
# Test with AnnData input
patches = generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial'
)
expected_patches = 4 # Since scaling_factor=2, expect 2x2=4 patches
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_generate_patches_with_DataFrame(self):
# Test with DataFrame input
patches = generate_patches(
spatial_data=self.df,
library_key='library_key',
library_id='sample_1',
scaling_factor=5,
spatial_key=['x', 'y']
)
expected_patches = 25 # Since scaling_factor=5, expect 5x5=25 patches
self.assertEqual(len(patches), expected_patches)
self.assertIsInstance(patches, list)
def test_invalid_spatial_data_type(self):
# Test with invalid spatial_data type
with self.assertRaises(ValueError):
generate_patches(
spatial_data='invalid_type',
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_library_key(self):
# Test with invalid library_key
with self.assertRaises(KeyError):
generate_patches(
spatial_data=self.adata,
library_key='invalid_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_library_id(self):
# Test with invalid library_id
with self.assertRaises(ValueError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='invalid_id',
scaling_factor=2,
spatial_key='spatial'
)
def test_invalid_spatial_key(self):
# Test with invalid spatial_key
with self.assertRaises(KeyError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='invalid_spatial_key'
)
def test_zero_scaling_factor(self):
# Test with zero scaling_factor
with self.assertRaises(ValueError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=0,
spatial_key='spatial'
)
def test_non_numeric_scaling_factor(self):
# Test with non-numeric scaling_factor
with self.assertRaises(TypeError):
generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor='two',
spatial_key='spatial'
)
def test_empty_spatial_data_filtered(self):
# Test when spatial_data_filtered is empty
with self.assertRaises(ValueError):
generate_patches(
spatial_data=self.adata[self.adata.obs['library_key'] == 'non_existent'],
library_key='library_key',
library_id='non_existent',
scaling_factor=2,
spatial_key='spatial'
)
def test_output_patch_coordinates(self):
# Test that patches have correct coordinates
patches = generate_patches(
spatial_data=self.adata,
library_key='library_key',
library_id='sample_1',
scaling_factor=2,
spatial_key='spatial'
)
# Extract min and max spatial values
min_coords = self.spatial_coords.min(axis=0)
max_coords = self.spatial_coords.max(axis=0)
expected_patch_width = (max_coords[0] - min_coords[0]) / 2
expected_patch_height = (max_coords[1] - min_coords[1]) / 2
# Check if patches cover the area correctly
for patch in patches:
x0, y0, x1, y1 = patch
self.assertAlmostEqual(x1 - x0, expected_patch_width)
self.assertAlmostEqual(y1 - y0, expected_patch_height)