import os
from sqlalchemy import create_engine, Column, Float, Integer, String, DateTime
from sqlalchemy.orm import sessionmaker
import datetime
from Fireworks import Message
from . import database as db
from deprecated import deprecated
"""
This module contains classes and functions for saving and loading data collected during experiments.
"""
metadata_columns = [
Column('name', String),
Column('iteration', Integer),
Column('description', String),
Column('timestamp', DateTime),
]
metadata_table = db.create_table('metadata', columns=metadata_columns)
[docs]def load_experiment(experiment_path): # TODO: clean up attribute assignments for loading
"""
Returns an experiment object corresponding to the database in the given path.
Args:
experiment_path (str): Path to the experiment folder.
Returns:
experiment (Experiment): An Experiment object loaded using the files in the given folder path.
"""
experiment_name = experiment_path.split('/')[-1]
db_path = '/'.join(experiment_path.split('/')[:-1])
return Experiment(experiment_name, db_path, load=True)
[docs]class Experiment:
# NOTE: For now, we assume that the underlying database is sqlite on local disk
# QUESTION: Should we implement an __eq__ method for experiments?
# TODO: Expand to support nonlocal databases
def __init__(self, experiment_name, db_path, description=None, load=False):
self.name = experiment_name
self.db_path = db_path
self.description = description or ''
self.timestamp = datetime.datetime.now() # QUESTION: Should this be updated on each load?
self.engines = {}
if load:
self.load_experiment()
self.engines = {
name.rstrip('.sqlite'): self.get_engine(name.rstrip('.sqlite'))
for name in os.listdir(self.save_path) if name.endswith('.sqlite')
}
self.load_metadata()
else:
self.create_dir()
self.init_metadata()
self.filenames = os.listdir(os.path.join(self.db_path,self.save_path)) # Refresh list of filenames
# Create/open save directory
# if not os.path.exists(save_dir):
# try:
# os.makedirs(save_dir)
# except Error as e:
# print("Could not create save directory {save_dir}. Please check permissions and try again: {error}".format(save_dir=save_dir, error=e))
# self.save_dir = save_dir
[docs] def load_experiment(self, path=None, experiment_name=None):
"""
Loads in parameters associated with this experiment from a directory.
Args:
path (str): Path to the experiment folder.
experiment_name (str): Name to set this experiment to.
"""
path = path or self.db_path
experiment_name = experiment_name or self.name
self.save_path = os.path.join(path, experiment_name)
if not experiment_name in os.listdir(path):
raise ValueError("Directory {exp_dir} was not found in {path}".format(exp_dir=experiment_name, path=path))
self.engine = create_engine("sqlite:///{save_path}".format(save_path=os.path.join(self.save_path,'metadata.sqlite')))
[docs] def create_dir(self):
"""
Creates a folder in db_path directory corresponding to this Experiment.
"""
dirs = os.listdir(self.db_path)
previous_experiments = [d for d in dirs if d.startswith(self.name)]
self.iteration = len(previous_experiments)
os.makedirs(os.path.join(self.db_path, "{name}_{iteration}".format(name=self.name, iteration=self.iteration))) # TODO: Upgrade to 3.6 and use f-strings
self.save_path = "{name}_{iteration}".format(name=self.name, iteration=self.iteration)
self.engine = create_engine("sqlite:///{save_path}".format(save_path=os.path.join(self.db_path,self.save_path,'metadata.sqlite')))
[docs] def get_engine(self, name):
"""
Creates an engine corresponding to a database with the given name. In particular, this creates a file called {name}.sqlite
in this experiment's save directory, and makes an engine to connect to it.
Args:
name: Name of engine to create. This will also be the name of the file that is created.
Returns:
engine: The new engine. You can also reach this engine now by calling self.engines[name]
"""
self.engines[name] = create_engine("sqlite:///{filename}".format(filename=os.path.join(self.db_path,self.save_path, name+'.sqlite')))
return self.engines[name]
# @deprecated(reason="This is an alias for get_engine and will be removed in the future.")
# def create_engine(self, name):
# """
# Alias for get_engine. This will be removed in the future.
# """
# return self.get_engine(name)
[docs] def get_session(self, name):
"""
Creates an SQLalchemy session corresponding to the engine with the given name that can be used to interact with the database.
Args:
name: Name of engine corresponding to session. The engine will be created if one with that name does not already exist.
Returns:
session: A session created from the chosen engine.
"""
if name in self.engines:
engine = self.engines[name]
else: # QUESTION: Should this raise an error or autocreate a new engine?
engine = self.get_engine(name)
Session = sessionmaker(bind=engine)
session = Session()
return session
[docs] def open(self, filename, *args, string_only=False):
"""
Returns a handle to a file with the given filename inside this experiment's directory.
If string_only is true, then this instead returns a string with the path to create the file.
If the a file with 'filename' is already present in the directory, this will raise an error.
Args:
filename (str): Name of file.
args: Additional positional args for the open function.
string_only (bool): If true, will return the path to the file rather than the file handle. This can be useful if you want to
create the file using some other library.
Returns:
file: If string_only is True, the path to the file. Otherwise, the opened file handle. Note: You can use this method in a
with statement to auto-close the file.
"""
self.filenames = os.listdir(os.path.join(self.db_path,self.save_path)) # Refresh list of filenames
# if filename in self.filenames:
# raise IOError("A file named {filename} already exists in this experiments directory ({directory})".format(filename=filename, directory=self.save_path))
path = os.path.join(self.db_path,self.save_path, filename)
if string_only:
return path
else:
return open(path, *args)
self.filenames = os.listdir(os.path.join(self.db_path,self.save_path)) # Refresh list of filenames
[docs]def filter_columns(message, columns = None):
"""
Returns only the given columns of message or everything if columns is None.
If tensor columns are requested, they are converted to ndarray first.
Args:
columns: Columns to keep. Default = None, meaning return the Message as is.
Returns:
message: Message with the filetered columns.
"""
return message # TODO: Implement this