Coverage for kye/loader/loader.py: 26%

76 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2024-01-04 14:58 -0700

1import duckdb 

2from duckdb import DuckDBPyRelation, DuckDBPyConnection 

3from kye.loader.json_lines import from_json 

4from kye.types import Type, EDGE, TYPE_REF 

5 

6 

7def append_table(con: DuckDBPyConnection, orig: DuckDBPyRelation, new: DuckDBPyRelation): 

8 """ 

9 This function will not be needed in the future if we can figure out a standard way 

10 to create the staging tables with the correct types before any data is uploaded. 

11 """ 

12 

13 def get_dtypes(r: DuckDBPyRelation): 

14 return dict(zip(r.columns, r.dtypes)) 

15 

16 orig_dtypes = get_dtypes(orig) 

17 new_dtypes = get_dtypes(new) 

18 

19 # Check that the types of the columns match 

20 for col in set(orig_dtypes) & set(new_dtypes): 

21 if orig_dtypes[col] != new_dtypes[col]: 

22 raise ValueError(f'''Column {col} has conflicting types: {orig_dtypes[col]} != {new_dtypes[col]}''') 

23 

24 # Alter the original table to include any new columns 

25 for col in set(new_dtypes) - set(orig_dtypes): 

26 con.sql(f'''ALTER TABLE "{orig.alias}" ADD COLUMN {col} {new_dtypes[col]}''') 

27 

28 # preserve the order of columns from the original table 

29 # and cast any new columns to null 

30 new = new.select(', '.join( 

31 col if col in new_dtypes 

32 else f'CAST(NULL as {orig_dtypes[col]}) as {col}' 

33 for col in con.table(f'"{orig.alias}"').columns 

34 )) 

35 

36 new.insert_into(f'"{orig.alias}"') 

37 

38def get_struct_keys(r: DuckDBPyRelation): 

39 assert r.columns[1] == 'val' 

40 assert r.dtypes[1].id == 'struct' 

41 return [col[0] for col in r.dtypes[1].children] 

42 

43def struct_pack(edges: list[str], r: DuckDBPyRelation): 

44 return 'struct_pack(' + ','.join( 

45 f'''"{edge_name}":="{edge_name}"''' 

46 for edge_name in edges 

47 if edge_name in r.columns 

48 ) + ')' 

49 

50def get_index(typ: Type, r: DuckDBPyRelation): 

51 # Hash the index columns 

52 r = r.select(f'''hash({struct_pack(sorted(typ.index), r)}) as _index, *''') 

53 

54 # Filter out null indexes 

55 r = r.filter(f'''{' AND '.join(edge + ' IS NOT NULL' for edge in typ.index)}''') 

56 return r 

57 

58class Loader: 

59 """ 

60 The loader is responsible for normalizing the shape of the data. It makes sure that 

61 all of the columns are present (filling in nulls where necessary) and also computes 

62 the index hash for each row so that it is easy to join the data together later. 

63  

64 The loader operates for each chunk of the data while it is loading. So it does not 

65 do any cross table aggregations or validation. 

66 

67 Any value normalization needs to be done here so that the index hash is consistent. 

68 """ 

69 # If I store types and edges in separate relations, then that will allow me 

70 # to have a more standard storage format and not have to append columns to tables 

71 # right? Because every edge table would look like (index:int64, value:str, args:list[str]) 

72 # It would also allow me to do my quad store if I really wanted to. 

73 tables: dict[TYPE_REF, duckdb.DuckDBPyRelation] 

74 models: dict[TYPE_REF, Type] 

75 db: duckdb.DuckDBPyConnection 

76 chunks: dict[str, duckdb.DuckDBPyRelation] 

77 

78 def __init__(self, models: dict[TYPE_REF, Type]): 

79 self.tables = {} 

80 self.models = models 

81 self.db = duckdb.connect(':memory:') 

82 self.chunks = {} 

83 

84 def _insert(self, model_name: TYPE_REF, r: duckdb.DuckDBPyRelation): 

85 table_name = f'"{model_name}.staging"' 

86 if model_name not in self.tables: 

87 r.create(table_name) 

88 else: 

89 append_table(self.db, self.tables[model_name], r) 

90 self.tables[model_name] = self.db.table(table_name) 

91 

92 def _load(self, typ: Type, r: duckdb.DuckDBPyRelation): 

93 chunk_id = typ.ref + '_' + str(len(self.chunks) + 1) 

94 chunk = r.select(f'''list_value('{chunk_id}', ROW_NUMBER() OVER () - 1) as _, {struct_pack(typ.edges, r)} as val''').set_alias(chunk_id) 

95 self.chunks[chunk_id] = chunk 

96 self._get_value(typ, chunk) 

97 

98 def _get_value(self, typ: Type, r: DuckDBPyRelation): 

99 if typ.has_index: 

100 edges = r.select('_') 

101 for edge in typ.edges: 

102 if edge in get_struct_keys(r): 

103 edge_rel = self._get_edge(typ, edge, r.select(f'''list_append(_, '{edge}') as _, val.{edge} as val''')).set_alias(typ.ref + '.' + edge) 

104 edge_rel = edge_rel.select(f'''array_pop_back(_) as _, val as {edge}''') 

105 edges = edges.join(edge_rel, '_', how='left') 

106 

107 edges = get_index(typ, edges) 

108 self._insert(typ.ref, edges) 

109 return edges.select(f'''_, _index as val''') 

110 

111 # Eventually this will be replaced with a custom function for normalizing 

112 # values right? Like a DateTime type needs to be converted into a standard 

113 # format for index and equivalency checks right? 

114 # The standard format that it is converted into might also depend on the 

115 # storage system 

116 elif r.dtypes[1].id != 'varchar': 

117 dtype = r.dtypes[1].id 

118 r = r.select(f'''_, CAST(val AS VARCHAR) as val''') 

119 # remove trailing '.0' from decimals so that 

120 # they will match integers of the same value 

121 if dtype in ['double','decimal','real']: 

122 r = r.select(f'''_, REGEXP_REPLACE(val, '\\.0$', '') as val''') 

123 return r 

124 

125 def _get_edge(self, typ: Type, edge: EDGE, r: DuckDBPyRelation): 

126 if typ.allows_multiple(edge): 

127 r = r.select('''_, unnest(val) as val''').select('list_append(_, ROW_NUMBER() OVER (PARTITION BY _) - 1) as _, val') 

128 

129 r = self._get_value(typ.get_edge(edge), r) 

130 

131 if typ.allows_multiple(edge): 

132 r = r.aggregate('array_pop_back(_) as _, list(val) as val','array_pop_back(_)') 

133 

134 return r 

135 

136 def from_json(self, model_name: TYPE_REF, data: list[dict]): 

137 r = from_json(self.models[model_name], data, self.db) 

138 self._load(self.models[model_name], r) 

139 

140 def __getitem__(self, model_name: str): 

141 return self.tables[model_name] 

142 

143 def __repr__(self): 

144 return f"<Loader {','.join(self.tables.keys())}>"