Coverage for src/hdmf/common/hierarchicaltable.py: 100%

93 statements  

« prev     ^ index     » next       coverage.py v7.2.5, created at 2023-07-10 23:48 +0000

1""" 

2Module providing additional functionality for dealing with hierarchically nested tables, i.e., 

3tables containing DynamicTableRegion references. 

4""" 

5import pandas as pd 

6import numpy as np 

7from hdmf.common.table import DynamicTable, DynamicTableRegion, VectorIndex 

8from hdmf.common.alignedtable import AlignedDynamicTable 

9from hdmf.utils import docval, getargs 

10 

11 

12@docval({'name': 'dynamic_table', 'type': DynamicTable, 

13 'doc': 'DynamicTable object to be converted to a hierarchical pandas.Dataframe'}, 

14 returns="Hierarchical pandas.DataFrame with usually a pandas.MultiIndex on both the index and columns.", 

15 rtype='pandas.DataFrame', 

16 is_method=False) 

17def to_hierarchical_dataframe(dynamic_table): 

18 """ 

19 Create a hierarchical pandas.DataFrame that represents all data from a collection of linked DynamicTables. 

20 

21 **LIMITATIONS:** Currently this function only supports DynamicTables with a single DynamicTableRegion column. 

22 If a table has more than one DynamicTableRegion column then the function will expand only the 

23 first DynamicTableRegion column found for each table. Any additional DynamicTableRegion columns will remain 

24 nested. 

25 

26 **NOTE:** Some useful functions for further processing of the generated 

27 DataFrame include: 

28 

29 * pandas.DataFrame.reset_index to turn the data from the pandas.MultiIndex into columns 

30 * :py:meth:`~hdmf.common.hierarchicaltable.drop_id_columns` to remove all 'id' columns 

31 * :py:meth:`~hdmf.common.hierarchicaltable.flatten_column_index` to flatten the column index 

32 """ 

33 # TODO: Need to deal with the case where we have more than one DynamicTableRegion column in a given table 

34 # Get the references column 

35 foreign_columns = dynamic_table.get_foreign_columns() 

36 # if table does not contain any DynamicTableRegion columns then we can just convert it to a dataframe 

37 if len(foreign_columns) == 0: 

38 return dynamic_table.to_dataframe() 

39 hcol_name = foreign_columns[0] # We only denormalize the first foreign column for now 

40 hcol = dynamic_table[hcol_name] # Either a VectorIndex pointing to a DynamicTableRegion or a DynamicTableRegion 

41 # Get the target DynamicTable that hcol is pointing to. If hcol is a VectorIndex then we first need 

42 # to get the target of it before we look up the table. 

43 hcol_target = hcol.table if isinstance(hcol, DynamicTableRegion) else hcol.target.table 

44 

45 # Create the data variables we need to collect the data for our output dataframe and associated index 

46 index = [] 

47 data = [] 

48 columns = None 

49 index_names = None 

50 

51 # First we here get a list of DataFrames, one for each row of the column we need to process. 

52 # If hcol is a VectorIndex (i.e., our column is a ragged array of row indices), then simply loading 

53 # the data from the VectorIndex will do the trick. If we have a regular DynamicTableRegion column, 

54 # then we need to load the elements ourselves (using slice syntax to make sure we get DataFrames) 

55 # one-row-at-a-time 

56 if isinstance(hcol, VectorIndex): 

57 rows = hcol.get(slice(None), index=False, df=True) 

58 else: 

59 rows = [hcol[i:(i+1)] for i in range(len(hcol))] 

60 # Retrieve the columns we need to iterate over from our input table. For AlignedDynamicTable we need to 

61 # use the get_colnames function instead of the colnames property to ensure we get all columns not just 

62 # the columns from the main table 

63 dynamic_table_colnames = (dynamic_table.get_colnames(include_category_tables=True, ignore_category_ids=False) 

64 if isinstance(dynamic_table, AlignedDynamicTable) 

65 else dynamic_table.colnames) 

66 

67 # Case 1: Our DynamicTableRegion column points to a DynamicTable that itself does not contain 

68 # any DynamicTableRegion references (i.e., we have reached the end of our table hierarchy). 

69 # If this is the case than we need to de-normalize the data and flatten the hierarchy 

70 if not hcol_target.has_foreign_columns(): 

71 # Iterate over all rows, where each row is described by a DataFrame with one-or-more rows 

72 for row_index, row_df in enumerate(rows): 

73 # Since each row contains a pandas.DataFrame (with possible multiple rows), we 

74 # next need to iterate over all rows in that table to denormalize our data 

75 for row in row_df.itertuples(index=True): 

76 # Determine the column data for our row. Each selected row from our target table 

77 # becomes a row in our flattened table 

78 data.append(row) 

79 # Determine the multi-index tuple for our row, consisting of: i) id of the row in this 

80 # table, ii) all columns (except the hierarchical column we are flattening), and 

81 # iii) the index (i.e., id) from our target row 

82 index_data = ([dynamic_table.id[row_index], ] + 

83 [dynamic_table[row_index, colname] 

84 for colname in dynamic_table_colnames if colname != hcol_name]) 

85 index.append(tuple(index_data)) 

86 

87 # Determine the names for our index and columns of our output table 

88 # We need to do this even if our table was empty (i.e. even is len(rows)==0) 

89 # NOTE: While for a regular DynamicTable the "colnames" property will give us the full list of column names, 

90 # for AlignedDynamicTable we need to use the get_colnames() function instead to make sure we include 

91 # the category table columns as well. 

92 index_names = ([(dynamic_table.name, 'id')] + 

93 [(dynamic_table.name, colname) 

94 for colname in dynamic_table_colnames if colname != hcol_name]) 

95 # Determine the name of our columns 

96 hcol_iter_columns = (hcol_target.get_colnames(include_category_tables=True, ignore_category_ids=False) 

97 if isinstance(hcol_target, AlignedDynamicTable) 

98 else hcol_target.colnames) 

99 columns = pd.MultiIndex.from_tuples([(hcol_target.name, 'id'), ] + 

100 [(hcol_target.name, c) for c in hcol_iter_columns], 

101 names=('source_table', 'label')) 

102 

103 # Case 2: Our DynamicTableRegion columns points to another table with a DynamicTableRegion, i.e., 

104 # we need to recursively resolve more levels of the table hierarchy 

105 else: 

106 # First we need to recursively flatten the hierarchy by calling 'to_hierarchical_dataframe()' 

107 # (i.e., this function) on the target of our hierarchical column 

108 hcol_hdf = to_hierarchical_dataframe(hcol_target) 

109 # Iterate over all rows, where each row is described by a DataFrame with one-or-more rows 

110 for row_index, row_df_level1 in enumerate(rows): 

111 # Since each row contains a pandas.DataFrame (with possible multiple rows), we 

112 # next need to iterate over all rows in that table to denormalize our data 

113 for row_df_level2 in row_df_level1.itertuples(index=True): 

114 # Since our target is itself a a DynamicTable with a DynamicTableRegion columns, 

115 # each target row itself may expand into multiple rows in the flattened hcol_hdf. 

116 # So we now need to look up the rows in hcol_hdf that correspond to the rows in 

117 # row_df_level2. 

118 # NOTE: In this look-up we assume that the ids (and hence the index) of 

119 # each row in the table are in fact unique. 

120 for row_tuple_level3 in hcol_hdf.loc[[row_df_level2[0]]].itertuples(index=True): 

121 # Determine the column data for our row. 

122 data.append(row_tuple_level3[1:]) 

123 # Determine the multi-index tuple for our row, 

124 index_data = ([dynamic_table.id[row_index], ] + 

125 [dynamic_table[row_index, colname] 

126 for colname in dynamic_table_colnames if colname != hcol_name] + 

127 list(row_tuple_level3[0])) 

128 index.append(tuple(index_data)) 

129 # Determine the names for our index and columns of our output table 

130 # We need to do this even if our table was empty (i.e. even is len(rows)==0) 

131 index_names = ([(dynamic_table.name, "id")] + 

132 [(dynamic_table.name, colname) 

133 for colname in dynamic_table_colnames if colname != hcol_name] + 

134 hcol_hdf.index.names) 

135 columns = hcol_hdf.columns 

136 

137 # Check if the index contains any unhashable types. If a table contains a VectorIndex column 

138 # (other than the DynamicTableRegion column) then "TypeError: unhashable type: 'list'" will 

139 # occur when converting the index to pd.MultiIndex. To avoid this error, we next check if any 

140 # of the columns in our index are of type list or np.ndarray 

141 unhashable_index_cols = [] 

142 if len(index) > 0: 

143 unhashable_index_cols = [i for i, v in enumerate(index[0]) if isinstance(v, (list, np.ndarray))] 

144 

145 # If we have any unhashable list or np.array objects in the index then update them to tuples. 

146 # Ideally we would detect this case when constructing the index, but it is easier to do this 

147 # here and it should not be much more expensive, but it requires iterating over all rows again 

148 if len(unhashable_index_cols) > 0: 

149 for i, v in enumerate(index): 

150 temp = list(v) 

151 for ci in unhashable_index_cols: 

152 temp[ci] = tuple(temp[ci]) 

153 index[i] = tuple(temp) 

154 

155 # Construct the pandas dataframe with the hierarchical multi-index 

156 multi_index = pd.MultiIndex.from_tuples(index, names=index_names) 

157 out_df = pd.DataFrame(data=data, index=multi_index, columns=columns) 

158 return out_df 

159 

160 

161def __get_col_name(col): 

162 """ 

163 Internal helper function to get the actual name of a pandas DataFrame column from a 

164 column name that may consists of an arbitrary sequence of tuples. The function 

165 will return the last value of the innermost tuple. 

166 """ 

167 curr_val = col 

168 while isinstance(curr_val, tuple): 

169 curr_val = curr_val[-1] 

170 return curr_val 

171 

172 

173def __flatten_column_name(col): 

174 """ 

175 Internal helper function used to iteratively flatten a nested tuple 

176 

177 :param col: Column name to flatten 

178 :type col: Tuple or String 

179 

180 :returns: If col is a tuple then the result is a flat tuple otherwise col is returned as is 

181 """ 

182 if isinstance(col, tuple): 

183 re = col 

184 while np.any([isinstance(v, tuple) for v in re]): 

185 temp = [] 

186 for v in re: 

187 if isinstance(v, tuple): 

188 temp += list(v) 

189 else: 

190 temp += [v, ] 

191 re = temp 

192 return tuple(re) 

193 else: 

194 return col 

195 

196 

197@docval({'name': 'dataframe', 'type': pd.DataFrame, 

198 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, 

199 {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', 

200 'default': False}, 

201 returns="pandas.DataFrame with the id columns removed", 

202 rtype='pandas.DataFrame', 

203 is_method=False) 

204def drop_id_columns(**kwargs): 

205 """ 

206 Drop all columns named 'id' from the table. 

207 

208 In case a column name is a tuple the function will drop any column for which 

209 the inner-most name is 'id'. The 'id' columns of DynamicTable is in many cases 

210 not necessary for analysis or display. This function allow us to easily filter 

211 all those columns. 

212 

213 :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. 

214 """ 

215 dataframe, inplace = getargs('dataframe', 'inplace', kwargs) 

216 col_name = 'id' 

217 drop_labels = [] 

218 for col in dataframe.columns: 

219 if __get_col_name(col) == col_name: 

220 drop_labels.append(col) 

221 re = dataframe.drop(labels=drop_labels, axis=1, inplace=inplace) 

222 return dataframe if inplace else re 

223 

224 

225@docval({'name': 'dataframe', 'type': pd.DataFrame, 

226 'doc': 'Pandas dataframe to update (usually generated by the to_hierarchical_dataframe function)'}, 

227 {'name': 'max_levels', 'type': (int, np.integer), 

228 'doc': 'Maximum number of levels to use in the resulting column Index. NOTE: When ' 

229 'limiting the number of levels the function simply removes levels from the ' 

230 'beginning. As such, removing levels may result in columns with duplicate names.' 

231 'Value must be >0.', 

232 'default': None}, 

233 {'name': 'inplace', 'type': 'bool', 'doc': 'Update the dataframe inplace or return a modified copy', 

234 'default': False}, 

235 returns="pandas.DataFrame with a regular pandas.Index columns rather and a pandas.MultiIndex", 

236 rtype='pandas.DataFrame', 

237 is_method=False) 

238def flatten_column_index(**kwargs): 

239 """ 

240 Flatten the column index of a pandas DataFrame. 

241 

242 The functions changes the dataframe.columns from a pandas.MultiIndex to a normal Index, 

243 with each column usually being identified by a tuple of strings. This function is 

244 typically used in conjunction with DataFrames generated 

245 by :py:meth:`~hdmf.common.hierarchicaltable.to_hierarchical_dataframe` 

246 

247 :raises ValueError: In case the num_levels is not >0 

248 :raises TypeError: In case that dataframe parameter is not a pandas.Dataframe. 

249 """ 

250 dataframe, max_levels, inplace = getargs('dataframe', 'max_levels', 'inplace', kwargs) 

251 if max_levels is not None and max_levels <= 0: 

252 raise ValueError('max_levels must be greater than 0') 

253 # Compute the new column names 

254 col_names = [__flatten_column_name(col) for col in dataframe.columns.values] 

255 # Apply the max_levels filter. Make sure to do this only for columns that are actually tuples 

256 # in order not to accidentally shorten the actual string name of columns 

257 if max_levels is None: 

258 select_levels = slice(None) 

259 elif max_levels == 1: 

260 select_levels = -1 

261 else: # max_levels > 1 

262 select_levels = slice(-max_levels, None) 

263 col_names = [col[select_levels] if isinstance(col, tuple) else col for col in col_names] 

264 re = dataframe if inplace else dataframe.copy() 

265 re.columns = col_names 

266 return re