Coverage for tests / unit / with_torch / test_get_module_device.py: 97%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-21 22:18 -0700

1from __future__ import annotations 

2 

3import pytest 

4import torch # type: ignore[import-not-found] 

5 

6from zanj.torchutil import get_module_device 

7 

8 

9def test_get_module_device_single_device(): 

10 # Create a model and move it to a device 

11 model = torch.nn.Linear(10, 2) 

12 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

13 model.to(device) 

14 

15 # Run the function 

16 is_single, device_or_dict = get_module_device(model) 

17 

18 # Assert that all parameters are on the same device and that device is returned 

19 assert is_single 

20 assert device_or_dict == device 

21 

22 

23def test_get_module_device_multiple_devices(): 

24 # Create a model with parameters on different devices 

25 if torch.cuda.device_count() < 1: 

26 pytest.skip("This test requires at least one CUDA device") 

27 

28 with torch.no_grad(): 

29 model = torch.nn.Linear(10, 2) 

30 print(f"{model = }") 

31 model.weight = torch.nn.Parameter(model.weight.to("meta")) 

32 model.bias = torch.nn.Parameter(model.bias.to("cpu")) 

33 

34 print(f"{model = }") 

35 print(f"{model.weight = }") 

36 print(f"{model.bias = }") 

37 

38 # Run the function 

39 is_single, device_or_dict = get_module_device(model) 

40 

41 print(f"{is_single = }, {device_or_dict = }") 

42 

43 # Assert that not all parameters are on the same device and a dict is returned 

44 assert not is_single 

45 assert isinstance(device_or_dict, dict) 

46 

47 # Check that the dict maps the correct devices 

48 assert device_or_dict["weight"] == torch.device("meta") 

49 assert device_or_dict["bias"] == torch.device("cpu") 

50 

51 

52def test_get_module_device_no_parameters(): 

53 # Create a model with no parameters 

54 model = torch.nn.Sequential() 

55 

56 # Run the function 

57 is_single, device_or_dict = get_module_device(model) 

58 

59 # Assert that an empty dict is returned 

60 assert not is_single 

61 assert device_or_dict == {}