Skip to content

API Reference

Graphical Causal Models (GCM)

causalinf.gcm.DAG

graph: (str, dict, list) A string with the a graph, or a list or dictionary with the edges

If string, it can have different formats
Example:
'''
X1 -> Y
X1 -> Z -> Y
X1 <- X2
# This is equivalent to directed edges: X3 -> A and X -> B
# and X4 -> C and X4 -> D
X3 -> {A, B}
{C, D} <- X4
# Bidirected edges (this is a comment, which is also allowed)
X3 <-> X4
# Undirected edges 
X3 -- X4
# Mixing
X5 -- X6 -> X7
'''

 If list, the edge types will be parsed based on their format:
 [
     ('X', 'Y'),                    # becomes X -> Y  (directed edge)
     {'X', 'Z'},                    # becomes X -- Y  (undirected edge)
     (('X1', 'X2'), ('X2', 'X1')),  # becomes X <-> Y (bidirected edge)
 ]

If dictionary, it must contains the edges as elements and the
edge type (directed, undirected, bidirected) as keys 
Example:
{'directed'  : [('X', 'Y'), ...],  # list of tuples
 'undirected': [{'X1', 'X2'}, ...] # list of dictionaries
 'bidirected': [ (('X1', 'X2'), ('X2', 'X1')), ...] # list of 2-tuple tuples
 }

SEM: a string with the structural equation model (SEM). Parameter and path effects definition are allowed in the SEM string. If parameters are provided, they are used as ‘edge_labels’, except if the later is also provided. See examples below.

Source code in causalinf/gcm.py
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
class DAG:
    """
    graph: (str, dict, list)
        A string with the a graph, or a list or dictionary with the edges

        If string, it can have different formats
        Example:
        '''
        X1 -> Y
        X1 -> Z -> Y
        X1 <- X2
        # This is equivalent to directed edges: X3 -> A and X -> B
        # and X4 -> C and X4 -> D
        X3 -> {A, B}
        {C, D} <- X4
        # Bidirected edges (this is a comment, which is also allowed)
        X3 <-> X4
        # Undirected edges 
        X3 -- X4
        # Mixing
        X5 -- X6 -> X7
        '''

         If list, the edge types will be parsed based on their format:
         [
             ('X', 'Y'),                    # becomes X -> Y  (directed edge)
             {'X', 'Z'},                    # becomes X -- Y  (undirected edge)
             (('X1', 'X2'), ('X2', 'X1')),  # becomes X <-> Y (bidirected edge)
         ]

        If dictionary, it must contains the edges as elements and the
        edge type (directed, undirected, bidirected) as keys 
        Example:
        {'directed'  : [('X', 'Y'), ...],  # list of tuples
         'undirected': [{'X1', 'X2'}, ...] # list of dictionaries
         'bidirected': [ (('X1', 'X2'), ('X2', 'X1')), ...] # list of 2-tuple tuples
         }

    SEM: a string with the structural equation model (SEM).
         Parameter and path effects definition are allowed in the
         SEM string. If parameters are provided,
         they are used as 'edge_labels', except if the later is also
         provided. 
         See examples below.
    """

    def __init__(self,
                 graph=None,
                 data=None,
                 # nodes
                 nodes_role=None,
                 nodes_label=None,
                 nodes_position=None,
                 # edges
                 edge_label=None
                 ):
        assert graph, "'graph' must be provided."
        assert nodes_position is None or isinstance(nodes_position, dict), (
            "nodes_position must be None or dict")
        assert nodes_label is None or isinstance(nodes_label, dict), (
            "nodes_label must be None or dict")
        assert nodes_role is None or isinstance(nodes_role, dict), (
            "nodes_roles must be None or dict")

        # deal with user provided roles in low case
        key_roles = ['Outcome', 'Exposure', "Latent"]
        if nodes_role:
            for role in  key_roles:
                if role.lower() in nodes_role.keys():
                    nodes_role[role] = nodes_role[role.lower()]
                    nodes_role.pop(role.lower())


        # graph
        self.__graph_list__ = []
        self.__graph_dict__ = {}
        self.__graph_str_original__ = None
        self.__graph_str_parsed__ = None
        self.__dagitty__ = None
        # edges 
        self.__edges_str_allowed__ = ['->', '<-', '<->', "--"]
        self.edge_label = edge_label or {}
        self.directed = []
        self.bidirected = []
        self.undirected = []
        # nodes 
        self.nodes = set()
        self.nodes_parents = {}
        self.exposure = []
        self.outcome = []
        self.latent = []
        self.observed = []
        self.nodes_role = {}
        self.nodes_position = {}
        self.nodes_label = {}
        self.nodes_info = {}
        # keep this order:
        self.__build_graph__(graph)
        self.__collect_info__(nodes_role, nodes_position, nodes_label)
        # dagitty
        self.__create_dagitty__()
        # others
        self.data = data
        self.__identification__ = None

    # manipulating graph  -----------------------------
    def get_nodes(self, exclude_latent=False):
        nodes = list(self.nodes)
        latent_nodes = self.latent

        if exclude_latent and latent_nodes:
            nodes = [n for n in nodes if n not in latent_nodes]
        return nodes

    def set_node_label(self, nodes_label):
        for node, label in nodes_label.items():
            self.nodes_label[node] = label

    def set_nodes_role(self, nodes_role):
        res = DAG(graph=self.__graph_str_parsed__,
                  nodes_role=nodes_role,
                  nodes_label=self.nodes_label,
                  nodes_position=self.nodes_position,
                  edge_label=self.edge_label,
                  data=self.data)
        return res

    def set_node_position(self, position):
        for node, p in position.items():
            self.position[node] = p

    def edge_add(self, edge):
        res = self
        if not self.edge_exist(edge):
            graph = self.__graph_list__.copy()
            graph.append(edge)
            res = self.__rebuild_graph__(graph)
        return res

    def edge_remove(self, edge):
        removed = False
        graph = self.__graph_list__.copy()

        if edge in self.__graph_list__:
            graph.remove(edge)
            removed = True
        elif self.__edge_type__(edge)=='bidirected':
            edge = (edge[1], edge[0])
            if edge in self.__graph_list__:
                graph.remove(edge)
                removed = True

        if removed:
            return self.__rebuild_graph__(graph)
        else:
            return  self

    def edge_replace(self, remove, add):
        res = self.edge_remove(remove)
        res = res.edge_add(add)
        return res

    def edge_exist(self, edge, edges=None):
        """
        Check whether `edge` exists in `edges`,
        robust to order of nodes for undirected and bidirected edges.

        Parameters
        ----------
        edge : tuple or set
            A tuple representing the edge to check, e.g., (node1, node2).
            For directed edge: typle is (from, to)
            For bidirected edge: typle is ((node1, node2), (node2, node1))
            For undirected edge: set is {node1, node2}

        edges : list or None, optional
            A list of edges to check against. If None, the method will retrieve 
            the edges associated with the edge type.

        Returns
        -------
        bool
            True if the edge exists in the edges list, False otherwise.
        """
        if edges is None:
            edge_type = self.__edge_type__(edge)
            edges = self.__getattribute__(edge_type)
        edges = [edges] if not isinstance(edges, list) else edges
        edge = self.__edge_frozen_format__(edge)
        edges_in_list = {self.__edge_frozen_format__(e) for e in edges}
        return edge in edges_in_list

    def set_edge_label(self, edge_label):
        for edge, label in edge_label.items():
            self.edge_label[edge] = label

    # computations --------------------------------------
    # dagitty (R dependencies)
    def dseparated(self, var1=None, var2=None, conditional=None):
        assert var1 and isinstance(var1, str), "'var1' (a str) must be provided."
        assert var2 and isinstance(var2, str), "'var2' (a str) must be provided."

        if conditional is None:
            conditional = NULL
        res = dagitty.dseparated(self.__dagitty__, X = var1, Y = var2, Z=conditional)[0]
        return res

    # dagitty (R dependencies)
    def dseparation(self, var1, var2):
        assert var1 and isinstance(var1, str), "'var1' (a str) must be provided."
        assert var2 and isinstance(var2, str), "'var2' (a str) must be provided."

        res = self.local_independencies()
        if res.nrow>0:
            res = (
                res
                .separate('term', into=['var1', 'var2|conditional'], sep='_||_', remove=False)
                .separate('var2|conditional', into=['var2', 'conditional'], sep=' | ', remove=True)  # 
                .mutate(var1 = tp.str_trim('var1'),
                        var2 = tp.str_trim('var2'),
                        conditional = tp.str_trim('conditional'),
                        )
                .replace_null({'conditional':''})
                .filter(((tp.col("var1")==var1) & (tp.col('var2')==var2)) |
                        ((tp.col("var2")==var1) & (tp.col('var1')==var2))
                        )
            )
            res = res.pull('conditional')
            res = [s.split(',') for s in res]
            res = [[string.strip() for string in inner_list] for inner_list in res]
        else:
            print(f'Not possible to d-separate {var1} and {var2} in the graph.')
            res = None
        return res

    # dagitty (R dependencies)
    def local_independencies(self, data=None, alpha=0.05, include_sep_cols=False):
        """
        Given a networkx.DiGraph, return implied conditional independencies using dagitty (via R).

        Parameters:
            G (nx.DiGraph): Directed acyclic graph (must be a valid DAG)
            data: tibble data frame from tidypolars4sci

        Returns:
            tibble dataframe from tidypolars4sci
        """
        if data is None:
            data = self.data
        # compute
        if data is None:
            inds = dagitty.impliedConditionalIndependencies(self.__dagitty__)
            res = tp.tibble()
            for ind in inds:
                y = ind[0][0]
                x = ind[1][0]
                z = ind[2]
                term = f"{y} _||_ {x}"
                term = f"{term} | {', '.join(z)}" if z else term
                tmp = tp.tibble({'term': [term],
                                 "var1": [y],
                                 "var2": [x],
                                 "cond": [z]})
                res = res.bind_rows(tmp)
            inds = res
        else:
            inds = dagitty.localTests(self.__dagitty__, data=convert().tp2tibble(data), abbreviate_names=False)
            z = dnorm.ppf(1-alpha/2)
            inds = convert().rtibble2tp(inds, rownames2col='term')\
                         .rename({'p.value':"pvalue",
                                  '2.5%':'lo',
                                  '97.5%':'hi',
                                  })\
                         .mutate(se = ( tp.col('hi')-tp.col('lo') ) / (2*z) )
            if inds.nrow>0:
                inds = (
                    inds
                    .separate('term', into=['var1', 'var2_cond'], sep='_||_', remove=False)
                    .separate('var2_cond', into=['var2', 'cond'], sep='|')
                )

        vars = ['term', 'estimate', 'se', 'lo', 'hi', 'pvalue']
        if include_sep_cols:
            vars += ['var1', 'var2', 'cond']
        inds = inds.select(vars)

        return inds

    # dagitty (R dependencies)
    def identification_analysis(self, exposure=None, outcome=None,
                                conditional = None,
                                causal_probability='maybe',
                                iv='maybe',
                                verbose=True
                                ):
        """
        causal_probability: str
            If 'always', always compute it; if 'maybe', compute it
            only if there is not identification by adjustment for
            total_effect_adj_set effect

        conditional: list or str
            List of variables to condition the causal effect on        
        """
        assert not outcome or isinstance(outcome, str), 'Outcome must be a string.'
        assert not exposure or (isinstance(exposure, str) or isinstance(exposure, list)), 'Exposure must be a string or list.'

        assert outcome or self.outcome, "No outcome found."
        assert exposure or self.exposure, "No exposure found."

        exposure = exposure or self.exposure
        outcome = outcome or self.outcome[0]
        conditional = [conditional] if isinstance(conditional, str) else conditional

        assert exposure is not None, "Exposure must be provided."
        assert outcome is not None, "Outcome must be provided."

        self.__identification__ = identification(G=self,
                                                 exposure = exposure,
                                                 outcome = outcome,
                                                 conditional = conditional,
                                                 causal_probability = causal_probability,
                                                 iv = iv,
                                                 verbose=verbose)
        if verbose:
            self.print('identification')

        return None

    def get_identified(self, by='parameter', include_all=False):
        """
        by : str
           'parameter' or 'strategy'
        """
        if not self.__identification__:
            self.identification_analysis()
        res = self.__identification__.get_identified(by=by, include_all=include_all)
        return res

    def identification(self, print='default', parameter='ACE', *args, **kws):
        if not self.__identification__:
            self.identification_analysis(verbose=False)

        identification = kws.get("identification", {})
        identification["content"] = print
        identification["parameter"] = parameter

        self.print('identification', identification=identification)
        return None

    @property
    def identification_dict(self):
        if not self.__identification__:
            self.identification_analysis()
        res = self.__identification__.identification
        return res

    def print(self,
              what = 'graph',
              identification = dict(
                  content='default',
                  style='text',
                  strategy = 'all',
                  parameter = 'ACE',
                  omit_DAG=True,
                  print_assumptions=None,
                  print_assumptions_verbose=None
              )
              ):
        """
        what : str
            What to print
            'graph', 'DAG', 'dag', 'identification'

        identification : dict
            Options to print identification results
             - content
             - style
             - strategy
             - parameter
             - omit_DAG
             - assumptions
             - assumptions_verbose

        """
        if what in ['graph', 'DAG', 'dag']:
            print(self)
        if what=='identification':
            ops = identification.copy()
            # defaults
            pars = ["print_assumptions", "print_assumptions_verbose"]
            for par in pars:
                if ops.get(par, None) is None:
                    ops[par] = get_options()[par]

            if not self.__identification__:
                self.identification_analysis()
            self.__identification__.print(**identification)
            self.__identification__.__assumptions_print__(category='identification', **ops)
        return None

    # dagitty (R dependencies)
    def paths(self, exposure=None, outcome=None, adj_set=None, directed=False):
        exposure = exposure or self.exposure
        outcome = outcome or self.outcome

        assert exposure, "Exposure must be provided."
        assert outcome, "Outcome must be provided."

        adj = adj_set or NULL
        paths_info = dagitty.paths(self.__dagitty__, exposure, to=outcome, Z=adj, directed=directed)
        paths = list(paths_info.rx2['paths'])
        are_open = list(paths_info.rx2['open'])

        return {path:{'open':is_open, 'adj_set':adj_set} for path, is_open in zip(paths, are_open)}

    def mediators(self, as_string=False):
        paths = self.paths(directed=True)
        paths = [p.split('->') for p in paths]
        exposure = self.exposure
        outcome = self.outcome
        res = []
        for path in paths:
            res += [[var.strip() for var in path if var.strip() not in  exposure + outcome]]
        res = [l for l in res if len(l)>0]

        if as_string:
            res = f"[{', '.join([f"[{', '.join(l) }]" for l in res])}]"
        return res

    # dagitty (R dependencies)
    def equivalence_class(self):
        """
        Details
        -------
        A equivalence class of a DAG is a graph that replaces directional edges
        by undirectional edges except in v-structures (triples X->Z<-Y where 
        X and Y are not adjacent). Therefore, all Markov
        equivalent DAGs will have the same equivalence class.
        """
        eq = dagitty.equivalenceClass(self.__dagitty__)
        dag, _ = self.__dagitty2inputs__(eq)
        res = self.__rebuild_graph__(dag)
        return res

    # dagitty (R dependencies)
    def equivalent_dags(self):
        eqs = dagitty.equivalentDAGs(self.__dagitty__)
        res = []
        for eq in eqs:
            dag, _ = self.__dagitty2inputs__(eq)
            res += [self.__rebuild_graph__(dag)]
        return res

    def observationally_equivalent(self, G):
        """
        Check if two DAGs are observationally equivalent by comparing their
        markov equivalent classes. It applies to CBN or for
        SCM when no functional form for the SCM equations were selected.
        See details.

        Details
        -------
        Observational equivalence is related to Markov equivalence.

        Two DAGs are Markov equivalent iff
        A. They have the same skeleton (same set of adjacencies, i.e. same undirected edges)
        B. They have the same set of v-structures, which are triples X->Z<-Y where 
           X and Y are not adjacent).

        A equivalence class of a DAG is a graph that replaces directional edges
        by undirectional edges except in v-structures. Therefore, all Markov
        equivalent DAGs will have the same equivalence class.

        For CBN:
        - Two CBNs are observational equivalence iff they are Markov equivalence.

        For SCM:
        Without functional form assumptions_show, for observational equivalence:
        - Necessary condition: both SCMs have the same set of conditional independencies
        - Sufficient condition: both SCMs are in the same markov equivalence class (Pearl, 2009)
        - Basically, two SCMs are observationally equivalent iff their causal graphs belong
          to the same Markov equivalence class — i.e., they share the same skeleton and v-structures.

        With functional form assumptions_show
        - Once you impose functional form restrictions on SCMs, such as linearity,
          Gaussian disturbance, or additive error, and so on, observational equivalence
          can be strictly finer. That is, Markov-equivalence is not a sufficient condition.
          Example:
          a. Linear Gaussian SEMs assumption:
             - All DAGs in the same equivalence class remain indistinguishable.
               Markov equivalence = observational equivalence.
               Reason: any covariance matrix that one DAG can generate can also
               be generated by another DAG in its equivalence class, via suitable parameter choice.

          b. Linear non-Gaussian models (LiNGAM)
             - Orientations become testable because independent non-Gaussian noise
               'pins down' which variable must be the parent, breaking Markov equivalence.
                Example: X->Y and X <- Y: In Gaussian case: indistinguishable.
                         In non-Gaussian: identifiable.

          c. Additive Noise Models (ANMs)
             - If the true relation is Y = f(X) + e with independent noise e,
               then typically the 'wrong' orientation X = g(Y) + e'
               cannot hold with independent noise. So direction becomes identifiable.

        In summary, generally SCMs (no distributional restrictions), Markov equivalence
        does imply observational equivalence. But once you impose restrictions
        (linear, Gaussian, additive, etc.), observational equivalence can be strictly finer.
        That is, if one assumes functional forms or noise properties, one may be able to 
        distinguish DAGs inside a Markov equivalence class. Some Markov-equivalent DAGs
        become distinguishable. Then, the test of equivalence depends on the
        functional form assumption adopted, so it is case-by-case.

        References
        ----------
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge Univ Press.
        """
        # check if same equivalence class
        G1_eq = self.equivalence_class()
        G2_eq = G.equivalence_class()
        diff = G1_eq.edge_differences(G2_eq)
        obs_eq = True
        for g, edges in diff.items():
            obs_eq &= all([len(e)==0 for e in edges.values()])
        return obs_eq 

    def assumptions(self, category=None, verbose=False):
        if not self.__identification__:
            self.identification_analysis()
        return self.__identification__.assumptions(category=category, verbose=verbose)
    # -------------------------------------------------

    # plots -------------------------------------------
    def plot(self,
             # nodes
             graph_style = None,
             nodes_label=None,
             nodes_position=None,
             estimates=None,
             # node
             node_subset=None,
             node_shape=None,
             node_size = None,
             node_color = None,
             node_border_color=None,
             node_border_style=None,
             node_border_width=None,
             node_latent_show=True,
             # node label
             show_labels = True,
             use_labels = True,
             node_label_box=True,
             node_label_fontsize=None,
             node_label_fontweight='normal',
             node_label_adj_x=0,
             node_label_adj_y=0,
             node_label_box_style="square",
             node_label_box_margin=.5,
             # edges
             edge_subset=None,
             edge_color=None,
             edge_style=None,
             edge_arc = None,
             edge_linewidth = None,
             edge_head_size = None,
             edge_head_style = None,
             edge_margin_tail=None,
             edge_margin_head=None,
             # edges labels
             edge_label=None,
             edge_label_color_background='white',
             edge_label_color_border='white',
             edge_label_size=None,
             edge_label_color=None,
             edge_label_alpha=None,
             edge_label_rotate=None,
             edge_label_position=None,
             edge_label_sig_level=0.05,
             edge_label_pvalue=None,
             edge_label_font_family = None,
             # legend
             legend_show=True,
             legend_title='Nodes',
             legend_title_align='left',
             legend_title_weight='bold',
             legend_title_size=12,
             legend_omit_cases=['Observed'],
             legend_keys=None,
             legend_loc='best',
             legend_fontsize=10,
             legend_frame=False,
             legend_kws={},
             #
             title = None,
             title_loc = 'left',
             title_kws = {},
             # 
             figsize = [6, 4],
             usetex = True,
             ax=None,
             show_plot=None,
             *args,
             **kws
             ):
        """
        Draw a custom DAG with support for:
          - Latent variables
          - Curved edges
          - Colored and dotted arcs
          - Optional arc representation for latent confounding
          - Custom node labels

        Parameters:
            G (nx.DiGraph): The input DAG with optional edge attributes 'style', 'color', 'curved'.
            estimates: obj
                A LSEM object from the cass causalinf.scm.estimate
            nodes_position (dict): Optional node positions for layout.
            nodes_role (dict): Optional dict with keys 'latent', 'exposure', 'outcome' listing node names.
            use_arc (bool): If True, draw dotted arcs between children of latent confounders instead of drawing latent nodes.
            nodes_label (dict): Optional dict mapping node names to display labels.
            show_labels  (str or None; Default='label'): One of 'label', 'name', or 'none'.
                If 'label', use labels if provided; If 'name', always use node name; If 'none',
                don't omit labels and names of nodes altogether.
            node_label_adj_x (float or dict): displaces the labels in the x direction. If dict, the keys
                should be the node labels or name, and displacement will be applied only to those points
                specified in the dict. If float, the same displacement is applied to all nodes.
            node_label_adj_x (float or dict): same as node_label_adj_y, but for the y axis
            graph_style (str): specific styles for nodes and arrows
                  - 'default': nodes in circles with labels in their middle
                  - 'rectangle': nodes in rectangles with labels in their middle
                  - 'pearl': nodes as dots with labels next to them (use node_label_adj_x and node_label_adj_y
                             to adjust the location of the labels)
                  All features can be overwrittied by specifying the value of the parameters for the plot.
        """
        assert estimates is None or isinstance(estimates, estimate), (
            "'estimates' must be either None or an object of causalinf.scm.estimate ")

        default_usetex = plt.rcParams["text.usetex"] 
        plt.rcParams["text.usetex"] = usetex
        plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath, amssymb, siunitx, bm}'
        show_plot = show_plot if not None else get_options('show_plot')

        # collect arguments
        pars = dict(locals())      # {'node_position':..., 'arg2':..., 'args':(...), 'kws':{...}}
        args = pars.pop('args') # extra positional
        kws  = pars.pop('kws')  # extra keyword

        # use estimates as labels
        if estimates is not None:
            edge_label, edge_label_pvalue = self.__plot_collect_labels_estimate__(estimates)

        # figure 
        # ------
        G_draw = self.__plot_create_nx__()
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
        plt.sca(ax)

        # styles
        # ------
        graph_style = graph_style or get_options('graph_style')
        nodes_style, labels_style, edges_style = self.__plot_get_style__(graph_style)

        # nodes 
        # -----
        node_subset    = self.__plot_nodes_subset__(node_subset, node_latent_show)
        nodes_position = self.__plot_nodes_positions__(G_draw, nodes_position)
        for role, nodes in node_subset.items():
            fig_nodes = nx.draw_networkx_nodes(
                G_draw,
                nodes_position,
                nodelist=nodes,
                ax=ax,
                # 
                node_size  = self.__plot_collect_aes__(role, node_size, nodes_style[role]['node_size']),
                node_color = self.__plot_collect_aes__(role, node_color, nodes_style[role]['node_color']),
                node_shape = self.__plot_collect_aes__(role, node_shape, nodes_style[role]['node_shape']),
                linewidths = self.__plot_collect_aes__(role, node_border_width,
                                                       nodes_style[role]['node_border_width']),
                edgecolors = self.__plot_collect_aes__(role, node_border_color,
                                                       nodes_style[role]['node_border_color']),
                alpha      = None,
                cmap       = None,
                vmin       = None,
                vmax       = None,
                label      = None,
                margins    = None, 
                hide_ticks = True
            )
            fig_nodes.set_linestyle(self.__plot_collect_aes__(role, node_border_style,
                                                              nodes_style[role]['node_border_style']))

        # nodes labels 
        # ------------
        if show_labels:
            nodes = set(itertools.chain.from_iterable(node_subset.values()))
            nodes_label = self.nodes_label | (nodes_label or {})
            adj_x = self.__plot_label_adj__(node_label_adj_x, nodes_label)
            adj_y = self.__plot_label_adj__(node_label_adj_y, nodes_label)
            for node in nodes:
                label = nodes_label.get(node, node) if use_labels else node
                role  = self.nodes_info[node]['role']
                x, y  = nodes_position[node] if nodes_position and nodes_position[node] else\
                    self.nodes_info[node]['position'] 

                bbox = None
                if node_label_box and graph_style=='rectangle':
                    bbox = {
                        "boxstyle": f"{node_label_box_style},pad={node_label_box_margin}",
                        "fc": self.__plot_collect_aes__(role, node_color, nodes_style[role]['node_color']),
                        "ec": self.__plot_collect_aes__(role, node_border_color,
                                                        nodes_style[role]['node_border_color']),
                        "lw": self.__plot_collect_aes__(role, node_border_width,
                                                       nodes_style[role]['node_border_width']),
                        "linestyle": self.__plot_collect_aes__(role, node_border_style,
                                                               nodes_style[role]['node_border_style']),
                        "alpha": 1}

                weight = self.__plot_collect_aes__(role, node_label_fontweight,
                                                   labels_style[role]['node_label_fontweight'])
                label = f"\\textbf{{{label}}}" if weight == 'bold' else label
                plt.text(x + adj_x[node],
                         y + adj_y[node],
                         label,
                         fontweight = weight,
                         fontsize   = self.__plot_collect_aes__(role, node_label_fontsize,
                                                                labels_style[role]['node_label_fontsize']),
                         ha = 'center',
                         va = 'center',
                         bbox = bbox)

        # edges and edges labels
        # ----------------------
        nodes = set(itertools.chain.from_iterable(node_subset.values()))
        for edge_type in ['directed', 'bidirected', 'undirected']:
            style = self.__plot_collect_aes__(edge_type, edge_style, edges_style['edge_style'][edge_type])
            color = self.__plot_collect_aes__(edge_type, edge_color, edges_style['edge_color'][edge_type])
            arc   = self.__plot_collect_aes__(edge_type, edge_arc, edges_style['edge_arc'][edge_type])
            width = self.__plot_collect_aes__(edge_type, edge_linewidth, edges_style['edge_linewidth'][edge_type])
            arrow_head_size = self.__plot_collect_aes__(edge_type, edge_head_size, edges_style['edge_head_size'][edge_type])
            arrow_head_style = self.__plot_collect_aes__(edge_type, edge_head_style, edges_style['edge_head_style'][edge_type])
            edge_margin_tail = self.__plot_edge_margin__(edge_margin_tail, edges_style["edge_margin_tail"][edge_type])
            edge_margin_head = self.__plot_edge_margin__(edge_margin_head, edges_style["edge_margin_head"][edge_type])

            for edge in self.__getattribute__(edge_type):
                edge = tuple(edge)
                if edge_type!='bidirected':
                    u, v = edge
                else:
                    u, v = edge[0][0], edge[0][1]

                # collect edges to show if edge_subset 
                show_edge = True
                if edge_subset:
                    e = set(edge) if edge_type=='undirected' else edge
                    show_edge = self.edge_exist(e, edge_subset.get(edge_type, []))

                if u in nodes and v in nodes and show_edge:
                    # edge
                    nx.draw_networkx_edges(
                        G_draw,
                        nodes_position,
                        edgelist            = [(u, v)],
                        style               = style,
                        edge_color          = color,
                        connectionstyle     = f"arc3,rad={arc}",
                        arrows              = True,
                        arrowstyle          = arrow_head_style,
                        arrowsize           = arrow_head_size,
                        min_source_margin   = edge_margin_tail.get(edge, 0),
                        min_target_margin   = edge_margin_head.get(edge, 0),
                        width               = width,
                        ax=ax)

                    # edge label
                    edge_label = edge_label or self.edge_label
                    label = edge_label.get(edge, '')
                    rotate = edge_label_rotate if edge_label_rotate is not None else True # must keep "is not None" here
                    nx.draw_networkx_edge_labels(
                        G_draw,
                        pos             = nodes_position,
                        connectionstyle = f"arc3,rad={arc}",
                        edge_labels     = {(u, v): label},
                        bbox=dict(facecolor=edge_label_color_background, edgecolor=edge_label_color_border),
                        # 
                        alpha      = self.__plot_edge_label_feature__('alpha', edge, edge_label_alpha, None, edge_label_sig_level,
                                                                      edge_label_pvalue=edge_label_pvalue),
                        font_size  = self.__plot_edge_label_feature__('size' , edge, edge_label_size, 15),
                        font_color = self.__plot_edge_label_feature__('color', edge, edge_label_color, label=label),
                        rotate     = self.__plot_edge_label_feature__('rotate', edge, edge_label_rotate, default=rotate),
                        label_pos  = self.__plot_edge_label_feature__('position', edge, edge_label_position, .5),
                        font_family=edge_label_font_family,
                        ax         = ax
                    )

        # legend 
        # ------
        if legend_show:
            keys = []
            for role, _ in node_subset.items():
                if role not in legend_omit_cases:
                    if role=='Latent' and node_latent_show:
                        marker = ''
                        linecolor = self.__plot_collect_aes__(role, node_border_color,
                                                              nodes_style[role]['node_border_color']) 
                    else:
                        marker = 'o'
                        linecolor='white'
                    keys += [
                        Line2D(
                            [0], [0],
                            marker=marker,
                            color=linecolor,
                            label=role,
                            markersize=10,
                            markeredgecolor=self.__plot_collect_aes__(role, node_border_color,
                                                                      nodes_style[role]['node_border_color']),
                            markerfacecolor=self.__plot_collect_aes__(role, node_color,
                                                                      nodes_style[role]['node_color']),
                            linestyle=self.__plot_collect_aes__(role, node_border_style,
                                                                nodes_style[role]['node_border_style'])
                        )
                    ]
                if keys: 
                    legend = plt.legend(handles        = keys,
                                        title          = legend_title,
                                        title_fontsize = legend_title_size,
                                        alignment      = legend_title_align,
                                        # title_weight   = legend_title_weight,
                                        loc            = legend_loc,
                                        fontsize       = legend_fontsize,
                                        frameon        = legend_frame,
                                        **legend_kws
                                        )
                    if legend_title_weight=='bold' and legend_title:
                        legend.set_title(title=f'\\textbf{{{legend_title}}}', prop={'weight': 'bold'})

        # title 
        # -----
        if title:
            plt.title(label=title, loc=title_loc, **title_kws)

        plt.axis("off")
        plt.tight_layout()
        if show_plot:
            plt.show()
        plt.rcParams["text.usetex"] = default_usetex

        return plt, ax

    def plot_paths(self, exposure=None, outcome=None, adj_set=None, directed=False,
                   show_full_dag = True,
                   use_labels=True,
                   title_fontsize = 10,
                   figsize=(16, 9),
                   path_color='black',
                   **plot_kws
                   ):
        adj_set = [adj_set] if isinstance(adj_set, str) else adj_set

        paths = self.paths(exposure=exposure, outcome=outcome, adj_set=adj_set, directed=directed)
        npaths = len(paths)
        ncols = int(math.ceil(math.sqrt(npaths)))
        nrows = int(math.ceil(npaths / ncols))
        fig, axs = plt.subplots(nrows, ncols, figsize=figsize, tight_layout=True)
        if ncols >1 or nrows>1:
            axs=axs.flatten()
        else:
            axs = [axs]
        [ax.axis('off') for ax in axs]
        # 

        pos = self.nodes_position
        roles = self.nodes_role
        nodes_label = self.nodes_label
        edge_label = self.edge_label
        for i, (path, info) in enumerate(paths.items()):
            ax = axs[i]

            show_labels=True
            if show_full_dag:
                self.plot(ax=ax, edge_color ='lightgray', **plot_kws)
                show_labels=False

            # G2 = DAG(path, nodes_role=roles, nodes_position=pos, nodes_label=nodes_label)
            G2 = self.__rebuild_graph__(path)
            G2.plot(ax=ax, edge_linewidth=3, show_labels=show_labels,
                    edge_color=path_color, use_labels=use_labels, **plot_kws)
            adj = info['adj_set']
            if adj:
                adj = [self.nodes_label.get(x, x) for x in adj] if use_labels else adj
                adj = ', '.join(adj)
            else:
                adj = ""
            title = rf"Path is \textbf{{{'open' if info['open'] else 'closed'}}}; Adjustment set: "+"\{"+adj+"\}"
            ax.set_title(title, loc='left', fontsize=title_fontsize)
            ax.axis('on')
            plt.tight_layout()

        return axs

    def plot_equivalent_dags(self,
                             use_labels=True,
                             show_labels=True,
                             edge_difference_color='red',
                             title_fontsize = 10,
                             title_original_graph = 'Original Graph',
                             title_equivalent_graph = "Equivalent DAG",
                             show_footnote = True,
                             figsize=(16, 9),
                             max_per_figure = 9,
                             max_eq_dags= 27,
                             **plot_kws
                             ):
        # collecting equivalent DAGs
        eq_dags = self.equivalent_dags()
        n_eq_dags = len(eq_dags)
        if n_eq_dags == 0:
            return None

        if n_eq_dags > max_eq_dags:
            print(f"\n**Note:**\n"+
                  f"---------\n"
                  f"Maximun number of equivalent DAGs to plot is set to {max_eq_dags}"+
                  f" by default, but there are {n_eq_dags} equivalent DAGs. Some equivalent DAGs"+
                  f" will be omitted. To change it, set 'max_eq_dags'.\n")

        max_eq_dags = np.min([n_eq_dags, max_eq_dags])
        figs = dict(self.__chunked_ranges__(max_eq_dags, max_per_figure))

        print(f"Total of equivalent DAGs: {n_eq_dags}\n"+
              f"Plotting {max_eq_dags} equivalent DAG(s)\n"
              f"Generating {len(figs.keys())} figure(s) with a maximum of {max_per_figure} panels per figure\n")
        figs_res = {}

        for fig_number, panels in figs.items():

            # figure
            ncols = int(math.ceil(math.sqrt(max_per_figure)))
            nrows = int(math.ceil(max_per_figure / ncols))
            fig, axs = plt.subplots(nrows, ncols, figsize=figsize, tight_layout=True)
            if ncols >1 or nrows>1:
                axs=axs.flatten()
            else:
                axs = [axs]
            [ax.axis('off') for ax in axs]

            # panels
            for panel, panel_number in enumerate(panels):
                print(f"Creating plot {panel_number+1} of {n_eq_dags}...", end='')
                ax = axs[panel]
                eq_dag = eq_dags[panels[panel]]
                # baseline plot
                eq_dag.plot(ax=ax, edge_linewidth=1,
                            show_labels=show_labels,
                            use_labels=use_labels,
                            title=title_equivalent_graph,
                            title_fontsize=title_fontsize,
                            **plot_kws)
                # superimpose edges highlighing the differences
                edges = self.edge_differences(eq_dag)['G2']
                nodes = self.__collect_nodes_from_edges__(edges)
                eq_dag.plot(ax=ax, edge_linewidth=3,
                            node_subset = nodes,
                            edge_subset = edges,
                            show_labels=show_labels,
                            edge_color=edge_difference_color,
                            use_labels=use_labels,
                            title=title_equivalent_graph,
                            title_fontsize=title_fontsize,
                            **plot_kws)
                if show_footnote:
                    # footnote
                    xcoord=1
                    ycoord=1.07
                    yoffset=-.1
                    fn = f"Equivalent DAG: {panel_number+1} of {n_eq_dags}"
                    ax.annotate(fn, xy=(xcoord,yoffset), xytext=(xcoord,yoffset),
                                xycoords='axes fraction', size=11, ha='right',
                                style='italic', alpha=.6)
                print('done!')
                ax.axis('on')
                plt.tight_layout()
                figs_res[fig_number] = [fig, axs]
        return figs_res

    @ut.copy_docstring(plot)
    def plot_equivalence_class(self, *args, **kws):
        self.equivalence_class().plot(*args, **kws)

    def plot_identification(self,
                            content='default', # detailed, default
                            effect='total', #total, direct, or do, only if if_info=full
                            show_np = True,
                            show_linear = True,
                            show_do = True,
                            kws_graph={},
                            kws_identification={},
                            kws_detailed = None,
                            figsize = None,
                            ratio   = None,
                            ncols   = None,
                            nrows   = None,
                            title_dag = None,
                            title_info = None,
                            txt_line_height=.55,
                            *args,
                            **kws
                            ):
        """
        txt_line_height: float
            height of the lines for the text. Not used if figsize is set.
        kws_detailed : dict
            Example: 
           {'parameter':'ACE'
             'strategy':'SoO'}
        """
        roles = ['Exposure', 'Outcome', 'Latent', 'Observed',
                 'exposure', 'outcome', 'latent', 'observed']
        for role in roles:
            assert not kws_graph.get(role, None) and not kws_identification.get(role, None), (
                f"Setting node role ({role}) not allowed in the plot kws. "+
                f"To set the node role, create a new DAG or use set_node_role before plotting.")

        if not self.__identification__ or kws_identification:
            self.identification_analysis(**kws_identification, verbose=False)

        # defaults for kws_detailed
        kws_detailed = kws_detailed or {}
        strategy = kws_detailed.get('strategy', 'SoO')
        parameter = kws_detailed.get('parameter', None)
        if not parameter:
            parameter = next(iter(self.__identification__.identification[strategy]))
        kws_detailed['strategy'] = strategy
        kws_detailed['parameter'] = parameter

        return self.__identification__.plot(G=self,
                                            info=content,
                                            effect=effect,
                                            show_np = show_np,
                                            show_linear = show_linear,
                                            show_do = show_do,
                                            figsize=figsize,
                                            ratio=ratio,
                                            ncols=ncols,
                                            nrows=nrows,
                                            kws_graph=kws_graph,
                                            kws_detailed = kws_detailed,
                                            txt_line_height=txt_line_height,
                                            title_dag = title_dag,
                                            title_info = title_info,
                                            *args,
                                            **kws
                                            )

    # building graph --------------------------------
    def __build_graph__(self, graph):
        # Always convert to dict first, and from dict to other formats
        # dict -> list
        # dict -> str
        # str -> dict -> list
        # list-> dict -> str
        if isinstance(graph, str):
            self.__graph_str_parse__(graph)
            self.__graph_str2dict__()
            self.__graph_dict2list__()
        elif isinstance(graph, dict):
            self.__graph_dict_parse__(graph)
            self.__graph_dict2str__()
            self.__graph_dict2list__()
        elif isinstance(graph, list):
            self.__graph_list_parse__(graph)
            self.__graph_list2dict__()
            self.__graph_dict2str__()

    def __graph_list_parse__(self, graph):
        for e in graph:
            if e not in self.__graph_list__:
                self.__graph_list__ += [e]

    def __graph_dict_parse__(self, graph):
        self.__graph_dict__ = {'directed':[], 'bidirected':[], 'undirected':[]}
        for edge_type, edges in graph.items():
            for edge in edges:
                if edge not in self.__graph_dict__[edge_type]:
                    self.__graph_dict__[edge_type] += [edge]

    def __graph_str_parse__(self, graph):
        self.__graph_str_original__ = graph
        edges_type = "|".join(self.__edges_str_allowed__)
        # edges_type = '|'.join(sorted(map(re.escape, self.__edges_str_allowed__), key=len, reverse=True))

        self.__graph_str_parsed__ = []
        regex = re.compile(rf"(\w+|\{{[^}}]*\}})\s*({edges_type})\s*(\w+|\{{[^}}]*\}})")

        # remove comments
        graph = "\n".join(line for line in re.sub(r"#.*", "", graph).splitlines() if line.strip())

        graph = self.__graph_str_parse_inline_paths__(graph)
        for ln in graph.strip().splitlines():
            ln = ln.strip()

            # collect if not a comment
            if not bool(re.search(pattern="^ ?#", string=ln)):
                m = regex.match(ln) 
                if m:
                    nodes1, edge, nodes2 = m.groups()

                    nodes1 = re.sub(pattern='\\{|\\}', repl='', string=nodes1)
                    nodes1 = re.split(r"[,\s]+", nodes1.strip())

                    nodes2 = re.sub(pattern='\\{|\\}', repl='', string=nodes2)
                    nodes2 = re.split(r"[,\s]+", nodes2.strip())

                    for n1, n2 in itertools.product(nodes1, nodes2):
                        self.__graph_str_parsed__.append(f"{n1} {edge} {n2}")
                else:
                    raise ValueError(f"Unrecognized line format: '{ln}'")

        self.__graph_str_parsed__ = "\n".join(self.__graph_str_parsed__)
        return None

    def __graph_str_parse_inline_paths__(self, dag):
        # Split the path string by spaces to separate nodes and arrows
        lines = dag.split("\n")
        edges_type = '|'.join(sorted(map(re.escape, self.__edges_str_allowed__), key=len, reverse=True))

        res = []
        for path in lines:
            delimiter_pattern = re.compile(rf'({edges_type})')
            unique_edges = set()

            # Split the path by the arrow delimiters
            components_raw = delimiter_pattern.split(path)

            # Clean the list: remove empty strings and strip whitespace from each part
            components = [c.strip() for c in components_raw if c and c.strip()]

            # Iterate through the components, taking 3 at a time to form an edge
            for i in range(0, len(components) - 1, 2):
                node1 = components[i]
                arrow = components[i+1]
                node2 = components[i+2]

                # Re-format the edge with standard spacing for consistent output
                edge = f"{node1} {arrow} {node2}"
                unique_edges.add(edge)
            res += ["\n".join(unique_edges)]

        res = "\n".join(res)
        res = res.replace("<- >", "<->")
        return res

    def __graph_str2dict__(self):
        # Parse DAG string to properties of the graph: nodes, directed, 
        # bidirected, and undirected edges. 
        DAG = self.__graph_str_parsed__
        directed, undirected, bidirected = [], [], []

        # One regex to handle all edge types
        pattern = re.compile(r"^\s*(\w+)\s*(->|<-|<->|--)\s*(\w+)\s*$")

        lines = DAG.strip().splitlines()
        for line in lines:
            line = line.strip()
            if not line or line.startswith("#"):
                continue  # skip empty/comment lines

            m = pattern.match(line)
            if not m:
                raise ValueError(f"\nUnrecognized format: '{line}'")

            lhs, op, rhs = m.groups()
            if op == "->":
                a, b = lhs, rhs
                directed.append((a, b))

            elif op == "<-":
                a, b = rhs, lhs   # normalize as parent=a -> child=b
                directed.append((a, b))

            elif op == "<->":
                a, b = lhs, rhs
                bidirected.append( ((a, b), (b, a)) )

            elif op == "--":
                a, b = lhs, rhs
                undirected.append({a, b})

            # single place to update the node set
            self.nodes.update({a, b})

        # eliminate duplicates
        directed = list(set(directed))
        bidirected = list(set(bidirected))
        undirected = list(set([tuple(g) for g in undirected]))
        undirected = [set(g) for g in undirected]

        self.__graph_dict__ = {"directed"  : directed,
                               'bidirected': bidirected,
                               'undirected': undirected}

    def __graph_list2dict__(self):
        self.__graph_dict__ = {'directed':[], 'bidirected':[], 'undirected':[]}
        for edge in self.__graph_list__:
            edge_type = self.__edge_type__(edge)
            self.__graph_dict__[edge_type] += [edge]

    def __graph_dict2list__(self):
        self.__graph_list__ = []
        for type, edges in self.__graph_dict__.items():
            self.__graph_list__ += [edges]
        # flatten
        self.__graph_list__ = list(itertools.chain.from_iterable(self.__graph_list__))

    def __graph_dict2str__(self):
        self.__graph_str_parsed__ = ''
        for type, edges in self.__graph_dict__.items():
            for nodes in edges:
                if type=='directed':
                    edge = '->'
                if type=='bidirected':
                    edge = '<->'
                    nodes = nodes[0]
                if type=='undirected':
                    edge = '--'
                    nodes = list(nodes)
                self.__graph_str_parsed__ += f"{nodes[0]} {edge} {nodes[1]}\n" 
        self.__graph_str_original__ = self.__graph_str_parsed__

    # collect info
    def __collect_info__(self, nodes_role, nodes_position, nodes_label):
        # collect info (keep order)
        self.__collect_nodes__()
        self.__collect_nodes_parents__()
        self.__collect_nodes_role__(nodes_role)
        self.__collect_nodes_position__(nodes_position)
        self.__collect_nodes_label__(nodes_label)
        # 
        self.nodes_info = {node:{} for node in self.nodes}
        self.__collect_info_nodes_role__()
        self.__collect_info_nodes_position__()
        self.__collect_info_nodes_label__()
        # 
        self.__collect_edges_properties__()

    def __collect_nodes__(self):
        nodes = set()
        for edge_type, edges in self.__graph_dict__.items():
            for edge in edges:
                for node in edge:
                    if edge_type=='bidirected':
                        node = node[0]
                    nodes = nodes.union([node])
        self.nodes = nodes

    def __collect_nodes_parents__(self):
        self.nodes_parents = defaultdict(set)  # child -> {parents}
        for n1, n2 in self.__graph_dict__['directed']:
            self.nodes_parents[n2].update([n1])
        self.nodes_parents = dict(self.nodes_parents)

    def __collect_nodes_label__(self, nodes_label):
        nodes_label = nodes_label or {}
        for node in self.nodes:
            self.nodes_label[node] = nodes_label.get(node, None) or node

    def __collect_nodes_position__(self, nodes_position):
        if nodes_position:
            self.nodes_position = {}
            for node, pos in nodes_position.items():
                if node in self.nodes:
                    self.nodes_position[node] = pos

    def __collect_nodes_role__(self, nodes_role):
        nodes_role = nodes_role or {}
        self.nodes_role['Observed'] = [] # keep this here
        nodes_with_role_already_set = []

        for role, node in nodes_role.items() :
            if role=='Outcome':
                if isinstance(node, list) and len(node)==1:
                    node = node[0]
                assert isinstance(node, str), "Check nodes_role. Node 'Outcome' must be a string or a 1-element list."

            else:
                assert isinstance(node, str) or isinstance(node, list), \
                    "Check nodes_role. Nodes 'Exposure' and 'Latent' must be strings or lists"
            node = node if isinstance(node, list) else [node]
            self.nodes_role[role] = [n for n in node if n in self.nodes]
            nodes_with_role_already_set += node

        # set observed as default if role of node is not provided
        for node in self.nodes:
            if node not in nodes_with_role_already_set:
                self.nodes_role['Observed'] += [node]

        self.exposure = self.nodes_role.get('Exposure', None)
        self.outcome  = self.nodes_role.get('Outcome', None)
        self.latent   = self.nodes_role.get('Latent', None)
        self.observed = self.nodes_role.get('Observed', None)

    def __collect_info_nodes_role__(self):
        res = {}
        for role, nodes in self.nodes_role.items():
            for node in nodes:
                self.nodes_info[node]['role'] = role

    def __collect_info_nodes_position__(self):
        res = {}
        for node, position in self.nodes_position.items():
            self.nodes_info[node]['position'] = position

    def __collect_info_nodes_label__(self):
        res = {}
        for node, label in self.nodes_label.items():
            self.nodes_info[node]['label'] = label

    def __collect_edges_properties__(self):
        self.directed   = self.__graph_dict__['directed']
        self.bidirected = self.__graph_dict__['bidirected']
        self.undirected = self.__graph_dict__['undirected']


    # R dagitty
    def __create_dagitty__(self):
        # # Convert to dagitty string: "dag { A -> B; B -> C; ... }"
        # edges = [f"{u} -> {v}" for u, v in self.G.edges()]
        # edges = '; '.join(edges)

        roles = ''
        for role, nodes in self.nodes_role.items():
            for node in nodes:
                roles += f"{node} [{role.lower()}]\n"

        # Load dagitty and pass the DAG string
        dagitty_str = f"dag {{ {self.__graph_str_parsed__} \n {roles} }}"
        self.__dagitty__ = dagitty.dagitty(dagitty_str)

    # R dagitty
    def __dagitty2inputs__(self, dag_dagitty):
        dag_str = ''
        dag_df = convert().rtibble2tp(dagitty.edges(dag_dagitty))
        for a, b, e, *_ in dag_df.to_polars().iter_rows():
            dag_str += f"{a} {e} {b}\n"

        roles = {"Exposure": list(dagitty.exposures(dag_dagitty)),
                 'Outcome' : list(dagitty.outcomes(dag_dagitty)),
                 "Latent"  : list(dagitty.latents(dag_dagitty))}

        return dag_str, roles
    # -------------------------------------------------

    def __rebuild_graph__(self, graph):
        res = DAG(graph,
                  nodes_role     = self.nodes_role,
                  nodes_position = self.nodes_position,
                  nodes_label    = self.nodes_label,
                  edge_label     = self.edge_label
                  )
        return res

    def __repr__(self):
        self.__print_graph__()
        return ''

    def __str__(self):
         self.__repr__()
         return ''

    def __print_graph__(self):
        out = 'Graph:\n'

        d = [f"{n1} -> {n2}" for n1, n2 in self.directed]
        out += '\n'.join(d) if len(d)>0 else ''

        b = [f"{n1[0]} <-> {n2[0]}" for n1, n2 in self.bidirected]
        out += '\n' + '\n'.join(b) if len(b)>0 else ''

        u = [f"{n1} -- {n2}" for n1, n2 in self.undirected]
        out += '\n' +'\n'.join(u) if len(u)>0 else ''

        roles = [f"{role}: {', '.join(nodes)}" for role, nodes in self.nodes_role.items()]
        out += "\n"+"\n".join(roles) if len(roles)>0 else ''

        print(out)
        return out

    def __collect_nodes_from_edges__(self, edges_dict):
        nodes = []
        for edge_type, edges in edges_dict.items():
            if edge_type!='bidirected':
                nodes += list(set(itertools.chain.from_iterable(edges)))
            else:
                nodes += list(set(itertools.chain.from_iterable(itertools.chain.from_iterable(edges))))
        return nodes

    def __chunked_ranges__(self, limit, n):
        # Split [0..limit] into chunks.
        # Each chunk has n elements, except:
        #   - the last one may have fewer if not divisible, OR
        #   - the last one may be larger if needed to include 'limit'.
        start = 0
        idx = 0
        limit -=1
        while start <= limit:
            end = start + n - 1
            if end >= limit:   # last chunk, go all the way to limit
                yield idx, list(range(start, limit + 1))
                break
            else:
                yield idx, list(range(start, end + 1))
                start = end + 1
                idx += 1

    def __edge_frozen_format__(self, edge):
        # Convert an edge into a canonical, hashable form.
        # - directed: ('A','B')
        # - undirected: frozenset({'A','B'})
        # - bidirected: frozenset({('A','B'),('B','A')})
        # undirected
        if isinstance(edge, (set, frozenset)):
            return frozenset(edge)

        # bidirected
        if (isinstance(edge, tuple) 
            and len(edge) == 2 
            and all(isinstance(e, tuple) and len(e) == 2 for e in edge)):
            return frozenset([tuple(edge[0]), tuple(edge[1])])

        # directed
        if (isinstance(edge, tuple) 
            and len(edge) == 2 
            and all(isinstance(x, str) for x in edge)):
            return tuple(edge)

        raise ValueError(f"Unrecognized edge format: {edge}")

    def __edge_type__(self, edge):
        # """
        # Classify an edge as 'directed', 'bidirected', or 'undirected'.
        # """
        # Undirected: set/frozenset of 2 nodes
        if isinstance(edge, (set, frozenset)):
            if all(isinstance(x, str) for x in edge) and len(edge) == 2:
                return "undirected"

        # Bidirected: tuple of two directed edges
        if (isinstance(edge, tuple) 
            and len(edge) == 2 
            and all(isinstance(e, tuple) and len(e) == 2 for e in edge)
            and all(isinstance(x, str) for e in edge for x in e)):
            return "bidirected"

        # Directed: tuple of two nodes
        if (isinstance(edge, tuple) 
            and len(edge) == 2 
            and all(isinstance(x, str) for x in edge)):
            return "directed"

        raise ValueError(f"Unrecognized edge format: {edge}")

    # comparing SCM
    def edge_differences(self, G2):
        res1 = self.__edge_differences__(G2)
        res2 = G2.__edge_differences__(self)
        return {"G1":res1, "G2":res2}

    def __edge_differences__(self, G2):
        res1 = {}
        edge_types = ['directed', 'undirected', 'bidirected']
        for edge_type in edge_types:
            res1[edge_type] = []
            edges_list1 = self.__getattribute__(edge_type)
            edges_list2 = G2.__getattribute__(edge_type)
            for edge in edges_list1:
                if edge_type=='bidirected':
                    if edge not in edges_list2 and (edge[1], edge[0]) not in edges_list2:
                        res1[edge_type] += [edge]
                else:
                    if edge not in G2.__getattribute__(edge_type):
                        res1[edge_type] += [edge]
        return res1

    # -------------------------------------------------

    # ancillary
    def __plot_create_nx__(self):
        G = nx.MultiDiGraph()  # allows multiple edges & types

        # Directed edges
        for u, v in self.directed:
            G.add_edge(u, v, type="directed")

        # Bidirected edges: add both directions
        for (u1, v1), (u2, v2) in self.bidirected:
            G.add_edge(u1, v1, type="bidirected")
            G.add_edge(u2, v2, type="bidirected")

        # Undirected edges: add both directions
        for uv in self.undirected:
            u, v = tuple(uv)
            G.add_edge(u, v, type="undirected")
            G.add_edge(v, u, type="undirected")

        return G

    def __plot_nodes_subset__(self, node_subset, node_latent_show):
        node_subset = node_subset or self.nodes
        nodes_to_plot = {}
        for role, nodes in self.nodes_role.items():
            if role=='Latent' and not node_latent_show:
                continue
            else:
                nodes_to_plot[role] = set([node for node in nodes if node in node_subset])
        return nodes_to_plot

    def __plot_nodes_positions__(self, G_draw, nodes_position):
        nodes_position = nodes_position or self.nodes_position
        if not nodes_position:
            try:
                from networkx.drawing.nx_pydot import graphviz_layout
                nodes_position = graphviz_layout(G_draw, prog="dot")
            except ImportError:
                nodes_position = nx.spring_layout(G_draw)
        return nodes_position 

    def __plot_label_adj__(self, node_label_adj, nodes_label):
        if isinstance(node_label_adj, dict):
            adj = {node:node_label_adj.get(node, 0)
                   for node in self.get_nodes(exclude_latent=False)}
        elif isinstance(node_label_adj, (float, int)):
            adj = {node:node_label_adj
                   for node in self.get_nodes(exclude_latent=False)}
        # same for if labels are used
        for node, label in nodes_label.items():
            adj[label] = adj[node]
        return adj

    def __plot_collect_labels_estimate__(self, estimates, show_sig=True, show_se=False, show_ci=False):
        tab = (estimates.est.parameters
               .separate('term',  ['_to', '_from'], '~', remove=False)
               .mutate(_to = tp.str_trim('_to'),
                       _from = tp.str_trim('_from')
                       )
               .filter(tp.col("_from")!='')
               .filter(tp.col("_to")!='')
               .drop_null('_from', '_to')
               )
        digits = 4
        ests = {}
        pvalues = {}
        for row in tab.iterrows():
            est = round(row['estimate'], digits)
            ests |= {(row['_from'], row['_to']): est}

            pvalue = row['pvalue']
            pvalues |= {(row['_from'], row['_to']): pvalue}

        return ests, pvalues

    # styles
    def __plot_get_style__(self, graph_style):
        if graph_style=='default':
            aes = self.__plot_get_style_default__()

        elif graph_style=='rectangle':
            aes = self.__plot_get_style_rectangle__()

        elif graph_style=='pearl':
            aes = self.__plot_get_style_pearl__()

        else:
            aes = self.__plot_get_style_default__()

        return aes

    def __plot_get_style_default__(self):
        nodes = {
            "Exposure": {
                "node_shape": "o",
                "node_size": 1000,
                "node_color": "lightgray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Outcome": {
                "node_shape": "o",
                "node_size": 1000,
                "node_color": "gray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Observed": {
                "node_shape": "o",
                "node_size": 1000,
                "node_color": "white",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Latent": {
                "node_shape": "o",
                "node_size": 1000,
                "node_color": "white",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "--"
            }
        }
        for role, node in self.nodes_role.items():
            if role not in nodes.keys():
                nodes[role] = nodes['Observed']

        labels = {
            "Exposure": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Outcome": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Observed": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Latent": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            }
        }
        for role, node in self.nodes_role.items():
            if role not in labels.keys():
                labels[role] = labels['Observed']

        edges = {
            "edge_label": self.edge_label,
            "edge_style": {"directed": "solid",
                           "bidirected": "dashed",
                           "undirected": "solid"
                           },
            "edge_color": {"directed": "black",
                           "bidirected": "black",
                           "undirected": "orange"
                           },
            "edge_arc": {"directed": 0,
                         "bidirected": -.33,
                         "undirected": 0,
                         },
            "edge_linewidth": {"directed": 1.5,
                               "bidirected": 1.5,
                               "undirected": 1.5
                               },
            "edge_head_size": {"directed": 20,
                               "bidirected": 20,
                               "undirected": 0
                               },
            "edge_head_style": {"directed": None,
                               "bidirected": '<|-|>',
                               "undirected": '-' 
                               },
            "edge_margin_tail": {"directed": 20,
                                 "bidirected": 20,
                                 "undirected": 0
                                 },
            "edge_margin_head": {"directed": 20,
                                 "bidirected": 20,
                                 "undirected": 0
                                 }
        }
        return [nodes, labels, edges]

    def __plot_get_style_pearl__(self):
        nodes = {
            "Exposure": {
                "node_shape": ".",
                "node_size": 200,
                "node_color": "lightgray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Outcome": {
                "node_shape": ".",
                "node_size": 200,
                "node_color": "gray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Observed": {
                "node_shape": ".",
                "node_size": 200,
                "node_color": "black",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Latent": {
                "node_shape": ".",
                "node_size": 200,
                "node_color": "white",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "--"
            }
        }
        for role, node in self.nodes_role.items():
            if role not in nodes.keys():
                nodes[role] = nodes['Observed']

        labels = {
            "Exposure": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Outcome": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Observed": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Latent": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            }
        }
        for role, node in self.nodes_role.items():
            if role not in labels.keys():
                labels[role] = labels['Observed']

        edges = {
            "edge_label": self.edge_label,
            "edge_style": {"directed": "solid",
                           "bidirected": "dashed",
                           "undirected": "solid"
                           },
            "edge_color": {"directed": "black",
                           "bidirected": "black",
                           "undirected": "orange"
                           },
            "edge_arc": {"directed": 0,
                         "bidirected": -.33,
                         "undirected": 0,
                         },
            "edge_linewidth": {"directed": 1.5,
                               "bidirected": 1.5,
                               "undirected": 1.5
                               },
            "edge_head_size": {"directed": 20,
                               "bidirected": 20,
                               "undirected": 0
                               },
            "edge_head_style": {"directed": None,
                               "bidirected": '<|-|>',
                               "undirected": '-' 
                               },
            "edge_margin_tail": {"directed": 0,
                                 "bidirected": -10,
                                 "undirected": -10
                                 },
            "edge_margin_head": {"directed": -10,
                                 "bidirected": -10,
                                 "undirected": 0
                                 }
        }
        return [nodes, labels, edges]

    def __plot_get_style_rectangle__(self, *args, **kws):
        nodes = {
            "Exposure": {
                "node_shape": "",
                "node_size": 1000,
                "node_color": "lightgray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Outcome": {
                "node_shape": "",
                "node_size": 1000,
                "node_color": "gray",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Observed": {
                "node_shape": "",
                "node_size": 1000,
                "node_color": "white",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "-"
            },
            "Latent": {
                "node_shape": "",
                "node_size": 1000,
                "node_color": "white",
                "node_border_color": "black",
                "node_border_width": 1,
                "node_border_style": "--"
            }
        }
        for role, node in self.nodes_role.items():
            if role not in nodes.keys():
                nodes[role] = nodes['Observed']

        labels = {
            "Exposure": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Outcome": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Observed": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            },
            "Latent": {
                "node_label_fontweight": "normal",
                "node_label_fontsize"  : 12,
            }
        }
        for role, node in self.nodes_role.items():
            if role not in labels.keys():
                labels[role] = labels['Observed']

        edges = {
            "edge_label": self.edge_label,
            "edge_style": {"directed": "solid",
                           "bidirected": "dashed",
                           "undirected": "solid"
                           },
            "edge_color": {"directed": "black",
                           "bidirected": "black",
                           "undirected": "orange"
                           },
            "edge_arc": {"directed": 0,
                         "bidirected": -.33,
                         "undirected": 0,
                         },
            "edge_linewidth": {"directed": 1.5,
                               "bidirected": 1.5,
                               "undirected": 1.5
                               },
            "edge_head_size": {"directed": 20,
                               "bidirected": 20,
                               "undirected": 0
                               },
            "edge_head_style": {"directed": None,
                               "bidirected": '<|-|>',
                               "undirected": '-' 
                               },
            "edge_margin_tail": {"directed": 20,
                                 "bidirected": 20,
                                 "undirected": 0
                                 },
            "edge_margin_head": {"directed": 20,
                                 "bidirected": 20,
                                 "undirected": 0
                                 }
        }
        return [nodes, labels, edges]

    def __plot_collect_aes__(self, role, aes, default):
        res = None
        if aes is not None:
            if isinstance(aes, dict):
                res = aes.get(role, None)
            else:
                res = aes

        if not res:
            res = default
        return res

    def __plot_edge_margin__(self, edge_margin, default=20):
        edge_margin = edge_margin or {}
        edges = self.directed + self.bidirected
        if isinstance(edge_margin, (float, int)):
            edge_margin = {e:edge_margin for e in edges}
        edge_margin = {e:edge_margin.get(e, default) for e in edges}

        return edge_margin

    def __plot_edge_label_feature__(self, feature, edge, value, default=None,
                                    alpha_level=0.05, label=None, edge_label_pvalue=None):
        res = value.get(edge, default) if isinstance(value, dict) else (value or default)

        # default color: red for negative, black for positive
        if feature=='color' and not res:
            try:
                label = float(label)
                res = 'red' if label < 0 else 'black'
            except (TypeError, ValueError) as e:
                # default
                res = 'black'

        # default alpha: full for significant, faded otherwise
        if feature=='alpha' and not res and edge_label_pvalue:
            try:
                res = 1 if edge_label_pvalue.get(edge, 0) <= alpha_level else 0.2
            except (TypeError, ValueError) as e:
                # default
                res = 1
        return res

causalinf.gcm.DAG.edge_exist(edge, edges=None)

Check whether edge exists in edges, robust to order of nodes for undirected and bidirected edges.

Parameters:

Name Type Description Default
edge tuple or set

A tuple representing the edge to check, e.g., (node1, node2). For directed edge: typle is (from, to) For bidirected edge: typle is ((node1, node2), (node2, node1)) For undirected edge: set is {node1, node2}

required
edges list or None

A list of edges to check against. If None, the method will retrieve the edges associated with the edge type.

None

Returns:

Type Description
bool

True if the edge exists in the edges list, False otherwise.

Source code in causalinf/gcm.py
def edge_exist(self, edge, edges=None):
    """
    Check whether `edge` exists in `edges`,
    robust to order of nodes for undirected and bidirected edges.

    Parameters
    ----------
    edge : tuple or set
        A tuple representing the edge to check, e.g., (node1, node2).
        For directed edge: typle is (from, to)
        For bidirected edge: typle is ((node1, node2), (node2, node1))
        For undirected edge: set is {node1, node2}

    edges : list or None, optional
        A list of edges to check against. If None, the method will retrieve 
        the edges associated with the edge type.

    Returns
    -------
    bool
        True if the edge exists in the edges list, False otherwise.
    """
    if edges is None:
        edge_type = self.__edge_type__(edge)
        edges = self.__getattribute__(edge_type)
    edges = [edges] if not isinstance(edges, list) else edges
    edge = self.__edge_frozen_format__(edge)
    edges_in_list = {self.__edge_frozen_format__(e) for e in edges}
    return edge in edges_in_list

causalinf.gcm.DAG.equivalence_class()

Details

A equivalence class of a DAG is a graph that replaces directional edges by undirectional edges except in v-structures (triples X->Z<-Y where X and Y are not adjacent). Therefore, all Markov equivalent DAGs will have the same equivalence class.

Source code in causalinf/gcm.py
def equivalence_class(self):
    """
    Details
    -------
    A equivalence class of a DAG is a graph that replaces directional edges
    by undirectional edges except in v-structures (triples X->Z<-Y where 
    X and Y are not adjacent). Therefore, all Markov
    equivalent DAGs will have the same equivalence class.
    """
    eq = dagitty.equivalenceClass(self.__dagitty__)
    dag, _ = self.__dagitty2inputs__(eq)
    res = self.__rebuild_graph__(dag)
    return res

causalinf.gcm.DAG.get_identified(by='parameter', include_all=False)

by : str ‘parameter’ or ‘strategy’

Source code in causalinf/gcm.py
def get_identified(self, by='parameter', include_all=False):
    """
    by : str
       'parameter' or 'strategy'
    """
    if not self.__identification__:
        self.identification_analysis()
    res = self.__identification__.get_identified(by=by, include_all=include_all)
    return res

causalinf.gcm.DAG.identification_analysis(exposure=None, outcome=None, conditional=None, causal_probability='maybe', iv='maybe', verbose=True)

causal_probability: str If ‘always’, always compute it; if ‘maybe’, compute it only if there is not identification by adjustment for total_effect_adj_set effect

conditional: list or str List of variables to condition the causal effect on

Source code in causalinf/gcm.py
def identification_analysis(self, exposure=None, outcome=None,
                            conditional = None,
                            causal_probability='maybe',
                            iv='maybe',
                            verbose=True
                            ):
    """
    causal_probability: str
        If 'always', always compute it; if 'maybe', compute it
        only if there is not identification by adjustment for
        total_effect_adj_set effect

    conditional: list or str
        List of variables to condition the causal effect on        
    """
    assert not outcome or isinstance(outcome, str), 'Outcome must be a string.'
    assert not exposure or (isinstance(exposure, str) or isinstance(exposure, list)), 'Exposure must be a string or list.'

    assert outcome or self.outcome, "No outcome found."
    assert exposure or self.exposure, "No exposure found."

    exposure = exposure or self.exposure
    outcome = outcome or self.outcome[0]
    conditional = [conditional] if isinstance(conditional, str) else conditional

    assert exposure is not None, "Exposure must be provided."
    assert outcome is not None, "Outcome must be provided."

    self.__identification__ = identification(G=self,
                                             exposure = exposure,
                                             outcome = outcome,
                                             conditional = conditional,
                                             causal_probability = causal_probability,
                                             iv = iv,
                                             verbose=verbose)
    if verbose:
        self.print('identification')

    return None

causalinf.gcm.DAG.local_independencies(data=None, alpha=0.05, include_sep_cols=False)

Given a networkx.DiGraph, return implied conditional independencies using dagitty (via R).

Parameters: G (nx.DiGraph): Directed acyclic graph (must be a valid DAG) data: tibble data frame from tidypolars4sci

Returns: tibble dataframe from tidypolars4sci

Source code in causalinf/gcm.py
def local_independencies(self, data=None, alpha=0.05, include_sep_cols=False):
    """
    Given a networkx.DiGraph, return implied conditional independencies using dagitty (via R).

    Parameters:
        G (nx.DiGraph): Directed acyclic graph (must be a valid DAG)
        data: tibble data frame from tidypolars4sci

    Returns:
        tibble dataframe from tidypolars4sci
    """
    if data is None:
        data = self.data
    # compute
    if data is None:
        inds = dagitty.impliedConditionalIndependencies(self.__dagitty__)
        res = tp.tibble()
        for ind in inds:
            y = ind[0][0]
            x = ind[1][0]
            z = ind[2]
            term = f"{y} _||_ {x}"
            term = f"{term} | {', '.join(z)}" if z else term
            tmp = tp.tibble({'term': [term],
                             "var1": [y],
                             "var2": [x],
                             "cond": [z]})
            res = res.bind_rows(tmp)
        inds = res
    else:
        inds = dagitty.localTests(self.__dagitty__, data=convert().tp2tibble(data), abbreviate_names=False)
        z = dnorm.ppf(1-alpha/2)
        inds = convert().rtibble2tp(inds, rownames2col='term')\
                     .rename({'p.value':"pvalue",
                              '2.5%':'lo',
                              '97.5%':'hi',
                              })\
                     .mutate(se = ( tp.col('hi')-tp.col('lo') ) / (2*z) )
        if inds.nrow>0:
            inds = (
                inds
                .separate('term', into=['var1', 'var2_cond'], sep='_||_', remove=False)
                .separate('var2_cond', into=['var2', 'cond'], sep='|')
            )

    vars = ['term', 'estimate', 'se', 'lo', 'hi', 'pvalue']
    if include_sep_cols:
        vars += ['var1', 'var2', 'cond']
    inds = inds.select(vars)

    return inds

causalinf.gcm.DAG.observationally_equivalent(G)

Check if two DAGs are observationally equivalent by comparing their markov equivalent classes. It applies to CBN or for SCM when no functional form for the SCM equations were selected. See details.

Details

Observational equivalence is related to Markov equivalence.

Two DAGs are Markov equivalent iff A. They have the same skeleton (same set of adjacencies, i.e. same undirected edges) B. They have the same set of v-structures, which are triples X->Z<-Y where X and Y are not adjacent).

A equivalence class of a DAG is a graph that replaces directional edges by undirectional edges except in v-structures. Therefore, all Markov equivalent DAGs will have the same equivalence class.

For CBN: - Two CBNs are observational equivalence iff they are Markov equivalence.

For SCM: Without functional form assumptions_show, for observational equivalence: - Necessary condition: both SCMs have the same set of conditional independencies - Sufficient condition: both SCMs are in the same markov equivalence class (Pearl, 2009) - Basically, two SCMs are observationally equivalent iff their causal graphs belong to the same Markov equivalence class — i.e., they share the same skeleton and v-structures.

With functional form assumptions_show - Once you impose functional form restrictions on SCMs, such as linearity, Gaussian disturbance, or additive error, and so on, observational equivalence can be strictly finer. That is, Markov-equivalence is not a sufficient condition. Example: a. Linear Gaussian SEMs assumption: - All DAGs in the same equivalence class remain indistinguishable. Markov equivalence = observational equivalence. Reason: any covariance matrix that one DAG can generate can also be generated by another DAG in its equivalence class, via suitable parameter choice.

b. Linear non-Gaussian models (LiNGAM) - Orientations become testable because independent non-Gaussian noise ‘pins down’ which variable must be the parent, breaking Markov equivalence. Example: X->Y and X <- Y: In Gaussian case: indistinguishable. In non-Gaussian: identifiable.

c. Additive Noise Models (ANMs) - If the true relation is Y = f(X) + e with independent noise e, then typically the ‘wrong’ orientation X = g(Y) + e’ cannot hold with independent noise. So direction becomes identifiable.

In summary, generally SCMs (no distributional restrictions), Markov equivalence does imply observational equivalence. But once you impose restrictions (linear, Gaussian, additive, etc.), observational equivalence can be strictly finer. That is, if one assumes functional forms or noise properties, one may be able to distinguish DAGs inside a Markov equivalence class. Some Markov-equivalent DAGs become distinguishable. Then, the test of equivalence depends on the functional form assumption adopted, so it is case-by-case.

References
  • Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge Univ Press.
Source code in causalinf/gcm.py
def observationally_equivalent(self, G):
    """
    Check if two DAGs are observationally equivalent by comparing their
    markov equivalent classes. It applies to CBN or for
    SCM when no functional form for the SCM equations were selected.
    See details.

    Details
    -------
    Observational equivalence is related to Markov equivalence.

    Two DAGs are Markov equivalent iff
    A. They have the same skeleton (same set of adjacencies, i.e. same undirected edges)
    B. They have the same set of v-structures, which are triples X->Z<-Y where 
       X and Y are not adjacent).

    A equivalence class of a DAG is a graph that replaces directional edges
    by undirectional edges except in v-structures. Therefore, all Markov
    equivalent DAGs will have the same equivalence class.

    For CBN:
    - Two CBNs are observational equivalence iff they are Markov equivalence.

    For SCM:
    Without functional form assumptions_show, for observational equivalence:
    - Necessary condition: both SCMs have the same set of conditional independencies
    - Sufficient condition: both SCMs are in the same markov equivalence class (Pearl, 2009)
    - Basically, two SCMs are observationally equivalent iff their causal graphs belong
      to the same Markov equivalence class — i.e., they share the same skeleton and v-structures.

    With functional form assumptions_show
    - Once you impose functional form restrictions on SCMs, such as linearity,
      Gaussian disturbance, or additive error, and so on, observational equivalence
      can be strictly finer. That is, Markov-equivalence is not a sufficient condition.
      Example:
      a. Linear Gaussian SEMs assumption:
         - All DAGs in the same equivalence class remain indistinguishable.
           Markov equivalence = observational equivalence.
           Reason: any covariance matrix that one DAG can generate can also
           be generated by another DAG in its equivalence class, via suitable parameter choice.

      b. Linear non-Gaussian models (LiNGAM)
         - Orientations become testable because independent non-Gaussian noise
           'pins down' which variable must be the parent, breaking Markov equivalence.
            Example: X->Y and X <- Y: In Gaussian case: indistinguishable.
                     In non-Gaussian: identifiable.

      c. Additive Noise Models (ANMs)
         - If the true relation is Y = f(X) + e with independent noise e,
           then typically the 'wrong' orientation X = g(Y) + e'
           cannot hold with independent noise. So direction becomes identifiable.

    In summary, generally SCMs (no distributional restrictions), Markov equivalence
    does imply observational equivalence. But once you impose restrictions
    (linear, Gaussian, additive, etc.), observational equivalence can be strictly finer.
    That is, if one assumes functional forms or noise properties, one may be able to 
    distinguish DAGs inside a Markov equivalence class. Some Markov-equivalent DAGs
    become distinguishable. Then, the test of equivalence depends on the
    functional form assumption adopted, so it is case-by-case.

    References
    ----------
    - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge Univ Press.
    """
    # check if same equivalence class
    G1_eq = self.equivalence_class()
    G2_eq = G.equivalence_class()
    diff = G1_eq.edge_differences(G2_eq)
    obs_eq = True
    for g, edges in diff.items():
        obs_eq &= all([len(e)==0 for e in edges.values()])
    return obs_eq 

causalinf.gcm.DAG.plot(graph_style=None, nodes_label=None, nodes_position=None, estimates=None, node_subset=None, node_shape=None, node_size=None, node_color=None, node_border_color=None, node_border_style=None, node_border_width=None, node_latent_show=True, show_labels=True, use_labels=True, node_label_box=True, node_label_fontsize=None, node_label_fontweight='normal', node_label_adj_x=0, node_label_adj_y=0, node_label_box_style='square', node_label_box_margin=0.5, edge_subset=None, edge_color=None, edge_style=None, edge_arc=None, edge_linewidth=None, edge_head_size=None, edge_head_style=None, edge_margin_tail=None, edge_margin_head=None, edge_label=None, edge_label_color_background='white', edge_label_color_border='white', edge_label_size=None, edge_label_color=None, edge_label_alpha=None, edge_label_rotate=None, edge_label_position=None, edge_label_sig_level=0.05, edge_label_pvalue=None, edge_label_font_family=None, legend_show=True, legend_title='Nodes', legend_title_align='left', legend_title_weight='bold', legend_title_size=12, legend_omit_cases=['Observed'], legend_keys=None, legend_loc='best', legend_fontsize=10, legend_frame=False, legend_kws={}, title=None, title_loc='left', title_kws={}, figsize=[6, 4], usetex=True, ax=None, show_plot=None, *args, **kws)

Draw a custom DAG with support for: - Latent variables - Curved edges - Colored and dotted arcs - Optional arc representation for latent confounding - Custom node labels

Parameters: G (nx.DiGraph): The input DAG with optional edge attributes ‘style’, ‘color’, ‘curved’. estimates: obj A LSEM object from the cass causalinf.scm.estimate nodes_position (dict): Optional node positions for layout. nodes_role (dict): Optional dict with keys ‘latent’, ‘exposure’, ‘outcome’ listing node names. use_arc (bool): If True, draw dotted arcs between children of latent confounders instead of drawing latent nodes. nodes_label (dict): Optional dict mapping node names to display labels. show_labels (str or None; Default=’label’): One of ‘label’, ‘name’, or ‘none’. If ‘label’, use labels if provided; If ‘name’, always use node name; If ‘none’, don’t omit labels and names of nodes altogether. node_label_adj_x (float or dict): displaces the labels in the x direction. If dict, the keys should be the node labels or name, and displacement will be applied only to those points specified in the dict. If float, the same displacement is applied to all nodes. node_label_adj_x (float or dict): same as node_label_adj_y, but for the y axis graph_style (str): specific styles for nodes and arrows - ‘default’: nodes in circles with labels in their middle - ‘rectangle’: nodes in rectangles with labels in their middle - ‘pearl’: nodes as dots with labels next to them (use node_label_adj_x and node_label_adj_y to adjust the location of the labels) All features can be overwrittied by specifying the value of the parameters for the plot.

Source code in causalinf/gcm.py
def plot(self,
         # nodes
         graph_style = None,
         nodes_label=None,
         nodes_position=None,
         estimates=None,
         # node
         node_subset=None,
         node_shape=None,
         node_size = None,
         node_color = None,
         node_border_color=None,
         node_border_style=None,
         node_border_width=None,
         node_latent_show=True,
         # node label
         show_labels = True,
         use_labels = True,
         node_label_box=True,
         node_label_fontsize=None,
         node_label_fontweight='normal',
         node_label_adj_x=0,
         node_label_adj_y=0,
         node_label_box_style="square",
         node_label_box_margin=.5,
         # edges
         edge_subset=None,
         edge_color=None,
         edge_style=None,
         edge_arc = None,
         edge_linewidth = None,
         edge_head_size = None,
         edge_head_style = None,
         edge_margin_tail=None,
         edge_margin_head=None,
         # edges labels
         edge_label=None,
         edge_label_color_background='white',
         edge_label_color_border='white',
         edge_label_size=None,
         edge_label_color=None,
         edge_label_alpha=None,
         edge_label_rotate=None,
         edge_label_position=None,
         edge_label_sig_level=0.05,
         edge_label_pvalue=None,
         edge_label_font_family = None,
         # legend
         legend_show=True,
         legend_title='Nodes',
         legend_title_align='left',
         legend_title_weight='bold',
         legend_title_size=12,
         legend_omit_cases=['Observed'],
         legend_keys=None,
         legend_loc='best',
         legend_fontsize=10,
         legend_frame=False,
         legend_kws={},
         #
         title = None,
         title_loc = 'left',
         title_kws = {},
         # 
         figsize = [6, 4],
         usetex = True,
         ax=None,
         show_plot=None,
         *args,
         **kws
         ):
    """
    Draw a custom DAG with support for:
      - Latent variables
      - Curved edges
      - Colored and dotted arcs
      - Optional arc representation for latent confounding
      - Custom node labels

    Parameters:
        G (nx.DiGraph): The input DAG with optional edge attributes 'style', 'color', 'curved'.
        estimates: obj
            A LSEM object from the cass causalinf.scm.estimate
        nodes_position (dict): Optional node positions for layout.
        nodes_role (dict): Optional dict with keys 'latent', 'exposure', 'outcome' listing node names.
        use_arc (bool): If True, draw dotted arcs between children of latent confounders instead of drawing latent nodes.
        nodes_label (dict): Optional dict mapping node names to display labels.
        show_labels  (str or None; Default='label'): One of 'label', 'name', or 'none'.
            If 'label', use labels if provided; If 'name', always use node name; If 'none',
            don't omit labels and names of nodes altogether.
        node_label_adj_x (float or dict): displaces the labels in the x direction. If dict, the keys
            should be the node labels or name, and displacement will be applied only to those points
            specified in the dict. If float, the same displacement is applied to all nodes.
        node_label_adj_x (float or dict): same as node_label_adj_y, but for the y axis
        graph_style (str): specific styles for nodes and arrows
              - 'default': nodes in circles with labels in their middle
              - 'rectangle': nodes in rectangles with labels in their middle
              - 'pearl': nodes as dots with labels next to them (use node_label_adj_x and node_label_adj_y
                         to adjust the location of the labels)
              All features can be overwrittied by specifying the value of the parameters for the plot.
    """
    assert estimates is None or isinstance(estimates, estimate), (
        "'estimates' must be either None or an object of causalinf.scm.estimate ")

    default_usetex = plt.rcParams["text.usetex"] 
    plt.rcParams["text.usetex"] = usetex
    plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath, amssymb, siunitx, bm}'
    show_plot = show_plot if not None else get_options('show_plot')

    # collect arguments
    pars = dict(locals())      # {'node_position':..., 'arg2':..., 'args':(...), 'kws':{...}}
    args = pars.pop('args') # extra positional
    kws  = pars.pop('kws')  # extra keyword

    # use estimates as labels
    if estimates is not None:
        edge_label, edge_label_pvalue = self.__plot_collect_labels_estimate__(estimates)

    # figure 
    # ------
    G_draw = self.__plot_create_nx__()
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
    plt.sca(ax)

    # styles
    # ------
    graph_style = graph_style or get_options('graph_style')
    nodes_style, labels_style, edges_style = self.__plot_get_style__(graph_style)

    # nodes 
    # -----
    node_subset    = self.__plot_nodes_subset__(node_subset, node_latent_show)
    nodes_position = self.__plot_nodes_positions__(G_draw, nodes_position)
    for role, nodes in node_subset.items():
        fig_nodes = nx.draw_networkx_nodes(
            G_draw,
            nodes_position,
            nodelist=nodes,
            ax=ax,
            # 
            node_size  = self.__plot_collect_aes__(role, node_size, nodes_style[role]['node_size']),
            node_color = self.__plot_collect_aes__(role, node_color, nodes_style[role]['node_color']),
            node_shape = self.__plot_collect_aes__(role, node_shape, nodes_style[role]['node_shape']),
            linewidths = self.__plot_collect_aes__(role, node_border_width,
                                                   nodes_style[role]['node_border_width']),
            edgecolors = self.__plot_collect_aes__(role, node_border_color,
                                                   nodes_style[role]['node_border_color']),
            alpha      = None,
            cmap       = None,
            vmin       = None,
            vmax       = None,
            label      = None,
            margins    = None, 
            hide_ticks = True
        )
        fig_nodes.set_linestyle(self.__plot_collect_aes__(role, node_border_style,
                                                          nodes_style[role]['node_border_style']))

    # nodes labels 
    # ------------
    if show_labels:
        nodes = set(itertools.chain.from_iterable(node_subset.values()))
        nodes_label = self.nodes_label | (nodes_label or {})
        adj_x = self.__plot_label_adj__(node_label_adj_x, nodes_label)
        adj_y = self.__plot_label_adj__(node_label_adj_y, nodes_label)
        for node in nodes:
            label = nodes_label.get(node, node) if use_labels else node
            role  = self.nodes_info[node]['role']
            x, y  = nodes_position[node] if nodes_position and nodes_position[node] else\
                self.nodes_info[node]['position'] 

            bbox = None
            if node_label_box and graph_style=='rectangle':
                bbox = {
                    "boxstyle": f"{node_label_box_style},pad={node_label_box_margin}",
                    "fc": self.__plot_collect_aes__(role, node_color, nodes_style[role]['node_color']),
                    "ec": self.__plot_collect_aes__(role, node_border_color,
                                                    nodes_style[role]['node_border_color']),
                    "lw": self.__plot_collect_aes__(role, node_border_width,
                                                   nodes_style[role]['node_border_width']),
                    "linestyle": self.__plot_collect_aes__(role, node_border_style,
                                                           nodes_style[role]['node_border_style']),
                    "alpha": 1}

            weight = self.__plot_collect_aes__(role, node_label_fontweight,
                                               labels_style[role]['node_label_fontweight'])
            label = f"\\textbf{{{label}}}" if weight == 'bold' else label
            plt.text(x + adj_x[node],
                     y + adj_y[node],
                     label,
                     fontweight = weight,
                     fontsize   = self.__plot_collect_aes__(role, node_label_fontsize,
                                                            labels_style[role]['node_label_fontsize']),
                     ha = 'center',
                     va = 'center',
                     bbox = bbox)

    # edges and edges labels
    # ----------------------
    nodes = set(itertools.chain.from_iterable(node_subset.values()))
    for edge_type in ['directed', 'bidirected', 'undirected']:
        style = self.__plot_collect_aes__(edge_type, edge_style, edges_style['edge_style'][edge_type])
        color = self.__plot_collect_aes__(edge_type, edge_color, edges_style['edge_color'][edge_type])
        arc   = self.__plot_collect_aes__(edge_type, edge_arc, edges_style['edge_arc'][edge_type])
        width = self.__plot_collect_aes__(edge_type, edge_linewidth, edges_style['edge_linewidth'][edge_type])
        arrow_head_size = self.__plot_collect_aes__(edge_type, edge_head_size, edges_style['edge_head_size'][edge_type])
        arrow_head_style = self.__plot_collect_aes__(edge_type, edge_head_style, edges_style['edge_head_style'][edge_type])
        edge_margin_tail = self.__plot_edge_margin__(edge_margin_tail, edges_style["edge_margin_tail"][edge_type])
        edge_margin_head = self.__plot_edge_margin__(edge_margin_head, edges_style["edge_margin_head"][edge_type])

        for edge in self.__getattribute__(edge_type):
            edge = tuple(edge)
            if edge_type!='bidirected':
                u, v = edge
            else:
                u, v = edge[0][0], edge[0][1]

            # collect edges to show if edge_subset 
            show_edge = True
            if edge_subset:
                e = set(edge) if edge_type=='undirected' else edge
                show_edge = self.edge_exist(e, edge_subset.get(edge_type, []))

            if u in nodes and v in nodes and show_edge:
                # edge
                nx.draw_networkx_edges(
                    G_draw,
                    nodes_position,
                    edgelist            = [(u, v)],
                    style               = style,
                    edge_color          = color,
                    connectionstyle     = f"arc3,rad={arc}",
                    arrows              = True,
                    arrowstyle          = arrow_head_style,
                    arrowsize           = arrow_head_size,
                    min_source_margin   = edge_margin_tail.get(edge, 0),
                    min_target_margin   = edge_margin_head.get(edge, 0),
                    width               = width,
                    ax=ax)

                # edge label
                edge_label = edge_label or self.edge_label
                label = edge_label.get(edge, '')
                rotate = edge_label_rotate if edge_label_rotate is not None else True # must keep "is not None" here
                nx.draw_networkx_edge_labels(
                    G_draw,
                    pos             = nodes_position,
                    connectionstyle = f"arc3,rad={arc}",
                    edge_labels     = {(u, v): label},
                    bbox=dict(facecolor=edge_label_color_background, edgecolor=edge_label_color_border),
                    # 
                    alpha      = self.__plot_edge_label_feature__('alpha', edge, edge_label_alpha, None, edge_label_sig_level,
                                                                  edge_label_pvalue=edge_label_pvalue),
                    font_size  = self.__plot_edge_label_feature__('size' , edge, edge_label_size, 15),
                    font_color = self.__plot_edge_label_feature__('color', edge, edge_label_color, label=label),
                    rotate     = self.__plot_edge_label_feature__('rotate', edge, edge_label_rotate, default=rotate),
                    label_pos  = self.__plot_edge_label_feature__('position', edge, edge_label_position, .5),
                    font_family=edge_label_font_family,
                    ax         = ax
                )

    # legend 
    # ------
    if legend_show:
        keys = []
        for role, _ in node_subset.items():
            if role not in legend_omit_cases:
                if role=='Latent' and node_latent_show:
                    marker = ''
                    linecolor = self.__plot_collect_aes__(role, node_border_color,
                                                          nodes_style[role]['node_border_color']) 
                else:
                    marker = 'o'
                    linecolor='white'
                keys += [
                    Line2D(
                        [0], [0],
                        marker=marker,
                        color=linecolor,
                        label=role,
                        markersize=10,
                        markeredgecolor=self.__plot_collect_aes__(role, node_border_color,
                                                                  nodes_style[role]['node_border_color']),
                        markerfacecolor=self.__plot_collect_aes__(role, node_color,
                                                                  nodes_style[role]['node_color']),
                        linestyle=self.__plot_collect_aes__(role, node_border_style,
                                                            nodes_style[role]['node_border_style'])
                    )
                ]
            if keys: 
                legend = plt.legend(handles        = keys,
                                    title          = legend_title,
                                    title_fontsize = legend_title_size,
                                    alignment      = legend_title_align,
                                    # title_weight   = legend_title_weight,
                                    loc            = legend_loc,
                                    fontsize       = legend_fontsize,
                                    frameon        = legend_frame,
                                    **legend_kws
                                    )
                if legend_title_weight=='bold' and legend_title:
                    legend.set_title(title=f'\\textbf{{{legend_title}}}', prop={'weight': 'bold'})

    # title 
    # -----
    if title:
        plt.title(label=title, loc=title_loc, **title_kws)

    plt.axis("off")
    plt.tight_layout()
    if show_plot:
        plt.show()
    plt.rcParams["text.usetex"] = default_usetex

    return plt, ax

causalinf.gcm.DAG.plot_identification(content='default', effect='total', show_np=True, show_linear=True, show_do=True, kws_graph={}, kws_identification={}, kws_detailed=None, figsize=None, ratio=None, ncols=None, nrows=None, title_dag=None, title_info=None, txt_line_height=0.55, *args, **kws)

txt_line_height: float height of the lines for the text. Not used if figsize is set. kws_detailed : dict Example: {‘parameter’:’ACE’ ‘strategy’:’SoO’}

Source code in causalinf/gcm.py
def plot_identification(self,
                        content='default', # detailed, default
                        effect='total', #total, direct, or do, only if if_info=full
                        show_np = True,
                        show_linear = True,
                        show_do = True,
                        kws_graph={},
                        kws_identification={},
                        kws_detailed = None,
                        figsize = None,
                        ratio   = None,
                        ncols   = None,
                        nrows   = None,
                        title_dag = None,
                        title_info = None,
                        txt_line_height=.55,
                        *args,
                        **kws
                        ):
    """
    txt_line_height: float
        height of the lines for the text. Not used if figsize is set.
    kws_detailed : dict
        Example: 
       {'parameter':'ACE'
         'strategy':'SoO'}
    """
    roles = ['Exposure', 'Outcome', 'Latent', 'Observed',
             'exposure', 'outcome', 'latent', 'observed']
    for role in roles:
        assert not kws_graph.get(role, None) and not kws_identification.get(role, None), (
            f"Setting node role ({role}) not allowed in the plot kws. "+
            f"To set the node role, create a new DAG or use set_node_role before plotting.")

    if not self.__identification__ or kws_identification:
        self.identification_analysis(**kws_identification, verbose=False)

    # defaults for kws_detailed
    kws_detailed = kws_detailed or {}
    strategy = kws_detailed.get('strategy', 'SoO')
    parameter = kws_detailed.get('parameter', None)
    if not parameter:
        parameter = next(iter(self.__identification__.identification[strategy]))
    kws_detailed['strategy'] = strategy
    kws_detailed['parameter'] = parameter

    return self.__identification__.plot(G=self,
                                        info=content,
                                        effect=effect,
                                        show_np = show_np,
                                        show_linear = show_linear,
                                        show_do = show_do,
                                        figsize=figsize,
                                        ratio=ratio,
                                        ncols=ncols,
                                        nrows=nrows,
                                        kws_graph=kws_graph,
                                        kws_detailed = kws_detailed,
                                        txt_line_height=txt_line_height,
                                        title_dag = title_dag,
                                        title_info = title_info,
                                        *args,
                                        **kws
                                        )

causalinf.gcm.DAG.print(what='graph', identification=dict(content='default', style='text', strategy='all', parameter='ACE', omit_DAG=True, print_assumptions=None, print_assumptions_verbose=None))

what : str What to print ‘graph’, ‘DAG’, ‘dag’, ‘identification’

identification : dict Options to print identification results - content - style - strategy - parameter - omit_DAG - assumptions - assumptions_verbose

Source code in causalinf/gcm.py
def print(self,
          what = 'graph',
          identification = dict(
              content='default',
              style='text',
              strategy = 'all',
              parameter = 'ACE',
              omit_DAG=True,
              print_assumptions=None,
              print_assumptions_verbose=None
          )
          ):
    """
    what : str
        What to print
        'graph', 'DAG', 'dag', 'identification'

    identification : dict
        Options to print identification results
         - content
         - style
         - strategy
         - parameter
         - omit_DAG
         - assumptions
         - assumptions_verbose

    """
    if what in ['graph', 'DAG', 'dag']:
        print(self)
    if what=='identification':
        ops = identification.copy()
        # defaults
        pars = ["print_assumptions", "print_assumptions_verbose"]
        for par in pars:
            if ops.get(par, None) is None:
                ops[par] = get_options()[par]

        if not self.__identification__:
            self.identification_analysis()
        self.__identification__.print(**identification)
        self.__identification__.__assumptions_print__(category='identification', **ops)
    return None

causalinf.gcm.examples

Registry of example DAGs.

Usage:

from causalinf.scm import examples examples() # -> [‘Not identifiable’, …] dag = examples(which=’Not identifiable’) # -> DAG instance examples(‘Not identifiable’, as_text=True)

Source code in causalinf/gcm.py
3130
3131
3132
3133
3134
3135
3136
3137
3138
3139
3140
3141
3142
3143
3144
3145
3146
3147
3148
3149
3150
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
3162
3163
3164
3165
3166
3167
3168
3169
3170
3171
3172
3173
3174
3175
3176
3177
3178
3179
3180
3181
3182
3183
3184
3185
3186
3187
3188
3189
3190
3191
3192
3193
3194
3195
3196
3197
3198
3199
3200
3201
3202
3203
3204
3205
3206
3207
3208
3209
3210
3211
3212
3213
3214
3215
3216
3217
3218
3219
3220
3221
3222
3223
3224
3225
3226
3227
3228
3229
3230
3231
3232
3233
3234
3235
3236
3237
3238
3239
3240
3241
3242
3243
3244
3245
3246
3247
3248
3249
3250
3251
3252
3253
3254
3255
3256
3257
3258
3259
3260
3261
3262
3263
3264
3265
3266
3267
3268
3269
3270
3271
3272
3273
3274
3275
3276
3277
3278
3279
3280
3281
3282
3283
3284
3285
3286
3287
3288
3289
3290
3291
3292
3293
3294
3295
3296
3297
3298
3299
3300
3301
3302
3303
3304
3305
3306
3307
3308
3309
3310
3311
3312
3313
3314
3315
3316
3317
3318
3319
3320
3321
3322
3323
3324
3325
3326
3327
3328
3329
3330
3331
3332
3333
3334
3335
3336
3337
3338
3339
3340
3341
3342
3343
3344
3345
3346
3347
3348
3349
3350
3351
3352
3353
3354
3355
3356
3357
3358
3359
3360
3361
3362
3363
3364
3365
3366
3367
3368
3369
3370
3371
3372
3373
3374
3375
3376
3377
3378
3379
3380
3381
3382
3383
3384
3385
3386
3387
3388
3389
3390
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
3409
3410
3411
3412
3413
3414
3415
3416
3417
3418
3419
3420
3421
3422
3423
3424
3425
3426
3427
3428
3429
3430
3431
3432
3433
3434
3435
3436
3437
3438
3439
3440
3441
3442
3443
3444
3445
3446
3447
3448
3449
3450
3451
3452
3453
3454
3455
3456
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
3483
3484
3485
3486
3487
3488
3489
3490
3491
3492
3493
3494
3495
3496
3497
3498
3499
3500
3501
3502
3503
3504
3505
3506
3507
3508
3509
3510
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
3522
3523
3524
3525
3526
3527
3528
3529
3530
3531
3532
3533
3534
3535
3536
3537
3538
3539
3540
3541
3542
3543
3544
3545
3546
3547
3548
class examples:
    """
    Registry of example DAGs.

    Usage:
    -------
    >>> from causalinf.scm import examples
    >>> examples()                                # -> ['Not identifiable', ...]
    >>> dag = examples(which='Not identifiable')  # -> DAG instance
    >>> examples('Not identifiable', as_text=True)
    """

    def __new__(cls, which=None, print_DAG=False, *args, **kws):
        if not which:
            examples._print_examples(print_DAG=print_DAG)
            dag = None
        else:
            try:
                dag = examples._get_examples(which, *args, **kws)
            except KeyError:
                # friendly suggestion for typos
                suggestion = difflib.get_close_matches(which, examples._get_examples().keys(), n=3, cutoff=0.4)
                hint = f" Did you mean: {', '.join(suggestion)}?" if suggestion else ""
                raise ValueError(f"Unknown example '{which}'.{hint}")
        return DAG(**dag) if dag else None

    def _get_examples(which=None, *args, **kws):
        all_examples = {
            "Not identifiable" : examples._example_not_identifiable(*args, **kws),
            "One confounder"  : examples._example_one_confounder(*args, **kws),
            "Two confounders" : examples._example_two_confounder(*args, **kws),
            "Front-door"       : examples._example_front_door(*args, **kws),
            "IV with 1 instrument"  : examples._example_iv_1_instrument(*args, **kws),
            "IV with 3 instruments"  : examples._example_iv_3_instruments(*args, **kws),
            "SoO, IV, and do identified with 1 confounder": examples._example_soo_iv_do_one_counfounder(*args,
                                                                                                        **kws),
            "Mediation: 2 sequential 1 confounder" : examples._example_mediation_2_sequential_1_confounder(*args,
                                                                                                           **kws),
            # "Back-door": self._back_door(),
            # Pearl's book
            "Pearl Example 1.1 (a)"  : examples._example_pearl_fig_1_1_a(*args, **kws),
            "Pearl Example 1.1 (b)"  : examples._example_pearl_fig_1_1_b(*args, **kws),
            "Pearl Example 1.2"  : examples._example_pearl_fig_1_2(*args, **kws),
            "Pearl Example 1.3 (a)"  : examples._example_pearl_fig_1_3_a(*args, **kws),
            "Pearl Example 1.3 (b)"  : examples._example_pearl_fig_1_3_b(*args, **kws),
            "Pearl Example 3.1"  : examples._example_pearl_fig_3_1(*args, **kws),
            "Pearl Example 3.4"  : examples._example_pearl_fig_3_4(*args, **kws),
            "Pearl Example 3.5"  : examples._example_pearl_fig_3_5(*args, **kws),
        }
        res = all_examples[which] if which else all_examples
        return res

    def _print_examples(print_DAG):
        print(dedent("""
        List of available examples:
        --------------------------\
        """))
        for i, (name, example) in enumerate(examples._get_examples().items()):
            print(f"{i+1}. {name}")
            if print_DAG:
                print(DAG(**example))
        print(f"\nUsage: examples(which='<example name>')"+
              f"\nExample: G = examples(which='{name}')")
        if not print_DAG:
            print("Note: To print the associated DAG of each example, use examples(print_DAG=True)")
        return None

    def _example_not_identifiable(*args, **kws):
        dag = """
        D  -> Y
        D <-> Y
        Z -> {D, Y}
        """
        pos = {"D": (0, 0), "Y": (1, 0), "Z": (0.5, 1)}
        roles = {"Exposure": "D", "Outcome": "Y"}
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

    def _example_one_confounder(*args, **kws):
        dag = """
        D  -> Y
        Z1 -> {D, Y}
        """
        pos = {"D": (0, 0), "Y": (1, 0), "Z1": (0.5, 1)}
        roles = {"Exposure": "D", "Outcome": "Y"}
        labels = {'Z1':"$Z_1$"}
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

    def _example_two_confounder(*args, **kws):
        dag = """
        D  -> Y
        Z1 -> {D, Y}
        Z2 -> {D, Y}
        """
        pos = {"D": (0, 0), "Y": (1, 0), "Z1": (0.5, 1), "Z2": (0.5, -1)}
        roles = {"Exposure": "D", "Outcome": "Y"}
        labels = {'Z2':"$Z_2$", 'Z1':"$Z_1$"}
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

    def _example_front_door(*args, **kws):
        dag  = """
        U -> {D, Y}
        D  -> Z -> Y
        Z2 -> {D, Y}
        """
        pos = {'D': (0 , 0),
               'Z': (.5, 0),
               'Y': (1 , 0),
               'U': (.5, 1),
               'Z2': (.5, -1),
               }
        roles = {'Exposure': "D",
                 'Outcome' : "Y",
                 "Latent"  : "U"
                 }
        labels = {'Z2':"$Z_2$"}
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

    def _example_iv_1_instrument(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X <-> Y
        Z -> X -> Y
        """
        pos = {'Z': ( 0, 0),
               'X': ( .5, -1),
               "Y": ( 1,-2)}
        roles = {'Exposure': "X",
                 'Outcome' : "Y"}
        edge_labels = {('Z', 'X'): "$\\beta$",
                       ('X', 'Y'): "$\\alpha$"}
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_iv_3_instruments(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        D <-> Y
        Z1 -> D -> Y
        D <- X1 -> Y
        Z1<- Z2 -> Y
        Z1<- Z3 -> D
        Z1-> Z4 <- X2 -> Y
        """
        pos = {'D':  ( 0, 0),
               "Y":  ( 1, 0),
               'Z1': (-1, 0),
               "Z2": (0,-.5),
               "Z3": (-.5, 1),
               "Z4": (-.5,-1),
               "X2": ( .5,-1),
               "X1": (.5, 1),
               }
        roles = {'Exposure': "D",
                 'Outcome' : "Y"}
        edge_labels = None
        labels = {'X1': '$X_1$',
                  'X2': '$X_2$',
                  'Z1': '$Z_1$',
                  'Z2': '$Z_2$',
                  'Z3': '$Z_3$',
                  'Z4': '$Z_4$',
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_soo_iv_do_one_counfounder(*args, **kws):
        dag  = """
        Z  -> D -> Y
        D <- X1 -> Y
        Z <- X2 -> Y
        """
        pos = {'D':  ( 0, 0),
               "Y":  ( 1, 0),
               'Z': (-1, 0),
               "X2": ( .5,-1),
               "X1": (.5, 1),
               }
        roles = {'Exposure': "D",
                 'Outcome' : "Y"}
        edge_labels = None
        labels = {'X1': '$X_1$',
                  'X2': '$X_2$',
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_mediation_2_sequential_1_confounder(*args, **kws):
        dag  = """
        D -> M1 -> M2 -> Y
        D -> Y
        D  <- Z -> Y
        """
        pos = {'D' : (0 , 0),
               'M1': (.5, 1),
               'M2': ( 1, 1),
               'Y' : (1.5 , 0),
               'Z': (.75, -1),
               }
        roles = {'Exposure': "D",
                 'Outcome' : "Y",
                 }
        labels = {"M1":'$M_1$',
                  "M2":'$M_2$',
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

    def _example_pearl_fig_1_1_a(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        W -> Z -> Y
        Z <-> X -> Y
        """
        pos = {'Z': ( 0, 0),
               'X': ( 1, 0),
               "Y": ( .5,-1),
               'W': ( 0, 1),
               }
        roles = {'Exposure': "X",
                 'Outcome' : "Y"
                 }
        edge_labels = None
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_1_1_b(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        Z -> {W, Z, Y}
        Y -> X
        """
        pos = {'Z': ( 0, 0),
               'X': ( 1, 0),
               "Y": ( .5,-1),
               'W': ( 0, 1),
               }
        roles = {'Exposure': "Z",
                 'Outcome' : "Y"
                 }
        edge_labels = None
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_1_2(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X1 -> {X2, X3} -> X4 -> X5
        """
        pos = {"X1": ( 0, 0),
               "X2": ( 1, -1),
               "X3": ( -1,-1),
               "X4": ( 0, -2),
               "X5": ( 0, -3),
               }
        roles = {'Exposure': "X1",
                 'Outcome' : "X5"
                 }
        edge_labels = None
        labels = {"X1" : 'X1 (Season)',
                  "X2" : "X2 (Rain)",
                  "X3" : "X3 (Sprinkler)",
                  "X4" : 'X4 (Wet)',
                  "X5" : 'X5 (Slippery)'
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_1_3_a(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X -> Z1 <- Z2 <- Z3 <- Y
        Z1 <-> Z3
        """
        pos = {"X":  (0 ,0),
               "Z1": (1 ,0),
               "Z2": (2 ,0),
               "Z3": (3 ,0),
               "Y" : (4 ,0),
               }
        roles = {'Exposure': "X",
                 'Outcome' : "Y"
                 }
        edge_labels = None
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_1_3_b(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X -> Z2 -> Z1 -> X
        Y -> Z2
        """
        pos = {"X":  (0 ,0),
               "Z1": (1 ,1),
               "Z2": (2 ,0),
               "Y" : (3 ,0),
               }
        roles = {'Exposure': "X",
                 'Outcome' : "Y"
                 }
        edge_labels = None
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_3_1(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X -> {Z2, Y}
        Z2 -> {Z3, Y}
        Z3 -> Y
        Z1 -> Z2
        B -> Z3
        Z0 -> {X, Z1, B}
        """
        pos = {"X":  (0 ,0),
               "Z0": (1 ,1),
               "Z1": (1 ,.5),
               "Z2": (1 ,0),
               "Z3": (2 ,0),
               "B":  (1.5 ,.5),
               "Y" : (1 ,-.5),
               }
        roles = {'Exposure': "X",
                 'Outcome' : "Y",
                 'Latent'  : ['Z0', 'B']
                 }
        edge_labels = None
        labels = {"Z0": "$Z_0$",
                  "Z1": "$Z_1$",
                  "Z2": "$Z_2$",
                  "Z3": "$Z_3$",
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_3_4(*args, **kws):
        """
        Source:
        - Pearl, J. (2009). Causality: Models, Reasoning and Inference. : Cambridge University Press.
        """
        dag  = """
        X1 -> {X3, X4}
        X2 -> {X4, X5}
        X3 -> Xi
        X4 -> {Xi, Xj}
        X5 -> Xj
        X6 -> Xj
        Xi -> X6
        """
        pos = {"Xi":  (0, 0),
               'Xj':  (2, 0),
               "X1":  (0, 2),
               "X2":  (2, 2),
               "X3":  (0, 1),
               "X4":  (1, 1),
               "X5":  (2, 1),
               "X6":  (1, 0),
               }
        roles = {'Exposure': "Xi",
                 'Outcome' : "Xj",
                 }
        edge_labels = None
        labels = {"Xi": "$X_i$",
                  "Xj": "$X_j$",
                  "X1": "$X_1$",
                  "X2": "$X_2$",
                  "X3": "$X_3$",
                  "X4": "$X_4$",
                  "X5": "$X_5$",
                  "X6": "$X_6$",
                  }
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels,
                    edge_label=edge_labels)

    def _example_pearl_fig_3_5(*args, **kws):
        dag  = """
        U -> {X, Y}
        X  -> Z -> Y
        """
        pos = {'X': (0 , 0),
               'Z': (.5, 0),
               'Y': (1 , 0),
               'U': (.5, 1),
               }
        roles = {'Exposure': "X",
                 'Outcome' : "Y",
                 "Latent"  : "U"
                 }
        labels = None
        return dict(graph=dag, nodes_role=roles, nodes_position=pos, nodes_label=labels)

SCM

causalinf.scm.estimate

Estimate a structural causal model.

formula : str or None (optional) A structural equation model Ex: y ~ d + x1 + z1 x2 ~ z1 + z2 where y, d, Z’s, and X’s are variables in the data

model: str Used only if formula is not provided ‘auto’ : use LSEM ‘LSEM’: use linear structural equation models. ‘GLSEM’: use generalied linear structural equation models ‘NPSEM’: uses nonparametric structural equation estimation In this case, it uses GAM.

se_cluster : str Name of the variable to cluster the std. errors.

se : str or None See the documentation of the specific model used. Example: causalinf.models.lsem (for LSEM)

Specific models

For documentation of model-specific arguments, see models. Example: - causalinf.models.lsem (for LSEM)

Source code in causalinf/scm.py
class estimate:
    """
    Estimate a structural causal model.


    formula : str or None (optional)
        A structural equation model
           Ex: y ~ d + x1 + z1
               x2 ~ z1 + z2
        where y, d, Z's, and X's are variables in the data

    model: str
        Used only if formula is not provided
        'auto' : use LSEM
        'LSEM': use linear structural equation models. 
        'GLSEM': use generalied linear structural equation models
        'NPSEM': uses nonparametric structural equation estimation
                 In this case, it uses GAM.

    se_cluster : str
        Name of the variable to cluster the std. errors. 

    se : str or None
       See the documentation of the specific model used. Example:
       causalinf.models.lsem (for LSEM)

    Specific models
    ---------------
    For documentation of model-specific arguments, see models. Example:
    - causalinf.models.lsem (for LSEM)
    """
    def __init__(self,
                 G,
                 formula=None,
                 data=None,
                 model='auto',
                 family = 'auto',
                 se_cluster=None,
                 se=None,
                 # 
                 model_kws={},
                 sem=None,
                 # 
                 weights=1,
                 *args,
                 **kws
                 ):
        assert data is not None, 'Data must be provided.'
        data = ut.data2tibble(data)
        self.model = "LSEM" if model=='auto' else model
        self.formula = formula or self._graph2sem(G)
        self.family = family
        # 
        self.G = G
        self.outcome = G.outcome[0]
        self.exposure =G.exposure
        #
        self.se_cluster = se_cluster
        self.se = se

        if self.model in 'LSEM':
            self._lsem(G, data=data, weights=weights, *args, **kws)

    @ut.copy_docstring(lsem)
    def _lsem(self, G, data, weights, *args, **kws):
        est =  lsem(formula=self.formula,
                    data=data,
                    weights=weights,
                    se=self.se,
                    se_cluster=self.se_cluster,
                    *args, **kws)
        self.fit = est.fit
        self.est = est.est

    @property
    def causalinf(self):
        print(self)

    @ut.copy_docstring(ut.summary)
    def summary(self, *args, **kws):
        formula = ("\n" + self.formula).replace('\n', '\n        ')
        kws |= {'formula': kws.get("formula", formula)}

        latex_replace = {"~~": "\\\\leftrightarrow ", "~" : "\\\\leftarrow "}
        kws |= {'latex_replace': kws.get("latex_replace", latex_replace)}

        res = ut.summary(
            model = self,
            id_strategy='SCM',
            *args, **kws
        )
        return res.res

    def get_fit_statsitics(self, stats):
        res = self.est['fit_stats'].get(stats, None)
        if res is None:
            print(f'Statistics {stats} not available.')
        return res

    @ut.copy_docstring(gcm.DAG.plot)
    def plot(self, *args, **kws):
        self.G.plot(estimates=self, *args, **kws)

    def _graph2sem(self, G, parameter_fmt="(beta_{cause}.{effect})"):
        # regression lines
        reg = ''
        for y, pa_y in G.nodes_parents.items():
            pa_y = " + ".join([f"(beta_{v}.{y})*{v}" for v in pa_y])
            reg += f"{y} ~ (beta_0{y})*1 + {pa_y}\n"
        reg = f"# LSEM:\n{reg}"

        # correlations
        corr = "\n".join([f"{n1[0]} ~~ {n2[0]}" for n1, n2 in G.bidirected])
        corr += "\n".join([f"{n1} ~~ {n2}" for n1, n2 in G.undirected])
        if corr:
            corr = f"# Correlations:\n{corr}"

        # indirect effects
        if G.exposure and G.outcome:
            ind = self._graph2sem_indirect_and_total_effects(G, parameter_fmt)
        else:
            ind = ''

        sem = f"{reg}{corr}{ind}"
        return sem

    def _graph2sem_indirect_and_total_effects(self, G, parameter_fmt):
        """
        edges    : list of (u, v) directed edges
        exposure, outcome : compute indirect paths from exposure to outcome
        parameter_fmt: how to name each coefficient (e.g., 'beta_{u}.{v}' or '(beta_{u}.{v})')
        returns  : list of (path_nodes, effect_str)
        """
        # build adjacency
        edges = G.directed
        exposure = G.exposure[0]
        outcome = G.outcome[0]

        adj = defaultdict(list)
        for u, v in edges:
            adj[u].append(v)

        dir_effect = []
        tot_effect = []
        ind_effects = []

        # direct effect
        if outcome in adj[exposure]:
            dir_effect = parameter_fmt.format(cause=exposure, effect=outcome)

        # enumerate simple paths (DFS) recursive function
        def dfs(node, visited, path):
            if node == outcome and len(path) >= 3:  # indirect: at least 2 edges
                # build product beta terms for edges along the path
                terms = [parameter_fmt.format(cause=path[i], effect=path[i+1]) for i in range(len(path)-1)]
                ind_effects.append( (path[:], " * ".join(terms)) )
                return
            for nxt in adj[node]:
                if nxt in visited:       # avoid cycles
                    continue
                visited.add(nxt)
                path.append(nxt)
                dfs(nxt, visited, path)
                path.pop()
                visited.remove(nxt)
        # this wil fill in the info and save it in the variable ind_effects
        dfs(exposure, {exposure}, [exposure])

        res = ''
        ind_effect_str = ''
        if ind_effects:
            for i, ind in enumerate(ind_effects):
                parameter = f"Indirect_effect_{i+1}" # f"beta_{'.'.join(ind[0])}"
                ind_effect_str += f"{parameter} := {ind[1]}\n"
                tot_effect += [parameter]
            res += f"# Indirect effects:\n{ind_effect_str}"
        if dir_effect:
            parameter = "Direct_effect"
            res += f"# Direct effect:\n{parameter} := {dir_effect}"
            tot_effect += [parameter]
        if tot_effect:
            res += f"\n# Total effect:\nTotal_effect := {' + '.join(tot_effect)}"

        return res

    def __repr__(self):
        self.summary(style="full")
        return ''

Core Methods

causalinf.utils.summary(model=None, model_name='Model 1', compare=None, output='text', style=None, omit=None, show_sig=True, show_se=False, show_ci=True, show_fit=True, digits=4, digits_fit=2, col_width=1000, col_width_term=20, latex_kws=None, fn=None, save_style='concise', save_copies=['csv', 'xlsx'], *args, **kws)

model (causalinf..estimate) An estimate object from causalinf

compare (dict or list) A list of dict of other causalinf estimate objects. The estimates will be shown in different columns. If a dictionary is used, the keys will be used as the column names. For the column name of the object calling the summary, use ‘model_name’. If a list is provided, names will be set to “Model 1”, “Model 2”, etc.

output (str) Format of the output: - ‘text’: returns None and print summary - ‘tibble’: returns a tibble - ‘latex’: returns a latex table To save the file, use ‘fn’. The output and the saved version Are independent. For instance, it is possible to print the summary (text) and at the same time save it in latex using fn.

model_name (str) Name of the column showing the estimates when output is ‘tibble’ of latex.

style (str) If style=’concise’, the summary table returns only - The parameter name (‘term’) - The confidence interval (if show_ci=True) and the std.errors (if show_se=True) and the p-value indicator (if show_sig=True) If style=’full’, the summary table includes all estimation statsitics available. Defaults: - ‘full’ when compare=None and output=’text’ - ‘concise’ otherwise

omit (str) A regular expression to match elements in the column terms. Matched cases will be omitted.

save_style (str) Same as ‘style’, but to save the summary in a files based on ‘fn’ and ‘save_copies’

fn (str) Path with the filename to save the output in a file. Relative paths are alowed. It automatically save the type of output based on the filnename extension (tex, xlsx, xls, csv). Copies are saved based on ‘save_copies’.

save_copies (list) List of strings with the extensions to save copies of the output table in the format of the extensions provided. Available are xls, xlsx, and csv. If None, it will not save copies of the output.

show_fit (bool or list) If False, omit fit statistics; If True, shows the stats listed in causalinf.estimate.est.fit; else, shows the statistics included in the list provided

show_sig, show_se, show_ci (bool) When comparing models, show_* can be used to set which information such as standard errors (se), confidence intervals (ci), significance level indicators (sig), whenever available, appears alongside the parameter estimates. This is ignored when output=’text’ and compare=’None’

digits, digits_fit: (int) Digits to show in the estimates and fit statistics, respectively.

col_width: int Length of the column widths in the printed summary

latex_kws : Keywords from tibble.to_latex()

Source code in causalinf/utils.py
def __init__(self,
             model=None,
             model_name = 'Model 1',
             compare=None,
             output = 'text',
             style = None,
             omit = None,
             show_sig = True,
             show_se =  False,
             show_ci =  True,
             show_fit = True,
             digits = 4,
             digits_fit = 2,
             col_width = 1000,
             col_width_term = 20,
             # latex args
             latex_kws=None,
             fn = None,
             save_style = 'concise', 
             save_copies = ['csv', 'xlsx'],
             *args, **kws
             ):
    # compare = kws.get("compare", None)
    assert isinstance(latex_kws, dict | None), "'latex_kws' must be None or a dict."
    assert isinstance(save_copies, list | None), "'save_copies' must be None or a list of file extensions."

    self.model_name = model_name
    self.output = output
    self.style = style or self.get_style(compare, output)
    self.omit = omit
    self.digits = digits
    self.digits_fit = digits_fit
    self.show_sig=show_sig
    self.show_se=show_se
    self.show_ci=show_ci
    self.show_fit = show_fit # self.show_fit = kws.get("show_fit", True)
    self.latex_kws = latex_kws or {}
    self.fn = fn
    self.save_style = save_style
    self.save_copies = save_copies
    self.col_width = col_width
    self.col_width_term = col_width_term

    self.outcome = model.outcome
    self.exposure = model.exposure

    self.collect_models(model, model_name, compare)
    self.merge_models()
    self.collect_info()

    # # implicit parameters
    self.id_strategy = kws.get("id_strategy", '')
    self.formula = kws.get("formula", '')
    self.latex_replace = kws.get("latex_replace", None) ## for latex only
    self.estimator = model.est.fit['Estimator']
    self.footnote_added = False # used to avoid duplicating footnote entries

    # keep this order
    self._save(fn=self.fn, silent=False)
    self._save_copies()
    self._output(self.style)