Coverage for trnbl\loggers\local\locallogger.py: 93%
123 statements
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
« prev ^ index » next coverage.py v7.6.10, created at 2025-01-17 02:23 -0700
1import datetime
2import hashlib
3import json
4from typing import Any
5from pathlib import Path
6import io
7import inspect
9import yaml # type: ignore[import-untyped]
10from trnbl.loggers.base import TrainingLoggerBase, rand_syllabic_string
13class FilePaths:
14 # configs and metadata
15 TRAIN_CONFIG: Path = Path("config.json")
16 LOGGER_META: Path = Path("meta.json")
17 # configs and metadata in yaml format for easier human readability
18 TRAIN_CONFIG_YML: Path = Path("config.yml")
19 LOGGER_META_YML: Path = Path("meta.yml")
21 # logs, metrics, and artifacts
22 ARTIFACTS: Path = Path("artifacts.jsonl")
23 METRICS: Path = Path("metrics.jsonl")
24 LOG: Path = Path("log.jsonl")
25 # keeps error message if an error occurs
26 ERROR_FILE: Path = Path("ERROR.txt")
28 # manifest is shared between all runs in a project
29 # relative to project path instead of run path
30 RUNS_MANIFEST: Path = Path("runs.jsonl")
31 # directory in project path for runs
32 # relative to project path instead of run path
33 RUNS_DIR: Path = Path("runs")
35 # frontend files
36 HTML_INDEX: Path = Path("index.html")
37 START_SERVER: Path = Path("start_server.py")
40class LocalLogger(TrainingLoggerBase):
41 def __init__(
42 self,
43 project: str,
44 metric_names: list[str],
45 train_config: dict,
46 group: str = "",
47 base_path: str | Path = Path("trnbl-logs"),
48 memusage_as_metrics: bool = True,
49 console_msg_prefix: str = "# ",
50 ):
51 # set up lists
52 self.log_list: list[dict] = list()
53 self.metrics_list: list[dict] = list()
54 self.artifacts_list: list[dict] = list()
56 # copy kwargs
57 self.train_config: dict = train_config
58 self.project: str = project
59 self.group: str = group
60 self.group_str: str = self.group + ("-" if group and group[-1] != "-" else "")
61 self.base_path: Path = Path(base_path)
62 self.console_msg_prefix: str = console_msg_prefix
64 # set up id
65 self._syllabic_id: str = rand_syllabic_string()
66 self.run_init_timestamp: datetime.datetime = datetime.datetime.now()
67 self.run_id: str = self._get_run_id()
69 # set up paths
70 self.project_path: Path = self.base_path / project
71 self._run_path: Path = self.project_path / FilePaths.RUNS_DIR / self.run_id
72 # make sure the run path doesn't already exist
73 assert not self._run_path.exists()
74 self._run_path.mkdir(parents=True, exist_ok=True)
76 # set up files and objects for logs, artifacts, and metrics
77 # ----------------------------------------
79 self.log_file: io.TextIOWrapper = open(self.run_path / FilePaths.LOG, "a")
81 self.metrics_file: io.TextIOWrapper = open(
82 self.run_path / FilePaths.METRICS, "a"
83 )
85 self.artifacts_file: io.TextIOWrapper = open(
86 self.run_path / FilePaths.ARTIFACTS, "a"
87 )
89 # metric names (getting mem usage might cause problems if we have an error)
90 self.metric_names: list[str] = metric_names
91 if memusage_as_metrics:
92 self.metric_names += list(self.get_mem_usage().keys())
94 # put everything in a config
95 self.logger_meta: dict = dict(
96 run_id=self.run_id,
97 run_path=self.run_path.as_posix(),
98 syllabic_id=self.syllabic_id,
99 group=self.group,
100 project=self.project,
101 run_init_timestamp=str(self.run_init_timestamp.isoformat()),
102 metric_names=metric_names,
103 train_config=train_config, # TODO: this duplicates the contents of FilePaths.TRAIN_CONFIG, is that ok?
104 )
106 # write to the project jsonl
107 with open(self.project_path / FilePaths.RUNS_MANIFEST, "a") as f:
108 json.dump(self.logger_meta, f)
109 f.write("\n")
111 # write the index.html and start_server.py files
112 # ----------------------------------------
113 from trnbl.loggers.local.html_frontend import get_html_frontend
115 with open(self.project_path / FilePaths.HTML_INDEX, "w") as f:
116 f.write(get_html_frontend())
118 import trnbl.loggers.local.start_server as start_server_module
120 with open(self.project_path / FilePaths.START_SERVER, "w") as f:
121 f.write(inspect.getsource(start_server_module))
123 # write init files
124 # ----------------------------------------
126 # logger metadata
127 with open(self.run_path / FilePaths.LOGGER_META, "w") as f:
128 json.dump(self.logger_meta, f, indent="\t")
130 with open(self.run_path / FilePaths.LOGGER_META_YML, "w") as f:
131 yaml.dump(self.logger_meta, f)
133 # training/model/dataset config
134 with open(self.run_path / FilePaths.TRAIN_CONFIG, "w") as f:
135 json.dump(train_config, f, indent="\t")
137 with open(self.run_path / FilePaths.TRAIN_CONFIG_YML, "w") as f:
138 yaml.dump(train_config, f)
140 self.message(f"starting logger with id {self.run_id}")
142 @property
143 def _run_hash(self) -> str:
144 return hashlib.md5(str(self.train_config).encode()).hexdigest()
146 @property
147 def syllabic_id(self) -> str:
148 return self._syllabic_id
150 def _get_run_id(self) -> str:
151 return f"{self.group_str}h{self._run_hash[:5]}-{self.run_init_timestamp.strftime('%y%m%d_%H%M')}-{self.syllabic_id}"
153 def get_timestamp(self) -> str:
154 return datetime.datetime.now().isoformat()
156 def _log(self, message: str, **kwargs) -> None:
157 """(internal) log a progress message"""
158 # TODO: also log messages via regular logger to stdout
159 msg_dict: dict = dict(
160 message=message,
161 timestamp=self.get_timestamp(),
162 )
163 if kwargs:
164 msg_dict.update(kwargs)
166 self.log_list.append(msg_dict)
167 self.log_file.write(json.dumps(msg_dict) + "\n")
168 self.log_file.flush()
170 def debug(self, message: str, **kwargs) -> None:
171 """log a debug message"""
172 self._log(message, __dbg__=True, **kwargs)
174 def message(self, message: str, **kwargs) -> None:
175 """log a progress message"""
176 # TODO: also log messages via regular logger to stdout
177 self._log(message, **kwargs)
178 print(self.console_msg_prefix + message)
180 def warning(self, message: str, **kwargs) -> None:
181 """log a warning message"""
182 self.message(
183 f"WARNING: {message}",
184 __warning__=True,
185 **kwargs,
186 )
188 def error(self, message: str, **kwargs) -> None:
189 """log an error message"""
190 self.message(
191 f"ERROR: {message}",
192 __error__=True,
193 **kwargs,
194 )
195 with open(self.run_path / FilePaths.ERROR_FILE, "a") as f:
196 f.write("=" * 80 + "\n")
197 f.write("exception at " + self.get_timestamp() + "\n")
198 f.write(message)
199 f.write("\n")
200 f.flush()
202 def metrics(self, data: dict[str, Any]) -> None:
203 """log a dictionary of metrics"""
204 data["timestamp"] = self.get_timestamp()
206 self.metrics_list.append(data)
207 self.metrics_file.write(json.dumps(data) + "\n")
209 def artifact(
210 self,
211 path: Path,
212 type: str,
213 aliases: list[str] | None = None,
214 metadata: dict | None = None,
215 ) -> None:
216 """log an artifact from a file"""
217 artifact_dict: dict = dict(
218 timestamp=self.get_timestamp(),
219 path=path.as_posix(),
220 type=type,
221 aliases=aliases,
222 metadata=metadata if metadata else {},
223 )
225 self.artifacts_list.append(artifact_dict)
226 self.artifacts_file.write(json.dumps(artifact_dict) + "\n")
228 @property
229 def url(self) -> str:
230 """Get the URL for the current logging run"""
231 return self.run_path.as_posix()
233 @property
234 def run_path(self) -> Path:
235 """Get the path to the current logging run"""
236 return self._run_path
238 def flush(self) -> None:
239 self.log_file.flush()
240 self.metrics_file.flush()
241 self.artifacts_file.flush()
243 def finish(self) -> None:
244 self.message("closing logger")
246 self.log_file.flush()
247 self.log_file.close()
249 self.metrics_file.flush()
250 self.metrics_file.close()
252 self.artifacts_file.flush()
253 self.artifacts_file.close()