Coverage for kye/engine/engine.py: 34%
61 statements
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-16 14:12 -0700
« prev ^ index » next coverage.py v7.3.2, created at 2024-01-16 14:12 -0700
1import duckdb
2from duckdb import DuckDBPyConnection
3from kye.types import Type, EDGE, TYPE_REF
4from kye.engine.load_json import json_to_edges
5from kye.engine.validate import check_table
6from kye.errors import error_factory, Error
7import pandas as pd
9class DuckDBEngine:
10 db: DuckDBPyConnection
11 models: dict[TYPE_REF, Type]
13 def __init__(self, models: dict[TYPE_REF, Type]):
14 self.db = duckdb.connect(':memory:')
15 self.models = models
16 self.has_validated = True
17 self.create_tables()
19 def create_tables(self):
20 self.db.sql('''
21 CREATE TABLE edges (
22 loc TEXT NOT NULL,
23 tbl TEXT NOT NULL,
24 row TEXT NOT NULL,
25 col TEXT NOT NULL,
26 val TEXT NOT NULL,
27 idx UINT64
28 );
29 CREATE TABLE errors (
30 err TEXT NOT NULL,
31 tbl TEXT NOT NULL,
32 idx TEXT,
33 row TEXT,
34 col TEXT,
35 val TEXT
36 );
37 ''')
39 @property
40 def edges(self):
41 return self.db.table('edges')
43 @property
44 def errors(self):
45 return self.db.table('errors')
47 def load_json(self, model: TYPE_REF, data):
48 self.has_validated = False
49 assert model in self.models
50 df = pd.DataFrame(json_to_edges(self.models[model], data))
51 r = duckdb.df(df, connection=self.db)
52 r.select('*, NULL as idx').insert_into('edges')
54 def validate(self):
55 if not self.has_validated:
56 self.db.sql('''
57 TRUNCATE errors;
58 UPDATE edges SET idx = NULL;
59 ''')
60 for model_name in self.edges.aggregate('distinct tbl').fetchall():
61 model = self.models[model_name[0]]
62 check_table(model, self.db)
63 self.has_validated = True
65 def get_table(self, model: TYPE_REF):
66 assert model in self.models
67 self.validate()
68 typ = self.models[model]
69 table = self.db.sql(f'''
70 PIVOT (
71 SELECT * FROM edges
72 ANTI JOIN errors on
73 edges.tbl=errors.tbl
74 AND (edges.row = errors.row OR errors.row IS NULL)
75 AND (edges.col = errors.col OR errors.col IS NULL)
76 AND (edges.val = errors.val OR errors.val IS NULL)
77 AND (edges.idx = errors.idx OR errors.idx IS NULL)
78 WHERE tbl = '{model}'
79 ) ON col USING list(val) GROUP BY idx
80 ''')
81 select = []
82 for edge in typ.edges:
83 if edge in table.columns:
84 if typ.allows_multiple(edge):
85 select.append(f'list_distinct({edge}) as {edge}')
86 else:
87 select.append(f'list_any_value({edge}) as {edge}')
88 else:
89 if typ.allows_multiple(edge):
90 select.append(f'CAST([] AS VARCHAR[]) as {edge}')
91 else:
92 select.append(f'CAST(NULL AS VARCHAR) as {edge}')
93 return table.select(','.join(select))
95 def fetch_json(self, model: TYPE_REF):
96 assert model in self.models
97 table = self.get_table(model)
98 return table.fetchdf().to_dict(orient='records')
100 def get_errors(self) -> list[Error]:
101 r = self.errors.aggregate('''
102 err,
103 tbl,
104 col,
105 count(distinct row) as num_row,
106 count(distinct idx) as num_idx,
107 count(distinct val) as num_val,
108 first(row) as row_example,
109 first(idx) as idx_example,
110 first(val) as val_example,
111 ''')
112 errors = []
113 for err,tbl,col, \
114 num_row, num_idx, num_val, \
115 row_example, idx_example, val_example in r.fetchall():
116 errors.append(error_factory(
117 err_type=err,
118 table_name=tbl,
119 column_name=col,
120 num_rows=num_row,
121 num_indexes=num_idx,
122 num_values=num_val,
123 row_example=row_example,
124 idx_example=idx_example,
125 val_example=val_example,
126 ))
127 return errors