Coverage for tests / unit / with_torch / test_get_module_device.py: 97%
33 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-21 22:18 -0700
1from __future__ import annotations
3import pytest
4import torch # type: ignore[import-not-found]
6from zanj.torchutil import get_module_device
9def test_get_module_device_single_device():
10 # Create a model and move it to a device
11 model = torch.nn.Linear(10, 2)
12 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 model.to(device)
15 # Run the function
16 is_single, device_or_dict = get_module_device(model)
18 # Assert that all parameters are on the same device and that device is returned
19 assert is_single
20 assert device_or_dict == device
23def test_get_module_device_multiple_devices():
24 # Create a model with parameters on different devices
25 if torch.cuda.device_count() < 1:
26 pytest.skip("This test requires at least one CUDA device")
28 with torch.no_grad():
29 model = torch.nn.Linear(10, 2)
30 print(f"{model = }")
31 model.weight = torch.nn.Parameter(model.weight.to("meta"))
32 model.bias = torch.nn.Parameter(model.bias.to("cpu"))
34 print(f"{model = }")
35 print(f"{model.weight = }")
36 print(f"{model.bias = }")
38 # Run the function
39 is_single, device_or_dict = get_module_device(model)
41 print(f"{is_single = }, {device_or_dict = }")
43 # Assert that not all parameters are on the same device and a dict is returned
44 assert not is_single
45 assert isinstance(device_or_dict, dict)
47 # Check that the dict maps the correct devices
48 assert device_or_dict["weight"] == torch.device("meta")
49 assert device_or_dict["bias"] == torch.device("cpu")
52def test_get_module_device_no_parameters():
53 # Create a model with no parameters
54 model = torch.nn.Sequential()
56 # Run the function
57 is_single, device_or_dict = get_module_device(model)
59 # Assert that an empty dict is returned
60 assert not is_single
61 assert device_or_dict == {}