Coverage for contextualized/dags/graph_utils.py: 77%
102 statements
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
« prev ^ index » next coverage.py v7.4.4, created at 2024-04-21 13:38 -0400
1import torch
2import igraph as ig
3import numpy as np
6def dag_pred_with_factors(X, W, P):
7 """
8 Pass observation X through a linear SEM low-dim network W with factors P.
9 """
10 # For linear SEM with P factors, X is [n, P], W is [P, P], P is [d, P]:
11 # X = XP^TWP/||P||_1
12 # where the final normalization is performed over the rows of P.
13 return torch.matmul(torch.matmul(torch.matmul(X, P.T), W), (P.T / P.sum(axis=1)).T)
16dag_pred = lambda X, W: torch.matmul(X.unsqueeze(1), W).squeeze(1)
17dag_pred_np = lambda x, w: np.matmul(x[:, np.newaxis, :], w).squeeze()
20def simulate_linear_sem(W, n_samples, sem_type, noise_scale=None):
21 """Simulate samples from linear SEM with specified type of noise.
23 For uniform, noise z ~ uniform(-a, a), where a = noise_scale.
25 Args:
26 W (np.ndarray): [d, d] weighted adj matrix of DAG
27 n (int): num of samples, n=inf mimics population risk
28 sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson
29 noise_scale (np.ndarray): scale parameter of additive noise, default all ones
31 Returns:
32 X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf
33 """
35 def _simulate_single_equation(X, w, scale):
36 """X: [n, num of parents], w: [num of parents], x: [n]"""
37 if sem_type == "gauss":
38 z = np.random.normal(scale=scale, size=n_samples)
39 # x = X @ w + z
40 x = np.matmul(X, w) + z
41 elif sem_type == "exp":
42 z = np.random.exponential(scale=scale, size=n_samples)
43 # x = X @ w + z
44 x = np.matmul(X, w) + z
45 elif sem_type == "gumbel":
46 z = np.random.gumbel(scale=scale, size=n_samples)
47 # x = X @ w + z
48 x = np.matmul(X, w) + z
49 elif sem_type == "uniform":
50 z = np.random.uniform(low=-scale, high=scale, size=n_samples)
51 # x = X @ w + z
52 x = np.matmul(X, w) + z
53 elif sem_type == "logistic":
54 # x = np.random.binomial(1, sigmoid(X @ w)) * 1.0
55 x = np.random.binomial(1, 1 / (1 + np.exp(-(np.matmul(X, w))))) * 1.0
56 elif sem_type == "poisson":
57 # x = np.random.poisson(np.exp(X @ w)) * 1.0
58 x = np.random.poisson(np.exp(np.matmul(X, w))) * 1.0
59 else:
60 raise ValueError("unknown sem type")
61 return x
63 d = W.shape[0]
64 if noise_scale is None:
65 scale_vec = np.ones(d)
66 elif np.isscalar(noise_scale):
67 scale_vec = noise_scale * np.ones(d)
68 else:
69 if len(noise_scale) != d:
70 raise ValueError("noise scale must be a scalar or has length d")
71 scale_vec = noise_scale
72 if not is_dag(W):
73 raise ValueError("W must be a DAG")
74 if np.isinf(n_samples): # population risk for linear gauss SEM
75 if sem_type == "gauss":
76 # make 1/d X'X = true cov
77 X = np.sqrt(d) * np.matmul(np.diag(scale_vec), np.linalg.inv(np.eye(d) - W))
78 return X
79 else:
80 raise ValueError("population risk not available")
81 # empirical risk
82 G = ig.Graph.Weighted_Adjacency(W.tolist())
83 ordered_vertices = G.topological_sorting()
84 assert len(ordered_vertices) == d
85 X = np.zeros([n_samples, d])
86 for j in ordered_vertices:
87 parents = G.neighbors(j, mode=ig.IN)
88 X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale_vec[j])
89 return X
92def break_symmetry(w):
93 for i in range(w.shape[0]):
94 w[i][i] = 0.0
95 for j in range(i):
96 if np.abs(w[i][j]) > np.abs(w[j][i]):
97 w[j][i] = 0.0
98 else:
99 w[i][j] = 0.0
100 return w
103# w is the weighted adjacency matrix
104def project_to_dag_torch(w):
105 if is_dag(w):
106 return w, 0.0
108 w_dag = w.copy()
109 w_dag = break_symmetry(w_dag)
111 vals = sorted(list(set(np.abs(w_dag).flatten())))
112 low = 0
113 high = len(vals) - 1
115 def binary_search(arr, low, high, w): # low and high are indices
116 # Check base case
117 if high == low:
118 return high
119 if high > low:
120 mid = (high + low) // 2
121 if mid == 0:
122 return mid
123 result = trim_params(w, arr[mid])
124 if is_dag(result):
125 result2 = trim_params(w, arr[mid - 1])
126 if is_dag(result2): # middle value is too high. go lower.
127 return binary_search(arr, low, mid - 1, w)
128 else:
129 return mid # found it
130 else: # middle value is too low. go higher.
131 return binary_search(arr, mid + 1, high, w)
132 else:
133 # Element is not present in the array
134 print("this should be impossible")
135 return -1
137 idx = binary_search(vals, low, high, w_dag) + 1
138 thresh = vals[idx]
139 w_dag = trim_params(w_dag, thresh)
141 # Now add back in edges with weights smaller than the thresh that don't violate DAG-ness.
142 # want a list of edges (i, j) with weight in decreasing order.
143 all_vals = np.abs(w_dag).flatten()
144 idxs_sorted = reversed(np.argsort(all_vals))
145 for idx in idxs_sorted:
146 i = idx // w_dag.shape[1]
147 j = idx % w_dag.shape[1]
148 if np.abs(w[i][j]) > thresh: # already retained
149 continue
150 w_dag[i][j] = w[i][j]
151 if not is_dag(w_dag):
152 w_dag[i][j] = 0.0
154 assert is_dag(w_dag)
155 return w_dag, thresh
158def is_dag(W):
159 G = ig.Graph.Weighted_Adjacency(W.tolist())
160 return G.is_dag()
163def trim_params(w, thresh=0.2):
164 return w * (np.abs(w) > thresh)