Coverage for src/trapi_predict_kit/save.py: 94%

39 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-25 21:14 +0200

1import os.path 

2import pickle 

3from dataclasses import dataclass 

4from typing import Any, Optional 

5 

6from mlem import api as mlem 

7from rdflib import Graph 

8 

9# from mlem.api import save as mlem_save, load as mlem_load 

10from trapi_predict_kit.utils import get_run_metadata, log 

11 

12 

13@dataclass 

14class LoadedModel: 

15 path: str 

16 model: Any 

17 metadata: Graph 

18 hyper_params: Optional[Any] = None 

19 scores: Optional[Any] = None 

20 # features: Any = None 

21 

22 

23def save( 

24 model: Any, 

25 path: str, 

26 sample_data: Any, 

27 method: str = "pickle", 

28 scores: Optional[Any] = None, 

29 hyper_params: Optional[Any] = None, 

30 # hyper_params: Optional[dict] = None, 

31) -> LoadedModel: 

32 model_name = path.rsplit("/", 1)[-1] 

33 # Create the parent directory if it doesn't exist 

34 parent_dir = os.path.dirname(path) 

35 if not os.path.exists(parent_dir): 35 ↛ 41line 35 didn't jump to line 41, because the condition on line 35 was never false

36 try: 

37 os.makedirs(parent_dir) 

38 except OSError as e: 

39 log.warn(f"Error creating directory: {e}") 

40 

41 log.info(f"💾 Saving the model in {path} using {method}") 

42 

43 # mlem_model = MlemModel.from_obj(model, sample_data=sample_data) 

44 # mlem_model.dump(path) 

45 if method == "mlem": 

46 mlem.save(model, path, sample_data=sample_data) 

47 else: 

48 with open(path, "wb") as f: 

49 pickle.dump(model, f) 

50 

51 g = get_run_metadata(scores, sample_data, hyper_params, model_name) 

52 g.serialize(f"{path}.ttl", format="ttl") 

53 # os.chmod(f"{path}.mlem", 0o644) 

54 

55 return LoadedModel( 

56 path=path, 

57 model=model, 

58 metadata=g, 

59 hyper_params=hyper_params, 

60 scores=scores, 

61 ) 

62 

63 

64def load(path: str, method: str = "pickle") -> LoadedModel: 

65 log.info(f"💽 Loading model from {path} using {method}") 

66 if method == "mlem": 

67 model = mlem.load(path) 

68 else: 

69 with open(path, "rb") as f: 

70 model = pickle.load(f) 

71 

72 g = Graph() 

73 g.parse(f"{path}.ttl", format="ttl") 

74 # TODO: extract scores and hyper_params from RDF? 

75 

76 return LoadedModel(path=path, model=model, metadata=g)