Coverage for tests/test_spliced_mixup_dataset.py: 0%
58 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-16 16:14 -0700
1import os
2import unittest
3from unittest.mock import MagicMock, patch
5import numpy as np
7from copick_torch import SplicedMixupDataset
10class TestSplicedMixupDataset(unittest.TestCase):
11 """Test the SplicedMixupDataset class with focus on Gaussian blending functionality."""
13 def setUp(self):
14 """Set up test case."""
15 # Create patches for all required methods
16 load_copick_roots_patch = patch("copick_torch.dataset.SplicedMixupDataset._load_copick_roots")
17 load_process_data_patch = patch("copick_torch.dataset.SimpleCopickDataset._load_or_process_data")
18 ensure_zarr_patch = patch("copick_torch.dataset.SplicedMixupDataset._ensure_zarr_loaded")
19 generate_samples_patch = patch("copick_torch.dataset.SplicedMixupDataset._generate_synthetic_samples")
21 # Start all patches
22 self.addCleanup(load_copick_roots_patch.stop)
23 self.addCleanup(load_process_data_patch.stop)
24 self.addCleanup(ensure_zarr_patch.stop)
25 self.addCleanup(generate_samples_patch.stop)
27 load_copick_roots_patch.start()
28 load_process_data_patch.start()
29 ensure_zarr_patch.start()
30 generate_samples_patch.start()
32 # The key: patch the validation check directly
33 with patch("copick_torch.dataset.SimpleCopickDataset.__init__", return_value=None):
34 self.dataset = SplicedMixupDataset(exp_dataset_id=1, synth_dataset_id=2, blend_sigma=2.0)
36 # Set necessary attributes that would normally be set in initialization
37 self.dataset.blend_sigma = 2.0
38 self.dataset._subvolumes = []
39 self.dataset._molecule_ids = []
41 def test_splice_volumes_gaussian_blending(self):
42 """Test the _splice_volumes method with Gaussian blending."""
43 # Create test data
44 boxsize = (16, 16, 16)
45 synthetic_region = np.ones(boxsize) # All 1s
46 exp_crop = np.zeros(boxsize) # All 0s
48 # Create a mask that covers part of the volume
49 region_mask = np.zeros(boxsize, dtype=bool)
50 region_mask[4:12, 4:12, 4:12] = True # Inner cube is True
52 # Test with blend_sigma=0 (no blending)
53 self.dataset.blend_sigma = 0.0
54 result_no_blend = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop)
56 # Verify values: should be 1 inside mask, 0 outside, with no blending
57 self.assertTrue(np.all(result_no_blend[region_mask] == 1.0))
58 self.assertTrue(np.all(result_no_blend[~region_mask] == 0.0))
60 # Test with blend_sigma=2.0 (Gaussian blending)
61 self.dataset.blend_sigma = 2.0
62 result_gaussian = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop)
64 # Verify values:
65 # 1. Core of masked area should still be relatively high (central region)
66 # Use a lower threshold (0.7 instead of 0.9) to account for the blending effect
67 self.assertTrue(np.all(result_gaussian[7:9, 7:9, 7:9] > 0.7))
69 # 2. Outside mask far from boundary should still be close to 0
70 self.assertTrue(np.all(result_gaussian[0:2, 0:2, 0:2] < 0.1))
72 # 3. Boundary region should have intermediate values
73 # Check for points near the boundary
74 boundary_values = result_gaussian[3:5, 8, 8] # Points near boundary
75 self.assertTrue(np.any((boundary_values > 0.1) & (boundary_values < 0.9)))
77 # 4. Values should transition smoothly across boundary
78 # Check along a line from center to outside
79 center_to_edge = result_gaussian[8, 8, 8:16] # Line from center to edge
80 # Verify monotonic decrease (or at least mostly decreasing)
81 # Check if at least 80% of the differences are non-positive
82 diffs = np.diff(center_to_edge)
83 self.assertTrue(np.sum(diffs <= 0) >= 0.8 * len(diffs))
85 def test_gaussian_blending_vs_no_blending(self):
86 """Compare Gaussian blending with no blending to ensure they're different."""
87 # Create test data
88 boxsize = (16, 16, 16)
89 synthetic_region = np.ones(boxsize) # All 1s
90 exp_crop = np.zeros(boxsize) # All 0s
92 # Create a mask
93 region_mask = np.zeros(boxsize, dtype=bool)
94 region_mask[4:12, 4:12, 4:12] = True # Inner cube is True
96 # Get results with and without blending
97 self.dataset.blend_sigma = 0.0
98 result_no_blend = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop)
100 self.dataset.blend_sigma = 2.0
101 result_gaussian = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop)
103 # Check that results are different
104 self.assertFalse(np.array_equal(result_no_blend, result_gaussian))
106 # Check that binary mask is recovered with no blending
107 binary_mask_recovered = result_no_blend > 0.5
108 self.assertTrue(np.array_equal(binary_mask_recovered, region_mask))
110 # With Gaussian blending, more pixels should be affected than just the mask
111 affected_pixels_gaussian = result_gaussian > 0.01
112 self.assertTrue(np.sum(affected_pixels_gaussian) > np.sum(region_mask))
115if __name__ == "__main__":
116 unittest.main()