Coverage for intelligence_toolkit/tests/unit/AI/test_base_batch_async.py: 100%

70 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-16 13:41 -0300

1# Copyright (c) 2024 Microsoft Corporation. All rights reserved. 

2# Licensed under the MIT license. See LICENSE file in the project. 

3# 

4import asyncio 

5from unittest.mock import MagicMock 

6 

7import pytest 

8 

9from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync 

10from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback 

11 

12 

13@pytest.fixture 

14def base_batch(): 

15 return BaseBatchAsync() 

16 

17 

18def test_base_batch_async_initialization(): 

19 batch = BaseBatchAsync() 

20 assert batch.total_tasks == 1 

21 assert batch.completed_tasks == 0 

22 assert batch.previous_completed_tasks == 0 

23 

24 

25def test_base_batch_async_progress_callback(base_batch): 

26 assert base_batch.completed_tasks == 0 

27 base_batch.progress_callback() 

28 assert base_batch.completed_tasks == 1 

29 base_batch.progress_callback() 

30 assert base_batch.completed_tasks == 2 

31 

32 

33@pytest.mark.asyncio 

34async def test_track_progress_with_tasks(): 

35 batch = BaseBatchAsync() 

36 batch.total_tasks = 3 

37 

38 callback = MagicMock(spec=ProgressBatchCallback) 

39 callback.on_batch_change = MagicMock() 

40 

41 # Create some mock tasks 

42 async def mock_task(): 

43 await asyncio.sleep(0.01) 

44 batch.progress_callback() 

45 

46 tasks = [asyncio.create_task(mock_task()) for _ in range(3)] 

47 

48 # Track progress 

49 await batch.track_progress(tasks, [callback]) 

50 

51 # Verify callback was called 

52 assert callback.on_batch_change.called 

53 assert batch.completed_tasks == 3 

54 

55 

56@pytest.mark.asyncio 

57async def test_track_progress_multiple_callbacks(): 

58 batch = BaseBatchAsync() 

59 batch.total_tasks = 2 

60 

61 callback1 = MagicMock(spec=ProgressBatchCallback) 

62 callback1.on_batch_change = MagicMock() 

63 callback2 = MagicMock(spec=ProgressBatchCallback) 

64 callback2.on_batch_change = MagicMock() 

65 

66 async def mock_task(): 

67 await asyncio.sleep(0.01) 

68 batch.progress_callback() 

69 

70 tasks = [asyncio.create_task(mock_task()) for _ in range(2)] 

71 

72 await batch.track_progress(tasks, [callback1, callback2]) 

73 

74 assert callback1.on_batch_change.called 

75 assert callback2.on_batch_change.called 

76 assert batch.completed_tasks == 2 

77 

78 

79@pytest.mark.asyncio 

80async def test_track_progress_completed_immediately(): 

81 batch = BaseBatchAsync() 

82 batch.total_tasks = 1 

83 batch.completed_tasks = 1 

84 

85 callback = MagicMock(spec=ProgressBatchCallback) 

86 callback.on_batch_change = MagicMock() 

87 

88 # Create already completed tasks 

89 async def completed_task(): 

90 pass 

91 

92 task = asyncio.create_task(completed_task()) 

93 await asyncio.sleep(0.01) # Let task complete 

94 

95 await batch.track_progress([task], [callback]) 

96 

97 assert callback.on_batch_change.called 

98 

99 

100@pytest.mark.asyncio 

101async def test_track_progress_no_change(): 

102 batch = BaseBatchAsync() 

103 batch.total_tasks = 0 

104 batch.completed_tasks = 0 

105 

106 callback = MagicMock(spec=ProgressBatchCallback) 

107 callback.on_batch_change = MagicMock() 

108 

109 # Empty task list 

110 await batch.track_progress([], [callback]) 

111 

112 # Should still call callback at the end 

113 assert callback.on_batch_change.called