{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2f935234",
   "metadata": {},
   "source": [
    "# Pruning Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c079a023",
   "metadata": {},
   "source": [
    "In this demo we will show you how to: \n",
    "* Use the pruning functionality in `SymbolicModel` to restrict output dimensions of a PyTorch MLP\n",
    "* Set up a pruning schedule and train\n",
    "* Perform symbolic regression on a pruned MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "127b5f78",
   "metadata": {},
   "source": [
    "## Pruning Background"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17f29635",
   "metadata": {},
   "source": [
    "For interpretability purposes, it is good to reduce the dimensionality of deep learning models. High-dimensional representations often entangle multiple features, making it difficult to extract clear, human-understandable relationships. By encouraging a sparse representation, we encourage the network to compress information into a smaller set of meaningful components. This may also make symbolic regression possible on these models."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e6d849f2",
   "metadata": {},
   "source": [
    "The SymTorch pruning class allows you to dynamically reduce the output dimensionality of MLPs by zero-masking the unimportant dimensions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f18ba6a",
   "metadata": {},
   "source": [
    "**Important dimensions**: The dimensions that the model uses the most in predicting the output. These would vary most with differences in the input. Hence we choose the important dimensions as the ones with the highest standard deviation across the datapoints. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "725b0f14",
   "metadata": {},
   "source": [
    "We pass some input data through the model (usually a subset of the validation set) and analyse the outputs of the MLP. We choose the output dimensions that have the highest standard deviation across the datapoints, as shown below. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "36902e01",
   "metadata": {},
   "source": [
    "<img src=\"../../_static/choosing_important_dims.png\" width=\"450\" height=\"300\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a92f4a7a",
   "metadata": {},
   "source": [
    "## Wrapping a PyTorch model\n",
    "Create a simple PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "adc61e0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    \"\"\"\n",
    "    Simple MLP.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_dim, output_dim, hidden_dim):\n",
    "        super(MLP, self).__init__()\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(input_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.2),\n",
    "            nn.Linear(hidden_dim, output_dim)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.mlp(x)\n",
    "\n",
    "class SimpleModel(nn.Module):\n",
    "    \"\"\"\n",
    "    Model with MLP f_net and linear g_net.\n",
    "    \"\"\"\n",
    "    def __init__(self, input_dim, output_dim, output_dim_f=32, hidden_dim=128):\n",
    "        super(SimpleModel, self).__init__()\n",
    "\n",
    "        self.f_net = MLP(input_dim, output_dim_f, hidden_dim)\n",
    "        # g is linear - only learns to combine the 2 pruned outputs from f\n",
    "        self.g_net = nn.Linear(output_dim_f, output_dim)  # Will use first 2 dims of f after pruning\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.f_net(x)\n",
    "        x = self.g_net(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6941215c",
   "metadata": {},
   "source": [
    "Train the model on some data. We have a composite function $y=g(f(\\mathbf{x}))$.\n",
    "\n",
    "$$\n",
    "\\begin{aligned}\n",
    "& f_0 = x_0^2\\\\\n",
    "& f_1 = \\sin{x_4}\n",
    "\\end{aligned}\n",
    "$$\n",
    "and $g$ is just a linear transformation of $f_0$ and $f_1$\n",
    "\n",
    "$$\n",
    "g(\\mathbf{f}) = 2.5f_0 -1.3f_1\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f363579",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Make the dataset \n",
    "x = np.array([np.random.uniform(0, 1, 10_000) for _ in range(5)]).T\n",
    "\n",
    "def f_func(x):\n",
    "    f0 = x[:, 0]**2 \n",
    "    f1 = np.sin(x[:, 4])  \n",
    "    return np.stack([f0, f1], axis=1)\n",
    "\n",
    "def g_func(f_output):\n",
    "    a, b = 2.5, -1.3  \n",
    "    return a * f_output[:, 0] + b * f_output[:, 1]\n",
    "\n",
    "# Generate ground truth data\n",
    "f_true = f_func(x)\n",
    "y = g_func(f_true)\n",
    "\n",
    "noise = np.array([np.random.normal(0, 0.05*np.std(y)) for _ in range(len(y))])\n",
    "y = y + noise "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24d25b3e",
   "metadata": {},
   "source": [
    "We need to set up the pruning model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8028b74a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython\n"
     ]
    }
   ],
   "source": [
    "from symtorch import SymbolicModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2023b11b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Pruning successfully set up for block f_net.\n",
      "   Initial dimensions: 32\n",
      "   Target dimensions: 2\n",
      "   Total steps: 31250\n",
      "   Pruning will complete at step 15625\n"
     ]
    }
   ],
   "source": [
    "# Create model with pruning for f, linear g_net\n",
    "model = SimpleModel(input_dim=x.shape[1], output_dim=1, output_dim_f=32)\n",
    "model.f_net = SymbolicModel(model.f_net, block_name=\"f_net\")\n",
    "\n",
    "# Set up pruning\n",
    "epochs = 100\n",
    "batch_size = 32 \n",
    "steps_per_epoch = x.shape[0]/batch_size\n",
    "total_steps = epochs * steps_per_epoch\n",
    "model.f_net.setup_pruning(initial_dim=32, \n",
    "                          target_dim=2, #dims at the end of pruning\n",
    "                          end_step_frac=0.5, #when to finish pruning eg. ~50% of the way through training \n",
    "                          decay_rate='exp',\n",
    "                          total_steps=int(total_steps))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "213f5b8b",
   "metadata": {},
   "source": [
    "## Training our model and dynamically reducing dimensionality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f495823a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up training\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "def train_model(model, dataloader, X_val, opt, criterion, epochs=100):\n",
    "    \"\"\"\n",
    "    Train model with MLP f (with pruning) and linear g_net.\n",
    "    \n",
    "    Args:\n",
    "        model: PyTorch model to train\n",
    "        dataloader: DataLoader for training data\n",
    "        X_val, y_val: Validation data for pruning\n",
    "        opt: Optimizer\n",
    "        criterion: Loss function\n",
    "        epochs: Number of training epochs\n",
    "        \n",
    "    Returns:\n",
    "        tuple: (trained_model, loss_tracker, active_dims_tracker)\n",
    "    \"\"\"\n",
    "    loss_tracker = []\n",
    "    active_dims_tracker = []\n",
    "    \n",
    "    step = 0\n",
    "    for epoch in range(epochs):\n",
    "        epoch_loss = 0.0\n",
    "    \n",
    "        for batch_x, batch_y in dataloader:\n",
    "            # Forward pass\n",
    "            pred = model(batch_x)\n",
    "            loss = criterion(pred, batch_y)\n",
    "            # Backward pass\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            \n",
    "            epoch_loss += loss.item()\n",
    "\n",
    "            model.f_net.prune(step, sample_data = X_val, # Pass in the validation set (or a subset of) to the model\n",
    "                    parent_model = model) # Pass in the parent model to get the correct inputs to the layer\n",
    "            \n",
    "            step+=1\n",
    "        \n",
    "        loss_tracker.append(epoch_loss)\n",
    "        active_dims_tracker.append(model.f_net.pruning_mask.sum().item())\n",
    "\n",
    "\n",
    "\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            avg_loss = epoch_loss / len(dataloader)\n",
    "            active_dims = model.f_net.pruning_mask.sum().item()\n",
    "            print(f'Epoch [{epoch+1}/{epochs}], Avg Loss: {avg_loss:.6f}, Active dims: {active_dims}')\n",
    "            \n",
    "    return model, loss_tracker, active_dims_tracker\n",
    "\n",
    "# Set up training\n",
    "criterion = nn.MSELoss()\n",
    "opt = optim.Adam(model.parameters(), lr=0.001)\n",
    "# Split data\n",
    "X_train, X_val, y_train, y_val = train_test_split(\n",
    "    x, y.reshape(-1,1), test_size=0.1, random_state=290402)\n",
    "\n",
    "# Set up dataset - only x as input now\n",
    "dataset = TensorDataset(torch.FloatTensor(X_train), torch.FloatTensor(y_train))\n",
    "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "add838bc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training...\n",
      "Epoch [10/100], Avg Loss: 0.002543, Active dims: 18\n",
      "Epoch [20/100], Avg Loss: 0.002195, Active dims: 11\n",
      "Epoch [30/100], Avg Loss: 0.002210, Active dims: 6\n",
      "Epoch [40/100], Avg Loss: 0.002022, Active dims: 4\n",
      "Epoch [50/100], Avg Loss: 0.001870, Active dims: 2\n",
      "Epoch [60/100], Avg Loss: 0.001901, Active dims: 2\n",
      "Epoch [70/100], Avg Loss: 0.001836, Active dims: 2\n",
      "Epoch [80/100], Avg Loss: 0.001824, Active dims: 2\n",
      "Epoch [90/100], Avg Loss: 0.001841, Active dims: 2\n",
      "Epoch [100/100], Avg Loss: 0.001829, Active dims: 2\n",
      "Training completed!\n"
     ]
    }
   ],
   "source": [
    "# Train the model and save the weights\n",
    "print(\"Starting training...\")\n",
    "model, losses, active_dims = train_model(model, dataloader, torch.FloatTensor(X_val), opt, criterion, 100)\n",
    "print(\"Training completed!\")\n",
    "torch.save(model.state_dict(), 'model_weights.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e21b32eb",
   "metadata": {},
   "source": [
    "Let's see how the number of active dimensions decrease as training progesses."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "63359981",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGwCAYAAACzXI8XAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAASvdJREFUeJzt3Qd4lFX2+PFDegfSSKUjgdCDIFUFVprSUVlWAVlZFRXBXpFdFeztZ/nrirh26YgC0hGkGHqvoaQRakghCSTzf+6NGROamWQm75Tv53nezTvvDDPHd5U5uffce6qZTCaTAAAAOCA3owMAAACoKBIZAADgsEhkAACAwyKRAQAADotEBgAAOCwSGQAA4LBIZAAAgMPyECdXVFQkqampEhgYKNWqVTM6HAAAUA5qm7usrCyJiooSNzc3101kVBITGxtrdBgAAKACjh07JjExMa6byKiRmJIbERQUZHQ4AACgHM6dO6cHIkq+x102kSmZTlJJDIkMAACO5a/KQij2BQAADotEBgAAOCwSGQAA4LBIZAAAgMMikQEAAA6LRAYAADgsEhkAAOCwSGQAAIDDIpEBAAAOi0QGAAC4TiJz9OhR3ZHyUuqaes4SH330kbRo0cLcPqBDhw6yYMEC8/N5eXkyduxYCQkJkYCAABk8eLAcP37c0pABAICTsjiRqVevnpw4ceKy66dPn9bPWUJ1s5wyZYps3LhREhMTpVu3btK/f3/ZuXOnfn78+PHy448/yvTp02XlypW6k/WgQYMsDRkAADipaqYrDa9cg5ubmx4VCQsLK3P9yJEj0rRpU8nJyalUQMHBwfL666/LkCFD9Gd88803+lzZs2ePNGnSRNauXSs33HBDubtnVq9eXTIzM63aNLKoyCRJp3Kkuq+nhAZ4W+19AQCAlPv7u9zdrydMmGDuQvn888+Ln5+f+bnCwkJZv369tGrVqsIBq/dQIy8qEVJTTGqU5sKFC9KjRw/za+Li4qR27drXTGTy8/P1UfpG2MLYbzbJgh3p8uJtTWVkJ8tGogAAgHWUO5HZvHmz/qkGcLZv3y5eXl7m59R5y5Yt5bHHHrM4APVeKnFR9TCqDmb27Nl6ZGfLli36fWvUqFHm9bVq1ZL09PSrvt/kyZNl0qRJYmuNagXqRGZnqm0SJQAAYMVEZvny5frnqFGj5N1337XaNE3jxo110qKGjmbMmCEjRozQ9TAV9fTTT5tHj0pGZGJjY8XamkYW//OTyAAA4ACJTInPP//cqgGoUZeGDRvq84SEBPn99991onTHHXdIQUGBnD17tsyojKrPiYiIuOr7eXt768PW4qOKE5n9GVlScLFIvDxYyQ4AgN0nMmpl0bUsW7asMvFIUVGRrnFRSY2np6csXbpUL7tW9u7dq5d4q6koo8XU9JUgHw85l3dRJzPxUdWNDgkAAJdjcSKjamFKUwW5ampox44delrI0mmg3r176wLerKwsvUJpxYoVsmjRIl2pPHr0aD1NpFYyqamshx56SCcx5V2xZEuq6LlpVJCsO3RadqWeI5EBAMAREpm33377itdffPFFyc7Otui9MjIy5O6775a0tDSduKjN8VQS87e//c38WWq5txqRUaM0PXv2lA8//FDshUpeVCKj6mSGGh0MAAAuyOJ9ZK7mwIED0q5dO70xnj2x1T4yysyNyfLo9K3Srl6w/PAv46e7AABwFuX9/rZahara28XHx0dcSXx08Y3dnXpOb5AHAADsfGrp0hYBakBHTQ2pFgNqozxX0iAsQK9Wysq/KMfO5EqdEH+jQwIAwKVYnMioYZ7SVA2L2gvm3//+t9xyyy3iSjzd3aRxrUDZnpKpC35JZAAAcLF9ZByd2hhPJTKq4Ld380ijwwEAwKVYnMiUUFNJu3fv1ueqpYDa98UV6TqZRJFdaezwCwCA3ScyycnJMmzYMFmzZo15x121+27Hjh3lu+++k5iYGHElJTv87kzNNDoUAABcjsWrlv75z3/qTfDUaIxaaq0Oda525FXPuZq4iCCpVk3k+Ll8OZn9Z9dtAABgh4mMauj40Ucf6QLfEur8/fffl1WrVomr8ff2kHp/FPmqgl8AAGDHiYzqJK1GZC5VWFgoUVFR4oqamKeXSGQAALDrROb111/XPY9UsW8JdT5u3Dh54403xBWV1MlQ8AsAgJ23KKhZs6bk5ubKxYsXxcOjuFa45Nzfv+w+KvbQrsCWLQpKrNx3QkZM3SD1w/xl2aM32eQzAABwJefK+f1doaaRqvMzyu4loySdzJHcgovi51XhVe0AAMACFn/jjhw50tI/4vTCAr0lPNBbMrLyZXdaliTUqWl0SAAAuASLa2Tc3d0lIyPjsuunTp3Sz7mqptTJAABg/4nM1Upq8vPzxcvLS1yVueCXjfEAALC/qaX33ntP/1T1Mf/9738lICCgzNJrtYdMXFycuKr4qOJmmizBBgDADhMZVeRbMiLz8ccfl5lGUiMxdevW1dddveB3T3qW5F0oFB9P151mAwDA7hKZpKQk/fPmm2+WWbNm6WXY+FPtYD+JrO4jaZl5smT3cbm1hWtuDggAgF3XyCxfvpwk5grc3KrJoDbR+nx6YrLR4QAA4BIsXn59zz33XPP5qVOniqsamhArHyw/KL/uPyHpmXkSUd3H6JAAAHBqFicyZ86cKfNY9V3asWOHnD17Vrp16yaurG6ov7SrGywbDp+WmZuSZezNDY0OCQAAp2ZxIjN79uzLrhUVFcn9998vDRo0EFc3pG2MTmSmJx6TB25qwC7IAADYU43MFd/EzU0mTJhgXtnkyvo2jxQ/L3c5fCpXEo+UHb0CAAB2mMgoBw8e1M0jXZ2/t4f0aR6pz9WoDAAAsKOpJTXyUpraVyYtLU1++uknGTFihDVjc1i3t42VGRuT5adtafJiv3iaSAIAYCMWf8Nu3rz5smmlsLAwefPNN/9yRZOruL5uTakb4qenl37eni5DEmKMDgkAAKfkUZF9ZHBtqsBXJS9v/LJPfkg8RiIDAIC91cicOHFCVq9erQ91jrIGtYkRtWBpQ9JpOXIqx+hwAABwShYnMjk5OXoKKTIyUrp27aqPqKgoGT16tOTm5tomSgcUVcNXOjcM1eezNqUYHQ4AAE7JrSLFvitXrpQff/xRb4Knjrlz5+prjz76qG2idFC9mkXon5uOsgwbAAC7qJGZOXOmzJgxQ2666SbztT59+oivr6/cfvvt8tFHH1k7RofVLKq6/rkz9Zxe3cXmeAAAGDwio6aPatWqddn18PBwppYu0TgiUNzdqsnpnAI5fi7f6HAAAHA6FicyHTp0kIkTJ0peXp752vnz52XSpEn6OfzJx9NdGoT56/OdqZlGhwMAgNOxeGrp3XfflZ49e0pMTIy0bNlSX9u6dav4+PjIokWLbBGjQ4uPqi77jmfr6aXuTS4fyQIAAFWYyDRr1kz2798vX3/9tezZs0dfGzZsmAwfPlzXyaCsppFBMntziuxKPWd0KAAAOJ0K7Z3v5+cn9957r/WjcULxUUH65840ppYAALDbppG4sqZ/JDLHTp+XzPMXjA4HAACnQiJjYzX8vCS6RvGU2+40ppcAALAmEpkqHJWhTgYAAOsikamigl9FrVwCAAAGFft+//33Mm/ePCkoKJDu3bvLfffdZ8VQnL/gdxdTSwAAGJPIqNYDY8eOlUaNGull1rNmzZKDBw/K66+/bt2InFB8dHGrgv3HsyT/YqF4e7gbHRIAAK41tfR///d/ekffvXv3ypYtW+SLL76QDz/80LbROYmo6j5S3ddTLhaZZP/xbKPDAQDA9RKZQ4cOyYgRI8yP//73v8vFixclLS3NVrE5DdUs0jy9RJ0MAABVn8jk5+eLv7//n3/QzU28vLx0nyVYUvDLxngAABhS7Pv888/rXX1LqKLfl19+WapXL64BUd566y2rBedM4qMp+AUAwLBEpmvXrro+prSOHTvqKSeUr3lkydRSUZFJ3NyqGR0SAACuk8isWLHCtpE4ufqh/uLt4SY5BYVy9HSu1A39c5oOAAAYvCGeGpm55ZZbrPV2TsfD3U3iIgL1ORvjAQBgZ4lMVlaWLF261KI/M3nyZLn++uslMDBQwsPDZcCAAZdNX91000161U/pw1E34jO3KqATNgAAjt+iYOXKlXqTvXXr1snixYvlwoULelQnJyenzOvuvfdevcy75HjttdfEETX9o06GERkAAAxYtWRtCxcuLPN42rRpemRm48aNuri4hFopFRERIY6uZC8ZEhkAAJywaWRmZvGUS3BwcJnrX3/9tYSGhkqzZs3k6aefltzc3Gvud3Pu3Lkyh71QNTLVqomcyMqX9Mw8o8MBAMB1RmRat26t61Ou5lrJRXkUFRXJI488Ip06ddIJS+kdhOvUqSNRUVGybds2efLJJ3Udjer1dLW6m0mTJok98vPykJYxNWTLsbMyf1uq/LNLfaNDAgDANRIZVYhrS6pWZseOHbJ69eoy18eMGWM+b968uURGRurO26phZYMGDS57HzViM2HCBPNjNSITGxsr9mJwQoxOZKYnJsvozvWumRwCAAArJTKqYaStPPjggzJ//nxZtWqVxMTEXPO17du31z8PHDhwxUTG29tbH/aqX8soeWn+Ltl7PEu2p2RKi5gaRocEAIDDMrRGxmQy6SRm9uzZsmzZMqlXr95f/hnVeVtRIzOOSHXB7hlfXLisRmUAAEAVjMh069atXK9TCYkl00nffPONzJ07V+8lk56erq+r3k2+vr56+kg936dPHwkJCdE1MuPHj9crmlq0aCGOamjbGJm3NVXmbkmRZ/s2ER9Pd6NDAgDA+VsUqKLbvn37iqenp1U+/KOPPjJvelfa559/LiNHjtTdtZcsWSLvvPOO3ltG1boMHjxYnnvuOXFkHRuESlR1H0nNzJNfdh3X000AAMBy1UxqfqccXn/9dZ1gnDp1SoYPHy733HNPmdVF9koV+6oRHrW0OyioeB8Xe/DmL3vl/WUHpEujUPlydHHdDwAAsOz7u9w1Mo8//rjs2rVL5syZo9sRqGXS7dq1k48//tiu9mpxFEMSiouaVx84KalnzxsdDgAArlHs26FDB/n00091qwBV4zJ16lS9xwvJjGXqhPhL+3rBosbDZm2i6BcAgCpdtbRp0ybdK2n37t16isladTOuZGjb4v1tpm9M1iu4AACADROZ1NRUeeWVV+S6666TIUOG6FYC69ev100f1SojWKZP8wjx93KXI6dyZUPSaaPDAQDAeRMZtQRabUCnEhdV+JucnCxvvPGGNG3a1LYROjHVsqBvi0jzqAwAALDRqiU3Nze9CZ3qTn2tbfXVlJM9sddVSyV+P3xahn68Vvy83OX3Z3uIv7ehDckBAHCo72+7aFHgytrWqSn1Qv0l6WSO/LQ9TW7/o24GAAD8NRIZg6nRLbUU+/VFe2VGYjKJDAAAjtJrCcUGt4kRt2oiGw6f1iMzAACgfEhk7EBEdR/p0ihMn8/YeMzocAAAcBgkMnbUSFKZuTFFCovYUwYAgPIgkbETPZrUkuq+npJ+Lk+3LQAAAFZKZNTGdydPFn+5qmaRqtcSrMvH0136tyrugj09keklAACslsgUFBSYeyl98cUXkpeXV643h2VKViz9suu4nM0tMDocAACcY/m1ahQ5YMAASUhI0D2BHn744au2JFBNJFEx8VFBEhcRKHvSs2Te1lS5u0Ndo0MCAMDxR2S++uor3aIgOztb73uidtk7c+bMFQ9UnLq35kaSibQsAADAai0KStSrV08SExMlJCREHIG9tyi41KnsfGn/ylK5WGSShY90kbgI+48ZAACjvr8tXrWUlJTkMEmMIwoJ8JbuTcL1OaMyAABcG8uv7dDQhOLppTmbU+RCYZHR4QAAYLdIZOzQTY3DJCzQW07lFMiyPRlGhwMAgN0ikbFDHu5uMqh1tD5nTxkAAK6ORMbOWxYs33tCMrLYtwcAAKslMgcPHpTnnntOhg0bJhkZxVMfCxYskJ07d1bk7XAFDcMDpVVsDd13SdXKAAAAKyQyK1eulObNm8v69etl1qxZem8ZZevWrTJx4kRL3w7lGJVRq5csXCUPAIBLsDiReeqpp+Sll16SxYsXi5eXl/l6t27dZN26ddaOz6Xd1jJKvD3cZH9GtmxNzjQ6HAAAHD+R2b59uwwcOPCy6+Hh4ebGkrCOIB9P6d0sQp9T9AsAgBUSmRo1akhaWtpl1zdv3izR0cUrbWA9JS0LVO+lvAuFRocDAIBjJzJ33nmnPPnkk5Kenq57AxUVFcmaNWvksccek7vvvts2UbqwDvVDJLqGr2TlXZRFO9ONDgcAAMdOZF555RWJi4uT2NhYXejbtGlT6dq1q3Ts2FGvZIJ1ublVk8EJxUW/MzexegkAgEo1jSxx9OhR2bFjh05mWrduLY0aNRJ75GhNI69kV+o56fPerxLg7SHbJt6ikxsAAJxZeb+/PSr6AbVr19YHbK9RrQDx8nCT7PyLcuxMrtQJ8Tc6JAAA7ILFiUxhYaFMmzZNli5dqjfDUzUypS1btsya8UFEPN3dpHGtQNmekqlHZ0hkAACoYCIzbtw4ncj07dtXmjVrpgt+YXtNI4N0IrMz9Zz0bh5pdDgAADhmIvPdd9/JDz/8IH369LFNRLii+OggkUSRnalsjAcAQIVXLandfBs2bGjpH4MVRmSUXWnnjA4FAADHTWQeffRReffdd+n9U8WaRAaJmsU7fi5fTmbnGx0OAACOM7U0aNCgywp6Vbfr+Ph48fT0LPOcaiQJ6/P39pB6If5y6GSOLvjtel2Y0SEBAOAYiYxax13alXotwfaaRAXpREYV/JLIAABQzkTm888/t30k+EvxUUHy07Y06mQAAKhojUy3bt3k7NmzV9yBTz0H2xf8snIJAIAKJjIrVqyQgoKCy67n5eXJr7/+aunbwQLxUcVTfEkncyS34KLR4QAA4Dj7yGzbts18vmvXLt39uvRuvwsXLpTo6GjrRwizsEBvCQ/0loysfNmdliUJdWoaHRIAAI6RyLRq1Urv4quOK00h+fr6yvvvv2/t+HCJplFBkrH3hOxKzSSRAQC4vHInMklJSXrvmPr168uGDRskLCyszCZ54eHh4u7ubqs4Uargd4VKZCj4BQCg/IlMnTp19M9Lm0SiajWNLK6TUUuwAQBwdRYX+8L4ERllT3qWXCwkqQQAuDYSGQdTO9hPArw9pOBikRw8kWN0OAAAGIpExsG4uVWTJpGB+nxXGvvJAABcm0WJjFpmvWrVqituiIeq309mZwp1MgAA12ZRIqNWJd1yyy1y5swZq3z45MmT5frrr5fAwEC96mnAgAGyd+/eyzbaGzt2rISEhEhAQIAMHjxYjh8/Lq6sZIdfVi4BAFydxVNLzZo1k0OHDlnlw1euXKmTlHXr1snixYvlwoULOlHKyfmz9mP8+PHy448/yvTp0/XrU1NTL+vG7Yp7yZSsXFJL4gEAcFXVTBZ+E6odfJ9++mn5z3/+IwkJCeLv71/m+aCg4i/Zijhx4oQemVEJS9euXSUzM1PvV/PNN9/IkCFD9Gv27NkjTZo0kbVr18oNN9zwl++pekCp7t3qvSoTmz1Rhb7xExfKhUKTrH7yZomp6Wd0SAAAWFV5v7/LvY9MiT59+uif/fr107v8llD5kHqs6mgqSgWrBAcH658bN27UozQ9evQwvyYuLk5q16591UQmPz9fH6VvhLPx8nCThuGBsjvtnGxPziSRAQC4LIsTmeXLl9skELXR3iOPPCKdOnXS01eK6uekdg2uUaNGmdfWqlWrTK+nS+tuJk2aJM7uhvrBOpGZvy1NejePNDocAAAcI5G58cYbbRKIqpXZsWOHrF69ulLvo6a9JkyYUGZEJjY2VpzN4DYx8vmaw7J413E5m1sgNfy8jA4JAAD7T2QUtfz6s88+k927d+vH8fHxcs899+i5rIp48MEHZf78+Xppd0xMjPl6RESEFBQU6M8rPSqjVi2p567E29tbH86uWXR1vXpJrVyauyVVRnSsa3RIAADY/6qlxMREadCggbz99tty+vRpfbz11lv62qZNmyx6L1VXo5KY2bNny7Jly6RevXplnlfFxJ6enrJ06VLzNbU8++jRo9KhQwdxdUPbFid90zceMzoUAAAcY9VSly5dpGHDhvLpp5+Kh0fxgM7Fixfln//8p16WrUZVyuuBBx7QK5Lmzp0rjRs3Nl9XIzu+vr76/P7775eff/5Zpk2bpquWH3roIX39t99+K9dnOOOqpRKncwqk/StL9Oqlnx/uYl6WDQCAoyvv97fFiYxKMDZv3qxXD5W2a9cuadu2reTm5pb7vUqveirt888/l5EjR5o3xHv00Ufl22+/1auRevbsKR9++OFVp5ZcKZFR7v9qoyzYkS6jOtWVibfFGx0OAABWUd7vb4unltSbqamdSx07dkzv0GsJlUNd6ShJYhQfHx/54IMP9BSW2ihv1qxZ5U5iXGl6SdXJqP1lAABwJRYnMnfccYeMHj1avv/+e528qOO7777TU0vDhg2zTZS4qq6NwiQ80FtPMy3b49qtGwAArsfiVUtvvPGGnhK6++67dW2MogpyVS3LlClTbBEjrsHD3U0GtYmRj1celOmJydKrGXvKAABcR7lqZLZt26Y3qXNz+3MAR9XCHDx4UJ+rFUt+fva5u6yz18goB09kS/c3V4pbNZF1T3eX8CAfo0MCAMB+amRat24tJ0+e1Of169eXU6dO6cSlefPm+rDXJMZVNAgLkDa1a0iRSWTW5hSjwwEAoMqUK5FRm9ElJSXp88OHD+t2ArAvQ9sW7148PfEYHbEBAC6jXDUygwcP1q0JIiMjdX2MWmbt7u5+xdeqvWRQ9W5tESmTftwpB0/kyOZjZ6VN7ZpGhwQAgH0kMp988okMGjRIDhw4IA8//LDce++9Fi+1hm0F+nhK72aRMntzii76JZEBALiCcq9a6tWrl/65ceNGGTduHImMHVJ7yqhEZv7WVHnh1qbi63XlUTMAAFx2Hxm16y5JjH26oV6IxNT0laz8i7JwZ5rR4QAAYH+JDOyXm1s1GZLwRyPJxGSjwwEAwOZIZJzM4DbFicxvB0/JsdPl73sFAIAjIpFxMrHBftKxQYg+n7mJURkAgHMjkXHiRpIzNiZLkdolDwAAJ2VxIvPFF1/ITz/9ZH78xBNP6A3zOnbsKEeOHLF2fKiAXvGREujtIclnzsu6pFNGhwMAgP0kMq+88or4+vrq87Vr18oHH3wgr732moSGhsr48eNtESMspJZd39oySp9T9AsAcGYWJzLHjh2Thg0b6vM5c+boXX/HjBkjkydPll9//dUWMaIS00sLdqTJubwLRocDAIB9JDIBAQG6aaTyyy+/yN/+9jd97uPjI+fPn7d+hKiQ1rE1pEGYv+RdKJKftrGnDADAOVmcyKjE5Z///Kc+9u3bJ3369NHXd+7cKXXr1rVFjKgA1ROrdCNJAACckcWJjKqJ6dChg5w4cUJmzpwpISEh5tYFw4YNs0WMqKBBraPF3a2abDp6Vg5kZBsdDgAAVlfNZDI59frcc+fOSfXq1SUzM1OCgoLE1Yye9rss3ZMh993YQJ7qHWd0OAAAWPX7u9xNI0s7e/asbNiwQTIyMqSoqKjMdMZdd91VkbeEDYt+VSKjNsd77JbrxMOdrYMAAM7D4kTmxx9/lOHDh0t2drbOkFTyUoJExv50i6slwf5eciIrX1btP6EfAwDgLCz+9fzRRx+Ve+65RycyamTmzJkz5uP06dO2iRIV5uXhJv1bsacMAMA5WZzIpKSkyMMPPyx+fn62iQhWNzShePXSkt3H5XROgdHhAABgXCLTs2dPSUxMtF4EsLmmUUESHxUkFwpNMndLitHhAABgXI1M37595fHHH5ddu3ZJ8+bNxdPTs8zz/fr1s150sJrb28bKxHk79fTSqE71jA4HAABjll+7uV19EEcV+xYWFoo9cfXl1yXO5hZIu5eXSkFhkcx/qLM0i65udEgAAFT6+9viqSW13Ppqh70lMfhTDT8v+VvT4hVLMzZS9AsAcA5sKuJChvzRSHLOlhTJv0jSCQBw0URm5cqVctttt+ku2OpQdTF0vrZ/XRuFSa0gbzmbe0GW7s4wOhwAAKo+kfnqq6+kR48eevm1WoatDl9fX+nevbt88803lY8INqP6Lg1qUzwqQyNJAIBLFvs2adJExowZI+PHjy9z/a233pJPP/1Udu/eLfaEYt+yDp3Ilm5vrhS3aiK/PdVdIqr7GB0SAABVV+x76NAhPa10KTW9lJSUZOnboYrVDwuQtnVqSpFJZNZmin4BAI7N4kQmNjZWli5detn1JUuW6OfgGI0klRmJyeLkzc8BAE7OoyK9llRdzJYtW6Rjx4762po1a2TatGny7rvv2iJGWFnfFlHy4rxdcuhkjmw6ekYS6gQbHRIAAFWTyNx///0SEREhb775pvzwww/mupnvv/9e+vfvX7EoUKUCvD2kd/MImbUpRe/0SyIDAHCZYl9HQ7Hvla09eEqGfbpO/L3c5ffneoifl8U5LQAAjlfsC+fQvl6w1A72k5yCQlmwPd3ocAAAqJByJTLBwcFy8uRJfV6zZk39+GoHHIObWzUZkvDHnjIb2VMGAOCYyjWf8Pbbb0tgYKD5XDWHhOMbnBAjby/ZJ+sOndZFv+GB3ubnQvy9xdfL3dD4AAD4K9TIuLh//He9rD5QPNpWWrC/lyyZcKP+CQCA09TIbNq0SbZv325+PHfuXBkwYIA888wzUlBQUPGIYYgHbmqgkxVvDzfzoXb9PZ1TILM3pxgdHgAA12RxIvOvf/1L9u3bZ97l94477tB9l6ZPny5PPPGEpW8Hg3VsGCqbnv+b7H2pt/mY1C/e3I/JyQfsAACulsioJKZVq1b6XCUvN954o24WqTbEmzlzpi1iRBXr1zJavDzcZE96luxMPWd0OAAAWC+RUb+hFxUVmdsS9OnTR5+r9gQlK5vg2Kr7ecotTWvp8x/okg0AcKZEpm3btvLSSy/Jl19+KStXrpS+ffvq66phZK1axV9+cHxD2xb3zZq7JVXyLhQaHQ4AANZJZN555x1d8Pvggw/Ks88+Kw0bNtTXZ8yYYe69BMfXuWGoRFb3kczzF2TJ7uNGhwMAgG2XX+fl5Ym7u7t4enqKPWH5dcW9sWiv/N/yA3LjdWHyxT3tjA4HAOBCztm6RYFaap2cnCxHjx7VR0ZGhqSlpVX07WCHSnb+/XX/CUnPzDM6HAAArLNqqUuXLuLr6yt16tSRevXq6aNu3br6pyVWrVolt912m0RFRendgufMmVPm+ZEjR+rrpY9evXpZGjIqqG6ov7SrGyxFJpGZm5KNDgcAgMtY3PJ41KhR4uHhIfPnz5fIyMhKtSvIycmRli1byj333CODBg264mtU4vL555+bH3t7/7mNPmxvSNsY2XD4tN5TRm2eR3sKAIBDJzJbtmyRjRs3SlxcXKU/vHfv3vq4FpW4REREVPqzUDF9m0fKi/N2yuFTuZJ45IxcX5fGoAAAB55aatq0aZXuF7NixQoJDw+Xxo0by/333y+nTp265uvz8/N1gVDpAxXn7+0hfZpH6nM1KgMAgEMnMq+++qpuRaASDJVU2DJpUNNK//vf/2Tp0qX6c9W+NWoEp7Dw6vuaTJ48WVc5lxxqoz5UztA/in5/2pYmuQUXjQ4HAICKL792cyvOfS6tlVBvo65dK8m4FvVnZ8+erRtQXo3q7dSgQQO9o3D37t2vOiKjjhIquVLJDMuvK079f3vTGyvkyKlceWNoS/NqJgAAjF5+bXGNzPLly8Uo9evXl9DQUDlw4MBVExlVU0NBsHWpJFONyrzxyz49vUQiAwCwFxYnMqpJpFHUvjVqOkutlkLVGtQmRt5cvE/WJ52WI6dypE6Iv9EhAQBQsQ3xfv31V/nHP/6hWxKkpKToa6r30urVqy16n+zsbL0KSh0l/ZrUudpgTz33+OOPy7p16+Tw4cO6TqZ///66JULPnj0rEjYqIaqGr25boMzYyJ4yAAAHTWRmzpypEwm1IZ7quVRSj6LmsF555RWL3isxMVFat26tD2XChAn6/IUXXtDtDrZt2yb9+vWT6667TkaPHi0JCQk6iWLqyNhGkjM3Jkuh2iUPAABHK/ZVicb48ePl7rvvlsDAQNm6dauuXdm8ebNeUZSeni72hF5L1qO6YLd7eYmcy7soX45uJ10ahRkdEgDASdms19LevXula9eul11XH3b27FnLI4XD8PF0l36tovT59ESmlwAAxrM4kVG77KpVQ5dS9TFqZAbO7fY/ppcW7kyXzNwLRocDAHBxFicy9957r4wbN07Wr1+vl+WmpqbK119/LY899pjeeRfOrXl0dWlcK1AKLhbJvG2pRocDAHBxFi+/fuqpp6SoqEjv45Kbm6unmVTxrUpkHnroIdtECfvaU6ZtjLz0026ZkXhM7rqhjtEhAQBcmMXFviUKCgr0FJNaJq36LwUEBIg9otjX+k5m58sNryyVi0UmWfRIV2kcEWh0SAAAJ2OzYt8SXl5eOoFp166d3SYxsI3QAG+5OS5cn8/YSCNJAIADTS3l5eXJ+++/r1sVZGRk6Gmm0tTeMnB+qmXB4l3HZfbmFHmiV5x4ulc4JwYAoOoSGbUx3S+//CJDhgzRozGXNo+Ea1AjMqEBXnIyu0CW78mQW+IjjA4JAOCCLE5k5s+fLz///LN06tTJNhHBIagRmIGto+XTX5Nk+sZkEhkAgCEsng+Ijo7WO/oCJS0L1IiMKgAGAMDuE5k333xTnnzySTly5IhtIoLDuK5WoLSMqa5XL83ZXNw8FAAAu05k2rZtqwt+1S6+amQmODi4zAHXMuSPUZkfEo9JBVfyAwBQdTUyw4YNk5SUFN3pulatWhT7urh+LaLkP/N3yb7j2bItOVNaxtYwOiQAgAuxOJH57bffZO3atdKyZUvbRASHUt3PU3rFR8i8rakyfeMxEhkAgH1PLcXFxcn58+dtEw0ckmpZoMzbkip5FwqNDgcA4EIsTmSmTJkijz76qKxYsUJOnTqltxAufcD1dGwQKlHVfeRc3kWZsTFZks/kmo/M83TIBgDYUa8lN7fi3OfS2hj1NupaYaF9/UZOr6Wq8eYve+X9ZQcuu+7uVk2+H3ODtK1LITgAwPrf3xbXyKjWBMClhrevo+tk0jPzzNcKi0x6afa03w6TyAAA7Kv7taNgRMY4O1Iy5db3V4uXh5v8/kwPXRgMAECVj8hs27ZNmjVrpqeV1Pm1tGjRolwBwvnFRwVJXESg7EnPknlbU+SuDnWNDgkA4GTKlci0atVK0tPTJTw8XJ+rWpgrDeTYY40MjKP+fVBtDNQ+Mz8kJpPIAACMSWSSkpIkLCzMfA6U14BWUTL5592yPSVT9qSfk7gIpvcAAFWcyNSpU+eK58BfCQnwlu5NwmXRzuMyPTFZnr+1qdEhAQBcLZGZN29eud+wX79+lYkHTuj2trE6kVGNJZ/qHSee7hZvXwQAQMUTmQEDBpR5fGmNTOk9ZaiRwaVuvC5MwgK95URWvizbkyE94yOMDgkA4CTK9atxUVGR+fjll190we+CBQvk7Nmz+vj555+lTZs2snDhQttHDIfj4e4mg1pH63M1vQQAgLVYvCHeI488Ih9//LF07tzZfK1nz57i5+cnY8aMkd27d1stODhXP6b/t+qQLN+bIRlZeRIe6GN0SAAAJ2BxscLBgwelRo3LOxyrTWsOHz5srbjgZBqGB0qr2Bp6t19VKwMAgCGJzPXXXy8TJkyQ48ePm6+p88cff1zatWtnlaDg3F2y1fSSk28oDQCw10Rm6tSpkpaWJrVr15aGDRvqQ52npKTIZ599Zpso4RRuaxkl3h5usj8jW7YmZxodDgDAFWtkVOKi2hQsXrxY9uzZo681adJEevTocVlHbKC0IB9P6dUsQuZuSZXpicf0VBMAAJVB00hUqTUHTsrw/66XQB8P+f3ZHuLj6W50SAAAB/7+ZmcyVKkO9UMkuoavZOVdlEU7040OBwDg4EhkUKXc3KrJ4IQ/i34BAKgMEhlUuaF/JDJrDp6UlLPnjQ4HAODASGRQ5WKD/eSG+sGiqrNmbmRUBgBQxYmM2hTvueeek2HDhklGRoa+ploW7Ny5sxKhwJUMTYjVP6dvPCZFRU5dbw4AsKdEZuXKldK8eXNZv369zJo1S7Kzs/X1rVu3ysSJE20RI5xQ7+YREuDtIcdOn5f1SaeNDgcA4CqJzFNPPSUvvfSS3kfGy8vLfL1bt26ybt06a8cHJ+Xn5SG3tog0j8oAAFAlicz27dtl4MCBl10PDw+XkydPVigIuHbLggXb0+XgiWxJPpNrPgouFhkdHgDAGXf2VQ0jVYuCevXqlbm+efNmiY6OtmZscHJtateU+mH+cuhEjnR/c2WZ5+qF+suiR7qKlwf16ACAq7P4W+LOO++UJ598UtLT03VLgqKiIlmzZo089thjcvfdd1v6dnBh6t+fcd0b6V1+VQ+mkkN1ukg6mSPL9vzZmBQAAKu0KCgoKJCxY8fKtGnTpLCwUDw8PPTPv//97/qau7t9bTlPiwLHM3nBbvl/Kw9J97hw+Wzk9UaHAwCw4+/vCvdaOnr0qOzYsUOvWmrdurU0atRI7BGJjOM5kJEtPd5aKW7VRNY93V3Cg3yMDgkAYKff3xbXyKxevVo6d+4stWvX1gdgbQ3DA6R17Rqy+ehZmbU5Re67sYHRIQEAnKVGRi2zVoW+zzzzjOzatcs2UcHl3d72jw3zEo+JkzdoBwBUZSKTmpoqjz76qN4Yr1mzZtKqVSt5/fXXJTmZreZhPWqPGR9PNzl4Ikc2HztrdDgAAGdJZEJDQ+XBBx/UK5VUq4KhQ4fKF198IXXr1tWjNYA1BPp4Su9mf2yYR5dsAMBVVGqTDjXFpHb6nTJlim5boEZpAGt3yf5xa6qcLyg0OhwAgDMlMmpE5oEHHpDIyEi99FpNM/30008WvceqVavktttuk6ioKL2nyJw5c8o8r2ojXnjhBf0Zvr6+0qNHD9m/f39FQ4aDuaF+iMTU9JXs/IuycGea0eEAAJwhkXn66af1SIyaRlJLsN999129Od6XX34pvXr1sui9cnJypGXLlvLBBx9c8fnXXntN3nvvPfn44491k0p/f3/p2bOn5OXlWRo2HJCbWzUZ8seoDNNLAACr7CPTqVMnGT58uNx+++26XsZa1IjM7NmzZcCAAfqxCkuN1KjCYrVrsKLWkteqVUtvvKd2GC4P9pFxbMdO50qX15br89kPdJSwQG/zc6EB3uLjaV8bMAIA7HwfGTWlVBWSkpL0SI+aTiqh/oHat28va9euvWoik5+fr4/SNwKOKzbYTzo2CJHfDp6SgR/+Vua5EH8vWTLhRqnp/2cXdgCAaylXIjNv3jzp3bu3eHp66vNr6devn1UCU0mMokZgSlOPS567ksmTJ8ukSZOsEgPswwM3NZTdaeckt1TB74XCIjmVUyBztqTIqE5lG5gCAFxHuRIZNd2jkofw8HDz1M/VpodU3yUjqRqeCRMmlBmRiY0t3lwNjqlzo1DZ/MItZa5NW5MkL/64S9fOkMgAgOsqV7Gv6nCtkpiS86sd1kxiIiIi9M/jx8t2QFaPS567Em9vbz2XVvqA8+nfKlq83N1kV9o52ZmaaXQ4AABHWbX0v//9r0wNSumu2Oo5a1Ero1TCsnTp0jKjK2r1UocOHaz2OXBMqi6mR9Pi5JoVTQDguixOZEaNGqUriC+VlZWln7OE6py9ZcsWfZQU+KpztaxbTVM98sgj8tJLL+m6nO3bt8vdd9+tVzJda3oLrmPoH/2YVJ1M/kU2zAMAV2TxqiW1LFolGZdSvZbUqiJLJCYmys0332x+XFLbMmLECL3E+oknntB7zYwZM0bOnj2ru24vXLhQfHx8LA0bTqhrozCpFeQtx8/ly9LdGdKneXFLAwCA6yj3PjKtW7fWCczWrVslPj5ePDz+zIFUbYwaTVEb4v3www9iT9hHxrm9unCPfLTioNzcOEw+H9XO6HAAAPa6j0zJdI6a+lG76wYEBJif8/Ly0k0jBw8eXNm4AYv7MalEZuW+E3L8XJ7UCmK0DgBcSbkTmYkTJ+qfKmFRm9Gp1UGA0eqHBUhCnZqy8cgZmbUpRe6/qYHRIQEA7LnYt2nTpubi3NLUaiJV8wJUtdvb/tGPaeMxXcMFAHAdFicyY8eOlWPHjl12PSUlRT8HVLW+LaLE19NdDp3IkWV7MiT5TK75UJ2zAQDOy+JVS7t27ZI2bdpcsRhYPQdUtQBvD+ndPEJPLY3+ouyooEpwFozrInVD/Q2LDwBgRyMyqjbm0t12lbS0tDIrmYCqdG+X+hIe6C3eHm7mw92tmpy/UCjfbjhqdHgAAKOXX5cYNmyYTlrmzp1r3jdG7fGiVjWpNgYsv4a9WLgjXe77aqOEBXrL2qe6iYe7xXk7AMBZll+XeOONN6Rr165Sp04dPZ2kqOJf1ZX6yy+/rFzUgBV1iwuXYH8vOZGVr5dnd29StpM6AMDxWfwranR0tGzbtk1ee+01vYIpISFB3n33Xd1CgC7TsCdeHm4yoFW0PqcfEwA4pwoVtfj7++u2AYC9u/36GJm6JkmW7D4up7LzJSSA/Y8AwJlUuDpXrVBSzR1V1+vS+vXrZ424AKuIiwiS5tHVZXtKpszZkiqjO9czOiQAgJGJzKFDh2TgwIF6Kkn1XiqpFS5pJKn6LgH2ZGjbGJ3ITE88Jvd0qnvFpqcAABepkRk3bpzUq1dPMjIyxM/PT3bu3CmrVq2Stm3byooVK2wTJVAJ/VpGiZe7m+xJz5KdqeeMDgcAYGQis3btWvn3v/8toaGh4ubmpo/OnTvL5MmT5eGHH7ZmbIBV1PDzkr/FF69YUqMyAAAXTmTU1FFgYKA+V8lMamqqPlfLsffu3Wv9CAEruL1t8Yq6uVtTJf8i058A4LI1Ms2aNZOtW7fq6aX27dvrZdheXl7yySefSP369W0TJVBJnRuGSmR1H0nLzNOtDLo0CjU/F+jjKdV9PQ2NDwBQRYnMc889Jzk5OfpcTTHdeuut0qVLFwkJCZHvv/++gmEAtqXaFQxqEy0fLD8oT8/aftlzP/yrgyTUqWlYfACAKmpRcCWnT5+WmjVr2uVqEFoUoERa5nm54/+tk+Pn8szXCotMcrHIpAuC3xtWvFM1AMCJWxRcSXBwsDXeBrCpyOq+suqJm8tc23rsrPT/YI0s2pkumecvMMUEAA6GLnpwaS1iqst1tQIk/2KR/Li1uHAdAOA4SGTg0tR06NCE4hVN0zfSjwkAHA2JDFzegNbR4uFWTU8z7TueZXQ4AABrJzJt2rSRM2fOmFcq5ebmWvIZgF0LC/SWm+PC9Tkb5gGAEyYyu3fvNi+5njRpkmRnZ9s6LqBKDU2I0T9nb06RC4VFRocDACincq1aatWqlYwaNUq3IlCrtd944w0JCAi44mtfeOGF8n42YDfUiExogJeczC6QFXtPyN+aFrc0AAA4QSIzbdo0mThxosyfP18XRy5YsEA8PC7/o+o5Ehk4Ik93NxnQKlr+uzpJTy+RyACAk26Ip5pEpqenS3h4cU2BvWNDPJTX3vQs6fnOKl34u+6Z7hIa4G10SADgss6V8/vb4lVLRUVFDpPEAJZoHBEoLWOq651+v153VJLP5JqP0zkFRocHALDWzr4HDx6Ud955RxcBK02bNpVx48ZJgwYNKvJ2gN0Y0jZWtiZnyttL9umjtI//0UZ6NYs0LDYAgBVGZBYtWqQTlw0bNkiLFi30sX79eomPj5fFixdb+naAXRnQKkqaRAaJt4eb+fB0L+4hNnX1YaPDAwBUtkamdevW0rNnT5kyZUqZ60899ZT88ssvsmnTJrEn1MjAGs0mO01ZJkUmkRWP3SR1Q/2NDgkAnN45W9XIqOmk0aNHX3b9nnvukV27dlkeKeAAzSY7NwrT5zNoYwAAdsXiRCYsLEy2bNly2XV1jSJgOPuGeTM3JUuhGpoBADhmse+9994rY8aMkUOHDknHjh31tTVr1sirr74qEyZMsEWMgOHUvjLVfT0lLTNPVh84KTdeVzxCAwBwsETm+eefl8DAQHnzzTfl6aef1teioqLkxRdflIcfftgWMQKG8/F0l/6touR/a4/oDfNIZADAQYt9S8vKKu4UrBIbe0WxL6xle3Km3PZ/q8XLw01+f6aHVPfzNDokAHBaNiv2LU0lMPacxADW1Cw6SOIiAqXgYpHM25pidDgAgMomMoArUb3EhvxR9Dud1UsAYBdIZAALDGwdrXsxbUvO1L2ZAAAO2KIAcFUhAd7SvUm4LNp5XL5cd1juu7GBxV22wwO99egOAKCKE5kLFy5Ir1695OOPP5ZGjRpZ4eMBxzM0IVYnMl+tO6oPS6nk56necTaJDQBcjUVTS56enrJt2zbbRQM4gJsah8kN9YPL9GMqz6FWOylfrzsi5wsKjf7HAADXXH49fvx48fb2vqzXkr1i+TXsRVGRSbq+vlySz5yXd+5oJQNaRxsdEgA4/Pe3xTUyFy9elKlTp8qSJUskISFB/P3LNtB76623KhYx4OTc3KrJ4DYx8u7S/TJ94zESGQCwAosTmR07dkibNm30+b59+8o8RwEjcG1q+bZKZH47eEqOnc6V2GA/o0MCANdKZJYvX26bSAAXoBKXjg1CdCKjGlA+0uM6o0MCANfcR+bAgQOyaNEiOX/+vH5ciU4HgEsZ2rZ4U70ZG5N13QwAoAoTmVOnTkn37t3luuuukz59+khaWpq+Pnr0aHn00UcrEQrgGnrFR0qgt4cu+l2XdMrocADAtRIZtWpJLcM+evSo+Pn9Ob9/xx13yMKFC60dH+B0fL3c5daWkfp8RiKtDgCgShOZX375RV599VWJiSkeHi+hNsg7cuSIWNOLL76oC4hLH3FxbCQGxzckIVb//HlHmmTlXTA6HABwnUQmJyenzEhMidOnT+v9ZawtPj5eT1+VHKtXr7b6ZwBVrU3tGtIgzF/yLhTJT9uKp2cBAFWQyHTp0kX+97//mR+rUZKioiJ57bXX5OabbxZr8/DwkIiICPMRGhpq9c8Aqpr672Zo2+JRme8Tj0nymdxrHpnnGbUBAKssv1YJiyr2TUxMlIKCAnniiSdk586dekRmzZo1Ym379++XqKgo8fHxkQ4dOsjkyZOldu3aV319fn6+PkrvDAjYo0Gto+X1RXtl89Gz0vnVa29roDpuT7+vg7SuXbPK4gMApxyRadasmd4Ir3PnztK/f3891TRo0CDZvHmzNGhgWSfgv9K+fXuZNm2aLiL+6KOPJCkpSY8IZWVlXfXPqERHbWlccsTGFv/WC9ib8CAfGdmxrvh4XrtHk0piLhaZ5IvfDhsdMgA4fq8lI509e1bq1Kmj2yCo5d7lHZFRyQy9luCoNh89IwM//E0nNb8/10OCfDyNDgkAHLfXknLmzBn57LPPZPfu3fpx06ZNZdSoURIcHCy2VKNGDb1/jdqM72pUwbEtio4Bo7SKrSENwwPkQEa2zN+aJn9vf/WpVQBwNRZPLa1atUrq1q0r7733nk5o1KHO69Wrp5+zpezsbDl48KBERhbvwQG4TGFwQvF2B6rZJACgEonM2LFj9eZ3ql5l1qxZ+jh06JDceeed+jlreuyxx2TlypVy+PBh+e2332TgwIHi7u4uw4YNs+rnAPZuYJtocXerpguDD2RcvUYMAFyNxYmMmtZRrQhUQlFCnU+YMOGaUz4VkZycrJOWxo0by+233y4hISGybt06CQsLs+rnAPYuPNBHbm5c/O/9dHYDBoCK18i0adNG18ao5KI0da1ly5ZiTd99951V3w9w9N2Al+zOkFmbU+Txno3Fw73CPV8BwLUSmW3btpnPH374YRk3bpwefbnhhhv0NTVK8sEHH8iUKVNsFyng4rrFhUuwv5ecyMqXlftOSPcmtYwOCQAcY/m1m5ubLjj8q5eq1xQWFoojLt8CHMG/f9wlU9ckSa/4CPn4rgSjwwEAx1h+rQp7ARjv9utjdCKzdM9xOZ1ToEdoAMCVlSuRUZvQATBeXESQNI+uLttTMuXrdUf0aqbKCg3wFh/PP4v3AcCRVGhDvNTUVN2FOiMjQzeMLE3V0ACwnaFtY3Qi8+biffqwRiKzdMKNUt2PHYMBuEAio3of/etf/xIvLy+9HFrVxZRQ5yQygG0NaB0t36w/Kkkncyr9XhcKi+Rkdr7M25oid3Woa5X4AMCuey2pvkX33XefPP3007oI2N5R7Atc3X9/PSQv/bRbWsRUl3kPdjY6HACw+Pvb4kwkNzdX7+LrCEkMgGsb2Dpad9felpwpe9PZMRiA47E4G1Fdp6dPn26baABUqZAAb70/jTI9kT5OAFxgakntE3PrrbfK+fPnpXnz5uLpWbZA8K233hJ7wtQScG1Ldh2Xf/4vUUL8vWTdM93Fkx2DATjbPjKlTZ48WRYtWmRuUXBpsS8Ax3JT4zC9ckkV/S7bkyE94yOMDgkAys3iRObNN9+UqVOnysiRIy39owDskOrZNKhNtHyy6pBuSEkiA8CRWDyG7O3tLZ06dbJNNAAMMTQhRv9cvjdD93ICAKdNZFTDyPfff9820QAwRKNagdIytoYUFplkzuYUo8MBANtNLW3YsEGWLVsm8+fPl/j4+MuKfWfNmmXpWwKwA7e3jZGtx87K9I3H5J9d6lHzBsA5E5kaNWrIoEGDbBMNAMPc1jJKd9fedzxbVuw7IY3CA6z+Gf5eHlKTRpcAjFx+7WhYfg2U37jvNsvcLak2/Yz/3t1WejStZdPPAOD4bLb8GoDzurdLffnt4Ck5d/6C1d+7yGSSC4UmmbomiUQGgNVYnMjUq3ftufNDhw5VNiYABmkWXV1+f7aHTd47+UyudHltuU6Ujp3OldhgP5t8DgDXYnEi88gjj5R5fOHCBdm8ebMsXLhQHn/8cWvGBsCJxNT0k44NQmTNgVMyY2OyjP/bdUaHBMAVExm1/PpKPvjgA0lMTLRGTACc1NCEWHMiM657I3FzY2UUgMqxWlOV3r17y8yZM631dgCcUK9mERLo4yEpZ8/LukOnjA4HgBOwWiIzY8YMCQ4OttbbAXBCPp7uepm3Mn1jstHhAHDFqaXWrVuXKfZVq7fT09PlxIkT8uGHH1o7PgBO2A7hm/VHZcGONJnUP16CfMpuqgkANk1kBgwYUOaxm5ubhIWFyU033SRxcXGWvh0AF9MqtoY0DA+QAxnZMn9rmvy9fW2jQwLgSonMxIkTbRMJAJegRnTVqMzkBXt0OwQSGQB2USMDAOU1sE20uLtVk81Hz8qBjCyjwwHgCiMyagrpr5rIqecvXrxojbgAOLHwQB+5uXGYLNmdIV+uPSL3dq1vdEgAKqGGn5cEeBvTLKDcnzp79uyrPrd27Vp57733pKioyFpxAXByQxJidSLzxdoj+gDguF4Z2NywaeJyJzL9+/e/7NrevXvlqaeekh9//FGGDx8u//73v60dHwAn1S0uXNrVDZatyWeNDgVAJbkbWKhSoXGg1NRUXfT7xRdfSM+ePWXLli3SrFkz60cHwGl5ebjJD/d1MDoMAA7OohxKtdJ+8sknpWHDhrJz505ZunSpHo0hiQEAAHY9IvPaa6/Jq6++KhEREfLtt99ecaoJAACgKlUzqa15y7lqydfXV3r06CHu7u5Xfd2sWbPEnpw7d06qV6+uR5OCgoKMDgcAAFjx+7vcIzJ33333Xy6/BgAAqErlTmSmTZtm20gAAAAsxM6+AADAYZHIAAAAh0UiAwAAHBaJDAAAcFgkMgAAwGGRyAAAAIdFIgMAABwWiQwAAHBYJDIAAMD5d/Z1VCWtpFTPBgAA4BhKvrf/qiWk0ycyWVlZ+mdsbKzRoQAAgAp8j6vmkZXufu2oioqKJDU1VQIDA63a9FJliio5OnbsGF21qwD3u+pwr6sO97rqcK8d716r9EQlMVFRUeLm5ua6IzLqHz4mJsZm76/+T+I/iqrD/a463Ouqw72uOtxrx7rX1xqJKUGxLwAAcFgkMgAAwGGRyFSQt7e3TJw4Uf+E7XG/qw73uupwr6sO99p577XTF/sCAADnxYgMAABwWCQyAADAYZHIAAAAh0UiAwAAHBaJTAV98MEHUrduXfHx8ZH27dvLhg0bjA7J4U2ePFmuv/56vQtzeHi4DBgwQPbu3VvmNXl5eTJ27FgJCQmRgIAAGTx4sBw/ftywmJ3FlClT9M7XjzzyiPka99p6UlJS5B//+Ie+l76+vtK8eXNJTEw0P6/WXLzwwgsSGRmpn+/Ro4fs37/f0JgdUWFhoTz//PNSr149fR8bNGgg//nPf8r06uFeV8yqVavktttu07vsqr8r5syZU+b58tzX06dPy/Dhw/UmeTVq1JDRo0dLdnZ2BSMq++Gw0HfffWfy8vIyTZ061bRz507Tvffea6pRo4bp+PHjRofm0Hr27Gn6/PPPTTt27DBt2bLF1KdPH1Pt2rVN2dnZ5tfcd999ptjYWNPSpUtNiYmJphtuuMHUsWNHQ+N2dBs2bDDVrVvX1KJFC9O4cePM17nX1nH69GlTnTp1TCNHjjStX7/edOjQIdOiRYtMBw4cML9mypQppurVq5vmzJlj2rp1q6lfv36mevXqmc6fP29o7I7m5ZdfNoWEhJjmz59vSkpKMk2fPt0UEBBgevfdd82v4V5XzM8//2x69tlnTbNmzVJZoWn27Nllni/Pfe3Vq5epZcuWpnXr1pl+/fVXU8OGDU3Dhg0zVRaJTAW0a9fONHbsWPPjwsJCU1RUlGny5MmGxuVsMjIy9H8wK1eu1I/Pnj1r8vT01H85ldi9e7d+zdq1aw2M1HFlZWWZGjVqZFq8eLHpxhtvNCcy3GvrefLJJ02dO3e+6vNFRUWmiIgI0+uvv26+pu6/t7e36dtvv62iKJ1D3759Tffcc0+Za4MGDTINHz5cn3OvrePSRKY893XXrl36z/3+++/m1yxYsMBUrVo1U0pKSqXiYWrJQgUFBbJx40Y9bFa6n5N6vHbtWkNjczaZmZn6Z3BwsP6p7vuFCxfK3Pu4uDipXbs2976C1NRR3759y9xThXttPfPmzZO2bdvK0KFD9ZRp69at5dNPPzU/n5SUJOnp6WXuteovo6asudeW6dixoyxdulT27dunH2/dulVWr14tvXv31o+517ZRnvuqfqrpJPXfQgn1evX9uX79+kp9vtM3jbS2kydP6nnYWrVqlbmuHu/Zs8ewuJyxa7mq1+jUqZM0a9ZMX1P/oXh5een/GC699+o5WOa7776TTZs2ye+//37Zc9xr6zl06JB89NFHMmHCBHnmmWf0/X744Yf1/R0xYoT5fl7p7xTutWWeeuop3XlZJd3u7u767+qXX35Z12Uo3GvbKM99VT9VIl+ah4eH/kW1sveeRAZ2O1KwY8cO/dsUrO/YsWMybtw4Wbx4sS5Yh22TcvVb6CuvvKIfqxEZ9e/2xx9/rBMZWM8PP/wgX3/9tXzzzTcSHx8vW7Zs0b8QqQJV7rXzYmrJQqGhoTrTv3T1hnocERFhWFzO5MEHH5T58+fL8uXLJSYmxnxd3V81tXf27Nkyr+feW05NHWVkZEibNm30b0XqWLlypbz33nv6XP0mxb22DrWKo2nTpmWuNWnSRI4eParPS+4nf6dU3uOPP65HZe688069Muyuu+6S8ePH6xWRCvfaNspzX9VP9XdOaRcvXtQrmSp770lkLKSGgxMSEvQ8bOnfuNTjDh06GBqbo1M1ZCqJmT17tixbtkwvoSxN3XdPT88y914tz1ZfCNx7y3Tv3l22b9+uf2MtOdSogRqCLznnXluHmh69dBsBVcNRp04dfa7+PVd/kZe+12p6RNUNcK8tk5ubq2suSlO/eKq/oxXutW2U576qn+oXI/VLVAn197z6/0bV0lRKpUqFXXj5tarGnjZtmq7EHjNmjF5+nZ6ebnRoDu3+++/Xy/dWrFhhSktLMx+5ublllgSrJdnLli3TS4I7dOigD1Re6VVLCvfaesvbPTw89NLg/fv3m77++muTn5+f6auvviqzdFX9HTJ37lzTtm3bTP3792dJcAWMGDHCFB0dbV5+rZYKh4aGmp544gnza7jXFV/huHnzZn2o1OGtt97S50eOHCn3fVXLr1u3bq23IVi9erVeMcnyawO9//77+i95tZ+MWo6t1sWjctR/HFc61N4yJdR/FA888ICpZs2a+stg4MCBOtmB9RMZ7rX1/Pjjj6ZmzZrpX4Di4uJMn3zySZnn1fLV559/3lSrVi39mu7du5v27t1rWLyO6ty5c/rfYfV3s4+Pj6l+/fp675P8/Hzza7jXFbN8+fIr/v2sksfy3tdTp07pxEXt7RMUFGQaNWqUTpAqq5r6n8qN6QAAABiDGhkAAOCwSGQAAIDDIpEBAAAOi0QGAAA4LBIZAADgsEhkAACAwyKRAQAADotEBgAAOCwSGQBOr1q1ajJnzhyjwwBgAyQyAGxq5MiROpG49OjVq5fRoQFwAh5GBwDA+amk5fPPPy9zzdvb27B4ADgPRmQA2JxKWiIiIsocNWvW1M+p0ZmPPvpIevfuLb6+vlK/fn2ZMWNGmT+/fft26datm34+JCRExowZI9nZ2WVeM3XqVImPj9efFRkZKQ8++GCZ50+ePCkDBw4UPz8/adSokcybN8/83JkzZ2T48OESFhamP0M9f2niBcA+kcgAMNzzzz8vgwcPlq1bt+qE4s4775Tdu3fr53JycqRnz5468fn9999l+vTpsmTJkjKJikqExo4dqxMclfSoJKVhw4ZlPmPSpEly++23y7Zt26RPnz76c06fPm3+/F27dsmCBQv056r3Cw0NreK7AKBCKt0/GwCuYcSIESZ3d3eTv79/mePll1/Wz6u/hu67774yf6Z9+/am+++/X59/8sknppo1a5qys7PNz//0008mNzc3U3p6un4cFRVlevbZZ68ag/qM5557zvxYvZe6tmDBAv34tttuM40aNcrK/+QAqgI1MgBs7uabb9ajHKUFBwebzzt06FDmOfV4y5Yt+lyNkLRs2VL8/f3Nz3fq1EmKiopk7969emoqNTVVunfvfs0YWrRoYT5X7xUUFCQZGRn68f33369HhDZt2iS33HKLDBgwQDp27FjJf2oAVYFEBoDNqcTh0qkea1E1LeXh6elZ5rFKgFQypKj6nCNHjsjPP/8sixcv1kmRmqp64403bBIzAOuhRgaA4datW3fZ4yZNmuhz9VPVzqhamRJr1qwRNzc3ady4sQQGBkrdunVl6dKllYpBFfqOGDFCvvrqK3nnnXfkk08+qdT7AagajMgAsLn8/HxJT08vc83Dw8NcUKsKeNu2bSudO3eWr7/+WjZs2CCfffaZfk4V5U6cOFEnGS+++KKcOHFCHnroIbnrrrukVq1a+jXq+n333Sfh4eF6dCUrK0snO+p15fHCCy9IQkKCXvWkYp0/f745kQJg30hkANjcwoUL9ZLo0tRoyp49e8wrir777jt54IEH9Ou+/fZbadq0qX5OLZdetGiRjBs3Tq6//nr9WNWzvPXWW+b3UklOXl6evP322/LYY4/pBGnIkCHljs/Ly0uefvppOXz4sJ6q6tKli44HgP2rpip+jQ4CgOtStSqzZ8/WBbYAYClqZAAAgMMikQEAAA6LGhkAhmJ2G0BlMCIDAAAcFokMAABwWCQyAADAYZHIAAAAh0UiAwAAHBaJDAAAcFgkMgAAwGGRyAAAAHFU/x+n78KL2KJWOAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.plot(active_dims)\n",
    "plt.xlabel('Epochs')\n",
    "plt.ylabel('Number of active dimensions for the f MLP output')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "396c9804",
   "metadata": {},
   "source": [
    "You can pass a `decay_rate` parameter into the `.set_schedule` method of a `Pruning_MLP`. The default is a cosine decay (as shown above). The other options are `exp` and `linear`. \n",
    "\n",
    "<img src=\"../_static/pruning_decay_schedules.png\">\n",
    "In the above image, the pruning finishes at epoch 75 and we prune 100 dimensions to 2 dimensions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1db5df8",
   "metadata": {},
   "source": [
    "## Interpret the MLP"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05200922",
   "metadata": {},
   "source": [
    "The `.distill` function only takes into account the active (non-masked) dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d637d376",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Running symbolic regression on pruned f...\n",
      "🛠️ Running SR on active dimension 0 (1/2)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "💡Best equation for active dimension 0: ((x0 * x0) * -9.177365) + (sin(x4) * (x0 + 3.7915835)).\n",
      "🛠️ Running SR on active dimension 21 (2/2)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "💡Best equation for active dimension 21: ((x0 * ((x4 * 0.02569695) + x0)) * 9.053295) + (sin(x4) * -4.9888153).\n",
      "❤️ SR on f_net active dimensions complete.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                                 x4   \n",
       " \t1         0.222776                                         -1.0853426   \n",
       " \t2         0.308274                                     x0 * -3.806873   \n",
       " \t3         0.058161                               inv(x0 + -1.1267803)   \n",
       " \t4         0.431742                              x0 * (x0 * -5.707732)   \n",
       " \t5         0.477709                       (x0 * -8.686444) + 3.2480958   \n",
       " \t6         0.153751                      x4 + ((x0 * -6.5375166) * x0)   \n",
       " \t7         0.544373                (x0 * -7.4334426) + (x4 * 4.868192)   \n",
       " \t8         0.376226                ((x0 * x0) * -7.367284) + (x4 + x4)   \n",
       " \t9         3.306437         ((x0 * x0) * -8.744584) + (x4 * 3.6595654)   \n",
       " \t10        0.125803   (x4 * (x0 + 3.344076)) + (x0 * (x0 * -9.105152))   \n",
       " \t11        0.138711  ((x4 * 3.3089917) + ((x4 + -10.204634) * (x0 *...   \n",
       " \t12  >>>>  1.101763  ((x0 * x0) * -9.177365) + (sin(x4) * (x0 + 3.7...   \n",
       " \t13        0.090894  ((x0 * ((x0 * -1.7644632) + -0.17303404)) + si...   \n",
       " \t14        0.000671  (((((x0 + 0.09639773) * -1.7700083) * x0) + si...   \n",
       " \t15        0.074827  ((x0 + 4.5193567) * (sin(x4) + (x0 * (x0 * -1....   \n",
       " \t16        0.008228  ((sin(x4) + ((x0 * x0) * -1.5951045)) * (x0 + ...   \n",
       " \t\n",
       " \t        loss  complexity  \n",
       " \t0   9.869409           1  \n",
       " \t1   7.898432           2  \n",
       " \t2   4.263609           4  \n",
       " \t3   4.022708           5  \n",
       " \t4   2.612253           6  \n",
       " \t5   1.620126           7  \n",
       " \t6   1.389234           8  \n",
       " \t7   0.806041           9  \n",
       " \t8   0.553305          10  \n",
       " \t9   0.020277          11  \n",
       " \t10  0.015766          13  \n",
       " \t11  0.011947          15  \n",
       " \t12  0.003970          16  \n",
       " \t13  0.003022          19  \n",
       " \t14  0.003016          22  \n",
       " \t15  0.002799          23  \n",
       " \t16  0.002730          26  \n",
       " ],\n",
       " 21: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                                 x0   \n",
       " \t1         0.095859                                            x0 + x0   \n",
       " \t2         0.123548                                     x0 * 3.4731839   \n",
       " \t3         0.192502                              (x0 * x0) * 5.3865266   \n",
       " \t4         0.694850                     (x0 + -0.41388267) * 9.1824665   \n",
       " \t5         0.378998                (x0 * 7.6599708) + (x4 * -5.620235)   \n",
       " \t6         1.981346         (x0 * (x0 * 9.021318)) + (x4 * -4.3798647)   \n",
       " \t7         0.792054   (x4 * (x4 + -5.188016)) + ((x0 * 9.143758) * x0)   \n",
       " \t8         0.034904      (sin(x4) * -4.90822) + (x0 * (x0 * 9.136547))   \n",
       " \t9         0.003309  ((x0 * 9.15831) * x0) + (((x4 * 1.1163542) + -...   \n",
       " \t10        0.056247  ((sin(x4) * -4.863487) + -0.03387193) + ((x0 *...   \n",
       " \t11  >>>>  0.093207  ((x0 * ((x4 * 0.02569695) + x0)) * 9.053295) +...   \n",
       " \t12        0.015435  ((x0 + 0.0090198405) * (x0 * 8.947147)) + (sin...   \n",
       " \t13        0.063585  (sin(x4) * -5.000108) + ((((x4 + (x4 + x2)) * ...   \n",
       " \t14        0.001942  (((x0 * 9.004585) * x0) + (x2 * 0.044007648)) ...   \n",
       " \t15        0.022210  (sin(x4) * ((x0 * 0.31133893) + -5.0260777)) +...   \n",
       " \t16        0.000561  ((((x2 * 0.009918883) + (x0 + 0.0050150244)) *...   \n",
       " \t\n",
       " \t        loss  complexity  \n",
       " \t0   7.511932           1  \n",
       " \t1   6.201404           3  \n",
       " \t2   5.480674           4  \n",
       " \t3   3.729311           6  \n",
       " \t4   1.861484           7  \n",
       " \t5   0.872300           9  \n",
       " \t6   0.016584          11  \n",
       " \t7   0.003402          13  \n",
       " \t8   0.003285          14  \n",
       " \t9   0.003264          16  \n",
       " \t10  0.003085          17  \n",
       " \t11  0.002560          19  \n",
       " \t12  0.002445          22  \n",
       " \t13  0.002294          23  \n",
       " \t14  0.002281          26  \n",
       " \t15  0.002230          27  \n",
       " \t16  0.002227          30  \n",
       " ]}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"\\nRunning symbolic regression on pruned f...\")\n",
    "\n",
    "sr_params = {'complexity_of_operators':  {\"sin\":3, \"exp\":3},\n",
    "             'complexity_of_constants': 2, \n",
    "             'constraints': {\"sin\": 3, \"exp\":3},\n",
    "             'parsimony': 0.01,\n",
    "             'verbosity': 0, \n",
    "             'niterations': 100}\n",
    "\n",
    "model.f_net.distill(torch.FloatTensor(X_train), \n",
    "                       sr_params=sr_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "718de0df",
   "metadata": {},
   "source": [
    "You can see that the outputs of the `f_net` NN are linear combinations of the f function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8a96625f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Symbolic expressions for output dimension 0:\n",
      "    complexity      loss                                           equation  \\\n",
      "0            1  9.869409                                                 x4   \n",
      "1            2  7.898432                                         -1.0853426   \n",
      "2            4  4.263609                                     x0 * -3.806873   \n",
      "3            5  4.022708                               inv(x0 + -1.1267803)   \n",
      "4            6  2.612253                              x0 * (x0 * -5.707732)   \n",
      "5            7  1.620126                       (x0 * -8.686444) + 3.2480958   \n",
      "6            8  1.389234                      x4 + ((x0 * -6.5375166) * x0)   \n",
      "7            9  0.806041                (x0 * -7.4334426) + (x4 * 4.868192)   \n",
      "8           10  0.553305                ((x0 * x0) * -7.367284) + (x4 + x4)   \n",
      "9           11  0.020277         ((x0 * x0) * -8.744584) + (x4 * 3.6595654)   \n",
      "10          13  0.015766   (x4 * (x0 + 3.344076)) + (x0 * (x0 * -9.105152))   \n",
      "11          15  0.011947  ((x4 * 3.3089917) + ((x4 + -10.204634) * (x0 *...   \n",
      "12          16  0.003970  ((x0 * x0) * -9.177365) + (sin(x4) * (x0 + 3.7...   \n",
      "13          19  0.003022  ((x0 * ((x0 * -1.7644632) + -0.17303404)) + si...   \n",
      "14          22  0.003016  (((((x0 + 0.09639773) * -1.7700083) * x0) + si...   \n",
      "15          23  0.002799  ((x0 + 4.5193567) * (sin(x4) + (x0 * (x0 * -1....   \n",
      "16          26  0.002730  ((sin(x4) + ((x0 * x0) * -1.5951045)) * (x0 + ...   \n",
      "\n",
      "       score                                       sympy_format  \\\n",
      "0   0.000000                                                 x4   \n",
      "1   0.222776                                  -1.08534260000000   \n",
      "2   0.308274                                     x0*(-3.806873)   \n",
      "3   0.058161                                 1/(x0 - 1.1267803)   \n",
      "4   0.431742                                  x0*x0*(-5.707732)   \n",
      "5   0.477709                         3.2480958 + x0*(-8.686444)   \n",
      "6   0.153751                            x0*(-6.5375166)*x0 + x4   \n",
      "7   0.544373                      x0*(-7.4334426) + x4*4.868192   \n",
      "8   0.376226                        x0*x0*(-7.367284) + x4 + x4   \n",
      "9   3.306437                   x0*x0*(-8.744584) + x4*3.6595654   \n",
      "10  0.125803             x0*x0*(-9.105152) + x4*(x0 + 3.344076)   \n",
      "11  0.138711         x0*x0*(x4 - 10.204634) + x0 + x4*3.3089917   \n",
      "12  1.101763       x0*x0*(-9.177365) + (x0 + 3.7915835)*sin(x4)   \n",
      "13  0.090894  (x0 + 3.811739)*(x0*(x0*(-1.7644632) - 0.17303...   \n",
      "14  0.000671  ((x0 + 0.09639773)*(-1.7700083)*x0 + sin(x4))*...   \n",
      "15  0.074827  (x0 + 4.5193567)*(x0*x0*(-1.5749025) + sin(x4)...   \n",
      "16  0.008228  (x0 + 4.476581)*(x0*x0*(-1.5951045) + sin(x4))...   \n",
      "\n",
      "                                        lambda_format  \n",
      "0                                 PySRFunction(X=>x4)  \n",
      "1                  PySRFunction(X=>-1.08534260000000)  \n",
      "2                     PySRFunction(X=>x0*(-3.806873))  \n",
      "3                 PySRFunction(X=>1/(x0 - 1.1267803))  \n",
      "4                  PySRFunction(X=>x0*x0*(-5.707732))  \n",
      "5         PySRFunction(X=>3.2480958 + x0*(-8.686444))  \n",
      "6            PySRFunction(X=>x0*(-6.5375166)*x0 + x4)  \n",
      "7      PySRFunction(X=>x0*(-7.4334426) + x4*4.868192)  \n",
      "8        PySRFunction(X=>x0*x0*(-7.367284) + x4 + x4)  \n",
      "9   PySRFunction(X=>x0*x0*(-8.744584) + x4*3.6595654)  \n",
      "10  PySRFunction(X=>x0*x0*(-9.105152) + x4*(x0 + 3...  \n",
      "11  PySRFunction(X=>x0*x0*(x4 - 10.204634) + x0 + ...  \n",
      "12  PySRFunction(X=>x0*x0*(-9.177365) + (x0 + 3.79...  \n",
      "13  PySRFunction(X=>(x0 + 3.811739)*(x0*(x0*(-1.76...  \n",
      "14  PySRFunction(X=>((x0 + 0.09639773)*(-1.7700083...  \n",
      "15  PySRFunction(X=>(x0 + 4.5193567)*(x0*x0*(-1.57...  \n",
      "16  PySRFunction(X=>(x0 + 4.476581)*(x0*x0*(-1.595...  \n",
      "🏆 Best: ((x0 * x0) * -9.177365) + (sin(x4) * (x0 + 3.7915835)) (loss: 3.969665e-03)\n"
     ]
    }
   ],
   "source": [
    "model.f_net.show_symbolic_expression(dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0952d48",
   "metadata": {},
   "source": [
    "We can even perform SR on the `g_net` to show that this layer is just a linear transformation of the inputs.\\\n",
    "Because `g_net` is an intermediate layer of the MLP, we need to pass in the `parent_model` (the whole model) to get the correct inputs to `g_net` for symbolic regression."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "f36f3941",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "💡Best equation for output 0 found to be (((x0 + 0.00012825789) * -0.104232065) + -0.0026094334) + (x21 * 0.17798738).\n",
      "❤️ SR on g_net complete.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t   pick         score                                           equation  \\\n",
       " \t0        0.000000e+00                                                x31   \n",
       " \t1        8.785424e-02                                         0.24939868   \n",
       " \t2        3.121439e+00                                    x21 * 0.2789512   \n",
       " \t3        5.638493e-01                    (x21 * 0.27616408) + 0.03388125   \n",
       " \t4        2.722068e+00           (x0 * -0.097962625) + (x21 * 0.18385741)   \n",
       " \t5        3.338238e-07     (x0 * -0.09796284) + ((x12 + 0.1838572) * x21)   \n",
       " \t6        1.949614e+01  ((x21 * 0.17798749) + (x0 * -0.10423195)) + -0...   \n",
       " \t7        2.760529e-01  (x21 * 0.17798734) + ((x26 + (x0 * -0.10423212...   \n",
       " \t8  >>>>  4.623406e-01  (((x0 + 0.00012825789) * -0.104232065) + -0.00...   \n",
       " \t\n",
       " \t           loss  complexity  \n",
       " \t0  7.395062e-01           1  \n",
       " \t1  6.773096e-01           2  \n",
       " \t2  1.316859e-03           4  \n",
       " \t3  2.426105e-04           7  \n",
       " \t4  1.048458e-06           9  \n",
       " \t5  1.048457e-06          11  \n",
       " \t6  3.576713e-15          12  \n",
       " \t7  2.059243e-15          14  \n",
       " \t8  1.296927e-15          15  \n",
       " ]}"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.g_net = SymbolicModel(model.g_net, block_name='g_net')\n",
    "model.g_net.distill(torch.FloatTensor(X_train), \n",
    "                     parent_model=model, # Pass in the parent_model because g_net is an intermediate layer\n",
    "                     sr_params = sr_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a5805a",
   "metadata": {},
   "source": [
    "The variables used in this NN are just the active dimensions of the `f_net` NN."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db788500",
   "metadata": {},
   "source": [
    "## Switch to Using the Equation Instead in the Forwards Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "abef72c8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Successfully switched f_net to symbolic equations for 2 active dimensions:\n",
      "   Dimension 0: ((x0 * x0) * -9.177365) + (sin(x4) * (x0 + 3.7915835))\n",
      "   Variables: ['x0', 'x4']\n",
      "   Dimension 21: ((x0 * ((x4 * 0.02569695) + x0)) * 9.053295) + (sin(x4) * -4.9888153)\n",
      "   Variables: ['x0', 'x4']\n",
      "🎯 Active dimensions [0, 21] now using symbolic equations.\n",
      "🔒 Inactive dimensions will output zeros.\n"
     ]
    }
   ],
   "source": [
    "model.f_net.switch_to_symbolic() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "49460cc6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Successfully switched g_net to symbolic equations for all 1 dimensions:\n",
      "   Dimension 0: (((x0 + 0.00012825789) * -0.104232065) + -0.0026094334) + (x21 * 0.17798738)\n",
      "   Variables: ['x0', 'x21']\n",
      "🎯 All 1 output dimensions now using symbolic equations.\n"
     ]
    }
   ],
   "source": [
    "model.g_net.switch_to_symbolic()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8e7cfd0",
   "metadata": {},
   "source": [
    "Now when running the forwards pass through the model, it uses the symbolic equation instead of the MLP. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e526154f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0244],\n",
       "        [ 0.1939],\n",
       "        [-0.4942],\n",
       "        ...,\n",
       "        [-0.7961],\n",
       "        [ 1.8698],\n",
       "        [ 1.2709]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "interpretable_outputs = model(torch.tensor(X_train, dtype=torch.float32))\n",
    "interpretable_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5835dd8e",
   "metadata": {},
   "source": [
    "## Switch to Using the MLP in the Forwards Pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "34291a48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Switched f_net back to block\n",
      "✅ Switched g_net back to block\n"
     ]
    }
   ],
   "source": [
    "model.f_net.switch_to_block()\n",
    "model.g_net.switch_to_block()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "79e43dc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0293],\n",
       "        [ 0.1937],\n",
       "        [-0.4943],\n",
       "        ...,\n",
       "        [-0.7949],\n",
       "        [ 1.8919],\n",
       "        [ 1.2544]])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    model_outputs = model(torch.tensor(X_train, dtype=torch.float32))\n",
    "model_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "0e258b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Clean up \n",
    "import os\n",
    "import shutil\n",
    "if os.path.exists('SR_output'):\n",
    "    shutil.rmtree('SR_output')\n",
    "os.remove('model_weights.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "symtorch_venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
