Source code for Fireworks.extensions.factory

import abc
from Fireworks import Message, Junction
from .training import IgniteJunction
from Fireworks.utils.exceptions import EndHyperparameterOptimization
from .database import create_table, TablePipe
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column
from sqlalchemy_utils import JSONType as JSON
from collections import defaultdict
import types

[docs]def update(bundle: dict, parameters: dict): """ Args: bundle - A dictionary of key: (obj, atr). Obj is the object referred to, and attr is a string with the name of the attribute to be assigned. parameters - A dictionary of key: value. Wherever keys match, obj.attr will be set to value. """ for key, param in parameters.items(): if key in bundle: obj, atr = bundle[key] setattr(obj, attr, param)
[docs]class Factory(Junction): """ Base class for hyperparameter optimization in pytorch using queues. """ # NOTE: This is currently not parallelized. It would be nice if it was. required_components = {'trainer': types.FunctionType, 'eval_set': object, 'parameterizer': types.FunctionType, 'metrics': dict} def __init__(self, *args, components=None, **kwargs): Junction.__init__(self, *args, components=components, **kwargs) self.get_connection()
[docs] @abc.abstractmethod def get_connection(self): pass
[docs] def run(self): while True: past_params, past_metrics = self.read() try: # Generate new set of parameters params = self.parameterizer(past_params, past_metrics) # Generate an evaluator evaluator = self.trainer(params) # NOTE: This part is pytorch ignite syntax for name, metric in self.metrics.items(): metric.attach(evaluator, name) # TODO: Make sure this resets the metric # Running the evaluator should perform training on the dataset followed by evlaution and return evaluation metrics evaluator.run(self.eval_set, max_epochs=1) # Evaluate the metrics that were attached to the evaluator computed_metrics = {name: metric.compute() for name, metric in self.metrics.items()} self.write(params, computed_metrics) evaluator = None except EndHyperparameterOptimization: self.after() break
[docs] @abc.abstractmethod def read(self): pass
[docs] @abc.abstractmethod def write(self, params, metrics_dict): pass
[docs] def train(self, params): self.trainer.model.update_components(params)
[docs] def after(self, *args, **kwargs): pass
[docs]class LocalMemoryFactory(Factory): """ Factory that stores parameters in memory. """
[docs] def get_connection(self): self.params = Message() self.computed_metrics = defaultdict(Message)
[docs] def read(self): return self.params, self.computed_metrics
[docs] def write(self, params, metrics_dict): self.params = self.params.append(params) for key in metrics_dict: self.computed_metrics[key] = self.computed_metrics[key].append(metrics_dict[key])
[docs]class SQLFactory(Factory): """ Factory that stores parameters in SQLalchemy database while caching them locally. """ required_components = { 'trainer': types.FunctionType, 'eval_set': object, 'parameterizer': types.FunctionType, 'metrics': dict, 'engine': object, 'params_table': object, 'metrics_tables': object, } def __init__(self,*args, components=None, **kwargs): Junction.__init__(self, *args, components=components, **kwargs) self.params_pipe = TablePipe(self.params_table, self.engine) self.metrics_pipes = {key: TablePipe(value, self.engine) for key, value in self.metrics_tables.items()} self.computed_metrics = defaultdict(Message) self.get_connection()
[docs] def get_connection(self): # TODO: Ensure id consistency accross these tables using foreign key constraints. This should implicitly # hold true without such constraints however, because these tables are updated in sync. for table in self.metrics_tables.values(): table.metadata.create_all(self.engine) self.params_table.metadata.create_all(self.engine) self.id = 0 self.sync()
[docs] def write(self, params, metrics): # if len(params) != len(metrics): # raise ValueError("Parameters and Metrics messages must be equal length.") params = Message(params) for key, metric in metrics.items(): self.computed_metrics[key] = self.computed_metrics[key].append(metric) self.metrics_pipes[key].insert(metric) self.metrics_pipes[key].commit() self.params = self.params.append(params) self.params_pipe.insert(params) self.params_pipe.commit()
[docs] def read(self): return self.params, self.computed_metrics
[docs] def read_db(self): return self.params_pipe.query().all(), {key: pipe.query().all() for key, pipe in self.metrics_pipes.items()}
[docs] def sync(self): """ Syncs local copy of metrics and params with db. """ self.params, self.computed_metrics = self.read_db()
[docs] def after(self): self.sync()