InFlowVarDist(
  (module_genmodel): InFlowGenerativeModel(
    (module_int_mu_u): LinearEncoding()
    (module_int_cov_u): SimpleMLPandExp(
      (module_mlp): SimpleMLP(
        (module): Sequential(
          (0): Linear(in_features=12, out_features=100, bias=True)
        )
      )
    )
    (module_spl_mu_u): LinearEncoding()
    (module_spl_cov_u): SimpleMLPandExp(
      (module_mlp): SimpleMLP(
        (module): Sequential(
          (0): Linear(in_features=12, out_features=100, bias=True)
        )
      )
    )
    (module_theta_aggr): KhopAvgPoolWithoutselfloop(
      (list_modules): ModuleList(
        (0): AvgpoolLayerWithoutselfloop()
      )
    )
    (module_Vflow_unwrapped): MLP(
      (net_z): Sequential(
        (0): Linear(in_features=101, out_features=64, bias=True)
        (1): SELU()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): SELU()
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): SELU()
        (6): Linear(in_features=64, out_features=100, bias=True)
      )
      (net_s): Sequential(
        (0): Linear(in_features=201, out_features=64, bias=True)
        (1): SELU()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): SELU()
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): SELU()
        (6): Linear(in_features=64, out_features=100, bias=True)
      )
    )
    (module_flow): WrapperTorchDiffEq(
      (model): MLP(
        (net_z): Sequential(
          (0): Linear(in_features=101, out_features=64, bias=True)
          (1): SELU()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): SELU()
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): SELU()
          (6): Linear(in_features=64, out_features=100, bias=True)
        )
        (net_s): Sequential(
          (0): Linear(in_features=201, out_features=64, bias=True)
          (1): SELU()
          (2): Linear(in_features=64, out_features=64, bias=True)
          (3): SELU()
          (4): Linear(in_features=64, out_features=64, bias=True)
          (5): SELU()
          (6): Linear(in_features=64, out_features=100, bias=True)
        )
      )
    )
    (module_w_dec_int): SimpleMLP(
      (module): Sequential(
        (0): ReLU()
        (1): Linear(in_features=101, out_features=313, bias=True)
        (2): Softmax(dim=1)
      )
    )
    (module_w_dec_spl): SimpleMLP(
      (module): Sequential(
        (0): ReLU()
        (1): Linear(in_features=101, out_features=313, bias=True)
        (2): Softmax(dim=1)
      )
    )
  )
  (module_varphi_enc_int): EncX2Xbar(
    (module_encX): Sequential(
      (0): Linear(in_features=314, out_features=31, bias=True)
      (1): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=31, out_features=100, bias=True)
      (4): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
    )
  )
  (module_varphi_enc_spl): EncX2Xbar(
    (module_encX): Sequential(
      (0): Linear(in_features=314, out_features=31, bias=True)
      (1): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=31, out_features=100, bias=True)
      (4): LayerNorm((100,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
    )
  )
  (list_ajdmatpredloss): ListAdjMatPredLoss(
    (list_adjpredictors): ModuleList()
  )
  (module_predictor_ranklossxbarint_X): Sequential(
    (0): Linear(in_features=200, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=2, bias=True)
  )
  (module_predictor_ranklossxbarint_Y): Sequential(
    (0): Linear(in_features=200, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=2, bias=True)
  )
  (module_predictor_xbarint2notNCC): PredictorPerCT(
    (list_modules): ModuleList(
      (0-11): 12 x Sequential(
        (0): Linear(in_features=100, out_features=50, bias=True)
        (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=50, out_features=50, bias=True)
        (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (8): ReLU()
        (9): Linear(in_features=50, out_features=12, bias=False)
        (10): Tanh()
      )
    )
  )
  (module_classifier_P1loss): Sequential(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=12, bias=True)
  )
  (crit_P1loss): CrossEntropyLoss()
  (module_predictor_P3loss): Sequential(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=12, bias=True)
  )
  (crit_P3loss): MSELoss()
  (module_classifier_xbarintCT): Sequential(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=12, bias=True)
  )
  (crit_loss_xbarint2CT): CrossEntropyLoss()
  (module_predictor_xbarsplNCC): Sequential(
    (0): Linear(in_features=100, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=12, bias=True)
  )
  (crit_loss_xbarspl2NCC): MSELoss()
  (crit_xbarint_rankloss): MarginRankingLoss()
  (crit_loss_xbarint2notNCC): WassDist()
  (module_predictor_z2notNCC): PredictorPerCT(
    (list_modules): ModuleList(
      (0-11): 12 x Sequential(
        (0): Linear(in_features=100, out_features=50, bias=True)
        (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=50, out_features=50, bias=True)
        (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (8): ReLU()
        (9): Linear(in_features=50, out_features=12, bias=False)
        (10): Tanh()
      )
    )
  )
  (crit_loss_z2notNCC): WassDist()
  (module_predictor_ranklossZ_X): Sequential(
    (0): Linear(in_features=200, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=2, bias=True)
  )
  (module_predictor_ranklossZ_Y): Sequential(
    (0): Linear(in_features=200, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=2, bias=True)
  )
  (crit_Z_rankloss): MarginRankingLoss()
  (module_impanddisentgl): GNNDisentangler(
    (module_gnn): SAGE(
      (list_modules): ModuleList(
        (0): SAGEConv(314, 100, aggr=mean)
      )
    )
    (module_muxspl): Sequential(
      (0): ReLU()
      (1): Dropout(p=0.1, inplace=False)
      (2): Linear(in_features=112, out_features=31, bias=True)
      (3): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
      (4): ReLU()
      (5): Linear(in_features=31, out_features=313, bias=True)
    )
    (module_covxspl): Sequential(
      (0): ReLU()
      (1): Dropout(p=0.1, inplace=False)
      (2): Linear(in_features=112, out_features=31, bias=True)
      (3): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
      (4): ReLU()
      (5): Linear(in_features=31, out_features=313, bias=True)
    )
  )
  (module_cond4flowvarphi0): Cond4FlowVarphi0SimpleMLPs(
    (module_enc_z): Sequential(
      (0): Dropout(p=0.1, inplace=False)
      (1): SimpleMLP(
        (module): Sequential(
          (0): Linear(in_features=112, out_features=50, bias=True)
          (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
          (3): Linear(in_features=50, out_features=100, bias=True)
        )
      )
    )
    (module_enc_sin): Sequential(
      (0): Dropout(p=0.1, inplace=False)
      (1): SimpleMLP(
        (module): Sequential(
          (0): Linear(in_features=112, out_features=50, bias=True)
          (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
          (3): Linear(in_features=50, out_features=100, bias=True)
        )
      )
    )
    (module_enc_sout): Sequential(
      (0): Dropout(p=0.1, inplace=False)
      (1): SimpleMLP(
        (module): Sequential(
          (0): Linear(in_features=212, out_features=50, bias=True)
          (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
          (2): ReLU()
          (3): Linear(in_features=50, out_features=100, bias=True)
        )
      )
    )
  )
  (module_predictor_xbarint2notbatchID): PredictorBatchID(
    (list_modeuls): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=100, out_features=50, bias=True)
        (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=50, out_features=50, bias=True)
        (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (8): ReLU()
        (9): Linear(in_features=50, out_features=1, bias=False)
        (10): Tanh()
      )
    )
  )
  (crit_loss_xbarint2notbatchID): WassDistBatchID()
  (module_predictor_xbarspl2notbatchID): PredictorBatchID(
    (list_modeuls): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=100, out_features=50, bias=True)
        (1): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (2): ReLU()
        (3): Linear(in_features=50, out_features=50, bias=True)
        (4): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (5): ReLU()
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): LayerNorm((50,), eps=1e-05, elementwise_affine=True)
        (8): ReLU()
        (9): Linear(in_features=50, out_features=1, bias=False)
        (10): Tanh()
      )
    )
  )
  (crit_loss_xbarspl2notbatchID): WassDistBatchID()
)