Coverage for /home/deng/Projects/ete4/hackathon/ete4/ete4/clustering/clustertree.py: 30%

88 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2024-03-21 09:19 +0100

1from sys import stderr 

2from . import clustvalidation 

3from ete4 import Tree, ArrayTable 

4import numpy 

5 

6__all__ = ["ClusterTree"] 

7 

8 

9class ClusterTree(Tree): 

10 """ 

11 A ClusterTree is a Tree that represents a clustering result. 

12 """ 

13 

14 def _set_forbidden(self, value): 

15 raise ValueError("This attribute can not be manually set.") 

16 

17 def _get_intra(self): 

18 if self._silhouette is None: 

19 self.get_silhouette() 

20 return self._intracluster_dist 

21 

22 def _get_inter(self): 

23 if self._silhouette is None: 

24 self.get_silhouette() 

25 return self._intercluster_dist 

26 

27 def _get_silh(self): 

28 if self._silhouette is None: 

29 self.get_silhouette() 

30 return self._silhouette 

31 

32 def _get_prof(self): 

33 if self._profile is None: 

34 self._calculate_avg_profile() 

35 return self._profile 

36 

37 def _get_std(self): 

38 if self._std_profile is None: 

39 self._calculate_avg_profile() 

40 return self._std_profile 

41 

42 def _set_profile(self, value): 

43 self._profile = value 

44 

45 intracluster_dist = property(fget=_get_intra, fset=_set_forbidden) 

46 intercluster_dist = property(fget=_get_inter, fset=_set_forbidden) 

47 silhouette = property(fget=_get_silh, fset=_set_forbidden) 

48 profile = property(fget=_get_prof, fset=_set_profile) 

49 deviation = property(fget=_get_std, fset=_set_forbidden) 

50 

51 def __init__(self, data=None, children=None, text_array=None, 

52 fdist=clustvalidation.default_dist): 

53 # Default dist is spearman_dist when scipy module is loaded 

54 # otherwise, it is set to euclidean_dist. 

55 

56 # Initialize basic tree features and loads the newick (if any) 

57 Tree.__init__(self, data, children) 

58 self._fdist = None 

59 self._silhouette = None 

60 self._intercluster_dist = None 

61 self._intracluster_dist = None 

62 self._profile = None 

63 self._std_profile = None 

64 

65 # Cluster especific features 

66 # self.features.add("intercluster_dist") 

67 # self.features.add("intracluster_dist") 

68 # self.features.add("silhouette") 

69 # self.features.add("profile") 

70 # self.features.add("deviation") 

71 

72 # Initialize tree with array data 

73 if text_array: 

74 self.link_to_arraytable(text_array) 

75 

76 if data: 

77 self.set_distance_function(fdist) 

78 

79 def __repr__(self): 

80 return "ClusterTree node (%s)" %hex(self.__hash__()) 

81 

82 def set_distance_function(self, fn): 

83 """Set the distance function used to calculate cluster 

84 distances and silouette index. 

85 

86 :param fn: Function acepting two numpy arrays as arguments. 

87 

88 Example::: 

89 

90 # Set a simple euclidean distance. 

91 my_dist_fn = lambda x,y: abs(x-y) 

92 tree.set_distance_function(my_dist_fn) 

93 """ 

94 for n in self.traverse(): 

95 n._fdist = fn 

96 n._silhouette = None 

97 n._intercluster_dist = None 

98 n._intracluster_dist = None 

99 

100 def link_to_arraytable(self, arraytbl): 

101 """Link the given arraytable to the tree and return a list of 

102 nodes for with profiles could not been found in arraytable. 

103 

104 Row names in the arraytable object are expected to match leaf 

105 names. 

106 """ 

107 # Initialize tree with array data 

108 

109 if type(arraytbl) == ArrayTable: 

110 array = arraytbl 

111 else: 

112 array = ArrayTable(arraytbl) 

113 

114 missing_leaves = [] 

115 matrix_values = [i for r in range(len(array.matrix))\ 

116 for i in array.matrix[r] if numpy.isfinite(i)] 

117 

118 array._matrix_min = min(matrix_values) 

119 array._matrix_max = max(matrix_values) 

120 

121 for n in self.traverse(): 

122 n.arraytable = array 

123 if n.is_leaf and n.name in array.rowNames: 

124 n._profile = array.get_row_vector(n.name) 

125 elif n.is_leaf: 

126 n._profile = [numpy.nan]*len(array.colNames) 

127 missing_leaves.append(n) 

128 

129 

130 if len(missing_leaves)>0: 

131 print("""[%d] leaf names could not be mapped to the matrix rows.""" %\ 

132 len(missing_leaves), file=stderr) 

133 

134 self.arraytable = array 

135 

136 def leaf_profiles(self): 

137 """Yield profiles associated to the leaves under this node.""" 

138 for l in self.leaves(): 

139 yield l.get_profile()[0] 

140 

141 def get_silhouette(self, fdist=None): 

142 """Calculates the node's silhouette value by using a given 

143 distance function. By default, euclidean distance is used. It 

144 also calculates the deviation profile, mean profile, and 

145 inter/intra-cluster distances. 

146 

147 It sets the following features into the analyzed node: 

148 - node.intracluster 

149 - node.intercluster 

150 - node.silhouete 

151 

152 Intracluster distances a(i) are calculated as the Centroid Diameter. 

153 

154 Intercluster distances b(i) are calculated as the Centroid 

155 linkage distance. 

156 

157 :Citation: 

158 

159 *Rousseeuw, P.J. (1987) Silhouettes: A graphical aid to the 

160 interpretation and validation of cluster analysis.* 

161 

162 J. Comput. Appl. Math., 20, 53-65. 

163 """ 

164 if fdist is None: 

165 fdist = self._fdist 

166 

167 # Updates internal values 

168 self._silhouette, self._intracluster_dist, self._intercluster_dist = \ 

169 clustvalidation.get_silhouette_width(fdist, self) 

170 # And returns them 

171 return self._silhouette, self._intracluster_dist, self._intercluster_dist 

172 

173 def get_dunn(self, clusters, fdist=None): 

174 """ Calculates the Dunn index for the given set of descendant 

175 nodes. 

176 """ 

177 

178 if fdist is None: 

179 fdist = self._fdist 

180 nodes = self._translate_nodes(clusters) 

181 return clustvalidation.get_dunn_index(fdist, *nodes) 

182 

183 def _calculate_avg_profile(self): 

184 """ This internal function updates the mean profile 

185 associated to an internal node. """ 

186 

187 # Updates internal values 

188 self._profile, self._std_profile = clustvalidation.get_avg_profile(self)