Coverage for /home/deng/Projects/metatree_drawer/treeprofiler_algo/pastml/pastml/models/__init__.py: 28%

236 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import logging 

2import os 

3 

4import numpy as np 

5import pandas as pd 

6 

7from pastml import NUM_NODES, NUM_TIPS 

8 

9 

10MODEL = 'model' 

11CHANGES_PER_AVG_BRANCH = 'state_changes_per_avg_branch' 

12SCALING_FACTOR = 'scaling_factor' 

13SMOOTHING_FACTOR = 'smoothing_factor' 

14FREQUENCIES = 'frequencies' 

15 

16 

17class Model(object): 

18 

19 def __init__(self, states, forest_stats, sf=None, tau=0, optimise_tau=False, 

20 parameter_file=None, reoptimise=False, **kwargs): 

21 self._name = None 

22 self._states = np.sort(states) 

23 self._forest_stats = forest_stats 

24 self._optimise_tau = optimise_tau 

25 self._optimise_sf = True 

26 self._sf = None 

27 self._tau = None 

28 self.parse_parameters(parameter_file, reoptimise) 

29 if self._sf is None: 

30 self._sf = sf if sf is not None else 1. / forest_stats.avg_nonzero_brlen 

31 if self._tau is None: 

32 self._tau = tau if tau else 0 

33 

34 self.calc_tau_factor() 

35 

36 self._extra_params_fixed = False 

37 

38 def calc_tau_factor(self): 

39 self._tau_factor = \ 

40 self._forest_stats.forest_length / (self._forest_stats.forest_length 

41 + self._tau * (self._forest_stats.num_nodes - 1)) if self._tau else 1 

42 

43 def __str__(self): 

44 return \ 

45 'Model {} with parameter values:\n' \ 

46 '{}'.format(self.name, 

47 self._print_parameters()) 

48 

49 def _print_parameters(self): 

50 """ 

51 Constructs a string representing parameter values (to be used to logging). 

52 

53 :return: str representing parameter values 

54 """ 

55 return self._print_basic_parameters() 

56 

57 def _print_basic_parameters(self): 

58 return '\tscaling factor:\t{:.6f}, i.e. {:.6f} changes per avg branch\t{}\n' \ 

59 '\tsmoothing factor:\t{:.6f}\t{}\n'\ 

60 .format(self.sf, self.forest_stats.avg_nonzero_brlen * self.sf, 

61 '(optimised)' if self._optimise_sf else '(fixed)', 

62 self.tau, '(optimised)' if self._optimise_tau else '(fixed)') 

63 

64 def save_parameters(self, filehandle): 

65 """ 

66 Writes this model parameter values to the parameter file (in the same format as the input parameter file). 

67 

68 :param filehandle: filehandle for the file where the parameter values should be written. 

69 :return: void 

70 """ 

71 filehandle.write('{}\t{}\n'.format(MODEL, self.name)) 

72 filehandle.write('{}\t{}\n'.format(NUM_NODES, self.forest_stats.num_nodes)) 

73 filehandle.write('{}\t{}\n'.format(NUM_TIPS, self.forest_stats.num_tips)) 

74 filehandle.write('{}\t{}\n'.format(SCALING_FACTOR, self.sf)) 

75 filehandle.write('{}\t{}\n'.format(CHANGES_PER_AVG_BRANCH, self.sf * self.forest_stats.avg_nonzero_brlen)) 

76 filehandle.write('{}\t{}\n'.format(SMOOTHING_FACTOR, self.tau)) 

77 

78 def extra_params_fixed(self): 

79 return self._extra_params_fixed 

80 

81 def fix_extra_params(self): 

82 self._extra_params_fixed = True 

83 

84 def unfix_extra_params(self): 

85 self._extra_params_fixed = False 

86 

87 @property 

88 def forest_stats(self): 

89 return self._forest_stats 

90 

91 @forest_stats.setter 

92 def forest_stats(self, forest_stats): 

93 self._forest_stats = forest_stats 

94 self.calc_tau_factor() 

95 

96 @property 

97 def name(self): 

98 return self._name 

99 

100 @name.setter 

101 def name(self, name): 

102 self._name = name 

103 

104 @property 

105 def states(self): 

106 return self._states 

107 

108 @states.setter 

109 def states(self, states): 

110 self._states = states 

111 

112 @property 

113 def sf(self): 

114 return self._sf 

115 

116 @sf.setter 

117 def sf(self, sf): 

118 if self._optimise_sf: 

119 self._sf = sf 

120 else: 

121 raise NotImplementedError('The scaling factor is preset and cannot be changed.') 

122 

123 @property 

124 def tau(self): 

125 return self._tau 

126 

127 @tau.setter 

128 def tau(self, tau): 

129 if self._optimise_tau: 

130 self._tau = tau 

131 self.calc_tau_factor() 

132 else: 

133 raise NotImplementedError('Tau is preset and cannot be changed.') 

134 

135 def get_Pij_t(self, t, *args, **kwargs): 

136 """ 

137 Returns a function for calculation of probability matrix of substitutions i->j over time t. 

138 

139 :return: probability matrix 

140 :rtype: function 

141 """ 

142 raise NotImplementedError("Please implement this method in the Model subclass") 

143 

144 def set_params_from_optimised(self, ps, **kwargs): 

145 """ 

146 Update this model parameter values from a vector representing parameters 

147 for the likelihood optimization algorithm. 

148 

149 :param ps: np.array containing parameters of the likelihood optimization algorithm 

150 :param kwargs: dict of eventual other arguments 

151 :return: void, update this model 

152 """ 

153 if self._optimise_sf: 

154 self.sf = ps[0] 

155 if self._optimise_tau: 

156 self.tau = ps[1 if self._optimise_sf else 0] 

157 

158 def get_optimised_parameters(self): 

159 """ 

160 Converts this model parameters to a vector representing parameters 

161 for the likelihood optimization algorithm. 

162 

163 :return: np.array containing parameters of the likelihood optimization algorithm 

164 """ 

165 return np.hstack(([self.sf] if self._optimise_sf else [], 

166 [self.tau] if self._optimise_tau else [])) 

167 

168 def get_bounds(self): 

169 """ 

170 Get bounds for parameters for likelihood optimization algorithm. 

171 

172 :return: np.array containing lower and upper (potentially infinite) bounds for each parameter 

173 """ 

174 bounds = [] 

175 if self._optimise_sf: 

176 bounds += [np.array([0.001 / self.forest_stats.avg_nonzero_brlen, 

177 10. / self.forest_stats.avg_nonzero_brlen])] 

178 if self._optimise_tau: 

179 bounds += [np.array([0, self.forest_stats.avg_nonzero_brlen])] 

180 return np.array(bounds, np.float64) 

181 

182 def get_num_params(self): 

183 """ 

184 Returns the number of optimized parameters for this model. 

185 

186 :return: the number of optimized parameters 

187 """ 

188 return (1 if self._optimise_sf else 0) + (1 if self._optimise_tau else 0) 

189 

190 def parse_parameters(self, params, reoptimise=False): 

191 """ 

192 Update this model's values from the input parameters. 

193 The input might contain: 

194 (1) the scaling factor, by which each branch length will be multiplied (optional). 

195 The key for this parameter is pastml.models.SCALING_FACTOR; 

196 (2) the smoothing factor, which will be added to each branch length 

197 before the branches are renormalized to keep the initial tree length (optional). 

198 The key for this parameter is pastml.models.SMOOTHING_FACTOR; 

199 

200 :param params: dict {key->value} 

201 or a path to the file containing a tab-delimited table with the first column containing keys 

202 and the second (named 'value') containing values. 

203 :param reoptimise: whether these model parameters should be treated as starting values (True) 

204 or as fixed values (False) 

205 :return: dict with parameter values 

206 """ 

207 

208 logger = logging.getLogger('pastml') 

209 frequencies, sf, kappa, tau = None, None, None, None 

210 if params is None: 

211 return {} 

212 if not isinstance(params, str) and not isinstance(params, dict): 

213 raise ValueError('Parameters must be specified either as a dict or as a path to a csv file, not as {}!' 

214 .format(type(params))) 

215 if isinstance(params, str): 

216 if not os.path.exists(params): 

217 raise ValueError('The specified parameter file ({}) does not exist.'.format(params)) 

218 try: 

219 param_dict = pd.read_csv(params, header=0, index_col=0, sep='\t') 

220 if 'value' not in param_dict.columns: 

221 raise ValueError('Could not find the "value" column in the parameter file {}. ' 

222 'It should be a tab-delimited file with two columns, ' 

223 'the first one containing parameter names, ' 

224 'and the second, named "value", containing parameter values.') 

225 param_dict = param_dict.to_dict()['value'] 

226 params = param_dict 

227 except: 

228 raise ValueError('The specified parameter file {} is malformed, ' 

229 'should be a tab-delimited file with two columns, ' 

230 'the first one containing parameter names, ' 

231 'and the second, named "value", containing parameter values.'.format(params)) 

232 params = {str(k.encode('ASCII', 'replace').decode()): v for (k, v) in params.items()} 

233 if SCALING_FACTOR in params: 

234 self._sf = params[SCALING_FACTOR] 

235 try: 

236 self._sf = np.float64(self._sf) 

237 if self._sf <= 0: 

238 logger.error('Scaling factor cannot be negative, ignoring the value given in parameters ({}).' 

239 .format(sf)) 

240 self._sf = None 

241 else: 

242 self._optimise_sf = reoptimise 

243 except: 

244 logger.error('Scaling factor ({}) given in parameters is not float, ignoring it.'.format(sf)) 

245 self._sf = None 

246 if SMOOTHING_FACTOR in params: 

247 self._tau = params[SMOOTHING_FACTOR] 

248 try: 

249 self._tau = np.float64(self._tau) 

250 if self._tau < 0: 

251 logger.error( 

252 'Smoothing factor cannot be negative, ignoring the value given in parameters ({}).'.format(tau)) 

253 self._tau = None 

254 except: 

255 logger.error('Smoothing factor ({}) given in parameters is not float, ignoring it.'.format(tau)) 

256 self._tau = None 

257 return params 

258 

259 def freeze(self): 

260 """ 

261 Prohibit parameter optimization by setting all optimization flags to False. 

262 

263 :return: void 

264 """ 

265 self._optimise_sf = False 

266 self._optimise_tau = False 

267 

268 def transform_t(self, t): 

269 return (t + self.tau) * self._tau_factor * self.sf 

270 

271 def basic_params_fixed(self): 

272 return not self._optimise_tau and not self._optimise_sf 

273 

274 

275class ModelWithFrequencies(Model): 

276 

277 def __init__(self, states, forest_stats, sf=None, frequencies=None, tau=0, 

278 optimise_tau=False, frequency_smoothing=False, parameter_file=None, reoptimise=False, **kwargs): 

279 self._frequencies = None 

280 self._optimise_frequencies = not frequency_smoothing 

281 self._frequency_smoothing = frequency_smoothing 

282 Model.__init__(self, states, forest_stats=forest_stats, 

283 sf=sf, tau=tau, optimise_tau=optimise_tau, reoptimise=reoptimise, 

284 parameter_file=parameter_file, **kwargs) 

285 if self._frequencies is None: 

286 self._frequencies = frequencies if frequencies is not None \ 

287 else np.ones(len(states), dtype=np.float64) / len(states) 

288 

289 @property 

290 def frequencies(self): 

291 return self._frequencies 

292 

293 @frequencies.setter 

294 def frequencies(self, frequencies): 

295 if self._optimise_frequencies or self._frequency_smoothing: 

296 self._frequencies = frequencies 

297 else: 

298 raise NotImplementedError('The frequencies are preset and cannot be changed.') 

299 

300 def get_num_params(self): 

301 """ 

302 Returns the number of optimized parameters for this model. 

303 

304 :return: the number of optimized parameters 

305 """ 

306 return Model.get_num_params(self) \ 

307 + ((len(self.frequencies) - 1) 

308 if self._optimise_frequencies else (1 if self._frequency_smoothing else 0)) 

309 

310 def set_params_from_optimised(self, ps, **kwargs): 

311 """ 

312 Update this model parameter values from a vector representing parameters 

313 for the likelihood optimization algorithm. 

314 

315 :param ps: np.array containing parameters of the likelihood optimization algorithm 

316 :param kwargs: dict of eventual other arguments 

317 :return: void, update this model 

318 """ 

319 Model.set_params_from_optimised(self, ps, **kwargs) 

320 if not self.extra_params_fixed(): 

321 n_freq = len(self.frequencies) 

322 n_params = Model.get_num_params(self) 

323 

324 freqs = self.frequencies 

325 

326 if self._optimise_frequencies: 

327 freqs = np.hstack((ps[n_params: n_params + (n_freq - 1)], [1.])) 

328 freqs /= freqs.sum() 

329 self.frequencies = freqs 

330 elif self._frequency_smoothing: 

331 freqs = freqs * self.forest_stats.num_tips 

332 freqs += ps[n_params] 

333 freqs /= freqs.sum() 

334 self.frequencies = freqs 

335 

336 def get_optimised_parameters(self): 

337 """ 

338 Converts this model parameters to a vector representing parameters 

339 for the likelihood optimization algorithm. 

340 

341 :return: np.array containing parameters of the likelihood optimization algorithm 

342 """ 

343 if not self.extra_params_fixed(): 

344 return np.hstack((Model.get_optimised_parameters(self), 

345 self.frequencies[:-1] / self.frequencies[-1] if self._optimise_frequencies 

346 else ([0] if self._frequency_smoothing else []))) 

347 return Model.get_optimised_parameters(self) 

348 

349 def get_bounds(self): 

350 """ 

351 Get bounds for parameters for likelihood optimization algorithm. 

352 

353 :return: np.array containing lower and upper (potentially infinite) bounds for each parameter 

354 """ 

355 if not self.extra_params_fixed(): 

356 extras = [] 

357 if self._optimise_frequencies: 

358 extras += [np.array([1e-6, 10e6], np.float64)] * (len(self.frequencies) - 1) 

359 if self._frequency_smoothing: 

360 extras.append(np.array([0, self.forest_stats.num_nodes])) 

361 return np.array((*Model.get_bounds(self), *extras)) 

362 return Model.get_bounds(self) 

363 

364 def parse_parameters(self, params, reoptimise=False): 

365 """ 

366 Update this model's values from the input parameters. 

367 For a model with frequencies, apart from the basic parameters 

368 (scaling factor and smoothing factor, see pastml.models.Model), 

369 the input might contain the frequency values: 

370 the key for each frequency value is the name of the corresponding character state. 

371 

372 :param params: dict {key->value} 

373 :param reoptimise: whether these model parameters should be treated as starting values (True) 

374 or as fixed values (False) 

375 :return: dict with parameter values (same as input) 

376 """ 

377 params = Model.parse_parameters(self, params, reoptimise) 

378 logger = logging.getLogger('pastml') 

379 known_freq_states = set(self.states) & set(params.keys()) 

380 if known_freq_states: 

381 unknown_freq_states = [state for state in self.states if state not in params.keys()] 

382 if unknown_freq_states and not reoptimise: 

383 logger.error('Frequencies for some of the states ({}) are missing, ' 

384 'ignoring the specified frequencies.'.format(', '.join(unknown_freq_states))) 

385 else: 

386 self._frequencies = np.array([params[state] if state in params.keys() else 0 for state in self.states]) 

387 try: 

388 self._frequencies = self._frequencies.astype(np.float64) 

389 if np.round(self._frequencies.sum() - 1, 2) != 0 and not reoptimise: 

390 logger.error('Frequencies given in parameters ({}) do not sum up to one ({}),' 

391 'ignoring them.'.format(self._frequencies, self._frequencies.sum())) 

392 self._frequencies = None 

393 else: 

394 if np.any(self._frequencies < 0) and not reoptimise: 

395 logger.error('Some of the frequencies given in parameters ({}) are negative,' 

396 'ignoring them.'.format(self._frequencies)) 

397 self._frequencies = None 

398 else: 

399 min_freq = \ 

400 min(1 / self.forest_stats.num_tips, 

401 min(float(params[state]) for state in known_freq_states 

402 if float(params[state]) > 0)) / 2 

403 if unknown_freq_states: 

404 logger.error('Frequencies for some of the states ({}) are missing from parameters, ' 

405 'setting them to {}.'.format(', '.join(unknown_freq_states), min_freq)) 

406 frequencies = np.maximum(self._frequencies, min_freq) 

407 frequencies /= frequencies.sum() 

408 self._optimise_frequencies = reoptimise and not self._frequency_smoothing 

409 except: 

410 logger.error('Could not convert the frequencies given in parameters ({}) to float, ' 

411 'ignoring them.'.format(self._frequencies)) 

412 self._frequencies = None 

413 return params 

414 

415 def _print_parameters(self): 

416 """ 

417 Constructs a string representing parameter values (to be used to logging). 

418 

419 :return: str representing parameter values 

420 """ 

421 return '{}' \ 

422 '\tfrequencies\t{}\n' \ 

423 '{}\n'.format(Model._print_parameters(self), 

424 '(optimised)' if self._optimise_frequencies 

425 else '(smoothed)' if self._frequency_smoothing else '(fixed)', 

426 '\n'.join('\t\t{}:\t{:g}'.format(state, freq) 

427 for (state, freq) in zip(self.states, self.frequencies))) 

428 

429 def freeze(self): 

430 """ 

431 Prohibit parameter optimization by setting all optimization flags to False. 

432 

433 :return: void 

434 """ 

435 Model.freeze(self) 

436 self._optimise_frequencies = False 

437 self._frequency_smoothing = False 

438 

439 def extra_params_fixed(self): 

440 return self._extra_params_fixed or Model.get_num_params(self) == self.get_num_params() 

441 

442 def basic_params_fixed(self): 

443 return not Model.get_num_params(self) 

444 

445 def save_parameters(self, filehandle): 

446 """ 

447 Writes this model parameter values to the parameter file (in the same format as the input parameter file). 

448 

449 :param filehandle: filehandle for the file where the parameter values should be written. 

450 :return: void 

451 """ 

452 Model.save_parameters(self, filehandle) 

453 for state, frequency in zip(self.states, self.frequencies): 

454 filehandle.write('{}\t{}\n'.format(state, frequency))