Coverage for intelligence_toolkit/tests/unit/AI/test_base_batch_async.py: 100%
70 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-16 13:41 -0300
1# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
2# Licensed under the MIT license. See LICENSE file in the project.
3#
4import asyncio
5from unittest.mock import MagicMock
7import pytest
9from intelligence_toolkit.AI.base_batch_async import BaseBatchAsync
10from intelligence_toolkit.helpers.progress_batch_callback import ProgressBatchCallback
13@pytest.fixture
14def base_batch():
15 return BaseBatchAsync()
18def test_base_batch_async_initialization():
19 batch = BaseBatchAsync()
20 assert batch.total_tasks == 1
21 assert batch.completed_tasks == 0
22 assert batch.previous_completed_tasks == 0
25def test_base_batch_async_progress_callback(base_batch):
26 assert base_batch.completed_tasks == 0
27 base_batch.progress_callback()
28 assert base_batch.completed_tasks == 1
29 base_batch.progress_callback()
30 assert base_batch.completed_tasks == 2
33@pytest.mark.asyncio
34async def test_track_progress_with_tasks():
35 batch = BaseBatchAsync()
36 batch.total_tasks = 3
38 callback = MagicMock(spec=ProgressBatchCallback)
39 callback.on_batch_change = MagicMock()
41 # Create some mock tasks
42 async def mock_task():
43 await asyncio.sleep(0.01)
44 batch.progress_callback()
46 tasks = [asyncio.create_task(mock_task()) for _ in range(3)]
48 # Track progress
49 await batch.track_progress(tasks, [callback])
51 # Verify callback was called
52 assert callback.on_batch_change.called
53 assert batch.completed_tasks == 3
56@pytest.mark.asyncio
57async def test_track_progress_multiple_callbacks():
58 batch = BaseBatchAsync()
59 batch.total_tasks = 2
61 callback1 = MagicMock(spec=ProgressBatchCallback)
62 callback1.on_batch_change = MagicMock()
63 callback2 = MagicMock(spec=ProgressBatchCallback)
64 callback2.on_batch_change = MagicMock()
66 async def mock_task():
67 await asyncio.sleep(0.01)
68 batch.progress_callback()
70 tasks = [asyncio.create_task(mock_task()) for _ in range(2)]
72 await batch.track_progress(tasks, [callback1, callback2])
74 assert callback1.on_batch_change.called
75 assert callback2.on_batch_change.called
76 assert batch.completed_tasks == 2
79@pytest.mark.asyncio
80async def test_track_progress_completed_immediately():
81 batch = BaseBatchAsync()
82 batch.total_tasks = 1
83 batch.completed_tasks = 1
85 callback = MagicMock(spec=ProgressBatchCallback)
86 callback.on_batch_change = MagicMock()
88 # Create already completed tasks
89 async def completed_task():
90 pass
92 task = asyncio.create_task(completed_task())
93 await asyncio.sleep(0.01) # Let task complete
95 await batch.track_progress([task], [callback])
97 assert callback.on_batch_change.called
100@pytest.mark.asyncio
101async def test_track_progress_no_change():
102 batch = BaseBatchAsync()
103 batch.total_tasks = 0
104 batch.completed_tasks = 0
106 callback = MagicMock(spec=ProgressBatchCallback)
107 callback.on_batch_change = MagicMock()
109 # Empty task list
110 await batch.track_progress([], [callback])
112 # Should still call callback at the end
113 assert callback.on_batch_change.called