Coverage for contextualized/dags/graph_utils.py: 77%

102 statements  

« prev     ^ index     » next       coverage.py v7.4.4, created at 2024-04-21 13:38 -0400

1import torch 

2import igraph as ig 

3import numpy as np 

4 

5 

6def dag_pred_with_factors(X, W, P): 

7 """ 

8 Pass observation X through a linear SEM low-dim network W with factors P. 

9 """ 

10 # For linear SEM with P factors, X is [n, P], W is [P, P], P is [d, P]: 

11 # X = XP^TWP/||P||_1 

12 # where the final normalization is performed over the rows of P. 

13 return torch.matmul(torch.matmul(torch.matmul(X, P.T), W), (P.T / P.sum(axis=1)).T) 

14 

15 

16dag_pred = lambda X, W: torch.matmul(X.unsqueeze(1), W).squeeze(1) 

17dag_pred_np = lambda x, w: np.matmul(x[:, np.newaxis, :], w).squeeze() 

18 

19 

20def simulate_linear_sem(W, n_samples, sem_type, noise_scale=None): 

21 """Simulate samples from linear SEM with specified type of noise. 

22 

23 For uniform, noise z ~ uniform(-a, a), where a = noise_scale. 

24 

25 Args: 

26 W (np.ndarray): [d, d] weighted adj matrix of DAG 

27 n (int): num of samples, n=inf mimics population risk 

28 sem_type (str): gauss, exp, gumbel, uniform, logistic, poisson 

29 noise_scale (np.ndarray): scale parameter of additive noise, default all ones 

30 

31 Returns: 

32 X (np.ndarray): [n, d] sample matrix, [d, d] if n=inf 

33 """ 

34 

35 def _simulate_single_equation(X, w, scale): 

36 """X: [n, num of parents], w: [num of parents], x: [n]""" 

37 if sem_type == "gauss": 

38 z = np.random.normal(scale=scale, size=n_samples) 

39 # x = X @ w + z 

40 x = np.matmul(X, w) + z 

41 elif sem_type == "exp": 

42 z = np.random.exponential(scale=scale, size=n_samples) 

43 # x = X @ w + z 

44 x = np.matmul(X, w) + z 

45 elif sem_type == "gumbel": 

46 z = np.random.gumbel(scale=scale, size=n_samples) 

47 # x = X @ w + z 

48 x = np.matmul(X, w) + z 

49 elif sem_type == "uniform": 

50 z = np.random.uniform(low=-scale, high=scale, size=n_samples) 

51 # x = X @ w + z 

52 x = np.matmul(X, w) + z 

53 elif sem_type == "logistic": 

54 # x = np.random.binomial(1, sigmoid(X @ w)) * 1.0 

55 x = np.random.binomial(1, 1 / (1 + np.exp(-(np.matmul(X, w))))) * 1.0 

56 elif sem_type == "poisson": 

57 # x = np.random.poisson(np.exp(X @ w)) * 1.0 

58 x = np.random.poisson(np.exp(np.matmul(X, w))) * 1.0 

59 else: 

60 raise ValueError("unknown sem type") 

61 return x 

62 

63 d = W.shape[0] 

64 if noise_scale is None: 

65 scale_vec = np.ones(d) 

66 elif np.isscalar(noise_scale): 

67 scale_vec = noise_scale * np.ones(d) 

68 else: 

69 if len(noise_scale) != d: 

70 raise ValueError("noise scale must be a scalar or has length d") 

71 scale_vec = noise_scale 

72 if not is_dag(W): 

73 raise ValueError("W must be a DAG") 

74 if np.isinf(n_samples): # population risk for linear gauss SEM 

75 if sem_type == "gauss": 

76 # make 1/d X'X = true cov 

77 X = np.sqrt(d) * np.matmul(np.diag(scale_vec), np.linalg.inv(np.eye(d) - W)) 

78 return X 

79 else: 

80 raise ValueError("population risk not available") 

81 # empirical risk 

82 G = ig.Graph.Weighted_Adjacency(W.tolist()) 

83 ordered_vertices = G.topological_sorting() 

84 assert len(ordered_vertices) == d 

85 X = np.zeros([n_samples, d]) 

86 for j in ordered_vertices: 

87 parents = G.neighbors(j, mode=ig.IN) 

88 X[:, j] = _simulate_single_equation(X[:, parents], W[parents, j], scale_vec[j]) 

89 return X 

90 

91 

92def break_symmetry(w): 

93 for i in range(w.shape[0]): 

94 w[i][i] = 0.0 

95 for j in range(i): 

96 if np.abs(w[i][j]) > np.abs(w[j][i]): 

97 w[j][i] = 0.0 

98 else: 

99 w[i][j] = 0.0 

100 return w 

101 

102 

103# w is the weighted adjacency matrix 

104def project_to_dag_torch(w): 

105 if is_dag(w): 

106 return w, 0.0 

107 

108 w_dag = w.copy() 

109 w_dag = break_symmetry(w_dag) 

110 

111 vals = sorted(list(set(np.abs(w_dag).flatten()))) 

112 low = 0 

113 high = len(vals) - 1 

114 

115 def binary_search(arr, low, high, w): # low and high are indices 

116 # Check base case 

117 if high == low: 

118 return high 

119 if high > low: 

120 mid = (high + low) // 2 

121 if mid == 0: 

122 return mid 

123 result = trim_params(w, arr[mid]) 

124 if is_dag(result): 

125 result2 = trim_params(w, arr[mid - 1]) 

126 if is_dag(result2): # middle value is too high. go lower. 

127 return binary_search(arr, low, mid - 1, w) 

128 else: 

129 return mid # found it 

130 else: # middle value is too low. go higher. 

131 return binary_search(arr, mid + 1, high, w) 

132 else: 

133 # Element is not present in the array 

134 print("this should be impossible") 

135 return -1 

136 

137 idx = binary_search(vals, low, high, w_dag) + 1 

138 thresh = vals[idx] 

139 w_dag = trim_params(w_dag, thresh) 

140 

141 # Now add back in edges with weights smaller than the thresh that don't violate DAG-ness. 

142 # want a list of edges (i, j) with weight in decreasing order. 

143 all_vals = np.abs(w_dag).flatten() 

144 idxs_sorted = reversed(np.argsort(all_vals)) 

145 for idx in idxs_sorted: 

146 i = idx // w_dag.shape[1] 

147 j = idx % w_dag.shape[1] 

148 if np.abs(w[i][j]) > thresh: # already retained 

149 continue 

150 w_dag[i][j] = w[i][j] 

151 if not is_dag(w_dag): 

152 w_dag[i][j] = 0.0 

153 

154 assert is_dag(w_dag) 

155 return w_dag, thresh 

156 

157 

158def is_dag(W): 

159 G = ig.Graph.Weighted_Adjacency(W.tolist()) 

160 return G.is_dag() 

161 

162 

163def trim_params(w, thresh=0.2): 

164 return w * (np.abs(w) > thresh)