Coverage for tests/test_spliced_mixup_dataset.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-16 16:14 -0700

1import os 

2import unittest 

3from unittest.mock import MagicMock, patch 

4 

5import numpy as np 

6 

7from copick_torch import SplicedMixupDataset 

8 

9 

10class TestSplicedMixupDataset(unittest.TestCase): 

11 """Test the SplicedMixupDataset class with focus on Gaussian blending functionality.""" 

12 

13 def setUp(self): 

14 """Set up test case.""" 

15 # Create patches for all required methods 

16 load_copick_roots_patch = patch("copick_torch.dataset.SplicedMixupDataset._load_copick_roots") 

17 load_process_data_patch = patch("copick_torch.dataset.SimpleCopickDataset._load_or_process_data") 

18 ensure_zarr_patch = patch("copick_torch.dataset.SplicedMixupDataset._ensure_zarr_loaded") 

19 generate_samples_patch = patch("copick_torch.dataset.SplicedMixupDataset._generate_synthetic_samples") 

20 

21 # Start all patches 

22 self.addCleanup(load_copick_roots_patch.stop) 

23 self.addCleanup(load_process_data_patch.stop) 

24 self.addCleanup(ensure_zarr_patch.stop) 

25 self.addCleanup(generate_samples_patch.stop) 

26 

27 load_copick_roots_patch.start() 

28 load_process_data_patch.start() 

29 ensure_zarr_patch.start() 

30 generate_samples_patch.start() 

31 

32 # The key: patch the validation check directly 

33 with patch("copick_torch.dataset.SimpleCopickDataset.__init__", return_value=None): 

34 self.dataset = SplicedMixupDataset(exp_dataset_id=1, synth_dataset_id=2, blend_sigma=2.0) 

35 

36 # Set necessary attributes that would normally be set in initialization 

37 self.dataset.blend_sigma = 2.0 

38 self.dataset._subvolumes = [] 

39 self.dataset._molecule_ids = [] 

40 

41 def test_splice_volumes_gaussian_blending(self): 

42 """Test the _splice_volumes method with Gaussian blending.""" 

43 # Create test data 

44 boxsize = (16, 16, 16) 

45 synthetic_region = np.ones(boxsize) # All 1s 

46 exp_crop = np.zeros(boxsize) # All 0s 

47 

48 # Create a mask that covers part of the volume 

49 region_mask = np.zeros(boxsize, dtype=bool) 

50 region_mask[4:12, 4:12, 4:12] = True # Inner cube is True 

51 

52 # Test with blend_sigma=0 (no blending) 

53 self.dataset.blend_sigma = 0.0 

54 result_no_blend = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop) 

55 

56 # Verify values: should be 1 inside mask, 0 outside, with no blending 

57 self.assertTrue(np.all(result_no_blend[region_mask] == 1.0)) 

58 self.assertTrue(np.all(result_no_blend[~region_mask] == 0.0)) 

59 

60 # Test with blend_sigma=2.0 (Gaussian blending) 

61 self.dataset.blend_sigma = 2.0 

62 result_gaussian = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop) 

63 

64 # Verify values: 

65 # 1. Core of masked area should still be relatively high (central region) 

66 # Use a lower threshold (0.7 instead of 0.9) to account for the blending effect 

67 self.assertTrue(np.all(result_gaussian[7:9, 7:9, 7:9] > 0.7)) 

68 

69 # 2. Outside mask far from boundary should still be close to 0 

70 self.assertTrue(np.all(result_gaussian[0:2, 0:2, 0:2] < 0.1)) 

71 

72 # 3. Boundary region should have intermediate values 

73 # Check for points near the boundary 

74 boundary_values = result_gaussian[3:5, 8, 8] # Points near boundary 

75 self.assertTrue(np.any((boundary_values > 0.1) & (boundary_values < 0.9))) 

76 

77 # 4. Values should transition smoothly across boundary 

78 # Check along a line from center to outside 

79 center_to_edge = result_gaussian[8, 8, 8:16] # Line from center to edge 

80 # Verify monotonic decrease (or at least mostly decreasing) 

81 # Check if at least 80% of the differences are non-positive 

82 diffs = np.diff(center_to_edge) 

83 self.assertTrue(np.sum(diffs <= 0) >= 0.8 * len(diffs)) 

84 

85 def test_gaussian_blending_vs_no_blending(self): 

86 """Compare Gaussian blending with no blending to ensure they're different.""" 

87 # Create test data 

88 boxsize = (16, 16, 16) 

89 synthetic_region = np.ones(boxsize) # All 1s 

90 exp_crop = np.zeros(boxsize) # All 0s 

91 

92 # Create a mask 

93 region_mask = np.zeros(boxsize, dtype=bool) 

94 region_mask[4:12, 4:12, 4:12] = True # Inner cube is True 

95 

96 # Get results with and without blending 

97 self.dataset.blend_sigma = 0.0 

98 result_no_blend = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop) 

99 

100 self.dataset.blend_sigma = 2.0 

101 result_gaussian = self.dataset._splice_volumes(synthetic_region, region_mask, exp_crop) 

102 

103 # Check that results are different 

104 self.assertFalse(np.array_equal(result_no_blend, result_gaussian)) 

105 

106 # Check that binary mask is recovered with no blending 

107 binary_mask_recovered = result_no_blend > 0.5 

108 self.assertTrue(np.array_equal(binary_mask_recovered, region_mask)) 

109 

110 # With Gaussian blending, more pixels should be affected than just the mask 

111 affected_pixels_gaussian = result_gaussian > 0.01 

112 self.assertTrue(np.sum(affected_pixels_gaussian) > np.sum(region_mask)) 

113 

114 

115if __name__ == "__main__": 

116 unittest.main()