sqlglot.lineage
1from __future__ import annotations 2 3import json 4import logging 5import typing as t 6from dataclasses import dataclass, field 7 8from sqlglot import Schema, exp, maybe_parse 9from sqlglot.errors import SqlglotError 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify 11 12if t.TYPE_CHECKING: 13 from sqlglot.dialects.dialect import DialectType 14 15logger = logging.getLogger("sqlglot") 16 17 18@dataclass(frozen=True) 19class Node: 20 name: str 21 expression: exp.Expression 22 source: exp.Expression 23 downstream: t.List[Node] = field(default_factory=list) 24 source_name: str = "" 25 reference_node_name: str = "" 26 27 def walk(self) -> t.Iterator[Node]: 28 yield self 29 30 for d in self.downstream: 31 yield from d.walk() 32 33 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 34 nodes = {} 35 edges = [] 36 37 for node in self.walk(): 38 if isinstance(node.expression, exp.Table): 39 label = f"FROM {node.expression.this}" 40 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 41 group = 1 42 else: 43 label = node.expression.sql(pretty=True, dialect=dialect) 44 source = node.source.transform( 45 lambda n: ( 46 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 47 ), 48 copy=False, 49 ).sql(pretty=True, dialect=dialect) 50 title = f"<pre>{source}</pre>" 51 group = 0 52 53 node_id = id(node) 54 55 nodes[node_id] = { 56 "id": node_id, 57 "label": label, 58 "title": title, 59 "group": group, 60 } 61 62 for d in node.downstream: 63 edges.append({"from": node_id, "to": id(d)}) 64 return GraphHTML(nodes, edges, **opts) 65 66 67def lineage( 68 column: str | exp.Column, 69 sql: str | exp.Expression, 70 schema: t.Optional[t.Dict | Schema] = None, 71 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 72 dialect: DialectType = None, 73 **kwargs, 74) -> Node: 75 """Build the lineage graph for a column of a SQL query. 76 77 Args: 78 column: The column to build the lineage for. 79 sql: The SQL string or expression. 80 schema: The schema of tables. 81 sources: A mapping of queries which will be used to continue building lineage. 82 dialect: The dialect of input SQL. 83 **kwargs: Qualification optimizer kwargs. 84 85 Returns: 86 A lineage node. 87 """ 88 89 expression = maybe_parse(sql, dialect=dialect) 90 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 91 92 if sources: 93 expression = exp.expand( 94 expression, 95 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 96 dialect=dialect, 97 ) 98 99 qualified = qualify.qualify( 100 expression, 101 dialect=dialect, 102 schema=schema, 103 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 104 ) 105 106 scope = build_scope(qualified) 107 108 if not scope: 109 raise SqlglotError("Cannot build lineage, sql must be SELECT") 110 111 if not any(select.alias_or_name == column for select in scope.expression.selects): 112 raise SqlglotError(f"Cannot find column '{column}' in query.") 113 114 return to_node(column, scope, dialect) 115 116 117def to_node( 118 column: str | int, 119 scope: Scope, 120 dialect: DialectType, 121 scope_name: t.Optional[str] = None, 122 upstream: t.Optional[Node] = None, 123 source_name: t.Optional[str] = None, 124 reference_node_name: t.Optional[str] = None, 125) -> Node: 126 source_names = { 127 dt.alias: dt.comments[0].split()[1] 128 for dt in scope.derived_tables 129 if dt.comments and dt.comments[0].startswith("source: ") 130 } 131 132 # Find the specific select clause that is the source of the column we want. 133 # This can either be a specific, named select or a generic `*` clause. 134 select = ( 135 scope.expression.selects[column] 136 if isinstance(column, int) 137 else next( 138 (select for select in scope.expression.selects if select.alias_or_name == column), 139 exp.Star() if scope.expression.is_star else scope.expression, 140 ) 141 ) 142 143 if isinstance(scope.expression, exp.Subquery): 144 for source in scope.subquery_scopes: 145 return to_node( 146 column, 147 scope=source, 148 dialect=dialect, 149 upstream=upstream, 150 source_name=source_name, 151 reference_node_name=reference_node_name, 152 ) 153 if isinstance(scope.expression, exp.Union): 154 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 155 156 index = ( 157 column 158 if isinstance(column, int) 159 else next( 160 ( 161 i 162 for i, select in enumerate(scope.expression.selects) 163 if select.alias_or_name == column or select.is_star 164 ), 165 -1, # mypy will not allow a None here, but a negative index should never be returned 166 ) 167 ) 168 169 if index == -1: 170 raise ValueError(f"Could not find {column} in {scope.expression}") 171 172 for s in scope.union_scopes: 173 to_node( 174 index, 175 scope=s, 176 dialect=dialect, 177 upstream=upstream, 178 source_name=source_name, 179 reference_node_name=reference_node_name, 180 ) 181 182 return upstream 183 184 if isinstance(scope.expression, exp.Select): 185 # For better ergonomics in our node labels, replace the full select with 186 # a version that has only the column we care about. 187 # "x", SELECT x, y FROM foo 188 # => "x", SELECT x FROM foo 189 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 190 else: 191 source = scope.expression 192 193 # Create the node for this step in the lineage chain, and attach it to the previous one. 194 node = Node( 195 name=f"{scope_name}.{column}" if scope_name else str(column), 196 source=source, 197 expression=select, 198 source_name=source_name or "", 199 reference_node_name=reference_node_name or "", 200 ) 201 202 if upstream: 203 upstream.downstream.append(node) 204 205 subquery_scopes = { 206 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 207 } 208 209 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 210 subquery_scope = subquery_scopes.get(id(subquery)) 211 if not subquery_scope: 212 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 213 continue 214 215 for name in subquery.named_selects: 216 to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) 217 218 # if the select is a star add all scope sources as downstreams 219 if select.is_star: 220 for source in scope.sources.values(): 221 if isinstance(source, Scope): 222 source = source.expression 223 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 224 225 # Find all columns that went into creating this one to list their lineage nodes. 226 source_columns = set(find_all_in_scope(select, exp.Column)) 227 228 # If the source is a UDTF find columns used in the UTDF to generate the table 229 if isinstance(source, exp.UDTF): 230 source_columns |= set(source.find_all(exp.Column)) 231 232 for c in source_columns: 233 table = c.table 234 source = scope.sources.get(table) 235 236 if isinstance(source, Scope): 237 selected_node, _ = scope.selected_sources.get(table, (None, None)) 238 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 239 to_node( 240 c.name, 241 scope=source, 242 dialect=dialect, 243 scope_name=table, 244 upstream=node, 245 source_name=source_names.get(table) or source_name, 246 reference_node_name=selected_node.name if selected_node else None, 247 ) 248 else: 249 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 250 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 251 # is not passed into the `sources` map. 252 source = source or exp.Placeholder() 253 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 254 255 return node 256 257 258class GraphHTML: 259 """Node to HTML generator using vis.js. 260 261 https://visjs.github.io/vis-network/docs/network/ 262 """ 263 264 def __init__( 265 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 266 ): 267 self.imports = imports 268 269 self.options = { 270 "height": "500px", 271 "width": "100%", 272 "layout": { 273 "hierarchical": { 274 "enabled": True, 275 "nodeSpacing": 200, 276 "sortMethod": "directed", 277 }, 278 }, 279 "interaction": { 280 "dragNodes": False, 281 "selectable": False, 282 }, 283 "physics": { 284 "enabled": False, 285 }, 286 "edges": { 287 "arrows": "to", 288 }, 289 "nodes": { 290 "font": "20px monaco", 291 "shape": "box", 292 "widthConstraint": { 293 "maximum": 300, 294 }, 295 }, 296 **(options or {}), 297 } 298 299 self.nodes = nodes 300 self.edges = edges 301 302 def __str__(self): 303 nodes = json.dumps(list(self.nodes.values())) 304 edges = json.dumps(self.edges) 305 options = json.dumps(self.options) 306 imports = ( 307 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 308 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 309 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 310 if self.imports 311 else "" 312 ) 313 314 return f"""<div> 315 <div id="sqlglot-lineage"></div> 316 {imports} 317 <script type="text/javascript"> 318 var nodes = new vis.DataSet({nodes}) 319 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 320 321 new vis.Network( 322 document.getElementById("sqlglot-lineage"), 323 {{ 324 nodes: nodes, 325 edges: new vis.DataSet({edges}) 326 }}, 327 {options}, 328 ) 329 </script> 330</div>""" 331 332 def _repr_html_(self) -> str: 333 return self.__str__()
logger =
<Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class
Node:
19@dataclass(frozen=True) 20class Node: 21 name: str 22 expression: exp.Expression 23 source: exp.Expression 24 downstream: t.List[Node] = field(default_factory=list) 25 source_name: str = "" 26 reference_node_name: str = "" 27 28 def walk(self) -> t.Iterator[Node]: 29 yield self 30 31 for d in self.downstream: 32 yield from d.walk() 33 34 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 35 nodes = {} 36 edges = [] 37 38 for node in self.walk(): 39 if isinstance(node.expression, exp.Table): 40 label = f"FROM {node.expression.this}" 41 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 42 group = 1 43 else: 44 label = node.expression.sql(pretty=True, dialect=dialect) 45 source = node.source.transform( 46 lambda n: ( 47 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 48 ), 49 copy=False, 50 ).sql(pretty=True, dialect=dialect) 51 title = f"<pre>{source}</pre>" 52 group = 0 53 54 node_id = id(node) 55 56 nodes[node_id] = { 57 "id": node_id, 58 "label": label, 59 "title": title, 60 "group": group, 61 } 62 63 for d in node.downstream: 64 edges.append({"from": node_id, "to": id(d)}) 65 return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
expression: sqlglot.expressions.Expression
source: sqlglot.expressions.Expression
downstream: List[Node]
def
to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
34 def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML: 35 nodes = {} 36 edges = [] 37 38 for node in self.walk(): 39 if isinstance(node.expression, exp.Table): 40 label = f"FROM {node.expression.this}" 41 title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>" 42 group = 1 43 else: 44 label = node.expression.sql(pretty=True, dialect=dialect) 45 source = node.source.transform( 46 lambda n: ( 47 exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n 48 ), 49 copy=False, 50 ).sql(pretty=True, dialect=dialect) 51 title = f"<pre>{source}</pre>" 52 group = 0 53 54 node_id = id(node) 55 56 nodes[node_id] = { 57 "id": node_id, 58 "label": label, 59 "title": title, 60 "group": group, 61 } 62 63 for d in node.downstream: 64 edges.append({"from": node_id, "to": id(d)}) 65 return GraphHTML(nodes, edges, **opts)
def
lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
68def lineage( 69 column: str | exp.Column, 70 sql: str | exp.Expression, 71 schema: t.Optional[t.Dict | Schema] = None, 72 sources: t.Optional[t.Dict[str, str | exp.Query]] = None, 73 dialect: DialectType = None, 74 **kwargs, 75) -> Node: 76 """Build the lineage graph for a column of a SQL query. 77 78 Args: 79 column: The column to build the lineage for. 80 sql: The SQL string or expression. 81 schema: The schema of tables. 82 sources: A mapping of queries which will be used to continue building lineage. 83 dialect: The dialect of input SQL. 84 **kwargs: Qualification optimizer kwargs. 85 86 Returns: 87 A lineage node. 88 """ 89 90 expression = maybe_parse(sql, dialect=dialect) 91 column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name 92 93 if sources: 94 expression = exp.expand( 95 expression, 96 {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()}, 97 dialect=dialect, 98 ) 99 100 qualified = qualify.qualify( 101 expression, 102 dialect=dialect, 103 schema=schema, 104 **{"validate_qualify_columns": False, "identify": False, **kwargs}, # type: ignore 105 ) 106 107 scope = build_scope(qualified) 108 109 if not scope: 110 raise SqlglotError("Cannot build lineage, sql must be SELECT") 111 112 if not any(select.alias_or_name == column for select in scope.expression.selects): 113 raise SqlglotError(f"Cannot find column '{column}' in query.") 114 115 return to_node(column, scope, dialect)
Build the lineage graph for a column of a SQL query.
Arguments:
- column: The column to build the lineage for.
- sql: The SQL string or expression.
- schema: The schema of tables.
- sources: A mapping of queries which will be used to continue building lineage.
- dialect: The dialect of input SQL.
- **kwargs: Qualification optimizer kwargs.
Returns:
A lineage node.
def
to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None) -> Node:
118def to_node( 119 column: str | int, 120 scope: Scope, 121 dialect: DialectType, 122 scope_name: t.Optional[str] = None, 123 upstream: t.Optional[Node] = None, 124 source_name: t.Optional[str] = None, 125 reference_node_name: t.Optional[str] = None, 126) -> Node: 127 source_names = { 128 dt.alias: dt.comments[0].split()[1] 129 for dt in scope.derived_tables 130 if dt.comments and dt.comments[0].startswith("source: ") 131 } 132 133 # Find the specific select clause that is the source of the column we want. 134 # This can either be a specific, named select or a generic `*` clause. 135 select = ( 136 scope.expression.selects[column] 137 if isinstance(column, int) 138 else next( 139 (select for select in scope.expression.selects if select.alias_or_name == column), 140 exp.Star() if scope.expression.is_star else scope.expression, 141 ) 142 ) 143 144 if isinstance(scope.expression, exp.Subquery): 145 for source in scope.subquery_scopes: 146 return to_node( 147 column, 148 scope=source, 149 dialect=dialect, 150 upstream=upstream, 151 source_name=source_name, 152 reference_node_name=reference_node_name, 153 ) 154 if isinstance(scope.expression, exp.Union): 155 upstream = upstream or Node(name="UNION", source=scope.expression, expression=select) 156 157 index = ( 158 column 159 if isinstance(column, int) 160 else next( 161 ( 162 i 163 for i, select in enumerate(scope.expression.selects) 164 if select.alias_or_name == column or select.is_star 165 ), 166 -1, # mypy will not allow a None here, but a negative index should never be returned 167 ) 168 ) 169 170 if index == -1: 171 raise ValueError(f"Could not find {column} in {scope.expression}") 172 173 for s in scope.union_scopes: 174 to_node( 175 index, 176 scope=s, 177 dialect=dialect, 178 upstream=upstream, 179 source_name=source_name, 180 reference_node_name=reference_node_name, 181 ) 182 183 return upstream 184 185 if isinstance(scope.expression, exp.Select): 186 # For better ergonomics in our node labels, replace the full select with 187 # a version that has only the column we care about. 188 # "x", SELECT x, y FROM foo 189 # => "x", SELECT x FROM foo 190 source = t.cast(exp.Expression, scope.expression.select(select, append=False)) 191 else: 192 source = scope.expression 193 194 # Create the node for this step in the lineage chain, and attach it to the previous one. 195 node = Node( 196 name=f"{scope_name}.{column}" if scope_name else str(column), 197 source=source, 198 expression=select, 199 source_name=source_name or "", 200 reference_node_name=reference_node_name or "", 201 ) 202 203 if upstream: 204 upstream.downstream.append(node) 205 206 subquery_scopes = { 207 id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes 208 } 209 210 for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES): 211 subquery_scope = subquery_scopes.get(id(subquery)) 212 if not subquery_scope: 213 logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}") 214 continue 215 216 for name in subquery.named_selects: 217 to_node(name, scope=subquery_scope, dialect=dialect, upstream=node) 218 219 # if the select is a star add all scope sources as downstreams 220 if select.is_star: 221 for source in scope.sources.values(): 222 if isinstance(source, Scope): 223 source = source.expression 224 node.downstream.append(Node(name=select.sql(), source=source, expression=source)) 225 226 # Find all columns that went into creating this one to list their lineage nodes. 227 source_columns = set(find_all_in_scope(select, exp.Column)) 228 229 # If the source is a UDTF find columns used in the UTDF to generate the table 230 if isinstance(source, exp.UDTF): 231 source_columns |= set(source.find_all(exp.Column)) 232 233 for c in source_columns: 234 table = c.table 235 source = scope.sources.get(table) 236 237 if isinstance(source, Scope): 238 selected_node, _ = scope.selected_sources.get(table, (None, None)) 239 # The table itself came from a more specific scope. Recurse into that one using the unaliased column name. 240 to_node( 241 c.name, 242 scope=source, 243 dialect=dialect, 244 scope_name=table, 245 upstream=node, 246 source_name=source_names.get(table) or source_name, 247 reference_node_name=selected_node.name if selected_node else None, 248 ) 249 else: 250 # The source is not a scope - we've reached the end of the line. At this point, if a source is not found 251 # it means this column's lineage is unknown. This can happen if the definition of a source used in a query 252 # is not passed into the `sources` map. 253 source = source or exp.Placeholder() 254 node.downstream.append(Node(name=c.sql(), source=source, expression=source)) 255 256 return node
class
GraphHTML:
259class GraphHTML: 260 """Node to HTML generator using vis.js. 261 262 https://visjs.github.io/vis-network/docs/network/ 263 """ 264 265 def __init__( 266 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 267 ): 268 self.imports = imports 269 270 self.options = { 271 "height": "500px", 272 "width": "100%", 273 "layout": { 274 "hierarchical": { 275 "enabled": True, 276 "nodeSpacing": 200, 277 "sortMethod": "directed", 278 }, 279 }, 280 "interaction": { 281 "dragNodes": False, 282 "selectable": False, 283 }, 284 "physics": { 285 "enabled": False, 286 }, 287 "edges": { 288 "arrows": "to", 289 }, 290 "nodes": { 291 "font": "20px monaco", 292 "shape": "box", 293 "widthConstraint": { 294 "maximum": 300, 295 }, 296 }, 297 **(options or {}), 298 } 299 300 self.nodes = nodes 301 self.edges = edges 302 303 def __str__(self): 304 nodes = json.dumps(list(self.nodes.values())) 305 edges = json.dumps(self.edges) 306 options = json.dumps(self.options) 307 imports = ( 308 """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script> 309 <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script> 310 <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />""" 311 if self.imports 312 else "" 313 ) 314 315 return f"""<div> 316 <div id="sqlglot-lineage"></div> 317 {imports} 318 <script type="text/javascript"> 319 var nodes = new vis.DataSet({nodes}) 320 nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0]) 321 322 new vis.Network( 323 document.getElementById("sqlglot-lineage"), 324 {{ 325 nodes: nodes, 326 edges: new vis.DataSet({edges}) 327 }}, 328 {options}, 329 ) 330 </script> 331</div>""" 332 333 def _repr_html_(self) -> str: 334 return self.__str__()
Node to HTML generator using vis.js.
GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
265 def __init__( 266 self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None 267 ): 268 self.imports = imports 269 270 self.options = { 271 "height": "500px", 272 "width": "100%", 273 "layout": { 274 "hierarchical": { 275 "enabled": True, 276 "nodeSpacing": 200, 277 "sortMethod": "directed", 278 }, 279 }, 280 "interaction": { 281 "dragNodes": False, 282 "selectable": False, 283 }, 284 "physics": { 285 "enabled": False, 286 }, 287 "edges": { 288 "arrows": "to", 289 }, 290 "nodes": { 291 "font": "20px monaco", 292 "shape": "box", 293 "widthConstraint": { 294 "maximum": 300, 295 }, 296 }, 297 **(options or {}), 298 } 299 300 self.nodes = nodes 301 self.edges = edges