{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2f935234",
   "metadata": {},
   "source": [
    "# SymbolicModel Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff2801a7",
   "metadata": {},
   "source": [
    "In this demo, we show you how to:\n",
    "* Wrap an MLP block in a PyTorch model with our `SymbolicModel` class\n",
    "* Perform symbolic regresson on the MLP with the `distill` method \n",
    "* Switch the MLP to an equation in the forward pass of the model with the `switch_to_equation` method \n",
    "* Switch back to the MLP with the `switch_to_block` method"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a92f4a7a",
   "metadata": {},
   "source": [
    "## Wrapping a PyTorch model\n",
    "Create a simple PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "adc61e0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple MLP.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_dim, output_dim, hidden_dim):\n",
    "        super(MLP, self).__init__()\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.mlp(x)\n",
    "\n",
    "class SimpleModel(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple model class.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_dim, output_dim, hidden_dim = 64):\n",
    "        super(SimpleModel, self).__init__()\n",
    "\n",
    "        self.mlp = MLP(input_dim, output_dim, hidden_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.mlp(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6941215c",
   "metadata": {},
   "source": [
    "Train the model on some data.\n",
    "\n",
    "$$\n",
    "y = x_0^2 +3 \\sin(x_4)-4\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f363579",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make the dataset \n",
    "x = np.array([np.random.uniform(0, 1, 10_000) for _ in range(5)]).T  \n",
    "y = x[:, 0]**2 + 3*np.sin(x[:, 4]) - 4\n",
    "noise = np.array([np.random.normal(0, 0.05*np.std(y)) for _ in range(len(y))])\n",
    "y = y + noise "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9fcf89e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up training\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def train_model(model, dataloader, opt, criterion, epochs = 100):\n",
    "    \"\"\"\n",
    "    Train a model for the specified number of epochs.\n",
    "    \n",
    "    Args:\n",
    "        model: PyTorch model to train\n",
    "        dataloader: DataLoader for training data\n",
    "        opt: Optimizer\n",
    "        criterion: Loss function\n",
    "        epochs: Number of training epochs\n",
    "        \n",
    "    Returns:\n",
    "        tuple: (trained_model, loss_tracker)\n",
    "    \"\"\"\n",
    "    loss_tracker = []\n",
    "    for epoch in range(epochs):\n",
    "        epoch_loss = 0.0\n",
    "        \n",
    "        for batch_x, batch_y in dataloader:\n",
    "            # Forward pass\n",
    "            pred = model(batch_x)\n",
    "            loss = criterion(pred, batch_y)\n",
    "            \n",
    "            # Backward pass\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            \n",
    "            epoch_loss += loss.item()\n",
    "        \n",
    "        loss_tracker.append(epoch_loss)\n",
    "        if (epoch + 1) % 5 == 0:\n",
    "            avg_loss = epoch_loss / len(dataloader)\n",
    "            print(f'Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.6f}')\n",
    "    return model, loss_tracker\n",
    "\n",
    "# Instantiate the model\n",
    "model = SimpleModel(input_dim=x.shape[1], output_dim=1)\n",
    "\n",
    "# Set up training\n",
    "criterion = nn.MSELoss()\n",
    "opt = optim.Adam(model.parameters(), lr=0.001)\n",
    "X_train, _, y_train, _ = train_test_split(x, y.reshape(-1,1), test_size=0.2, random_state=290402)\n",
    "\n",
    "# Set up dataset\n",
    "dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))\n",
    "dataloader = DataLoader(dataset, batch_size=32, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "add838bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [5/20], Avg Loss: 0.076408\n",
      "Epoch [10/20], Avg Loss: 0.052064\n",
      "Epoch [15/20], Avg Loss: 0.042181\n",
      "Epoch [20/20], Avg Loss: 0.033330\n"
     ]
    }
   ],
   "source": [
    "# Train the model and save the weights\n",
    "\n",
    "model, losses = train_model(model, dataloader, opt, criterion, 20)\n",
    "torch.save(model.state_dict(), 'model_weights.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "feb22ef1",
   "metadata": {},
   "source": [
    "Wrap the mlp layer in the trained model with SymbolicMLP."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "77bc4f4a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython\n"
     ]
    }
   ],
   "source": [
    "from symtorch import SymbolicModel\n",
    "model.mlp = SymbolicModel(model.mlp, block_name = 'Sequential')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1db5df8",
   "metadata": {},
   "source": [
    "## Interpret the MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca15a8d3",
   "metadata": {},
   "source": [
    "In this example, we pass extra parameters into the `.distill` method (complexity of operators/constants and parsimony, which is a penalisation of complexity).\\\n",
    "To see all the possible parameters, please see the `PySRRegressor` class from [PySR](https://ai.damtp.cam.ac.uk/pysr/api/)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6b47775",
   "metadata": {},
   "source": [
    "In this example, we turn verbosity off because we are in a Jupyter notebook. For best performance, run in IPython, as you can terminate the SR any time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d637d376",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "💡Best equation for output 0 found to be (((exp(x0) * 0.24175334) + -1.879353) + x4) * 2.382997.\n",
      "❤️ SR on Sequential complete.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick         score                                           equation  \\\n",
       " \t0         0.000000e+00                                                 x4   \n",
       " \t1         2.675247e+00                                         -2.3036592   \n",
       " \t2         4.113613e-01                                    x4 + -2.8012233   \n",
       " \t3         4.553621e-01                             (x4 + x4) + -3.2988408   \n",
       " \t4         1.302947e-01                      (x4 * 2.3804636) + -3.4881692   \n",
       " \t5         1.391758e+00                      x4 + ((x4 + -3.7980285) + x0)   \n",
       " \t6         8.372327e-01                (x0 + (x4 * 2.3829083)) + -3.988571   \n",
       " \t7         5.373152e-08      inv(inv((x0 + (x4 * 2.3829083)) + -3.988571))   \n",
       " \t8         1.165584e-02  ((x4 + (x0 * 0.40451804)) + -1.6661206) * 2.38...   \n",
       " \t9         1.494943e-06  ((x0 + -3.9705966) + (x4 * 2.3829074)) + (x0 *...   \n",
       " \t10  >>>>  3.965307e-01  (((exp(x0) * 0.24175334) + -1.879353) + x4) * ...   \n",
       " \t11        4.405181e-02  (x0 * exp(x0 + -1.0060903)) + ((x4 + -1.617400...   \n",
       " \t12        9.909887e-04  (((x4 * 0.36639503) + ((x4 + x4) + (x3 * (x4 *...   \n",
       " \t\n",
       " \t        loss  complexity  \n",
       " \t0   8.091652           1  \n",
       " \t1   0.557432           2  \n",
       " \t2   0.244843           4  \n",
       " \t3   0.098484           6  \n",
       " \t4   0.086453           7  \n",
       " \t5   0.021495           8  \n",
       " \t6   0.009306           9  \n",
       " \t7   0.009306          11  \n",
       " \t8   0.009198          12  \n",
       " \t9   0.009198          14  \n",
       " \t10  0.006187          15  \n",
       " \t11  0.005665          17  \n",
       " \t12  0.005604          28  \n",
       " ]}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Configure the SR\n",
    "\n",
    "sr_params = {'complexity_of_operators':  {\"sin\":3, \"exp\":3},\n",
    "             'complexity_of_constants': 2, \n",
    "             'parsimony': 0.1,\n",
    "             'verbosity': 0,\n",
    "             'niterations': 50}\n",
    "\n",
    "model.mlp.distill(torch.FloatTensor(X_train),\n",
    "                       sr_params = sr_params, \n",
    "                       parent_model=model) #Pass in the parent model (really only required if the MLP is \n",
    "                                            #not at the start of the model but it is good practice)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8275ed6e",
   "metadata": {},
   "source": [
    "See the full Pareto front of equations. The best equation is chosen as a balance of accuracy and complexity.\\\n",
    "Outputs from *PySR* are saved in `SR_output/MLP_name`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "858c013e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Symbolic expressions for output dimension 0:\n",
      "    complexity      loss                                           equation  \\\n",
      "0            1  7.717672                                                 x4   \n",
      "1            2  0.482811                                         -2.2379751   \n",
      "2            4  0.193551                                    x4 + -2.7430148   \n",
      "3            6  0.069146                             (x4 + x4) + -3.2480512   \n",
      "4            7  0.063802                      (x4 * 2.2548008) + -3.3767343   \n",
      "5            8  0.014271                      x0 + ((x4 + x4) + -3.7485306)   \n",
      "6            9  0.009059               (x4 * 2.2514427) + (x0 + -3.8755205)   \n",
      "7           11  0.006982         (x4 * 2.2504792) + ((x0 * x0) + -3.709683)   \n",
      "8           12  0.006277           (x0 + (sin(x4) * 2.6363063)) + -3.962223   \n",
      "9           14  0.003978  ((x0 * x0) * 0.8169185) + ((x4 * 2.2512612) + ...   \n",
      "10          17  0.001264  ((x0 * sin(x0)) + -3.764245) + (sin(x4) * 2.63...   \n",
      "11          19  0.001157  inv(inv((sin(x4) * 2.6362014) + (((x0 * x0) * ...   \n",
      "12          22  0.000922  inv(sin(inv(((sin(x4) + ((x0 * x0) * 0.3172727...   \n",
      "13          25  0.000879  inv(sin(inv(((sin(x4) * 2.77826) + -3.7736876)...   \n",
      "14          28  0.000873  inv(sin(inv((sin(x4) * 2.7769802) + ((exp(x0 *...   \n",
      "\n",
      "       score                                       sympy_format  \\\n",
      "0   0.000000                                                 x4   \n",
      "1   2.771643                                  -2.23797510000000   \n",
      "2   0.457042                                     x4 - 2.7430148   \n",
      "3   0.514658                                x4 + x4 - 3.2480512   \n",
      "4   0.080440                           x4*2.2548008 - 3.3767343   \n",
      "5   1.497572                           x0 + x4 + x4 - 3.7485306   \n",
      "6   0.454409                      x0 + x4*2.2514427 - 3.8755205   \n",
      "7   0.130204                    x0*x0 + x4*2.2504792 - 3.709683   \n",
      "8   0.106513                  x0 + sin(x4)*2.6363063 - 3.962223   \n",
      "9   0.228033         x0*x0*0.8169185 + x4*2.2512612 - 3.6487224   \n",
      "10  0.382286          x0*sin(x0) + sin(x4)*2.6359825 - 3.764245   \n",
      "11  0.044174    x0*x0*0.81745505 + sin(x4)*2.6362014 - 3.735648   \n",
      "12  0.075511  1/sin(1/((x0*x0*0.31727275 + sin(x4))*2.777514...   \n",
      "13  0.015831  1/sin(1/(exp(x0)*0.3389404*x0 + sin(x4)*2.7782...   \n",
      "14  0.002625  1/sin(1/(exp(x0*1.0814174)*x0*0.3136479 + sin(...   \n",
      "\n",
      "                                        lambda_format  \n",
      "0                                 PySRFunction(X=>x4)  \n",
      "1                  PySRFunction(X=>-2.23797510000000)  \n",
      "2                     PySRFunction(X=>x4 - 2.7430148)  \n",
      "3                PySRFunction(X=>x4 + x4 - 3.2480512)  \n",
      "4           PySRFunction(X=>x4*2.2548008 - 3.3767343)  \n",
      "5           PySRFunction(X=>x0 + x4 + x4 - 3.7485306)  \n",
      "6      PySRFunction(X=>x0 + x4*2.2514427 - 3.8755205)  \n",
      "7    PySRFunction(X=>x0*x0 + x4*2.2504792 - 3.709683)  \n",
      "8   PySRFunction(X=>x0 + sin(x4)*2.6363063 - 3.962...  \n",
      "9   PySRFunction(X=>x0*x0*0.8169185 + x4*2.2512612...  \n",
      "10  PySRFunction(X=>x0*sin(x0) + sin(x4)*2.6359825...  \n",
      "11  PySRFunction(X=>x0*x0*0.81745505 + sin(x4)*2.6...  \n",
      "12  PySRFunction(X=>1/sin(1/((x0*x0*0.31727275 + s...  \n",
      "13  PySRFunction(X=>1/sin(1/(exp(x0)*0.3389404*x0 ...  \n",
      "14  PySRFunction(X=>1/sin(1/(exp(x0*1.0814174)*x0*...  \n",
      "🏆 Best: ((x0 * sin(x0)) + -3.764245) + (sin(x4) * 2.6359825) (loss: 1.263592e-03)\n"
     ]
    }
   ],
   "source": [
    "model.mlp.show_symbolic_expression()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db788500",
   "metadata": {},
   "source": [
    "## Switch to Using the Equation Instead in the Forwards Pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2dbf360",
   "metadata": {},
   "source": [
    "You can choose the equation you want to switch to by choosing the desired complexity of equation. \\\n",
    "If left blank, then we choose the best equation chosen by *PySR*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "abef72c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Successfully switched Sequential to symbolic equations for all 1 dimensions:\n",
      "   Dimension 0: ((x0 * x0) * 0.8169185) + ((x4 * 2.2512612) + -3.6487224)\n",
      "   Variables: ['x0', 'x4']\n",
      "🎯 All 1 output dimensions now using symbolic equations.\n"
     ]
    }
   ],
   "source": [
    "model.mlp.switch_to_symbolic(complexity=[14]) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8e7cfd0",
   "metadata": {},
   "source": [
    "Now when running the forwards pass through the model, it uses the symbolic equation instead of the MLP. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e526154f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.6968],\n",
       "        [-2.0715],\n",
       "        [-1.1126],\n",
       "        ...,\n",
       "        [-2.6302],\n",
       "        [-2.6785],\n",
       "        [-2.9952]])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "interpretable_outputs = model(torch.tensor(X_train, dtype=torch.float32))\n",
    "interpretable_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "748e191d",
   "metadata": {},
   "source": [
    "You can also make a callable function using the `get_symbolic_function` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "41df60af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-1.6968284, -2.0715022, -1.1125848, ..., -2.6301584, -2.6784554,\n",
       "       -2.9951892], shape=(8000,), dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "symbolic_function = model.mlp.get_symbolic_function(complexity = 14)\n",
    "symbolic_function(X_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5835dd8e",
   "metadata": {},
   "source": [
    "## Switch to Using the MLP in the Forwards Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "34291a48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Switched Sequential back to block\n"
     ]
    }
   ],
   "source": [
    "mlp_outputs = model.mlp.switch_to_block()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "79e43dc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-1.6186],\n",
       "        [-2.0624],\n",
       "        [-1.1711],\n",
       "        ...,\n",
       "        [-2.6357],\n",
       "        [-2.7018],\n",
       "        [-3.0245]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    model_outputs = model.mlp(torch.tensor(X_train, dtype=torch.float32))\n",
    "model_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a77416b",
   "metadata": {},
   "source": [
    "## Wrapping a Python Function\n",
    "\n",
    "`SymbolicModel` also works with generic Python functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "466cc829",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Same function as before\n",
    "def f(X):\n",
    "    return X[:, 0]**2 + 3*np.sin(X[:, 4]) - 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "50208460",
   "metadata": {},
   "outputs": [],
   "source": [
    "symbolic_function = SymbolicModel(f, block_name = \"callable_function\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "69df2ff5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "💡Best equation for output 0 found to be (sin(x4) + (x0 * x0)) + ((sin(x4) + sin(x4)) + -4.0).\n",
      "❤️ SR on callable_function complete.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick         score                                           equation  \\\n",
       " \t0         0.000000e+00                                                 x4   \n",
       " \t1         2.533145e+00                                          -2.272284   \n",
       " \t2         3.840348e-01                                     x4 + -2.777311   \n",
       " \t3         4.541706e-01                             (x4 + x4) + -3.2823517   \n",
       " \t4         2.530731e-01                       (x4 * 2.5681043) + -3.569265   \n",
       " \t5         9.782338e-01                      (x4 + x0) + (x4 + -3.7828333)   \n",
       " \t6         1.416253e+00                 x0 + ((x4 * 2.564849) + -4.068105)   \n",
       " \t7         5.468644e-01         (x0 * x0) + ((x4 * 2.563873) + -3.9022593)   \n",
       " \t8         1.771850e-08  ((x4 * 0.56387687) + ((x4 + -3.902262) + (x0 *...   \n",
       " \t9         3.764534e+00  (((x4 * -0.7083076) + 3.2762177) * x4) + ((x0 ...   \n",
       " \t10        2.500752e+00  ((inv(x4 + -1.958775) + 3.5588908) * x4) + ((x...   \n",
       " \t11        6.710256e-07  inv(inv(((inv(-1.958775 + x4) + 3.5588908) * x...   \n",
       " \t12  >>>>  9.804187e+00  (sin(x4) + (x0 * x0)) + ((sin(x4) + sin(x4)) +...   \n",
       " \t\n",
       " \t            loss  complexity  \n",
       " \t0   8.008494e+00           1  \n",
       " \t1   6.359454e-01           2  \n",
       " \t2   2.950200e-01           4  \n",
       " \t3   1.189498e-01           6  \n",
       " \t4   9.235398e-02           7  \n",
       " \t5   3.472275e-02           8  \n",
       " \t6   8.424487e-03           9  \n",
       " \t7   2.821909e-03          11  \n",
       " \t8   2.821909e-03          15  \n",
       " \t9   6.540737e-05          16  \n",
       " \t10  5.364925e-06          17  \n",
       " \t11  5.364918e-06          19  \n",
       " \t12  1.635891e-14          21  \n",
       " ]}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "symbolic_function.distill(X_train, sr_params=sr_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e723e546",
   "metadata": {},
   "source": [
    "**Note:** if distilling a generic function (not a PyTorch MLP) you cannot pass in the `parent_model` parameter. If this function is part of a hybrid PyTorch model and is not a `nn.Module` type, then you must pass in the inputs to the function not to the whole model when running `.distill`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "41ea9980",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Successfully switched callable_function to symbolic equations for all 1 dimensions:\n",
      "   Dimension 0: (sin(x4) + (x0 * x0)) + ((sin(x4) + sin(x4)) + -4.0)\n",
      "   Variables: ['x0', 'x4']\n",
      "🎯 All 1 output dimensions now using symbolic equations.\n"
     ]
    }
   ],
   "source": [
    "symbolic_function.switch_to_symbolic()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0e258b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean up \n",
    "import os\n",
    "import shutil\n",
    "if os.path.exists('SR_output'):\n",
    "    shutil.rmtree('SR_output')\n",
    "    os.remove('model_weights.pth')\n",
    "\n",
    "if os.path.exists('symtorch_data'):\n",
    "    shutil.rmtree('symtorch_data')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "symtorch_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
