#
# This file is part of TensorToolbox.
#
# TensorToolbox is free software: you can redistribute it and/or modify
# it under the terms of the LGNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TensorToolbox is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# LGNU Lesser General Public License for more details.
#
# You should have received a copy of the LGNU Lesser General Public License
# along with TensorToolbox. If not, see <http://www.gnu.org/licenses/>.
#
# DTU UQ Library
# Copyright (C) 2014 The Technical University of Denmark
# Scientific Computing Section
# Department of Applied Mathematics and Computer Science
#
# Author: Daniele Bigoni
#
__all__ = ['load','storable_object','ttcross_store','to_v_0_3_0']
import time
import itertools
import shutil
import os.path
import cPickle as pkl
import h5py
from TensorToolbox import __version__ as TT_version
[docs]def to_v_0_3_0(filename):
""" Used to upgrade the storage version from version <0.3.0 to version 0.3.0
:param string filename: path to the filename. This must be the main filename with no extension.
"""
pkl_location = filename + ".pkl"
# Open the old version
print "Opening %s" % pkl_location
ff = open(pkl_location, 'rb')
obj = pkl.load(ff)
ff.close()
try:
obj.VERSION
except:
# Copy old version for backup
print "Backup copy %s" % pkl_location
shutil.copyfile(pkl_location, pkl_location + ".deprec")
# Remove old version file
print "Removing %s" % pkl_location
os.remove(pkl_location)
# Update version of the object
print "Updating objects"
obj.to_v_0_3_0(filename)
# Force storage
print "Storing %s" % filename
obj.store(force=True)
else:
print "The version of the file is already >0.3.0"
[docs]def load(filename,load_data=True):
""" Used to load TensorToolbox data.
:param string filename: path to the filename. This must be the main filename with no extension.
:param bool load_data: whether to load additional data from ".h5" files.
"""
pkl_location = filename + ".pkl"
h5_location = filename + ".h5"
ff = open(pkl_location,'rb')
obj = pkl.load(ff)
ff.close()
if load_data:
obj.load(h5_location=h5_location)
return obj
class storable_object( object ):
def __init__(self, store_location='', store_freq=None, store_overwrite=False, store_object=None):
#######################################
# List of attributes
self.VERSION = None
self.store_freq = None
self.store_location = ''
self.serialize_list = ['VERSION', 'serialize_list', 'subserialize_list', 'store_location', 'store_freq']
self.subserialize_list = []
# Non serialized attributes which must be re-init on setstate
self.last_store_time = None
# End list
#######################################
self.VERSION = TT_version
self.store_location = store_location
self.store_freq = store_freq
self.last_store_time = -float("inf")
self.store_object = store_object
if self.store_object == None and os.path.isfile(self.store_location) and not store_overwrite:
raise AttributeError("The file %s already exist." % self.store_location)
def __getstate__(self):
return dict( [ (tag, getattr( self, tag )) for tag in self.serialize_list ] )
def __setstate__(self,state, store_object = None):
for tag in state.keys():
setattr(self, tag, state[tag])
# Reset parameters
self.reset_store_time()
self.set_store_object( store_object )
def to_be_stored(self,force=False):
# force = force or (self.last_store_time == None) or (self.store_freq == None)
# if self.store_location not in ('',None) and \
# (force or time.time() > self.last_store_time + self.store_freq):
# return True
# else: return False
# Ensure first storage when last_store_time is not set yet
force = force or (self.store_freq != None and self.last_store_time == None)
if self.store_location not in ('',None) and \
(force or \
(self.store_freq != None and time.time() > self.last_store_time + self.store_freq)
):
return True
else: return False
def reset_store_time(self):
self.last_store_time = time.time()
def set_store_object(self, store_object):
self.store_object = store_object
def h5store(self, h5file):
""" Used to store additional data in hdf5 format. To be redefined in subclasses.
"""
pass
def h5load(self, h5file):
""" Used to load additional data in hdf5 format. To be redefined in subclasses.
"""
pass
def load(self,h5_location=None):
""" Used to load additional data.
"""
if self.store_object == None:
try:
self.VERSION
except:
# Old pickle serialization. Nothing extra to be loaded.
pass
else:
# New storage of objects (versions > 0.3.0)
# The input will be:
# - an h5 file containing the data
# File name
if h5_location == None:
h5_location = self.store_location + ".h5"
# Call the data loading method in self
h5file = h5py.File(h5_location, 'r')
self.h5load(h5file)
h5file.close()
else:
self.store_object.load()
def store(self, force=False):
""" Used to store any object in the library.
:param bool force: force storage before time
"""
if self.store_object == None:
if self.to_be_stored(force):
try:
self.VERSION
except:
# Old pickle serialization of objects
if os.path.isfile(self.store_location):
# If the file already exists, make a safety copy
shutil.copyfile(self.store_location, self.store_location+".old")
ff = open(self.store_location,'wb')
pkl.dump(self,ff)
ff.close()
else:
# New storage of objects (versions > 0.3.0)
# The output will be two files containing:
# - a pickle file conatining the serialization of self
# - an h5 file containing the data
# File names
pkl_location = self.store_location + ".pkl"
h5_location = self.store_location + ".h5"
# Store old copy for safety
if os.path.isfile(pkl_location):
shutil.copyfile(pkl_location, pkl_location + ".old")
if os.path.isfile(h5_location):
shutil.copyfile(h5_location, h5_location + ".old")
# Dump the serialized version of the object
ff = open(pkl_location,'wb')
pkl.dump(self,ff)
ff.close()
# Call the data storage method in self
h5file = h5py.File(h5_location,'a') # Read/write if exists, create otherwise
self.h5store(h5file)
h5file.close()
finally:
self.reset_store_time()
else:
self.store_object.store(force)
def to_v_0_3_0(self, store_location):
""" To be implemented for objects that need to be upgraded to v0.3.0.
:param string filename: path to the filename. This must be the main filename with no extension.
"""
self.VERSION = '0.3.0'
self.store_location = store_location
self.serialize_list.append('VERSION')
##############################################
# DEPRECATED
##############################################
[docs]def ttcross_store(path,TW,TTapp):
""" Used to store the computed values of a TTcross approximation. Usually needed when the single function evaluation is demanding or when we need to restart TTcross later on.
:param string path: path pointing to the location where to store the data
:param TensorWrapper TW: Tensor wrapper used to build the ttcross approximation. TW.get_data(), TW.get_X() and TW.get_params() will be stored.
:param TTvec TTapp: TTcross approximation. TTapp.ttcross.Jinit will be stored.
.. deprecated:: 0.3.0
Use the objects' methods :func:`store`.
"""
dic = {'Jinit' : TTapp.ttcross_Jinit,
'rs': TTapp.ttcross_rs,
'Js': TTapp.ttcross_Js,
'Is': TTapp.ttcross_Is,
'Js_last': TTapp.ttcross_Js_last,
'ltor_fiber_lists': TTapp.ltor_fiber_lists,
'rtol_fiber_lists': TTapp.rtol_fiber_lists,
'X' : TW.get_X(),
'params' : TW.get_params(),
'data' : TW.get_data()}
if os.path.isfile(path):
# If the file already exists, make a safety copy
shutil.copyfile(path, path+".old")
ff = open(path,'wb')
pkl.dump(dic,ff)
ff.close()