import abc
from Fireworks import Message
from Fireworks.utils.exceptions import EndHyperparameterOptimization
from .database import create_table, TablePipe
from sqlalchemy.orm import sessionmaker
from sqlalchemy import Column
from sqlalchemy_utils import JSONType as JSON
from collections import defaultdict
[docs]class Factory:
"""
Base class for parallel hyperparameter optimization in pytorch using queues.
"""
# NOTE: This is currently not parallelized yet
def __init__(self, trainer, metrics_dict, generator, eval_dataloader, *args, **kwargs):
self.trainer = trainer
self.metrics_dict = metrics_dict
self.generator = generator
self.dataloader = eval_dataloader
self.get_connection()
[docs] @abc.abstractmethod
def get_connection(self): pass
[docs] def run(self):
while True:
past_params, past_metrics = self.read()
try:
# Generate new set of parameters
params = self.generator(past_params, past_metrics)
# Generate an evaluator from the params
evaluator = self.trainer(params)
# NOTE: This part is pytorch ignite syntax
for name, metric in self.metrics_dict.items():
metric.attach(evaluator, name) # TODO: Make sure this resets the metric
# Running the evaluator should perform training on the dataset followed by evlaution and return evaluation metrics
evaluator.run(self.dataloader)
# Evaluate the metrics that were attached to the evaluator
computed_metrics = {name: metric.compute() for name, metric in self.metrics_dict.items()}
self.write(params, computed_metrics)
evaluator = None
except EndHyperparameterOptimization:
self.after()
break
[docs] @abc.abstractmethod
def read(self): pass
[docs] @abc.abstractmethod
def write(self, params, metrics_dict): pass
[docs] def after(self, *args, **kwargs): pass
[docs]class LocalMemoryFactory(Factory):
"""
Factory that stores parameters in memory.
"""
[docs] def get_connection(self):
self.params = Message()
self.metrics = defaultdict(Message)
[docs] def read(self):
return self.params, self.metrics
[docs] def write(self, params, metrics_dict):
self.params = self.params.append(params)
for key in metrics_dict:
self.metrics[key] = self.metrics[key].append(metrics_dict[key])
# Table for storing hyperparameter data in SQLFactory
# columns = [
# Column('parameters', JSON),
# Column('metrics', JSON),
# ]
#
# factory_table = create_table('hyperparmeters', columns)
[docs]class SQLFactory(Factory):
"""
Factory that stores parameters in SQLalchemy database while caching them locally.
"""
def __init__(self,*args, params_table, metrics_tables, engine, **kwargs):
self.engine = engine
# self.database = TablePipe(factory_table, self.engine, columns=['parameters', 'metrics'])
self.metrics_tables = metrics_tables
self.params_table = params_table
self.params_pipe = TablePipe(self.params_table, self.engine)
self.metrics_pipes = {key: TablePipe(value, self.engine) for key, value in self.metrics_tables.items()}
super().__init__(*args,**kwargs)
[docs] def get_connection(self):
# TODO: Ensure id consistency accross these tables using foreign key constraints. This should implicitly
# hold true without such constraints however, because these tables are updated in sync.
for table in self.metrics_tables.values():
table.metadata.create_all(self.engine)
self.params_table.metadata.create_all(self.engine)
self.id = 0
self.sync()
# Session = sessionmaker(bind=self.engine)
# self.session = Session()
[docs] def write(self, params, metrics):
# self.database.insert(Fireworks.Message({'params':[params], 'metrics_dict': [metrics_dict]}))
if len(params) != len(metrics):
raise ValueError("Parameters and Metrics messages must be equal length.")
for key, metric in metrics.items():
self.metrics[key] = self.metrics[key].append(metric)
self.metrics_pipes[key].insert(metric)
self.metrics_pipes[key].commit()
self.params = self.params.append(params)
self.params_pipe.insert(params)
self.params_pipe.commit()
[docs] def read(self):
return self.params, self.metrics
[docs] def read_db(self):
return self.params_pipe.query().all(), {key: pipe.query().all() for key, pipe in self.metrics_pipes.items()}
[docs] def sync(self):
""" Syncs local copy of metrics and params with db. """
self.params, self.metrics = self.read_db()
[docs] def after(self):
self.sync()