Coverage for kye/loader/loader.py: 26%
76 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-04 14:58 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-04 14:58 -0700
1import duckdb
2from duckdb import DuckDBPyRelation, DuckDBPyConnection
3from kye.loader.json_lines import from_json
4from kye.types import Type, EDGE, TYPE_REF
7def append_table(con: DuckDBPyConnection, orig: DuckDBPyRelation, new: DuckDBPyRelation):
8 """
9 This function will not be needed in the future if we can figure out a standard way
10 to create the staging tables with the correct types before any data is uploaded.
11 """
13 def get_dtypes(r: DuckDBPyRelation):
14 return dict(zip(r.columns, r.dtypes))
16 orig_dtypes = get_dtypes(orig)
17 new_dtypes = get_dtypes(new)
19 # Check that the types of the columns match
20 for col in set(orig_dtypes) & set(new_dtypes):
21 if orig_dtypes[col] != new_dtypes[col]:
22 raise ValueError(f'''Column {col} has conflicting types: {orig_dtypes[col]} != {new_dtypes[col]}''')
24 # Alter the original table to include any new columns
25 for col in set(new_dtypes) - set(orig_dtypes):
26 con.sql(f'''ALTER TABLE "{orig.alias}" ADD COLUMN {col} {new_dtypes[col]}''')
28 # preserve the order of columns from the original table
29 # and cast any new columns to null
30 new = new.select(', '.join(
31 col if col in new_dtypes
32 else f'CAST(NULL as {orig_dtypes[col]}) as {col}'
33 for col in con.table(f'"{orig.alias}"').columns
34 ))
36 new.insert_into(f'"{orig.alias}"')
38def get_struct_keys(r: DuckDBPyRelation):
39 assert r.columns[1] == 'val'
40 assert r.dtypes[1].id == 'struct'
41 return [col[0] for col in r.dtypes[1].children]
43def struct_pack(edges: list[str], r: DuckDBPyRelation):
44 return 'struct_pack(' + ','.join(
45 f'''"{edge_name}":="{edge_name}"'''
46 for edge_name in edges
47 if edge_name in r.columns
48 ) + ')'
50def get_index(typ: Type, r: DuckDBPyRelation):
51 # Hash the index columns
52 r = r.select(f'''hash({struct_pack(sorted(typ.index), r)}) as _index, *''')
54 # Filter out null indexes
55 r = r.filter(f'''{' AND '.join(edge + ' IS NOT NULL' for edge in typ.index)}''')
56 return r
58class Loader:
59 """
60 The loader is responsible for normalizing the shape of the data. It makes sure that
61 all of the columns are present (filling in nulls where necessary) and also computes
62 the index hash for each row so that it is easy to join the data together later.
64 The loader operates for each chunk of the data while it is loading. So it does not
65 do any cross table aggregations or validation.
67 Any value normalization needs to be done here so that the index hash is consistent.
68 """
69 # If I store types and edges in separate relations, then that will allow me
70 # to have a more standard storage format and not have to append columns to tables
71 # right? Because every edge table would look like (index:int64, value:str, args:list[str])
72 # It would also allow me to do my quad store if I really wanted to.
73 tables: dict[TYPE_REF, duckdb.DuckDBPyRelation]
74 models: dict[TYPE_REF, Type]
75 db: duckdb.DuckDBPyConnection
76 chunks: dict[str, duckdb.DuckDBPyRelation]
78 def __init__(self, models: dict[TYPE_REF, Type]):
79 self.tables = {}
80 self.models = models
81 self.db = duckdb.connect(':memory:')
82 self.chunks = {}
84 def _insert(self, model_name: TYPE_REF, r: duckdb.DuckDBPyRelation):
85 table_name = f'"{model_name}.staging"'
86 if model_name not in self.tables:
87 r.create(table_name)
88 else:
89 append_table(self.db, self.tables[model_name], r)
90 self.tables[model_name] = self.db.table(table_name)
92 def _load(self, typ: Type, r: duckdb.DuckDBPyRelation):
93 chunk_id = typ.ref + '_' + str(len(self.chunks) + 1)
94 chunk = r.select(f'''list_value('{chunk_id}', ROW_NUMBER() OVER () - 1) as _, {struct_pack(typ.edges, r)} as val''').set_alias(chunk_id)
95 self.chunks[chunk_id] = chunk
96 self._get_value(typ, chunk)
98 def _get_value(self, typ: Type, r: DuckDBPyRelation):
99 if typ.has_index:
100 edges = r.select('_')
101 for edge in typ.edges:
102 if edge in get_struct_keys(r):
103 edge_rel = self._get_edge(typ, edge, r.select(f'''list_append(_, '{edge}') as _, val.{edge} as val''')).set_alias(typ.ref + '.' + edge)
104 edge_rel = edge_rel.select(f'''array_pop_back(_) as _, val as {edge}''')
105 edges = edges.join(edge_rel, '_', how='left')
107 edges = get_index(typ, edges)
108 self._insert(typ.ref, edges)
109 return edges.select(f'''_, _index as val''')
111 # Eventually this will be replaced with a custom function for normalizing
112 # values right? Like a DateTime type needs to be converted into a standard
113 # format for index and equivalency checks right?
114 # The standard format that it is converted into might also depend on the
115 # storage system
116 elif r.dtypes[1].id != 'varchar':
117 dtype = r.dtypes[1].id
118 r = r.select(f'''_, CAST(val AS VARCHAR) as val''')
119 # remove trailing '.0' from decimals so that
120 # they will match integers of the same value
121 if dtype in ['double','decimal','real']:
122 r = r.select(f'''_, REGEXP_REPLACE(val, '\\.0$', '') as val''')
123 return r
125 def _get_edge(self, typ: Type, edge: EDGE, r: DuckDBPyRelation):
126 if typ.allows_multiple(edge):
127 r = r.select('''_, unnest(val) as val''').select('list_append(_, ROW_NUMBER() OVER (PARTITION BY _) - 1) as _, val')
129 r = self._get_value(typ.get_edge(edge), r)
131 if typ.allows_multiple(edge):
132 r = r.aggregate('array_pop_back(_) as _, list(val) as val','array_pop_back(_)')
134 return r
136 def from_json(self, model_name: TYPE_REF, data: list[dict]):
137 r = from_json(self.models[model_name], data, self.db)
138 self._load(self.models[model_name], r)
140 def __getitem__(self, model_name: str):
141 return self.tables[model_name]
143 def __repr__(self):
144 return f"<Loader {','.join(self.tables.keys())}>"