# BROJA_2PID.py -- Python module
#
# BROJA_2PID: Bertschinger-Rauh-Olbrich-Jost-Ay (BROJA) bivariate Partial
# Information Decomposition https://github.com/Abzinger/BROJA_2PID
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify with proper attribution
# (Apache License version 2.0)
#
# Information about the algorithm, documentation, and examples are here:
# @Article{makkeh-theis-vicente:pidOpt:2017,
# author = {Makkeh, Abdullah and Theis, Dirk Oliver and
# Vicente, Raul},
# title = {BROJA-2PID: A cone programming based Partial
# Information Decomposition estimator},
# journal = {jo},
# year = 2017,
# key = {key},
# volume = {vol},
# number = {nr},
# pages = {1--2}
# }
# Please cite this paper when you use this software (cf. README.md)
###############################################################################
from scipy import sparse
import numpy as np
from numpy import linalg as LA
import math
from collections import defaultdict
from . import idtxl_exceptions as ex
from .idtxl_exceptions import BROJA_2PID_Exception
try:
import ecos
except ImportError as err:
ex.package_missing(err, 'ECOS is not available on this system. Install it '
'from https://pypi.python.org/pypi/ecos to use '
'The Tartu cone programming PID estimator.')
log = math.log2
ln = math.log
# ECOS's exp cone: (r,p,q) w/ q>0 & exp(r/q) ≤ p/q
# Translation: (0,1,2) w/ 2>0 & 0/2 ≤ ln(1/2)
[docs]def r_vidx(i):
return 3*i
[docs]def p_vidx(i):
return 3*i+1
[docs]def q_vidx(i):
return 3*i+2
[docs]class Solve_w_ECOS:
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify under Apache License version 2.0
def __init__(self, marg_xy, marg_xz):
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify under Apache License version 2.0
# ECOS parameters
self.ecos_kwargs = dict()
self.verbose = False
# Data for ECOS
self.c = None
self.G = None
self.h = None
self.dims = dict()
self.A = None
self.b = None
# ECOS result
self.sol_rpq = None
self.sol_slack = None
self.sol_lambda = None # dual variables for equality constraints
self.sol_mu = None # dual variables for generalized ieqs
self.sol_info = None
# Probability density funciton data
self.b_xy = dict(marg_xy)
self.b_xz = dict(marg_xz)
self.X = set([x for x, y in self.b_xy.keys()] +
[x for x, z in self.b_xz.keys()])
self.Y = set([y for x, y in self.b_xy.keys()])
self.Z = set([z for x, z in self.b_xz.keys()])
self.idx_of_trip = dict()
self.trip_of_idx = []
# Do stuff:
for x in self.X:
for y in self.Y:
if (x, y) in self.b_xy.keys():
for z in self.Z:
if (x, z) in self.b_xz.keys():
self.idx_of_trip[(x, y, z)] = len(self.trip_of_idx)
self.trip_of_idx.append((x, y, z))
[docs] def create_model(self):
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify under Apache License version 2.0
n = len(self.trip_of_idx)
m = len(self.b_xy) + len(self.b_xz)
n_vars = 3*n
n_cons = n+m
# Create the equations: Ax = b
self.b = np.zeros((n_cons,), dtype=np.double)
Eqn = []
Var = []
Coeff = []
# The q-p coupling equations: q_{*yz} - p_{xyz} = 0
for i, xyz in enumerate(self.trip_of_idx):
eqn = i
p_var = p_vidx(i)
Eqn.append(eqn)
Var.append(p_var)
Coeff.append(-1.)
(x, y, z) = xyz
for u in self.X:
if (u, y, z) in self.idx_of_trip.keys():
q_var = q_vidx(self.idx_of_trip[(u, y, z)])
Eqn.append(eqn)
Var.append(q_var)
Coeff.append(+1.)
# running number
eqn = -1 + len(self.trip_of_idx)
# The xy marginals q_{xy*} = b^y_{xy}
for x in self.X:
for y in self.Y:
if (x, y) in self.b_xy.keys():
eqn += 1
for z in self.Z:
if (x, y, z) in self.idx_of_trip.keys():
q_var = q_vidx(self.idx_of_trip[(x, y, z)])
Eqn.append(eqn)
Var.append(q_var)
Coeff.append(1.)
self.b[eqn] = self.b_xy[(x, y)]
# The xz marginals q_{x*z} = b^z_{xz}
for x in self.X:
for z in self.Z:
if (x, z) in self.b_xz.keys():
eqn += 1
for y in self.Y:
if (x, y, z) in self.idx_of_trip.keys():
q_var = q_vidx(self.idx_of_trip[(x, y, z)])
Eqn.append(eqn)
Var.append(q_var)
Coeff.append(1.)
self.b[eqn] = self.b_xz[(x, z)]
self.A = sparse.csc_matrix(
(Coeff, (Eqn, Var)), shape=(n_cons, n_vars), dtype=np.double)
# Generalized ieqs: gen.nneg of the variable triples (r_i,q_i,p_i),
# i=0,dots,n-1:
Ieq = []
Var = []
Coeff = []
for i, xyz in enumerate(self.trip_of_idx):
r_var = r_vidx(i)
q_var = q_vidx(i)
p_var = p_vidx(i)
Ieq.append(len(Ieq))
Var.append(r_var)
Coeff.append(-1.)
Ieq.append(len(Ieq))
Var.append(p_var)
Coeff.append(-1.)
Ieq.append(len(Ieq))
Var.append(q_var)
Coeff.append(-1.)
self.G = sparse.csc_matrix(
(Coeff, (Ieq, Var)), shape=(n_vars, n_vars), dtype=np.double)
self.h = np.zeros((n_vars,), dtype=np.double)
self.dims['e'] = n
# Objective function:
self.c = np.zeros((n_vars,), dtype=np.double)
for i, xyz in enumerate(self.trip_of_idx):
self.c[r_vidx(i)] = -1.
[docs] def solve(self):
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify under Apache License version 2.0
self.marg_yz = None # for cond[]mutinf computation below
if self.verbose is not None:
self.ecos_kwargs["verbose"] = self.verbose
solution = ecos.solve(self.c, self.G, self.h,
self.dims, self.A, self.b, **self.ecos_kwargs)
if 'x' in solution.keys():
self.sol_rpq = solution['x']
self.sol_slack = solution['s']
self.sol_lambda = solution['y']
self.sol_mu = solution['z']
self.sol_info = solution['info']
return "success"
else: # "x" not in dict solution
return "x not in dict solution -- No Solution Found!!!"
[docs] def provide_marginals(self):
if self.marg_yz == None:
self.marg_yz = dict()
self.marg_y = defaultdict(lambda: 0.)
self.marg_z = defaultdict(lambda: 0.)
for y in self.Y:
for z in self.Z:
zysum = 0.
for x in self.X:
if (x, y, z) in self.idx_of_trip.keys():
q = self.sol_rpq[
q_vidx(self.idx_of_trip[(x, y, z)])]
if q > 0:
zysum += q
self.marg_y[y] += q
self.marg_z[z] += q
if zysum > 0.:
self.marg_yz[(y, z)] = zysum
[docs] def condYmutinf(self):
self.provide_marginals()
mysum = 0.
for x in self.X:
for z in self.Z:
if not (x, z) in self.b_xz.keys():
continue
for y in self.Y:
if (x, y, z) in self.idx_of_trip.keys():
i = q_vidx(self.idx_of_trip[(x, y, z)])
q = self.sol_rpq[i]
if q > 0:
mysum += q * log(q * self.marg_y[y] / (
self.b_xy[(x, y)] * self.marg_yz[(y, z)]))
return mysum
[docs] def condZmutinf(self):
self.provide_marginals()
mysum = 0.
for x in self.X:
for y in self.Y:
if not (x, y) in self.b_xy.keys():
continue
for z in self.Z:
if (x, y, z) in self.idx_of_trip.keys():
i = q_vidx(self.idx_of_trip[(x, y, z)])
q = self.sol_rpq[i]
if q > 0:
mysum += q * log(
q * self.marg_z[z] / (
self.b_xz[(x, z)] * self.marg_yz[(y, z)]))
return mysum
[docs] def entropy_X(self, pdf):
mysum = 0.
for x in self.X:
psum = 0.
for y in self.Y:
if not (x, y) in self.b_xy:
continue
for z in self.Z:
if (x, y, z) in pdf.keys():
psum += pdf[(x, y, z)]
mysum -= psum * log(psum)
return mysum
[docs] def condentropy(self):
# compute cond entropy of the distribution in self.sol_rpq
mysum = 0.
for y in self.Y:
for z in self.Z:
marg_x = 0.
q_list = [q_vidx(self.idx_of_trip[ (x,y,z)]) for
x in self.X if (x, y, z) in self.idx_of_trip.keys()]
for i in q_list:
marg_x += max(0, self.sol_rpq[i])
for i in q_list:
q = self.sol_rpq[i]
if q > 0:
mysum -= q * log(q / marg_x)
return mysum
[docs] def condentropy__orig(self, pdf):
mysum = 0.
for y in self.Y:
for z in self.Z:
x_list = [x for x in self.X if (x, y, z) in pdf.keys()]
marg = 0.
for x in x_list:
marg += pdf[(x, y, z)]
for x in x_list:
p = pdf[(x, y, z)]
mysum -= p * log(p / marg)
return mysum
[docs] def dual_value(self):
return -np.dot(self.sol_lambda, self.b)
[docs] def check_feasibility(self): # returns pair (p,d) of primal/dual infeasibility (maxima)
# Primal infeasiblility
# ---------------------
max_q_negativity = 0.
for i in range(len(self.trip_of_idx)):
max_q_negativity = max(max_q_negativity, -self.sol_rpq[q_vidx(i)])
#^ for
max_violation_of_eqn = 0.
# xy* - marginals:
for xy in self.b_xy.keys():
mysum = self.b_xy[xy]
for z in self.Z:
x, y = xy
if (x, y, z) in self.idx_of_trip.keys():
i = self.idx_of_trip[(x, y, z)]
q = max(0., self.sol_rpq[q_vidx(i)])
mysum -= q
max_violation_of_eqn = max(max_violation_of_eqn, abs(mysum))
for xz in self.b_xz.keys():
mysum = self.b_xz[xz]
for y in self.Y:
x, z = xz
if (x, y, z) in self.idx_of_trip.keys():
i = self.idx_of_trip[(x, y, z)]
q = max(0., self.sol_rpq[q_vidx(i)])
mysum -= q
max_violation_of_eqn = max(max_violation_of_eqn, abs(mysum))
primal_infeasability = max(max_violation_of_eqn, max_q_negativity)
# Dual infeasiblility
# -------------------
idx_of_xy = dict()
i = 0
for x in self.X:
for y in self.Y:
if (x,y) in self.b_xy.keys():
idx_of_xy[(x,y)] = i
i += 1
#^ for
idx_of_xz = dict()
i = 0
for x in self.X:
for z in self.Z:
if (x,z) in self.b_xz.keys():
idx_of_xz[(x,z)] = i
i += 1
#^ for
dual_infeasability = 0.
for i,xyz in enumerate(self.trip_of_idx):
mu_yz = 0.
x,y,z = xyz
# Compute mu_*yz
# mu_xyz: dual variable of the coupling constraints
for j,uvw in enumerate(self.trip_of_idx):
u,v,w = uvw
if v == y and w == z:
mu_yz += self.sol_lambda[j]
# Get indices of dual variables of the marginal constriants
xy_idx = len(self.trip_of_idx) + idx_of_xy[(x,y)]
xz_idx = len(self.trip_of_idx) + len(self.b_xy) + idx_of_xz[(x,z)]
# Find the most violated dual ieq
dual_infeasability = max(
dual_infeasability,
-self.sol_lambda[xy_idx] - self.sol_lambda[xz_idx] - mu_yz -ln(-self.sol_lambda[i]) - 1)
#^ for
return primal_infeasability, dual_infeasability
#^ check_feasibility()
#^ class Solve_w_ECOS
[docs]def marginal_xy(p):
marg = dict()
for xyz, r in p.items():
x, y, z = xyz
if (x, y) in marg.keys():
marg[(x, y)] += r
else:
marg[(x, y)] = r
return marg
[docs]def marginal_xz(p):
marg = dict()
for xyz, r in p.items():
x, y, z = xyz
if (x, z) in marg.keys():
marg[(x, z)] += r
else:
marg[(x, z)] = r
return marg
[docs]def I_X_Y(p):
# Mutual information I( X ; Y )
mysum = 0.
marg_x = defaultdict(lambda: 0.)
marg_y = defaultdict(lambda: 0.)
b_xy = marginal_xy(p)
for xyz, r in p.items():
x, y, z = xyz
if r > 0:
marg_x[x] += r
marg_y[y] += r
for xy, t in b_xy.items():
x, y = xy
if t > 0:
mysum += t * log(t / (marg_x[x] * marg_y[y]))
return mysum
#^ I_X_Y()
[docs]def I_X_Z(p):
# Mutual information I( X ; Z )
mysum = 0.
marg_x = defaultdict(lambda: 0.)
marg_z = defaultdict(lambda: 0.)
b_xz = marginal_xz(p)
for xyz,r in p.items():
x,y,z = xyz
if r > 0 :
marg_x[x] += r
marg_z[z] += r
for xz,t in b_xz.items():
x,z = xz
if t > 0: mysum += t * log( t / ( marg_x[x]*marg_z[z] ) )
return mysum
#^ I_X_Z()
[docs]def I_X_YZ(p):
# Mutual information I( X ; Y , Z )
mysum = 0.
marg_x = defaultdict(lambda: 0.)
marg_yz = defaultdict(lambda: 0.)
for xyz,r in p.items():
x,y,z = xyz
if r > 0 :
marg_x[x] += r
marg_yz[(y,z)] += r
for xyz,t in p.items():
x,y,z = xyz
if t > 0: mysum += t * log( t / ( marg_x[x]*marg_yz[(y,z)] ) )
return mysum
#^ I_X_YZ()
[docs]def pid(pdf_dirty, cone_solver="ECOS", output=0, **solver_args):
# (c) Abdullah Makkeh, Dirk Oliver Theis
# Permission to use and modify under Apache License version 2.0
assert type(pdf_dirty) is dict, "broja_2pid.pid(pdf): pdf must be a dictionary"
assert type(cone_solver) is str, "broja_2pid.pid(pdf): `cone_solver' parameter must be string (e.g., 'ECOS')"
if __debug__:
for k,v in pdf_dirty.items():
assert type(k) is tuple or type(k) is list, "broja_2pid.pid(pdf): pdf's keys must be tuples or lists"
assert len(k)==3, "broja_2pid.pid(pdf): pdf's keys must be tuples/lists of length 3"
assert type(v) is float or ( type(v)==int and v==0 ), "broja_2pid.pid(pdf): pdf's values must be floats"
assert v > -.1, "broja_2pid.pid(pdf): pdf's values must not be negative"
#^ for
#^ if
assert type(output) is int, "broja_2pid.pid(pdf,output): output must be an integer"
# Check if the solver is implemented:
assert cone_solver=="ECOS", "broja_2pid.pid(pdf): We currently don't have an interface for the Cone Solver "+cone_solver+" (only ECOS)."
pdf = { k:v for k,v in pdf_dirty.items() if v > 1.e-300 }
by_xy = marginal_xy(pdf)
bz_xz = marginal_xz(pdf)
# if cone_solver=="ECOS": .....
if output > 0: print("BROJA_2PID: Preparing Cone Program data",end="...")
solver = Solve_w_ECOS(by_xy, bz_xz)
solver.create_model()
if output > 1: solver.verbose = True
ecos_keep_solver_obj = False
if 'keep_solver_object' in solver_args.keys():
if solver_args['keep_solver_object']==True: ecos_keep_solver_obj = True
del solver_args['keep_solver_object']
solver.ecos_kwargs = solver_args
if output > 0: print("done.")
if output == 1: print("BROJA_2PID: Starting solver",end="...")
if output > 1: print("BROJA_2PID: Starting solver.")
retval = solver.solve()
if retval != "success":
print("\nCone Programming solver failed to find (near) optimal "
"solution.\nPlease report the input probability density "
"function to abdullah.makkeh@gmail.com\n")
if ecos_keep_solver_obj:
return solver
else:
raise BROJA_2PID_Exception(
"BROJA_2PID_Exception: Cone Programming solver failed to find "
"(near) optimal solution. Please report the input probability "
"density function to abdullah.makkeh@gmail.com")
if output > 0: print("\nBROJA_2PID: done.")
if output > 1: print(solver.sol_info)
entropy_X = solver.entropy_X(pdf)
condent = solver.condentropy()
condent__orig = solver.condentropy__orig(pdf)
condYmutinf = solver.condYmutinf()
condZmutinf = solver.condZmutinf()
dual_val = solver.dual_value()
bits = 1/log(2)
# elsif cone_solver=="SCS":
# .....
# #^endif
return_data = dict()
return_data["SI"] = (entropy_X - condent -
condZmutinf - condYmutinf) * bits
return_data["UIY"] = (condZmutinf) * bits
return_data["UIZ"] = (condYmutinf) * bits
return_data["CI"] = (condent - condent__orig) * bits
primal_infeas,dual_infeas = solver.check_feasibility()
return_data["Num_err"] = (primal_infeas, dual_infeas, max(-condent*ln(2) - dual_val, 0.0))
return_data["Solver"] = "ECOS http://www.embotech.com/ECOS"
if ecos_keep_solver_obj:
return_data["Solver Object"] = solver
return return_data