Coverage for src/trapi_predict_kit/save.py: 94%
39 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:14 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-25 21:14 +0200
1import os.path
2import pickle
3from dataclasses import dataclass
4from typing import Any, Optional
6from mlem import api as mlem
7from rdflib import Graph
9# from mlem.api import save as mlem_save, load as mlem_load
10from trapi_predict_kit.utils import get_run_metadata, log
13@dataclass
14class LoadedModel:
15 path: str
16 model: Any
17 metadata: Graph
18 hyper_params: Optional[Any] = None
19 scores: Optional[Any] = None
20 # features: Any = None
23def save(
24 model: Any,
25 path: str,
26 sample_data: Any,
27 method: str = "pickle",
28 scores: Optional[Any] = None,
29 hyper_params: Optional[Any] = None,
30 # hyper_params: Optional[dict] = None,
31) -> LoadedModel:
32 model_name = path.rsplit("/", 1)[-1]
33 # Create the parent directory if it doesn't exist
34 parent_dir = os.path.dirname(path)
35 if not os.path.exists(parent_dir): 35 ↛ 41line 35 didn't jump to line 41, because the condition on line 35 was never false
36 try:
37 os.makedirs(parent_dir)
38 except OSError as e:
39 log.warn(f"Error creating directory: {e}")
41 log.info(f"💾 Saving the model in {path} using {method}")
43 # mlem_model = MlemModel.from_obj(model, sample_data=sample_data)
44 # mlem_model.dump(path)
45 if method == "mlem":
46 mlem.save(model, path, sample_data=sample_data)
47 else:
48 with open(path, "wb") as f:
49 pickle.dump(model, f)
51 g = get_run_metadata(scores, sample_data, hyper_params, model_name)
52 g.serialize(f"{path}.ttl", format="ttl")
53 # os.chmod(f"{path}.mlem", 0o644)
55 return LoadedModel(
56 path=path,
57 model=model,
58 metadata=g,
59 hyper_params=hyper_params,
60 scores=scores,
61 )
64def load(path: str, method: str = "pickle") -> LoadedModel:
65 log.info(f"💽 Loading model from {path} using {method}")
66 if method == "mlem":
67 model = mlem.load(path)
68 else:
69 with open(path, "rb") as f:
70 model = pickle.load(f)
72 g = Graph()
73 g.parse(f"{path}.ttl", format="ttl")
74 # TODO: extract scores and hyper_params from RDF?
76 return LoadedModel(path=path, model=model, metadata=g)