Coverage for trnbl\loggers\local\locallogger.py: 93%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.10, created at 2025-01-17 02:23 -0700

1import datetime 

2import hashlib 

3import json 

4from typing import Any 

5from pathlib import Path 

6import io 

7import inspect 

8 

9import yaml # type: ignore[import-untyped] 

10from trnbl.loggers.base import TrainingLoggerBase, rand_syllabic_string 

11 

12 

13class FilePaths: 

14 # configs and metadata 

15 TRAIN_CONFIG: Path = Path("config.json") 

16 LOGGER_META: Path = Path("meta.json") 

17 # configs and metadata in yaml format for easier human readability 

18 TRAIN_CONFIG_YML: Path = Path("config.yml") 

19 LOGGER_META_YML: Path = Path("meta.yml") 

20 

21 # logs, metrics, and artifacts 

22 ARTIFACTS: Path = Path("artifacts.jsonl") 

23 METRICS: Path = Path("metrics.jsonl") 

24 LOG: Path = Path("log.jsonl") 

25 # keeps error message if an error occurs 

26 ERROR_FILE: Path = Path("ERROR.txt") 

27 

28 # manifest is shared between all runs in a project 

29 # relative to project path instead of run path 

30 RUNS_MANIFEST: Path = Path("runs.jsonl") 

31 # directory in project path for runs 

32 # relative to project path instead of run path 

33 RUNS_DIR: Path = Path("runs") 

34 

35 # frontend files 

36 HTML_INDEX: Path = Path("index.html") 

37 START_SERVER: Path = Path("start_server.py") 

38 

39 

40class LocalLogger(TrainingLoggerBase): 

41 def __init__( 

42 self, 

43 project: str, 

44 metric_names: list[str], 

45 train_config: dict, 

46 group: str = "", 

47 base_path: str | Path = Path("trnbl-logs"), 

48 memusage_as_metrics: bool = True, 

49 console_msg_prefix: str = "# ", 

50 ): 

51 # set up lists 

52 self.log_list: list[dict] = list() 

53 self.metrics_list: list[dict] = list() 

54 self.artifacts_list: list[dict] = list() 

55 

56 # copy kwargs 

57 self.train_config: dict = train_config 

58 self.project: str = project 

59 self.group: str = group 

60 self.group_str: str = self.group + ("-" if group and group[-1] != "-" else "") 

61 self.base_path: Path = Path(base_path) 

62 self.console_msg_prefix: str = console_msg_prefix 

63 

64 # set up id 

65 self._syllabic_id: str = rand_syllabic_string() 

66 self.run_init_timestamp: datetime.datetime = datetime.datetime.now() 

67 self.run_id: str = self._get_run_id() 

68 

69 # set up paths 

70 self.project_path: Path = self.base_path / project 

71 self._run_path: Path = self.project_path / FilePaths.RUNS_DIR / self.run_id 

72 # make sure the run path doesn't already exist 

73 assert not self._run_path.exists() 

74 self._run_path.mkdir(parents=True, exist_ok=True) 

75 

76 # set up files and objects for logs, artifacts, and metrics 

77 # ---------------------------------------- 

78 

79 self.log_file: io.TextIOWrapper = open(self.run_path / FilePaths.LOG, "a") 

80 

81 self.metrics_file: io.TextIOWrapper = open( 

82 self.run_path / FilePaths.METRICS, "a" 

83 ) 

84 

85 self.artifacts_file: io.TextIOWrapper = open( 

86 self.run_path / FilePaths.ARTIFACTS, "a" 

87 ) 

88 

89 # metric names (getting mem usage might cause problems if we have an error) 

90 self.metric_names: list[str] = metric_names 

91 if memusage_as_metrics: 

92 self.metric_names += list(self.get_mem_usage().keys()) 

93 

94 # put everything in a config 

95 self.logger_meta: dict = dict( 

96 run_id=self.run_id, 

97 run_path=self.run_path.as_posix(), 

98 syllabic_id=self.syllabic_id, 

99 group=self.group, 

100 project=self.project, 

101 run_init_timestamp=str(self.run_init_timestamp.isoformat()), 

102 metric_names=metric_names, 

103 train_config=train_config, # TODO: this duplicates the contents of FilePaths.TRAIN_CONFIG, is that ok? 

104 ) 

105 

106 # write to the project jsonl 

107 with open(self.project_path / FilePaths.RUNS_MANIFEST, "a") as f: 

108 json.dump(self.logger_meta, f) 

109 f.write("\n") 

110 

111 # write the index.html and start_server.py files 

112 # ---------------------------------------- 

113 from trnbl.loggers.local.html_frontend import get_html_frontend 

114 

115 with open(self.project_path / FilePaths.HTML_INDEX, "w") as f: 

116 f.write(get_html_frontend()) 

117 

118 import trnbl.loggers.local.start_server as start_server_module 

119 

120 with open(self.project_path / FilePaths.START_SERVER, "w") as f: 

121 f.write(inspect.getsource(start_server_module)) 

122 

123 # write init files 

124 # ---------------------------------------- 

125 

126 # logger metadata 

127 with open(self.run_path / FilePaths.LOGGER_META, "w") as f: 

128 json.dump(self.logger_meta, f, indent="\t") 

129 

130 with open(self.run_path / FilePaths.LOGGER_META_YML, "w") as f: 

131 yaml.dump(self.logger_meta, f) 

132 

133 # training/model/dataset config 

134 with open(self.run_path / FilePaths.TRAIN_CONFIG, "w") as f: 

135 json.dump(train_config, f, indent="\t") 

136 

137 with open(self.run_path / FilePaths.TRAIN_CONFIG_YML, "w") as f: 

138 yaml.dump(train_config, f) 

139 

140 self.message(f"starting logger with id {self.run_id}") 

141 

142 @property 

143 def _run_hash(self) -> str: 

144 return hashlib.md5(str(self.train_config).encode()).hexdigest() 

145 

146 @property 

147 def syllabic_id(self) -> str: 

148 return self._syllabic_id 

149 

150 def _get_run_id(self) -> str: 

151 return f"{self.group_str}h{self._run_hash[:5]}-{self.run_init_timestamp.strftime('%y%m%d_%H%M')}-{self.syllabic_id}" 

152 

153 def get_timestamp(self) -> str: 

154 return datetime.datetime.now().isoformat() 

155 

156 def _log(self, message: str, **kwargs) -> None: 

157 """(internal) log a progress message""" 

158 # TODO: also log messages via regular logger to stdout 

159 msg_dict: dict = dict( 

160 message=message, 

161 timestamp=self.get_timestamp(), 

162 ) 

163 if kwargs: 

164 msg_dict.update(kwargs) 

165 

166 self.log_list.append(msg_dict) 

167 self.log_file.write(json.dumps(msg_dict) + "\n") 

168 self.log_file.flush() 

169 

170 def debug(self, message: str, **kwargs) -> None: 

171 """log a debug message""" 

172 self._log(message, __dbg__=True, **kwargs) 

173 

174 def message(self, message: str, **kwargs) -> None: 

175 """log a progress message""" 

176 # TODO: also log messages via regular logger to stdout 

177 self._log(message, **kwargs) 

178 print(self.console_msg_prefix + message) 

179 

180 def warning(self, message: str, **kwargs) -> None: 

181 """log a warning message""" 

182 self.message( 

183 f"WARNING: {message}", 

184 __warning__=True, 

185 **kwargs, 

186 ) 

187 

188 def error(self, message: str, **kwargs) -> None: 

189 """log an error message""" 

190 self.message( 

191 f"ERROR: {message}", 

192 __error__=True, 

193 **kwargs, 

194 ) 

195 with open(self.run_path / FilePaths.ERROR_FILE, "a") as f: 

196 f.write("=" * 80 + "\n") 

197 f.write("exception at " + self.get_timestamp() + "\n") 

198 f.write(message) 

199 f.write("\n") 

200 f.flush() 

201 

202 def metrics(self, data: dict[str, Any]) -> None: 

203 """log a dictionary of metrics""" 

204 data["timestamp"] = self.get_timestamp() 

205 

206 self.metrics_list.append(data) 

207 self.metrics_file.write(json.dumps(data) + "\n") 

208 

209 def artifact( 

210 self, 

211 path: Path, 

212 type: str, 

213 aliases: list[str] | None = None, 

214 metadata: dict | None = None, 

215 ) -> None: 

216 """log an artifact from a file""" 

217 artifact_dict: dict = dict( 

218 timestamp=self.get_timestamp(), 

219 path=path.as_posix(), 

220 type=type, 

221 aliases=aliases, 

222 metadata=metadata if metadata else {}, 

223 ) 

224 

225 self.artifacts_list.append(artifact_dict) 

226 self.artifacts_file.write(json.dumps(artifact_dict) + "\n") 

227 

228 @property 

229 def url(self) -> str: 

230 """Get the URL for the current logging run""" 

231 return self.run_path.as_posix() 

232 

233 @property 

234 def run_path(self) -> Path: 

235 """Get the path to the current logging run""" 

236 return self._run_path 

237 

238 def flush(self) -> None: 

239 self.log_file.flush() 

240 self.metrics_file.flush() 

241 self.artifacts_file.flush() 

242 

243 def finish(self) -> None: 

244 self.message("closing logger") 

245 

246 self.log_file.flush() 

247 self.log_file.close() 

248 

249 self.metrics_file.flush() 

250 self.metrics_file.close() 

251 

252 self.artifacts_file.flush() 

253 self.artifacts_file.close()