Edit on GitHub

sqlglot.lineage

  1from __future__ import annotations
  2
  3import json
  4import logging
  5import typing as t
  6from dataclasses import dataclass, field
  7
  8from sqlglot import Schema, exp, maybe_parse
  9from sqlglot.errors import SqlglotError
 10from sqlglot.optimizer import Scope, build_scope, find_all_in_scope, normalize_identifiers, qualify
 11
 12if t.TYPE_CHECKING:
 13    from sqlglot.dialects.dialect import DialectType
 14
 15logger = logging.getLogger("sqlglot")
 16
 17
 18@dataclass(frozen=True)
 19class Node:
 20    name: str
 21    expression: exp.Expression
 22    source: exp.Expression
 23    downstream: t.List[Node] = field(default_factory=list)
 24    source_name: str = ""
 25    reference_node_name: str = ""
 26
 27    def walk(self) -> t.Iterator[Node]:
 28        yield self
 29
 30        for d in self.downstream:
 31            if isinstance(d, Node):
 32                yield from d.walk()
 33            else:
 34                yield d
 35
 36    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
 37        nodes = {}
 38        edges = []
 39
 40        for node in self.walk():
 41            if isinstance(node.expression, exp.Table):
 42                label = f"FROM {node.expression.this}"
 43                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
 44                group = 1
 45            else:
 46                label = node.expression.sql(pretty=True, dialect=dialect)
 47                source = node.source.transform(
 48                    lambda n: (
 49                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
 50                    ),
 51                    copy=False,
 52                ).sql(pretty=True, dialect=dialect)
 53                title = f"<pre>{source}</pre>"
 54                group = 0
 55
 56            node_id = id(node)
 57
 58            nodes[node_id] = {
 59                "id": node_id,
 60                "label": label,
 61                "title": title,
 62                "group": group,
 63            }
 64
 65            for d in node.downstream:
 66                edges.append({"from": node_id, "to": id(d)})
 67        return GraphHTML(nodes, edges, **opts)
 68
 69
 70def lineage(
 71    column: str | exp.Column,
 72    sql: str | exp.Expression,
 73    schema: t.Optional[t.Dict | Schema] = None,
 74    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 75    dialect: DialectType = None,
 76    **kwargs,
 77) -> Node:
 78    """Build the lineage graph for a column of a SQL query.
 79
 80    Args:
 81        column: The column to build the lineage for.
 82        sql: The SQL string or expression.
 83        schema: The schema of tables.
 84        sources: A mapping of queries which will be used to continue building lineage.
 85        dialect: The dialect of input SQL.
 86        **kwargs: Qualification optimizer kwargs.
 87
 88    Returns:
 89        A lineage node.
 90    """
 91
 92    expression = maybe_parse(sql, dialect=dialect)
 93    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 94
 95    if sources:
 96        expression = exp.expand(
 97            expression,
 98            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
 99            dialect=dialect,
100        )
101
102    qualified = qualify.qualify(
103        expression,
104        dialect=dialect,
105        schema=schema,
106        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
107    )
108
109    scope = build_scope(qualified)
110
111    if not scope:
112        raise SqlglotError("Cannot build lineage, sql must be SELECT")
113
114    if not any(select.alias_or_name == column for select in scope.expression.selects):
115        raise SqlglotError(f"Cannot find column '{column}' in query.")
116
117    return to_node(column, scope, dialect)
118
119
120def to_node(
121    column: str | int,
122    scope: Scope,
123    dialect: DialectType,
124    scope_name: t.Optional[str] = None,
125    upstream: t.Optional[Node] = None,
126    source_name: t.Optional[str] = None,
127    reference_node_name: t.Optional[str] = None,
128) -> Node:
129    source_names = {
130        dt.alias: dt.comments[0].split()[1]
131        for dt in scope.derived_tables
132        if dt.comments and dt.comments[0].startswith("source: ")
133    }
134
135    # Find the specific select clause that is the source of the column we want.
136    # This can either be a specific, named select or a generic `*` clause.
137    select = (
138        scope.expression.selects[column]
139        if isinstance(column, int)
140        else next(
141            (select for select in scope.expression.selects if select.alias_or_name == column),
142            exp.Star() if scope.expression.is_star else scope.expression,
143        )
144    )
145
146    if isinstance(scope.expression, exp.Subquery):
147        for source in scope.subquery_scopes:
148            return to_node(
149                column,
150                scope=source,
151                dialect=dialect,
152                upstream=upstream,
153                source_name=source_name,
154                reference_node_name=reference_node_name,
155            )
156    if isinstance(scope.expression, exp.Union):
157        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
158
159        index = (
160            column
161            if isinstance(column, int)
162            else next(
163                (
164                    i
165                    for i, select in enumerate(scope.expression.selects)
166                    if select.alias_or_name == column or select.is_star
167                ),
168                -1,  # mypy will not allow a None here, but a negative index should never be returned
169            )
170        )
171
172        if index == -1:
173            raise ValueError(f"Could not find {column} in {scope.expression}")
174
175        for s in scope.union_scopes:
176            to_node(
177                index,
178                scope=s,
179                dialect=dialect,
180                upstream=upstream,
181                source_name=source_name,
182                reference_node_name=reference_node_name,
183            )
184
185        return upstream
186
187    if isinstance(scope.expression, exp.Select):
188        # For better ergonomics in our node labels, replace the full select with
189        # a version that has only the column we care about.
190        #   "x", SELECT x, y FROM foo
191        #     => "x", SELECT x FROM foo
192        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
193    else:
194        source = scope.expression
195
196    # Create the node for this step in the lineage chain, and attach it to the previous one.
197    node = Node(
198        name=f"{scope_name}.{column}" if scope_name else str(column),
199        source=source,
200        expression=select,
201        source_name=source_name or "",
202        reference_node_name=reference_node_name or "",
203    )
204
205    if upstream:
206        upstream.downstream.append(node)
207
208    subquery_scopes = {
209        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
210    }
211
212    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
213        subquery_scope = subquery_scopes.get(id(subquery))
214        if not subquery_scope:
215            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
216            continue
217
218        for name in subquery.named_selects:
219            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
220
221    # if the select is a star add all scope sources as downstreams
222    if select.is_star:
223        for source in scope.sources.values():
224            if isinstance(source, Scope):
225                source = source.expression
226            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
227
228    # Find all columns that went into creating this one to list their lineage nodes.
229    source_columns = set(find_all_in_scope(select, exp.Column))
230
231    # If the source is a UDTF find columns used in the UTDF to generate the table
232    if isinstance(source, exp.UDTF):
233        source_columns |= set(source.find_all(exp.Column))
234
235    for c in source_columns:
236        table = c.table
237        source = scope.sources.get(table)
238
239        if isinstance(source, Scope):
240            selected_node, _ = scope.selected_sources.get(table, (None, None))
241            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
242            to_node(
243                c.name,
244                scope=source,
245                dialect=dialect,
246                scope_name=table,
247                upstream=node,
248                source_name=source_names.get(table) or source_name,
249                reference_node_name=selected_node.name if selected_node else None,
250            )
251        else:
252            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
253            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
254            # is not passed into the `sources` map.
255            source = source or exp.Placeholder()
256            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
257
258    return node
259
260
261class GraphHTML:
262    """Node to HTML generator using vis.js.
263
264    https://visjs.github.io/vis-network/docs/network/
265    """
266
267    def __init__(
268        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
269    ):
270        self.imports = imports
271
272        self.options = {
273            "height": "500px",
274            "width": "100%",
275            "layout": {
276                "hierarchical": {
277                    "enabled": True,
278                    "nodeSpacing": 200,
279                    "sortMethod": "directed",
280                },
281            },
282            "interaction": {
283                "dragNodes": False,
284                "selectable": False,
285            },
286            "physics": {
287                "enabled": False,
288            },
289            "edges": {
290                "arrows": "to",
291            },
292            "nodes": {
293                "font": "20px monaco",
294                "shape": "box",
295                "widthConstraint": {
296                    "maximum": 300,
297                },
298            },
299            **(options or {}),
300        }
301
302        self.nodes = nodes
303        self.edges = edges
304
305    def __str__(self):
306        nodes = json.dumps(list(self.nodes.values()))
307        edges = json.dumps(self.edges)
308        options = json.dumps(self.options)
309        imports = (
310            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
311  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
312  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
313            if self.imports
314            else ""
315        )
316
317        return f"""<div>
318  <div id="sqlglot-lineage"></div>
319  {imports}
320  <script type="text/javascript">
321    var nodes = new vis.DataSet({nodes})
322    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
323
324    new vis.Network(
325        document.getElementById("sqlglot-lineage"),
326        {{
327            nodes: nodes,
328            edges: new vis.DataSet({edges})
329        }},
330        {options},
331    )
332  </script>
333</div>"""
334
335    def _repr_html_(self) -> str:
336        return self.__str__()
logger = <Logger sqlglot (WARNING)>
@dataclass(frozen=True)
class Node:
19@dataclass(frozen=True)
20class Node:
21    name: str
22    expression: exp.Expression
23    source: exp.Expression
24    downstream: t.List[Node] = field(default_factory=list)
25    source_name: str = ""
26    reference_node_name: str = ""
27
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
36
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
Node( name: str, expression: sqlglot.expressions.Expression, source: sqlglot.expressions.Expression, downstream: List[Node] = <factory>, source_name: str = '', reference_node_name: str = '')
name: str
downstream: List[Node]
source_name: str = ''
reference_node_name: str = ''
def walk(self) -> Iterator[Node]:
28    def walk(self) -> t.Iterator[Node]:
29        yield self
30
31        for d in self.downstream:
32            if isinstance(d, Node):
33                yield from d.walk()
34            else:
35                yield d
def to_html( self, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **opts) -> GraphHTML:
37    def to_html(self, dialect: DialectType = None, **opts) -> GraphHTML:
38        nodes = {}
39        edges = []
40
41        for node in self.walk():
42            if isinstance(node.expression, exp.Table):
43                label = f"FROM {node.expression.this}"
44                title = f"<pre>SELECT {node.name} FROM {node.expression.this}</pre>"
45                group = 1
46            else:
47                label = node.expression.sql(pretty=True, dialect=dialect)
48                source = node.source.transform(
49                    lambda n: (
50                        exp.Tag(this=n, prefix="<b>", postfix="</b>") if n is node.expression else n
51                    ),
52                    copy=False,
53                ).sql(pretty=True, dialect=dialect)
54                title = f"<pre>{source}</pre>"
55                group = 0
56
57            node_id = id(node)
58
59            nodes[node_id] = {
60                "id": node_id,
61                "label": label,
62                "title": title,
63                "group": group,
64            }
65
66            for d in node.downstream:
67                edges.append({"from": node_id, "to": id(d)})
68        return GraphHTML(nodes, edges, **opts)
def lineage( column: str | sqlglot.expressions.Column, sql: str | sqlglot.expressions.Expression, schema: Union[Dict, sqlglot.schema.Schema, NoneType] = None, sources: Optional[Dict[str, str | sqlglot.expressions.Query]] = None, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None, **kwargs) -> Node:
 71def lineage(
 72    column: str | exp.Column,
 73    sql: str | exp.Expression,
 74    schema: t.Optional[t.Dict | Schema] = None,
 75    sources: t.Optional[t.Dict[str, str | exp.Query]] = None,
 76    dialect: DialectType = None,
 77    **kwargs,
 78) -> Node:
 79    """Build the lineage graph for a column of a SQL query.
 80
 81    Args:
 82        column: The column to build the lineage for.
 83        sql: The SQL string or expression.
 84        schema: The schema of tables.
 85        sources: A mapping of queries which will be used to continue building lineage.
 86        dialect: The dialect of input SQL.
 87        **kwargs: Qualification optimizer kwargs.
 88
 89    Returns:
 90        A lineage node.
 91    """
 92
 93    expression = maybe_parse(sql, dialect=dialect)
 94    column = normalize_identifiers.normalize_identifiers(column, dialect=dialect).name
 95
 96    if sources:
 97        expression = exp.expand(
 98            expression,
 99            {k: t.cast(exp.Query, maybe_parse(v, dialect=dialect)) for k, v in sources.items()},
100            dialect=dialect,
101        )
102
103    qualified = qualify.qualify(
104        expression,
105        dialect=dialect,
106        schema=schema,
107        **{"validate_qualify_columns": False, "identify": False, **kwargs},  # type: ignore
108    )
109
110    scope = build_scope(qualified)
111
112    if not scope:
113        raise SqlglotError("Cannot build lineage, sql must be SELECT")
114
115    if not any(select.alias_or_name == column for select in scope.expression.selects):
116        raise SqlglotError(f"Cannot find column '{column}' in query.")
117
118    return to_node(column, scope, dialect)

Build the lineage graph for a column of a SQL query.

Arguments:
  • column: The column to build the lineage for.
  • sql: The SQL string or expression.
  • schema: The schema of tables.
  • sources: A mapping of queries which will be used to continue building lineage.
  • dialect: The dialect of input SQL.
  • **kwargs: Qualification optimizer kwargs.
Returns:

A lineage node.

def to_node( column: str | int, scope: sqlglot.optimizer.scope.Scope, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType], scope_name: Optional[str] = None, upstream: Optional[Node] = None, source_name: Optional[str] = None, reference_node_name: Optional[str] = None) -> Node:
121def to_node(
122    column: str | int,
123    scope: Scope,
124    dialect: DialectType,
125    scope_name: t.Optional[str] = None,
126    upstream: t.Optional[Node] = None,
127    source_name: t.Optional[str] = None,
128    reference_node_name: t.Optional[str] = None,
129) -> Node:
130    source_names = {
131        dt.alias: dt.comments[0].split()[1]
132        for dt in scope.derived_tables
133        if dt.comments and dt.comments[0].startswith("source: ")
134    }
135
136    # Find the specific select clause that is the source of the column we want.
137    # This can either be a specific, named select or a generic `*` clause.
138    select = (
139        scope.expression.selects[column]
140        if isinstance(column, int)
141        else next(
142            (select for select in scope.expression.selects if select.alias_or_name == column),
143            exp.Star() if scope.expression.is_star else scope.expression,
144        )
145    )
146
147    if isinstance(scope.expression, exp.Subquery):
148        for source in scope.subquery_scopes:
149            return to_node(
150                column,
151                scope=source,
152                dialect=dialect,
153                upstream=upstream,
154                source_name=source_name,
155                reference_node_name=reference_node_name,
156            )
157    if isinstance(scope.expression, exp.Union):
158        upstream = upstream or Node(name="UNION", source=scope.expression, expression=select)
159
160        index = (
161            column
162            if isinstance(column, int)
163            else next(
164                (
165                    i
166                    for i, select in enumerate(scope.expression.selects)
167                    if select.alias_or_name == column or select.is_star
168                ),
169                -1,  # mypy will not allow a None here, but a negative index should never be returned
170            )
171        )
172
173        if index == -1:
174            raise ValueError(f"Could not find {column} in {scope.expression}")
175
176        for s in scope.union_scopes:
177            to_node(
178                index,
179                scope=s,
180                dialect=dialect,
181                upstream=upstream,
182                source_name=source_name,
183                reference_node_name=reference_node_name,
184            )
185
186        return upstream
187
188    if isinstance(scope.expression, exp.Select):
189        # For better ergonomics in our node labels, replace the full select with
190        # a version that has only the column we care about.
191        #   "x", SELECT x, y FROM foo
192        #     => "x", SELECT x FROM foo
193        source = t.cast(exp.Expression, scope.expression.select(select, append=False))
194    else:
195        source = scope.expression
196
197    # Create the node for this step in the lineage chain, and attach it to the previous one.
198    node = Node(
199        name=f"{scope_name}.{column}" if scope_name else str(column),
200        source=source,
201        expression=select,
202        source_name=source_name or "",
203        reference_node_name=reference_node_name or "",
204    )
205
206    if upstream:
207        upstream.downstream.append(node)
208
209    subquery_scopes = {
210        id(subquery_scope.expression): subquery_scope for subquery_scope in scope.subquery_scopes
211    }
212
213    for subquery in find_all_in_scope(select, exp.UNWRAPPED_QUERIES):
214        subquery_scope = subquery_scopes.get(id(subquery))
215        if not subquery_scope:
216            logger.warning(f"Unknown subquery scope: {subquery.sql(dialect=dialect)}")
217            continue
218
219        for name in subquery.named_selects:
220            to_node(name, scope=subquery_scope, dialect=dialect, upstream=node)
221
222    # if the select is a star add all scope sources as downstreams
223    if select.is_star:
224        for source in scope.sources.values():
225            if isinstance(source, Scope):
226                source = source.expression
227            node.downstream.append(Node(name=select.sql(), source=source, expression=source))
228
229    # Find all columns that went into creating this one to list their lineage nodes.
230    source_columns = set(find_all_in_scope(select, exp.Column))
231
232    # If the source is a UDTF find columns used in the UTDF to generate the table
233    if isinstance(source, exp.UDTF):
234        source_columns |= set(source.find_all(exp.Column))
235
236    for c in source_columns:
237        table = c.table
238        source = scope.sources.get(table)
239
240        if isinstance(source, Scope):
241            selected_node, _ = scope.selected_sources.get(table, (None, None))
242            # The table itself came from a more specific scope. Recurse into that one using the unaliased column name.
243            to_node(
244                c.name,
245                scope=source,
246                dialect=dialect,
247                scope_name=table,
248                upstream=node,
249                source_name=source_names.get(table) or source_name,
250                reference_node_name=selected_node.name if selected_node else None,
251            )
252        else:
253            # The source is not a scope - we've reached the end of the line. At this point, if a source is not found
254            # it means this column's lineage is unknown. This can happen if the definition of a source used in a query
255            # is not passed into the `sources` map.
256            source = source or exp.Placeholder()
257            node.downstream.append(Node(name=c.sql(), source=source, expression=source))
258
259    return node
class GraphHTML:
262class GraphHTML:
263    """Node to HTML generator using vis.js.
264
265    https://visjs.github.io/vis-network/docs/network/
266    """
267
268    def __init__(
269        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
270    ):
271        self.imports = imports
272
273        self.options = {
274            "height": "500px",
275            "width": "100%",
276            "layout": {
277                "hierarchical": {
278                    "enabled": True,
279                    "nodeSpacing": 200,
280                    "sortMethod": "directed",
281                },
282            },
283            "interaction": {
284                "dragNodes": False,
285                "selectable": False,
286            },
287            "physics": {
288                "enabled": False,
289            },
290            "edges": {
291                "arrows": "to",
292            },
293            "nodes": {
294                "font": "20px monaco",
295                "shape": "box",
296                "widthConstraint": {
297                    "maximum": 300,
298                },
299            },
300            **(options or {}),
301        }
302
303        self.nodes = nodes
304        self.edges = edges
305
306    def __str__(self):
307        nodes = json.dumps(list(self.nodes.values()))
308        edges = json.dumps(self.edges)
309        options = json.dumps(self.options)
310        imports = (
311            """<script type="text/javascript" src="https://unpkg.com/vis-data@latest/peer/umd/vis-data.min.js"></script>
312  <script type="text/javascript" src="https://unpkg.com/vis-network@latest/peer/umd/vis-network.min.js"></script>
313  <link rel="stylesheet" type="text/css" href="https://unpkg.com/vis-network/styles/vis-network.min.css" />"""
314            if self.imports
315            else ""
316        )
317
318        return f"""<div>
319  <div id="sqlglot-lineage"></div>
320  {imports}
321  <script type="text/javascript">
322    var nodes = new vis.DataSet({nodes})
323    nodes.forEach(row => row["title"] = new DOMParser().parseFromString(row["title"], "text/html").body.childNodes[0])
324
325    new vis.Network(
326        document.getElementById("sqlglot-lineage"),
327        {{
328            nodes: nodes,
329            edges: new vis.DataSet({edges})
330        }},
331        {options},
332    )
333  </script>
334</div>"""
335
336    def _repr_html_(self) -> str:
337        return self.__str__()

Node to HTML generator using vis.js.

https://visjs.github.io/vis-network/docs/network/

GraphHTML( nodes: Dict, edges: List, imports: bool = True, options: Optional[Dict] = None)
268    def __init__(
269        self, nodes: t.Dict, edges: t.List, imports: bool = True, options: t.Optional[t.Dict] = None
270    ):
271        self.imports = imports
272
273        self.options = {
274            "height": "500px",
275            "width": "100%",
276            "layout": {
277                "hierarchical": {
278                    "enabled": True,
279                    "nodeSpacing": 200,
280                    "sortMethod": "directed",
281                },
282            },
283            "interaction": {
284                "dragNodes": False,
285                "selectable": False,
286            },
287            "physics": {
288                "enabled": False,
289            },
290            "edges": {
291                "arrows": "to",
292            },
293            "nodes": {
294                "font": "20px monaco",
295                "shape": "box",
296                "widthConstraint": {
297                    "maximum": 300,
298                },
299            },
300            **(options or {}),
301        }
302
303        self.nodes = nodes
304        self.edges = edges
imports
options
nodes
edges