from numpy import *
import gc, traceback, sys
from time import time
import logging
log = logging.getLogger(__file__)
from pyec.util.TernaryString import TernaryString
from pyec.util.partitions import Segment, Point, Partition, ScoreTree
[docs]class BadAlgorithm(Exception):
pass
[docs]class RunStats(object):
totals = {}
times = {}
counts = {}
recording = True
[docs] def start(self, key):
if not self.recording: return
if not self.totals.has_key(key):
self.totals[key] = 0.0
self.counts[key] = 0
self.times[key] = time()
[docs] def stop(self, key):
if not self.recording: return
now = time()
self.totals[key] += now - self.times[key]
del self.times[key]
self.counts[key] += 1
def __getitem__(self, key):
return self.totals[key] / self.counts[key]
def __str__(self):
ret = ""
for key,val in sorted(self.totals.items(), key=lambda x: x[0]):
ret += "%s: %.9f\n" % (key, self[key])
return ret
[docs]class Trainer(object):
def __init__(self, fitness, evoAlg, **kwargs):
self.fitness = fitness
self.algorithm = evoAlg
self.config = evoAlg.config
self.sort = True
self.segment = None
self.save = True
self.data = []
self.groupby = 50
self.since = 0
self.groupCount = 0
self.maxOrg = None
self.maxScore = None
if kwargs.has_key("save"):
self.save = kwargs["save"]
else:
if hasattr(self.config, "save"):
self.save = self.config.save
config = self.config
if self.save:
self.segment = Segment(name=config.segment, config=config)
[docs] def train(self):
trainStart = time()
stats = RunStats()
stats.recording = self.config.recording
maxScore = -1e100
maxOrg = None
gens = self.config.generations
if gens < 1:
return maxScore, maxOrg
successfulMutations = 0
successImprovement = 0
lastTime = time()
for idx in xrange(gens):
startTime = time()
stats.start("generation")
i = idx
population = []
start = time()
stats.start("sample")
self.config.selectedScores = []
total = 0.
count = 0.
for w, x in enumerate(self.algorithm.batch(self.config.populationSize)):
stats.stop("sample")
stats.start("score")
if not hasattr(self.config, 'convert') or self.config.convert:
z = self.algorithm.convert(x)
score = float(self.fitness(self.algorithm.convert(x)))
else:
z = x
score = float(self.fitness(x))
if self.config.bounded and not self.config.in_bounds(z):
score = -1e300
if hasattr(self.fitness, 'statistics'):
fitnessStats = self.fitness.statistics
else:
fitnessStats = None
if score != score:
score = -1e300
total += score
count += 1
self.since += 1
stats.stop("score")
population.append((x,score, fitnessStats))
stats.start("sample")
if len(self.config.selectedScores) > w:
baseScore = self.config.selectedScores[w]
if score > baseScore:
successfulMutations += 1.0
successImprovement += score - baseScore
genavg = total / count
attempts = (idx + 1) * self.config.populationSize
success = ((successfulMutations + 0.) / attempts)
avgImprove = (successImprovement / (successfulMutations + 1e-10))
if self.sort:
population = sorted(population, key=lambda x: x[1], reverse=True)
genmax = population[0][1]
genorg = population[0][0]
genstat = population[0][2]
else:
genmax = max([s for x,s,f in population])
for x,s,f in population:
if s == genmax:
genorg = x
genstat = f
break
if genmax > maxScore:
del maxOrg
del maxScore
maxScore = genmax
maxOrg = genorg
#print str(self.config.encode(genorg))
#print genstat
if hasattr(maxOrg, 'computeEdgeStatistics'):
maxOrg.computeEdgeStatistics()
print maxOrg.edges
else:
del genorg
while self.since >= self.groupby:
self.since -= self.groupby
self.groupCount += 1
self.data.append((self.groupCount, genavg, maxScore, genmax))
cnt = 0
pop2 = []
gps = []
for point, score, fitnessStats in population:
stats.start("point")
pop2.append((point, score))
if self.save:
try:
pt = None
bn = None
bit = None
other = None
if isinstance(point, ndarray):
pt = maximum(1e-30 * ones(len(point)), abs(point))
pt *= sign(point)
elif isinstance(point, TernaryString):
bit = point
elif hasattr(point, 'computeEdgeStatistics'):
bn = point
# bn.computeEdgeStatistics()
# print bn.edges
else:
other = point
gp = Point(point=pt, bayes=bn, binary=bit, other=other, statistics=fitnessStats, score=score, count=1, segment=self.segment)
gps.append(gp)
except:
raise
stats.stop("point")
if self.save:
stats.start("save")
Point.objects.bulkSave(gps, stats)
stats.stop("save")
population = pop2
stats.start("update")
self.algorithm.update(i+2,population)
stats.stop("update")
stats.stop("generation")
del population
if self.save:
gc.collect()
if (time() - lastTime) > 1.0:
lastTime = time()
if self.config.printOut:
if self.config.recording:
print stats
if hasattr(self.algorithm, 'var'):
print i, ": ", time() - startTime, self.algorithm.var, '%.16f' % genmax, '%.16f' % maxScore
else:
print i, ": ", time() - startTime, genmax, maxScore
if maxScore >= self.config.stopAt:
break
if self.config.printOut:
print "total time: ", time() - trainStart
print "best score: ", maxScore
self.maxScore = maxScore
self.maxOrg = maxOrg
return maxScore, maxOrg