Coverage for /home/deng/Projects/ete4/hackathon/ete4/ete4/core/arraytable.py: 20%

115 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1import numpy as np 

2 

3from ..parser.text_arraytable import write_arraytable, read_arraytable 

4 

5__all__ = ["ArrayTable"] 

6 

7 

8class ArrayTable: 

9 """Class to work with matrix datasets (like microarrays). 

10 

11 It allows to load the matrix and access easily row and column vectors. 

12 """ 

13 

14 def __init__(self, matrix_file=None, mtype="float"): 

15 self.colNames = [] 

16 self.rowNames = [] 

17 self.colValues = {} 

18 self.rowValues = {} 

19 self.matrix = None 

20 self.mtype = None 

21 

22 # If matrix file is supplied: 

23 if matrix_file is not None: 

24 read_arraytable(matrix_file, mtype=mtype, arraytable_object=self) 

25 

26 def __repr__(self): 

27 return "ArrayTable (%s)" % hex(self.__hash__()) 

28 

29 def __str__(self): 

30 return str(self.matrix) 

31 

32 def get_row_vector(self, rowname): 

33 """Return the vector associated to the given row name.""" 

34 return self.rowValues.get(rowname) 

35 

36 def get_column_vector(self, colname): 

37 """Return the vector associated to the given column name.""" 

38 return self.colValues.get(colname) 

39 

40 def get_several_row_vectors(self, rownames): 

41 """Return a list of vectors associated to several row names.""" 

42 vectors = [self.rowValues[rname] for rname in rownames] 

43 return np.array(vectors) 

44 

45 def get_several_column_vectors(self, colnames): 

46 """Return a list of vectors associated to several column names.""" 

47 vectors = [self.colValues[cname] for cname in colnames] 

48 return np.array(vectors) 

49 

50 def remove_column(self, colname): 

51 """Remove the given column form the current dataset.""" 

52 col_value = self.colValues.pop(colname, None) 

53 

54 if col_value is None: 

55 return 

56 

57 new_indexes = list(range(len(self.colNames))) 

58 index = self.colNames.index(colname) 

59 

60 self.colNames.pop(index) 

61 new_indexes.pop(index) 

62 

63 newmatrix = self.matrix.swapaxes(0,1) 

64 newmatrix = newmatrix[new_indexes].swapaxes(0,1) 

65 

66 self._link_names2matrix(newmatrix) 

67 

68 def merge_columns(self, groups, grouping_criterion): 

69 """Return a new ArrayTable with merged columns. 

70 

71 The columns are merged (grouped) according to the given criterion. 

72 

73 :param groups: Dictionary in which keys are the new column 

74 names, and each value is the list of current column names 

75 to be merged. 

76 :param grouping_criterion: How to merge numeric values. Can be 

77 'min', 'max' or 'mean'. 

78 

79 Example:: 

80 

81 my_groups = {'NewColumn': ['column5', 'column6']} 

82 new_Array = Array.merge_columns(my_groups, 'max') 

83 """ 

84 groupings = {'max': get_max_vector, 

85 'min': get_min_vector, 

86 'mean': get_mean_vector} 

87 try: 

88 grouping_f = groupings[grouping_criterion] 

89 except KeyError: 

90 raise ValueError(f'grouping_criterion "{grouping_criterion}" not ' 

91 'supported. Valid ones: %s' % ' '.join(groupings)) 

92 

93 grouped_array = self.__class__() 

94 grouped_matrix = [] 

95 colNames = [] 

96 alltnames = set([]) 

97 for gname, tnames in groups.items(): 

98 all_vectors=[] 

99 for tn in tnames: 

100 if tn not in self.colValues: 

101 raise ValueError(f'column not found: {tn}') 

102 if tn in alltnames: 

103 raise ValueError(f'duplicated column name for merging: {tn}') 

104 alltnames.add(tn) 

105 vector = self.get_column_vector(tn).astype(float) 

106 all_vectors.append(vector) 

107 # Store the group vector = max expression of all items in group 

108 grouped_matrix.append(grouping_f(all_vectors)) 

109 # store group name 

110 colNames.append(gname) 

111 

112 for cname in self.colNames: 

113 if cname not in alltnames: 

114 grouped_matrix.append(self.get_column_vector(cname)) 

115 colNames.append(cname) 

116 

117 grouped_array.rowNames= self.rowNames 

118 grouped_array.colNames= colNames 

119 vmatrix = np.array(grouped_matrix).transpose() 

120 grouped_array._link_names2matrix(vmatrix) 

121 return grouped_array 

122 

123 def transpose(self): 

124 """Return a new ArrayTable in which current matrix is transposed.""" 

125 transposedA = self.__class__() 

126 transposedM = self.matrix.transpose() 

127 transposedA.colNames = list(self.rowNames) 

128 transposedA.rowNames = list(self.colNames) 

129 transposedA._link_names2matrix(transposedM) 

130 

131 # Check that everything is ok 

132 # for n in self.colNames: 

133 # print self.get_column_vector(n) == transposedA.get_row_vector(n) 

134 # for n in self.rowNames: 

135 # print self.get_row_vector(n) == transposedA.get_column_vector(n) 

136 return transposedA 

137 

138 def _link_names2matrix(self, m): 

139 """Synchronize curent column and row names to the given matrix.""" 

140 if len(self.rowNames) != m.shape[0]: 

141 raise ValueError("Expecting matrix with %d rows" % m.size[0]) 

142 

143 if len(self.colNames) != m.shape[1]: 

144 raise ValueError("Expecting matrix with %d columns" % m.size[1]) 

145 

146 self.matrix = m 

147 self.colValues.clear() 

148 self.rowValues.clear() 

149 

150 # link columns names to vectors 

151 for i, colname in enumerate(self.colNames): 

152 self.colValues[colname] = self.matrix[:,i] 

153 

154 # link row names to vectors 

155 for i, rowname in enumerate(self.rowNames): 

156 self.rowValues[rowname] = self.matrix[i,:] 

157 

158 def write(self, fname, colnames=None): 

159 write_arraytable(self, fname, colnames=colnames) 

160 

161 

162def get_centroid_dist(vcenter, vlist, fdist): 

163 return 2 * sum(fdist(v, vcenter) for v in vlist) / len(vlist) 

164 

165 

166def get_average_centroid_linkage_dist(vcenter1, vlist1, 

167 vcenter2, vlist2, fdist): 

168 d1 = sum(fdist(v, vcenter2) for v in vlist1) 

169 d2 = sum(fdist(v, vcenter1) for v in vlist2) 

170 

171 return (d1 + d2) / (len(vlist1) + len(vlist2)) 

172 

173 

174def safe_mean(values): 

175 """Return the mean value and std discarding non finite values.""" 

176 valid_values = [v for v in values if np.isfinite(v)] 

177 return np.mean(valid_values), np.std(valid_values) 

178 

179 

180def safe_mean_vector(vectors): 

181 """Return list of (mean, std) profiles discarding non finite values.""" 

182 # If only one vector, avg = itself 

183 if len(vectors) == 1: 

184 return vectors[0], np.zeros(len(vectors[0])) 

185 

186 safe_mean = [] 

187 safe_std = [] 

188 for i in range(len(vectors[0])): # take vector length form the first item 

189 values = [v[i] for v in vectors if np.isfinite(v[i])] 

190 

191 safe_mean.append(np.mean(values)) 

192 safe_std.append(np.std(values)) 

193 

194 return safe_mean, safe_std 

195 

196 

197def get_mean_vector(vlist): 

198 return np.mean(vlist, 0) 

199 

200def get_max_vector(vlist): 

201 return np.max(vlist, 0) 

202 

203def get_min_vector(vlist): 

204 return np.min(vlist, 0)