sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15logger = logging.getLogger("sqlglot") 16 17 18@dataclass(frozen=True) 19class Node: 20 name: str 21 expression: exp.Expression 22 source: exp.Expression 23 downstream: t.List[Node] = field(default_factory=list) 24 source_name: str = "" 25 reference_node_name: str = "" 26 27 def walk(self) -> t.Iterator[Node]: 28 yield self 29 30 for d in self.downstream: 31 if isinstance(d, Node): 32 yield from d.walk() 33 else: 34 yield d 35 36 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 37 nodes = {} 38 edges = [] 39 40 for node in self.walk(): 41 if isinstance(node.expression, exp.Table): 42 label = f"FROM {node.expression.this}" 43 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 44 group = 1 45 else: 46 label = node.expression.sql(pretty=True, dialect=dialect) 47 source = node.source.transform( 48 lambda n: ( 49 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 50 ), 51 copy=False, 52 ).sql(pretty=True, dialect=dialect) 53 title = f"<pre>{source}</pre>" 54 group = 0 55 56 node_id = id(node) 57 58 nodes[node_id] = { 59 "id": node_id, 60 "label": label, 61 "title": title, 62 "group": group, 63 } 64 65 for d in node.downstream: 66 edges.append({"from": node_id, "to": id(d)}) 67 return GraphHTML(nodes, edges, **opts) 68 69 70def lineage( 71 column: str | exp.Column, 72 sql: str | exp.Expression, 73 schema: t.Optional[t.Dict | Schema] = None, 74 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 75 dialect: DialectType = None, 76 **kwargs, 77) -> Node: 78 """Build the lineage graph for a column of a SQL query. 79 80 Args: 81 column: The column to build the lineage for. 82 sql: The SQL string or expression. 83 schema: The schema of tables. 84 sources: A mapping of queries which will be used to continue building lineage. 85 dialect: The dialect of input SQL. 86 **kwargs: Qualification optimizer kwargs. 87 88 Returns: 89 A lineage node. 90 """ 91 92 expression = maybe_parse(sql, dialect=dialect) 93 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 94 95 if sources: 96 expression = exp.expand( 97 expression, 98 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 99 dialect=dialect, 100 ) 101 102 qualified = qualify.qualify( 103 expression, 104 dialect=dialect, 105 schema=schema, 106 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 107 ) 108 109 scope = build_scope(qualified) 110 111 if not scope: 112 raise SqlglotError("Cannot build lineage, sql must be SELECT") 113 114 if not any(select.alias_or_name == column for select in scope.expression.selects): 115 raise SqlglotError(f"Cannot find column '{column}' in query.") 116 117 return to_node(column, scope, dialect) 118 119 120def to_node( 121 column: str | int, 122 scope: Scope, 123 dialect: DialectType, 124 scope_name: t.Optional[str] = None, 125 upstream: t.Optional[Node] = None, 126 source_name: t.Optional[str] = None, 127 reference_node_name: t.Optional[str] = None, 128) -> Node: 129 source_names = { 130 dt.alias: dt.comments[0].split()[1] 131 for dt in scope.derived_tables 132 if dt.comments and dt.comments[0].startswith("source: ") 133 } 134 135 # Find the specific select clause that is the source of the column we want. 136 # This can either be a specific, named select or a generic `*` clause. 137 select = ( 138 scope.expression.selects[column] 139 if isinstance(column, int) 140 else next( 141 (select for select in scope.expression.selects if select.alias_or_name == column), 142 exp.Star() if scope.expression.is_star else scope.expression, 143 ) 144 ) 145 146 if isinstance(scope.expression, exp.Subquery): 147 for source in scope.subquery_scopes: 148 return to_node( 149 column, 150 scope=source, 151 dialect=dialect, 152 upstream=upstream, 153 source_name=source_name, 154 reference_node_name=reference_node_name, 155 ) 156 if isinstance(scope.expression, exp.Union): 157 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 158 159 index = ( 160 column 161 if isinstance(column, int) 162 else next( 163 ( 164 i 165 for i, select in enumerate(scope.expression.selects) 166 if select.alias_or_name == column or select.is_star 167 ), 168 -1, # mypy will not allow a None here, but a negative index should never be returned 169 ) 170 ) 171 172 if index == -1: 173 raise ValueError(f"Could not find {column} in {scope.expression}") 174 175 for s in scope.union_scopes: 176 to_node( 177 index, 178 scope=s, 179 dialect=dialect, 180 upstream=upstream, 181 source_name=source_name, 182 reference_node_name=reference_node_name, 183 ) 184 185 return upstream 186 187 if isinstance(scope.expression, exp.Select): 188 # For better ergonomics in our node labels, replace the full select with 189 # a version that has only the column we care about. 190 # "x", SELECT x, y FROM foo 191 # => "x", SELECT x FROM foo 192 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 193 else: 194 source = scope.expression 195 196 # Create the node for this step in the lineage chain, and attach it to the previous one. 197 node = Node( 198 name=f"{scope_name}.{column}" if scope_name else str(column), 199 source=source, 200 expression=select, 201 source_name=source_name or "", 202 reference_node_name=reference_node_name or "", 203 ) 204 205 if upstream: 206 upstream.downstream.append(node) 207 208 subquery_scopes = { 209 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 210 } 211 212 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 213 subquery_scope = subquery_scopes.get(id(subquery)) 214 if not subquery_scope: 215 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 216 continue 217 218 for name in subquery.named_selects: 219 to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) 220 221 # if the select is a star add all scope sources as downstreams 222 if select.is_star: 223 for source in scope.sources.values(): 224 if isinstance(source, Scope): 225 source = source.expression 226 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 227 228 # Find all columns that went into creating this one to list their lineage nodes. 229 source_columns = set(find_all_in_scope(select, exp.Column)) 230 231 # If the source is a UDTF find columns used in the UTDF to generate the table 232 if isinstance(source, exp.UDTF): 233 source_columns |= set(source.find_all(exp.Column)) 234 235 for c in source_columns: 236 table = c.table 237 source = scope.sources.get(table) 238 239 if isinstance(source, Scope): 240 selected_node, _ = scope.selected_sources.get(table, (None, None)) 241 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 242 to_node( 243 c.name, 244 scope=source, 245 dialect=dialect, 246 scope_name=table, 247 upstream=node, 248 source_name=source_names.get(table) or source_name, 249 reference_node_name=selected_node.name if selected_node else None, 250 ) 251 else: 252 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 253 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 254 # is not passed into the `sources` map. 255 source = source or exp.Placeholder() 256 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 257 258 return node 259 260 261class GraphHTML: 262 """Node to HTML generator using vis.js. 263 264 https://visjs.github.io/vis-network/docs/network/ 265 """ 266 267 def __init__( 268 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 269 ): 270 self.imports = imports 271 272 self.options = { 273 "height": "500px", 274 "width": "100%", 275 "layout": { 276 "hierarchical": { 277 "enabled": True, 278 "nodeSpacing": 200, 279 "sortMethod": "directed", 280 }, 281 }, 282 "interaction": { 283 "dragNodes": False, 284 "selectable": False, 285 }, 286 "physics": { 287 "enabled": False, 288 }, 289 "edges": { 290 "arrows": "to", 291 }, 292 "nodes": { 293 "font": "20px monaco", 294 "shape": "box", 295 "widthConstraint": { 296 "maximum": 300, 297 }, 298 }, 299 **(options or {}), 300 } 301 302 self.nodes = nodes 303 self.edges = edges 304 305 def __str__(self): 306 nodes = json.dumps(list(self.nodes.values())) 307 edges = json.dumps(self.edges) 308 options = json.dumps(self.options) 309 imports = ( 310 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 311 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 312 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 313 if self.imports 314 else "" 315 ) 316 317 return f"""<div> 318 <div id="sqlglot-lineage"></div> 319 {imports} 320 <script type="text/javascript"> 321 var nodes = new vis.DataSet({nodes}) 322 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 323 324 new vis.Network( 325 document.getElementById("sqlglot-lineage"), 326 {{ 327 nodes: nodes, 328 edges: new vis.DataSet({edges}) 329 }}, 330 {options}, 331 ) 332 </script> 333</div>""" 334 335 def _repr_html_(self) -> str: 336 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
19@dataclass(frozen=True) 20class Node: 21 name: str 22 expression: exp.Expression 23 source: exp.Expression 24 downstream: t.List[Node] = field(default_factory=list) 25 source_name: str = "" 26 reference_node_name: str = "" 27 28 def walk(self) -> t.Iterator[Node]: 29 yield self 30 31 for d in self.downstream: 32 if isinstance(d, Node): 33 yield from d.walk() 34 else: 35 yield d 36 37 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 38 nodes = {} 39 edges = [] 40 41 for node in self.walk(): 42 if isinstance(node.expression, exp.Table): 43 label = f"FROM {node.expression.this}" 44 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 45 group = 1 46 else: 47 label = node.expression.sql(pretty=True, dialect=dialect) 48 source = node.source.transform( 49 lambda n: ( 50 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 51 ), 52 copy=False, 53 ).sql(pretty=True, dialect=dialect) 54 title = f"<pre>{source}</pre>" 55 group = 0 56 57 node_id = id(node) 58 59 nodes[node_id] = { 60 "id": node_id, 61 "label": label, 62 "title": title, 63 "group": group, 64 } 65 66 for d in node.downstream: 67 edges.append({"from": node_id, "to": id(d)}) 68 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
37 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 38 nodes = {} 39 edges = [] 40 41 for node in self.walk(): 42 if isinstance(node.expression, exp.Table): 43 label = f"FROM {node.expression.this}" 44 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 45 group = 1 46 else: 47 label = node.expression.sql(pretty=True, dialect=dialect) 48 source = node.source.transform( 49 lambda n: ( 50 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 51 ), 52 copy=False, 53 ).sql(pretty=True, dialect=dialect) 54 title = f"<pre>{source}</pre>" 55 group = 0 56 57 node_id = id(node) 58 59 nodes[node_id] = { 60 "id": node_id, 61 "label": label, 62 "title": title, 63 "group": group, 64 } 65 66 for d in node.downstream: 67 edges.append({"from": node_id, "to": id(d)}) 68 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
71def lineage( 72 column: str | exp.Column, 73 sql: str | exp.Expression, 74 schema: t.Optional[t.Dict | Schema] = None, 75 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 76 dialect: DialectType = None, 77 **kwargs, 78) -> Node: 79 """Build the lineage graph for a column of a SQL query. 80 81 Args: 82 column: The column to build the lineage for. 83 sql: The SQL string or expression. 84 schema: The schema of tables. 85 sources: A mapping of queries which will be used to continue building lineage. 86 dialect: The dialect of input SQL. 87 **kwargs: Qualification optimizer kwargs. 88 89 Returns: 90 A lineage node. 91 """ 92 93 expression = maybe_parse(sql, dialect=dialect) 94 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 95 96 if sources: 97 expression = exp.expand( 98 expression, 99 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 100 dialect=dialect, 101 ) 102 103 qualified = qualify.qualify( 104 expression, 105 dialect=dialect, 106 schema=schema, 107 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 108 ) 109 110 scope = build_scope(qualified) 111 112 if not scope: 113 raise SqlglotError("Cannot build lineage, sql must be SELECT") 114 115 if not any(select.alias_or_name == column for select in scope.expression.selects): 116 raise SqlglotError(f"Cannot find column '{column}' in query.") 117 118 return to_node(column, scope, dialect)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None) -> Node:
121def to_node( 122 column: str | int, 123 scope: Scope, 124 dialect: DialectType, 125 scope_name: t.Optional[str] = None, 126 upstream: t.Optional[Node] = None, 127 source_name: t.Optional[str] = None, 128 reference_node_name: t.Optional[str] = None, 129) -> Node: 130 source_names = { 131 dt.alias: dt.comments[0].split()[1] 132 for dt in scope.derived_tables 133 if dt.comments and dt.comments[0].startswith("source: ") 134 } 135 136 # Find the specific select clause that is the source of the column we want. 137 # This can either be a specific, named select or a generic `*` clause. 138 select = ( 139 scope.expression.selects[column] 140 if isinstance(column, int) 141 else next( 142 (select for select in scope.expression.selects if select.alias_or_name == column), 143 exp.Star() if scope.expression.is_star else scope.expression, 144 ) 145 ) 146 147 if isinstance(scope.expression, exp.Subquery): 148 for source in scope.subquery_scopes: 149 return to_node( 150 column, 151 scope=source, 152 dialect=dialect, 153 upstream=upstream, 154 source_name=source_name, 155 reference_node_name=reference_node_name, 156 ) 157 if isinstance(scope.expression, exp.Union): 158 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 159 160 index = ( 161 column 162 if isinstance(column, int) 163 else next( 164 ( 165 i 166 for i, select in enumerate(scope.expression.selects) 167 if select.alias_or_name == column or select.is_star 168 ), 169 -1, # mypy will not allow a None here, but a negative index should never be returned 170 ) 171 ) 172 173 if index == -1: 174 raise ValueError(f"Could not find {column} in {scope.expression}") 175 176 for s in scope.union_scopes: 177 to_node( 178 index, 179 scope=s, 180 dialect=dialect, 181 upstream=upstream, 182 source_name=source_name, 183 reference_node_name=reference_node_name, 184 ) 185 186 return upstream 187 188 if isinstance(scope.expression, exp.Select): 189 # For better ergonomics in our node labels, replace the full select with 190 # a version that has only the column we care about. 191 # "x", SELECT x, y FROM foo 192 # => "x", SELECT x FROM foo 193 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 194 else: 195 source = scope.expression 196 197 # Create the node for this step in the lineage chain, and attach it to the previous one. 198 node = Node( 199 name=f"{scope_name}.{column}" if scope_name else str(column), 200 source=source, 201 expression=select, 202 source_name=source_name or "", 203 reference_node_name=reference_node_name or "", 204 ) 205 206 if upstream: 207 upstream.downstream.append(node) 208 209 subquery_scopes = { 210 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 211 } 212 213 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 214 subquery_scope = subquery_scopes.get(id(subquery)) 215 if not subquery_scope: 216 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 217 continue 218 219 for name in subquery.named_selects: 220 to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) 221 222 # if the select is a star add all scope sources as downstreams 223 if select.is_star: 224 for source in scope.sources.values(): 225 if isinstance(source, Scope): 226 source = source.expression 227 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 228 229 # Find all columns that went into creating this one to list their lineage nodes. 230 source_columns = set(find_all_in_scope(select, exp.Column)) 231 232 # If the source is a UDTF find columns used in the UTDF to generate the table 233 if isinstance(source, exp.UDTF): 234 source_columns |= set(source.find_all(exp.Column)) 235 236 for c in source_columns: 237 table = c.table 238 source = scope.sources.get(table) 239 240 if isinstance(source, Scope): 241 selected_node, _ = scope.selected_sources.get(table, (None, None)) 242 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 243 to_node( 244 c.name, 245 scope=source, 246 dialect=dialect, 247 scope_name=table, 248 upstream=node, 249 source_name=source_names.get(table) or source_name, 250 reference_node_name=selected_node.name if selected_node else None, 251 ) 252 else: 253 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 254 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 255 # is not passed into the `sources` map. 256 source = source or exp.Placeholder() 257 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 258 259 return node
class
GraphHTML:
262class GraphHTML: 263 """Node to HTML generator using vis.js. 264 265 https://visjs.github.io/vis-network/docs/network/ 266 """ 267 268 def __init__( 269 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 270 ): 271 self.imports = imports 272 273 self.options = { 274 "height": "500px", 275 "width": "100%", 276 "layout": { 277 "hierarchical": { 278 "enabled": True, 279 "nodeSpacing": 200, 280 "sortMethod": "directed", 281 }, 282 }, 283 "interaction": { 284 "dragNodes": False, 285 "selectable": False, 286 }, 287 "physics": { 288 "enabled": False, 289 }, 290 "edges": { 291 "arrows": "to", 292 }, 293 "nodes": { 294 "font": "20px monaco", 295 "shape": "box", 296 "widthConstraint": { 297 "maximum": 300, 298 }, 299 }, 300 **(options or {}), 301 } 302 303 self.nodes = nodes 304 self.edges = edges 305 306 def __str__(self): 307 nodes = json.dumps(list(self.nodes.values())) 308 edges = json.dumps(self.edges) 309 options = json.dumps(self.options) 310 imports = ( 311 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 312 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 313 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 314 if self.imports 315 else "" 316 ) 317 318 return f"""<div> 319 <div id="sqlglot-lineage"></div> 320 {imports} 321 <script type="text/javascript"> 322 var nodes = new vis.DataSet({nodes}) 323 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 324 325 new vis.Network( 326 document.getElementById("sqlglot-lineage"), 327 {{ 328 nodes: nodes, 329 edges: new vis.DataSet({edges}) 330 }}, 331 {options}, 332 ) 333 </script> 334</div>""" 335 336 def _repr_html_(self) -> str: 337 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
268 def __init__( 269 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 270 ): 271 self.imports = imports 272 273 self.options = { 274 "height": "500px", 275 "width": "100%", 276 "layout": { 277 "hierarchical": { 278 "enabled": True, 279 "nodeSpacing": 200, 280 "sortMethod": "directed", 281 }, 282 }, 283 "interaction": { 284 "dragNodes": False, 285 "selectable": False, 286 }, 287 "physics": { 288 "enabled": False, 289 }, 290 "edges": { 291 "arrows": "to", 292 }, 293 "nodes": { 294 "font": "20px monaco", 295 "shape": "box", 296 "widthConstraint": { 297 "maximum": 300, 298 }, 299 }, 300 **(options or {}), 301 } 302 303 self.nodes = nodes 304 self.edges = edges