{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d86b3359",
   "metadata": {},
   "source": [
    "# LLM Learned Operations"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2026a62f",
   "metadata": {},
   "source": [
    "Small LLMs are notoriously bad with maths (which is funny because one forward pass through a transfomer model has many many matrix multiplications and additions!). We can use our `SymbolicModel` to probe what functions a small LLM is actually using when carrying out mathematical operations.\n",
    "\n",
    "In this demo, we use the small model Llama-3.2-1B-Instruct. Depending on your laptop, you should be able to run this whole notebook locally!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad63fb54",
   "metadata": {},
   "source": [
    "## Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "faf78eb4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Detected IPython. Loading juliacall extension. See https://juliapy.github.io/PythonCall.jl/stable/compat/#IPython\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from symtorch import SymbolicModel\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "import torch\n",
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dd88bbe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is the model we are going to use\n",
    "model_name = \"meta-llama/Llama-3.2-1B-Instruct\"\n",
    "\n",
    "# Load the tokenizer and the model\n",
    "tok = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "model.generation_config.pad_token_id = tok.eos_token_id\n",
    "\n",
    "torch.manual_seed(290402)\n",
    "# For our experiment, we want a deterministic model\n",
    "torch.use_deterministic_algorithms(True)\n",
    "\n",
    "# Function which calls our LLM\n",
    "def llm_call(prompt: str, max_tokens = 250) -> str:\n",
    "    inputs = tok(prompt, return_tensors=\"pt\")\n",
    "\n",
    "    out = model.generate(\n",
    "        **inputs,\n",
    "        max_new_tokens=max_tokens,\n",
    "        do_sample=False,          # greedy\n",
    "    )\n",
    "    new_tokens = out[0][inputs['input_ids'].shape[1]:]\n",
    "    return tok.decode(new_tokens, skip_special_tokens=True).strip()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1f25bb4",
   "metadata": {},
   "source": [
    "Let's try out our LLM to see how it performs at basic addition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c82a0253",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
     ]
    }
   ],
   "source": [
    "output = llm_call(\"Return only the numeric answer in the format $boxed$. What is 12+7=?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e228a9ca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      ".\n",
      "\n",
      "## Step 1: We need to add 12 and 7 together.\n",
      "## Step 2: The result of the addition is 19.\n",
      "## Step 3: We need to put the result in the format $boxed$.\n",
      "## Step 4: The final answer is $\\boxed{19}$.\n",
      "\n",
      "The final answer is: $\\boxed{19}$\n"
     ]
    }
   ],
   "source": [
    "print(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e279434e",
   "metadata": {},
   "source": [
    "For smaller numbers it can perform reasonably well. Let's see it's behvaiour for larger (3 digit) numbers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "17f262b1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "## Step 1: Add the two numbers together\n",
      "First, we need to add 972 and 373 together.\n",
      "\n",
      "## Step 2: Calculate the sum\n",
      "972 + 373 = 1445\n",
      "\n",
      "## Step 3: Format the answer\n",
      "The answer should be in the format $boxed{1445}$.\n",
      "\n",
      "The final answer is: $\\boxed{1445}$\n",
      "True answer =  1345\n"
     ]
    }
   ],
   "source": [
    "output = llm_call(\"Return only the numeric answer in the format $boxed$. What is 972+373=?\")\n",
    "print(output)\n",
    "\n",
    "print(\"True answer = \", 972+373)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb5db808",
   "metadata": {},
   "source": [
    "No longer performs that great! "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9899703e",
   "metadata": {},
   "source": [
    "We can use `SymbolicModel` to approximate the functions that the LLM is using when performing maths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "588ed039",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get out the number outputted by llm as float\n",
    "def extract_boxed_number(text: str) -> float:\n",
    "    def parse_number(s: str) -> float:\n",
    "        return float(s.replace(',', ''))\n",
    "    \n",
    "    # Try $\\boxed{...}$ format first\n",
    "    match = re.search(r'\\$\\\\boxed\\{([^}]+)\\}\\$', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Try $boxed{...}$ format (without backslash)\n",
    "    match = re.search(r'\\$boxed\\{([^}]+)\\}\\$', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Try \\boxed{...} without dollar signs\n",
    "    match = re.search(r'\\\\boxed\\{([^}]+)\\}', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Try boxed{...} without anything\n",
    "    match = re.search(r'boxed\\{([^}]+)\\}', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Try $number$ format (without boxed)\n",
    "    match = re.search(r'\\$([0-9,.]+)\\$', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Fallback: try to find any number after an equals sign\n",
    "    match = re.search(r'=\\s*([\\d,]+)', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    # Fallback: try to find \"Answer: number\"\n",
    "    match = re.search(r'Answer:\\s*([\\d,.]+)', text)\n",
    "    if match:\n",
    "        return parse_number(match.group(1))\n",
    "    raise ValueError(f\"No boxed number found in: {text}\")\n",
    "\n",
    "# Function to create a dataset of random number pairs \n",
    "def random_number_pairs(N = 100, maximum = 999):\n",
    "    return np.random.randint(0, maximum, size=(N, 2))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d48ba53b",
   "metadata": {},
   "source": [
    "## Addition"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c198206a",
   "metadata": {},
   "source": [
    "`SymbolicModel` is model-agnostic. You just need to pass a function that is of the form `f(inputs) = outputs`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6d1c5fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a function that the SymbolicModel expects \n",
    "def llm_addition(X):\n",
    "    outputs = []\n",
    "    # X is of shape (N,2)\n",
    "    for n in range(X.shape[0]):\n",
    "        a = X[n,0]\n",
    "        b = X[n,1]\n",
    "        output = llm_call(f\"Return only the numeric answer in the format $boxed$. What is {int(a)}+{int(b)}=?\")\n",
    "        output = extract_boxed_number(output)\n",
    "        outputs.append(output)\n",
    "    return np.array(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71748d66",
   "metadata": {},
   "source": [
    "Create a random dataset of numbers to add."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "63902f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(290402)\n",
    "\n",
    "X = random_number_pairs(50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4669b1f",
   "metadata": {},
   "source": [
    "Example of the numbers in our dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f995c198",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[451  41]\n",
      " [871 582]\n",
      " [237 193]\n",
      " [661 992]\n",
      " [417 724]]\n"
     ]
    }
   ],
   "source": [
    "print(X[:5,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ec7b911c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialise our model\n",
    "symbolic_model_addition = SymbolicModel(llm_addition, block_name = \"llm_addition_func\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "644e6d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "sr_params = {'constraints': {'sin':1, 'exp':1}, 'niterations' : 1000}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9c021eba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n",
      "Compiling Julia backend...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ Info: Started!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Expressions evaluated per second: 2.380e+06\n",
      "Progress: 10707 / 31000 total iterations (34.539%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.522e+05  0.000e+00  y = 1002.9\n",
      "3           8.858e+03  1.414e+00  y = x₁ + x₀\n",
      "5           8.173e+03  4.001e-02  y = (x₁ * 0.95782) + x₀\n",
      "7           7.936e+03  1.442e-02  y = ((x₁ * 0.89627) + x₀) * 1.0399\n",
      "9           7.932e+03  1.577e-05  y = (((x₁ * 0.89719) + -5.5603) + x₀) * 1.0447\n",
      "10          1.660e+03  1.564e+00  y = (x₀ + inv((x₀ * -0.005024) + 0.34498)) + x₁\n",
      "12          1.510e+03  4.739e-02  y = x₀ + ((x₁ * 0.98066) + inv((x₀ * -0.003442) + 0.23577)...\n",
      "                                      )\n",
      "14          1.465e+03  1.486e-02  y = ((x₁ * 0.97429) + (inv((x₀ * -0.0031489) + 0.21556) + ...\n",
      "                                      7.7701)) + x₀\n",
      "16          1.399e+03  2.322e-02  y = ((x₁ + (x₁ * (x₁ * -3.3621e-05))) + inv((x₀ * -0.00344...\n",
      "                                      21) + 0.23577)) + x₀\n",
      "18          1.304e+03  3.496e-02  y = ((inv((x₀ * -0.0050244) + 0.34497) + x₀) + (((x₁ * x₁)...\n",
      "                                       * -4.3997e-08) * x₁)) + x₁\n",
      "20          1.159e+03  5.875e-02  y = (((x₀ * (inv(sin(x₀)) * -0.0050244)) + x₁) + inv(0.344...\n",
      "                                      98 + (-0.0050244 * x₀))) + x₀\n",
      "22          1.109e+03  2.206e-02  y = (x₁ + -9.1433) + (x₀ + (((inv(sin(x₀)) * x₀) * -0.0050...\n",
      "                                      244) + inv((x₀ * -0.0050244) + 0.34498)))\n",
      "24          9.103e+02  9.873e-02  y = (inv((x₀ * -0.0050242) + 0.34498) + (x₁ * (((x₀ * x₀) ...\n",
      "                                      * (inv(sin(x₀)) * -9.4883e-09)) + 0.98734))) + x₀\n",
      "26          9.075e+02  1.519e-03  y = inv(0.34498 + (-0.0050244 * x₀)) + ((x₀ + (x₁ * 0.9873...\n",
      "                                      4)) + ((x₀ * (x₀ * inv(sin(x₀)))) * (x₀ * -9.7307e-09)))\n",
      "28          8.602e+02  2.669e-02  y = ((x₁ * (((inv(sin(x₀)) * x₀) + (x₁ + x₁)) * (-9.0981e-...\n",
      "                                      09 * x₁))) + x₁) + (inv((x₀ * -0.0050244) + 0.34498) + x₀)\n",
      "30          8.591e+02  6.374e-04  y = ((x₁ * (((inv(sin(x₀)) * x₀) + (x₁ + x₁)) * (x₁ * -9.4...\n",
      "                                      813e-09))) + x₁) + ((x₀ + inv((x₀ * -0.0050244) + 0.34498)...\n",
      "                                      ) + 0.16443)\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.450e+06\n",
      "Progress: 20702 / 31000 total iterations (66.781%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.522e+05  0.000e+00  y = 1002.9\n",
      "3           8.858e+03  1.414e+00  y = x₁ + x₀\n",
      "5           8.173e+03  4.001e-02  y = x₀ + (x₁ * 0.95783)\n",
      "7           7.936e+03  1.442e-02  y = ((x₁ * 0.89627) + x₀) * 1.0399\n",
      "9           7.932e+03  1.577e-05  y = (((x₁ * 0.89719) + -5.5603) + x₀) * 1.0447\n",
      "10          1.657e+03  1.565e+00  y = x₁ + (inv((x₀ * -0.0071333) + 0.49053) + x₀)\n",
      "12          1.494e+03  5.176e-02  y = x₀ + ((x₁ * 0.97876) + inv((x₀ * -0.0071331) + 0.49048...\n",
      "                                      ))\n",
      "14          1.451e+03  1.442e-02  y = ((inv((x₀ * -0.005023) + 0.34489) + (x₁ * 0.97293)) + ...\n",
      "                                      x₀) + 6.6198\n",
      "16          1.386e+03  2.288e-02  y = (((x₁ * (x₁ * -3.3431e-05)) + x₀) + x₁) + inv((x₀ * -0...\n",
      "                                      .0071332) + 0.49048)\n",
      "18          1.301e+03  3.163e-02  y = x₀ + (inv((x₀ * -0.0071332) + 0.49048) + (((x₁ * x₁) *...\n",
      "                                       (x₁ * -4.3968e-08)) + x₁))\n",
      "20          1.103e+03  8.245e-02  y = ((x₀ * -0.0071334) * inv(sin(x₀))) + (x₀ + (inv(0.4905...\n",
      "                                      4 + (x₀ * -0.0071334)) + x₁))\n",
      "22          9.207e+02  9.041e-02  y = (inv((x₀ * -0.0071331) + 0.49048) + (x₁ * (((x₀ * -8.2...\n",
      "                                      236e-06) * inv(sin(x₀))) + 0.97876))) + x₀\n",
      "24          8.808e+02  2.209e-02  y = ((x₀ + (((inv(sin(x₀)) * (x₀ * -8.2191e-06)) + 0.97374...\n",
      "                                      ) * x₁)) + 8.4552) + inv((x₀ * -0.0050228) + 0.34489)\n",
      "26          8.502e+02  1.765e-02  y = (((x₁ * ((-2.6855e-05 + ((x₀ * inv(sin(x₀))) * -8.528e...\n",
      "                                      -09)) * x₁)) + x₀) + inv((-0.0071331 * x₀) + 0.49049)) + x...\n",
      "                                      ₁\n",
      "28          8.013e+02  2.961e-02  y = (x₁ + inv((x₀ * -0.007133) + 0.49047)) + (x₀ + ((x₁ * ...\n",
      "                                      x₁) * ((inv(sin(x₀)) * (x₀ * -7.9752e-09)) + (x₁ * -3.5858...\n",
      "                                      e-08))))\n",
      "30          7.531e+02  3.102e-02  y = (((x₁ * ((x₁ * -5.1365e-08) + ((inv(sin(x₀)) * x₀) * -...\n",
      "                                      8.0908e-09))) * x₁) + ((inv((x₀ * -0.0050244) + 0.34498) +...\n",
      "                                       x₀) + x₁)) * 1.0078\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.500e+06\n",
      "Progress: 30305 / 31000 total iterations (97.758%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.522e+05  0.000e+00  y = 1002.9\n",
      "3           8.858e+03  1.414e+00  y = x₁ + x₀\n",
      "5           8.173e+03  4.001e-02  y = x₀ + (x₁ * 0.95783)\n",
      "7           7.936e+03  1.442e-02  y = (x₁ + (x₀ * 1.1157)) * 0.93206\n",
      "9           7.932e+03  1.568e-05  y = (((x₁ * 0.89719) + -5.5603) + x₀) * 1.0447\n",
      "10          1.657e+03  1.566e+00  y = x₀ + (x₁ + inv((x₀ * -0.012401) + 0.85401))\n",
      "12          1.494e+03  5.156e-02  y = (inv((x₀ * -0.0071331) + 0.49048) + x₀) + (x₁ * 0.9791...\n",
      "                                      9)\n",
      "14          1.436e+03  1.990e-02  y = ((x₀ + (x₁ * 0.96111)) + 9.9769) + inv((x₀ * -0.005023...\n",
      "                                      ) + 0.34487)\n",
      "16          1.386e+03  1.738e-02  y = (x₀ + (inv((x₀ * -0.0071332) + 0.49048) + x₁)) + (x₁ *...\n",
      "                                       (x₁ * -3.3414e-05))\n",
      "18          1.301e+03  3.163e-02  y = ((x₁ + inv((x₀ * -0.0071332) + 0.49048)) + x₀) + ((x₁ ...\n",
      "                                      * (x₁ * x₁)) * -4.4005e-08)\n",
      "20          1.103e+03  8.245e-02  y = ((x₀ * -0.0071334) * inv(sin(x₀))) + (x₀ + (inv(0.4905...\n",
      "                                      4 + (x₀ * -0.0071334)) + x₁))\n",
      "22          9.207e+02  9.041e-02  y = ((((x₀ * (inv(sin(x₀)) * -8.2236e-06)) + 0.97876) * x₁...\n",
      "                                      ) + inv((x₀ * -0.0071331) + 0.49048)) + x₀\n",
      "24          8.767e+02  2.444e-02  y = (x₁ * ((x₀ * (inv(sin(x₀)) * -8.2038e-06)) + 0.97192))...\n",
      "                                       + (inv((x₀ * -0.0050228) + 0.34489) + (x₀ + 9.3771))\n",
      "26          8.465e+02  1.751e-02  y = (x₁ + (((x₁ * -2.7782e-05) + (inv(sin(x₀)) * (x₀ * -8....\n",
      "                                      0432e-06))) * x₁)) + (inv((x₀ * -0.0071328) + 0.49047) + x...\n",
      "                                      ₀)\n",
      "28          7.956e+02  3.099e-02  y = ((x₁ + inv((x₀ * -0.007133) + 0.49047)) + (((inv(sin(x...\n",
      "                                      ₀)) * (x₀ * -7.8172e-06)) + ((x₁ * -3.6475e-08) * x₁)) * x...\n",
      "                                      ₁)) + x₀\n",
      "30          7.516e+02  2.840e-02  y = (((inv((x₀ * -0.0050244) + 0.34498) + x₀) + x₁) + ((x₁...\n",
      "                                       * x₁) * ((x₁ * -5.2133e-08) + ((x₀ * -8.1475e-09) * inv(s...\n",
      "                                      in(x₀)))))) * 1.0082\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.522e+05  0.000e+00  y = 1002.9\n",
      "3           8.858e+03  1.414e+00  y = x₁ + x₀\n",
      "5           8.173e+03  4.001e-02  y = x₀ + (x₁ * 0.95783)\n",
      "7           7.936e+03  1.442e-02  y = (x₁ + (x₀ * 1.1157)) * 0.93206\n",
      "9           7.932e+03  1.568e-05  y = (((x₁ * 0.89719) + -5.5603) + x₀) * 1.0447\n",
      "10          1.657e+03  1.566e+00  y = x₀ + (x₁ + inv((x₀ * -0.012401) + 0.85401))\n",
      "12          1.494e+03  5.156e-02  y = (inv((x₀ * -0.0071331) + 0.49048) + x₀) + (x₁ * 0.9791...\n",
      "                                      9)\n",
      "14          1.436e+03  1.990e-02  y = ((x₀ + (x₁ * 0.96111)) + 9.9769) + inv((x₀ * -0.005023...\n",
      "                                      ) + 0.34487)\n",
      "16          1.386e+03  1.738e-02  y = (x₀ + (inv((x₀ * -0.0071332) + 0.49048) + x₁)) + (x₁ *...\n",
      "                                       (x₁ * -3.3414e-05))\n",
      "18          1.301e+03  3.163e-02  y = ((x₁ + inv((x₀ * -0.0071332) + 0.49048)) + x₀) + ((x₁ ...\n",
      "                                      * (x₁ * x₁)) * -4.4005e-08)\n",
      "20          1.103e+03  8.245e-02  y = ((x₀ * -0.0071334) * inv(sin(x₀))) + (x₀ + (inv(0.4905...\n",
      "                                      4 + (x₀ * -0.0071334)) + x₁))\n",
      "22          9.162e+02  9.283e-02  y = ((((inv(sin(x₀)) * (x₀ * -8.281e-06)) + 0.98164) * x₁)...\n",
      "                                       + inv((x₀ * -0.0071331) + 0.49048)) + x₀\n",
      "24          8.767e+02  2.202e-02  y = (x₁ * ((x₀ * (inv(sin(x₀)) * -8.2038e-06)) + 0.97192))...\n",
      "                                       + (inv((x₀ * -0.0050228) + 0.34489) + (x₀ + 9.3771))\n",
      "26          8.465e+02  1.751e-02  y = (x₁ + (((x₁ * -2.7782e-05) + (inv(sin(x₀)) * (x₀ * -8....\n",
      "                                      0432e-06))) * x₁)) + (inv((x₀ * -0.0071328) + 0.49047) + x...\n",
      "                                      ₀)\n",
      "28          7.956e+02  3.099e-02  y = ((x₁ + inv((x₀ * -0.007133) + 0.49047)) + (((inv(sin(x...\n",
      "                                      ₀)) * (x₀ * -7.8172e-06)) + ((x₁ * -3.6475e-08) * x₁)) * x...\n",
      "                                      ₁)) + x₀\n",
      "30          7.516e+02  2.840e-02  y = (((inv((x₀ * -0.0050244) + 0.34498) + x₀) + x₁) + ((x₁...\n",
      "                                       * x₁) * ((x₁ * -5.2133e-08) + ((x₀ * -8.1475e-09) * inv(s...\n",
      "                                      in(x₀)))))) * 1.0082\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "💡Best equation for output 0 found to be ((((inv(sin(x0)) * (x0 * -8.280994e-6)) + 0.9816416) * x1) + inv((x0 * -0.007133129) + 0.49048254)) + x0.\n",
      "❤️ SR on llm_addition_func complete.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ Info: Final population:\n",
      "[ Info: Results saved to:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                          1002.8534   \n",
       " \t1         1.422072                                            x1 + x0   \n",
       " \t2         0.040272                             x0 + (x1 * 0.95783037)   \n",
       " \t3         0.014673                (x1 + (x0 * 1.1157405)) * 0.9320593   \n",
       " \t4         0.000260  (((x1 * 0.89719146) + -5.5602617) + x0) * 1.04...   \n",
       " \t5         1.566168   x0 + (x1 + inv((x0 * -0.012401185) + 0.8540098))   \n",
       " \t6         0.051629  (inv((x0 * -0.0071331137) + 0.49047625) + x0) ...   \n",
       " \t7         0.019960  ((x0 + (x1 * 0.9611138)) + 9.976891) + inv((x0...   \n",
       " \t8         0.017441  (x0 + (inv((x0 * -0.007133195) + 0.49047813) +...   \n",
       " \t9         0.031700  ((x1 + inv((x0 * -0.007133231) + 0.49048457)) ...   \n",
       " \t10        0.082518  ((x0 * -0.007133387) * inv(sin(x0))) + (x0 + (...   \n",
       " \t11  >>>>  0.092871  ((((inv(sin(x0)) * (x0 * -8.280994e-6)) + 0.98...   \n",
       " \t12        0.022050  (x1 * ((x0 * (inv(sin(x0)) * -8.203762e-6)) + ...   \n",
       " \t13        0.017543  (x1 + (((x1 * -2.778236e-5) + (inv(sin(x0)) * ...   \n",
       " \t14        0.031019  ((x1 + inv((x0 * -0.0071330313) + 0.49047446))...   \n",
       " \t15        0.028436  (((inv((x0 * -0.0050244) + 0.34497637) + x0) +...   \n",
       " \t\n",
       " \t            loss  complexity  \n",
       " \t0   152241.12000           1  \n",
       " \t1     8858.00000           3  \n",
       " \t2     8172.52050           5  \n",
       " \t3     7936.18160           7  \n",
       " \t4     7932.05760           9  \n",
       " \t5     1656.56140          10  \n",
       " \t6     1494.04370          12  \n",
       " \t7     1435.57570          14  \n",
       " \t8     1386.36360          16  \n",
       " \t9     1301.19750          18  \n",
       " \t10    1103.23750          20  \n",
       " \t11     916.22595          22  \n",
       " \t12     876.69750          24  \n",
       " \t13     846.47180          26  \n",
       " \t14     795.55450          28  \n",
       " \t15     751.57166          30  \n",
       " ]}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  - SR_output/llm_addition_func/dim0_1764345478/hall_of_fame.csv\n"
     ]
    }
   ],
   "source": [
    "#Perform SR on our model\n",
    "symbolic_model_addition.distill(X, sr_params= sr_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3742e12b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Dimension 0 - Complexity 3:\n",
      "   x1 + x0 (loss: 8.858000e+03)\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_addition.show_symbolic_expression(complexity=[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "4cf4d505",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Standard symbolic expressions for output dimension 0:\n",
      "    complexity          loss  \\\n",
      "0            1  152241.12000   \n",
      "1            3    8858.00000   \n",
      "2            5    8172.51950   \n",
      "3            7    7936.18260   \n",
      "4            8    5539.63570   \n",
      "5           10    1634.10170   \n",
      "6           11    1587.55290   \n",
      "7           12    1507.41760   \n",
      "8           13    1477.81480   \n",
      "9           14    1379.15530   \n",
      "10          16    1229.54610   \n",
      "11          18    1178.01400   \n",
      "12          20    1151.30570   \n",
      "13          21    1131.40780   \n",
      "14          22    1008.77200   \n",
      "15          24     863.80830   \n",
      "16          26     795.50726   \n",
      "17          27     783.65510   \n",
      "18          28     712.91790   \n",
      "19          30     704.59010   \n",
      "\n",
      "                                             equation     score  \\\n",
      "0                                           1002.8534  0.000000   \n",
      "1                                             x1 + x0  1.422072   \n",
      "2                              (x1 * 0.95781434) + x0  0.040272   \n",
      "3                 ((x1 * 0.8962664) + x0) * 1.0399358  0.014672   \n",
      "4                     inv(x0 + -69.00627) + (x0 + x1)  0.359504   \n",
      "5       x0 + (x1 * (inv(x0 + -70.18642) + 0.9776793))  0.610418   \n",
      "6   (x1 * inv(inv(x0 + -68.841896) + 1.0215467)) + x0  0.028899   \n",
      "7   x0 + (x1 * (inv(-325.42346 + (4.6990423 * x0))...  0.051796   \n",
      "8   ((inv(inv(x0 + -68.874886) + 1.0532107) * x1) ...  0.019833   \n",
      "9   x0 + (x1 * (((x1 * -0.00013732273) + 1.0851918...  0.069093   \n",
      "10  x0 + (x1 * (((x1 * -0.00013732273) + 1.0851918...  0.057413   \n",
      "11  (x1 * ((((x1 * -1.243298e-7) * x1) + inv(-325....  0.021408   \n",
      "12  (x1 * ((inv(-325.42786 + (x0 * 4.6990423)) + 1...  0.011467   \n",
      "13  (((((x1 * x1) * -1.15719295e-7) + 1.0469434) +...  0.017434   \n",
      "14  (x1 * ((inv(x0 + -70.16485) + 1.070555) + ((in...  0.114729   \n",
      "15  (x1 * (inv((x0 * 4.6989927) + -325.42795) + ((...  0.077569   \n",
      "16  (x1 * ((inv((x0 * 4.699038) + -325.42786) + ((...  0.041185   \n",
      "17  (x1 * ((inv(x0) + 1.0708052) + (inv((4.699 * x...  0.015011   \n",
      "18  x0 + (x1 * ((inv(-325.42786 + (x0 * 4.699038))...  0.094603   \n",
      "19  (((((x1 + x0) * -8.387242e-8) * (x1 + inv(0.80...  0.005875   \n",
      "\n",
      "                                         sympy_format  \\\n",
      "0                                    1002.85340000000   \n",
      "1                                             x0 + x1   \n",
      "2                                  x0 + x1*0.95781434   \n",
      "3                       (x0 + x1*0.8962664)*1.0399358   \n",
      "4                         x0 + x1 + 1/(x0 - 69.00627)   \n",
      "5             x0 + x1*(0.9776793 + 1/(x0 - 70.18642))   \n",
      "6            x0 + x1/(1.0215467 + 1/(x0 - 68.841896))   \n",
      "7   x0 + x1*(0.9747181 + 1/(4.6990423*x0 - 325.423...   \n",
      "8   x0 + x1/(1.0532107 + 1/(x0 - 68.874886)) + 20....   \n",
      "9   x0 + x1*(x1*(-0.00013732273) + 1.0851918 + 1/(...   \n",
      "10  x0 + x1*(x1*(-0.00013732273) + 1.0851918 + 1/(...   \n",
      "11  x0 + x1*(x1*(-1.243298e-7)*x1 + 1.0599524 + 1/...   \n",
      "12  x0 + x1*(x1*x1*(-1.3013685e-7) + 1.0772082 + 1...   \n",
      "13  x0 + x1*(x1*x1*(-1.15719295e-7) + 1.0469434 + ...   \n",
      "14  x0 + x1*((x1 + 1/(sin(x1) + 0.8001926))*(-0.00...   \n",
      "15  x0 + x1*((x1 + 1/(sin(x1) + 0.80024636))*(-0.0...   \n",
      "16  x0 + x1*((x1 + 1/(sin(x1) + 0.80029213))*x1*(-...   \n",
      "17  x0 + x1*((x1 + 1/(sin(x1) + 0.8002465))*(-0.00...   \n",
      "18  x0 + x1*((x1 + 1/(sin(x1) + 0.80029213))*(x0 +...   \n",
      "19  x0 + x1*((x0 + x1)*(-8.387242e-8)*(x1 + 1/(sin...   \n",
      "\n",
      "                                        lambda_format  \n",
      "0                   PySRFunction(X=>1002.85340000000)  \n",
      "1                            PySRFunction(X=>x0 + x1)  \n",
      "2                 PySRFunction(X=>x0 + x1*0.95781434)  \n",
      "3      PySRFunction(X=>(x0 + x1*0.8962664)*1.0399358)  \n",
      "4        PySRFunction(X=>x0 + x1 + 1/(x0 - 69.00627))  \n",
      "5   PySRFunction(X=>x0 + x1*(0.9776793 + 1/(x0 - 7...  \n",
      "6   PySRFunction(X=>x0 + x1/(1.0215467 + 1/(x0 - 6...  \n",
      "7   PySRFunction(X=>x0 + x1*(0.9747181 + 1/(4.6990...  \n",
      "8   PySRFunction(X=>x0 + x1/(1.0532107 + 1/(x0 - 6...  \n",
      "9   PySRFunction(X=>x0 + x1*(x1*(-0.00013732273) +...  \n",
      "10  PySRFunction(X=>x0 + x1*(x1*(-0.00013732273) +...  \n",
      "11  PySRFunction(X=>x0 + x1*(x1*(-1.243298e-7)*x1 ...  \n",
      "12  PySRFunction(X=>x0 + x1*(x1*x1*(-1.3013685e-7)...  \n",
      "13  PySRFunction(X=>x0 + x1*(x1*x1*(-1.15719295e-7...  \n",
      "14  PySRFunction(X=>x0 + x1*((x1 + 1/(sin(x1) + 0....  \n",
      "15  PySRFunction(X=>x0 + x1*((x1 + 1/(sin(x1) + 0....  \n",
      "16  PySRFunction(X=>x0 + x1*((x1 + 1/(sin(x1) + 0....  \n",
      "17  PySRFunction(X=>x0 + x1*((x1 + 1/(sin(x1) + 0....  \n",
      "18  PySRFunction(X=>x0 + x1*((x1 + 1/(sin(x1) + 0....  \n",
      "19  PySRFunction(X=>x0 + x1*((x0 + x1)*(-8.387242e...  \n",
      "🏆 Best: (x1 * ((inv(x0 + -70.16485) + 1.070555) + ((inv(sin(x1) + 0.8001926) + x1) * -0.0001287453))) + x0 (loss: 1.008772e+03)\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_addition.show_symbolic_expression()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aa03a2b",
   "metadata": {},
   "source": [
    "`symbolic_model_addition` contains a list of equations. The more complex equations fit the inputs $\\rightarrow$ outputs better, but may overfit. The 'best equation' is the one that balances complexity and accuracy the most (largest gain in accuracy per increase in complexity)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "370ae0ba",
   "metadata": {},
   "source": [
    "Let's see how the LLM performs other tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8389a0fa",
   "metadata": {},
   "source": [
    "## Multiplication"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0b16fd23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def llm_multiplication(X):\n",
    "    outputs = []\n",
    "    # X is of shape (N,2)\n",
    "    for n in range(X.shape[0]):\n",
    "        a = X[n,0]\n",
    "        b = X[n,1]\n",
    "        output = llm_call(f\"Return only the numeric answer in the format $boxed$. What is {int(a)} * {int(b)}=?\")\n",
    "        output = extract_boxed_number(output)\n",
    "        outputs.append(output)\n",
    "    return np.array(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "9cb3c36f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No name specified for this block. Label is block_17429733968.\n"
     ]
    }
   ],
   "source": [
    "# Initialise our model\n",
    "symbolic_model_multiplication = SymbolicModel(llm_multiplication)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "81107959",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n",
      "[ Info: Started!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Expressions evaluated per second: 2.590e+06\n",
      "Progress: 10862 / 31000 total iterations (35.039%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           4.527e+10  0.000e+00  y = 2.2987e+05\n",
      "3           1.676e+10  -0.000e+00  y = x₁ * x₀\n",
      "5           1.429e+10  -0.000e+00  y = (x₁ * 0.85344) * x₀\n",
      "7           1.361e+10  -0.000e+00  y = ((x₀ * 0.77067) * x₁) + 38313\n",
      "9           1.129e+10  -0.000e+00  y = (x₁ * x₀) * ((x₀ * -0.00090611) + 1.5175)\n",
      "10          9.190e+07  -0.000e+00  y = ((x₀ * inv(x₁ + -982.95)) + x₁) * x₀\n",
      "11          7.972e+06  5.332e-01  y = x₀ * (inv(-0.054014 + sin(x₀)) + x₁)\n",
      "13          4.205e+06  -0.000e+00  y = x₁ * (x₀ + inv((sin(x₀) * -9.606) + 0.50798))\n",
      "15          3.842e+06  -0.000e+00  y = (x₀ + (inv((sin(x₀) * -9.606) + 0.50798) + -0.96585))...\n",
      "                                        * x₁\n",
      "17          3.818e+06  -0.000e+00  y = (x₀ + (inv((sin(x₀) * -9.606) + 0.50798) + -1.2062)) ...\n",
      "                                       * (x₁ + 0.40649)\n",
      "18          3.218e+06  -0.000e+00  y = x₁ * (x₀ + (inv((sin(x₁) * (inv(x₀) * x₁)) + -1.0185)...\n",
      "                                        * -0.67973))\n",
      "20          2.896e+06  -0.000e+00  y = ((inv(-1.0185 + ((x₁ * sin(x₁)) * inv(x₀))) * -0.6797...\n",
      "                                       3) + (-0.92131 + x₀)) * x₁\n",
      "22          2.857e+06  -0.000e+00  y = (((x₀ + -0.98589) + (inv(((x₁ * inv(x₀)) * sin(x₁)) +...\n",
      "                                        -1.0183) * -0.82886)) * x₁) + 0.36201\n",
      "25          2.841e+06  -0.000e+00  y = (sin(x₀) + x₁) * ((x₀ + -0.84241) + (inv(((inv(x₀) * ...\n",
      "                                       sin(x₁)) * x₁) + -1.0185) * -0.67973))\n",
      "27          2.828e+06  -0.000e+00  y = (((inv((x₁ * (sin(x₁) * inv(x₀))) + -1.0185) * -0.679...\n",
      "                                       73) + x₀) + -1.0116) * (x₁ + (sin(x₀) + 0.29815))\n",
      "29          2.828e+06  -0.000e+00  y = (((x₀ + -1.001) + (inv((x₁ * (inv(x₀) * sin(x₁))) + -...\n",
      "                                       1.0185) * -0.67973)) * ((x₁ + 0.3082) + sin(x₀))) + -14.6...\n",
      "                                       34\n",
      "30          2.826e+06  -0.000e+00  y = (x₁ + (sin(x₀) + 0.2785)) * ((x₀ + (-1.0116 + inv(x₀)...\n",
      "                                       )) + (inv((x₁ * (sin(x₁) * inv(x₀))) + -1.0185) * -0.6797...\n",
      "                                       3))\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.630e+06\n",
      "Progress: 20758 / 31000 total iterations (66.961%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           4.527e+10  0.000e+00  y = 2.2987e+05\n",
      "3           1.676e+10  -0.000e+00  y = x₁ * x₀\n",
      "5           1.429e+10  -0.000e+00  y = (x₁ * 0.85344) * x₀\n",
      "7           1.361e+10  -0.000e+00  y = ((x₀ * 0.77067) * x₁) + 38313\n",
      "9           1.129e+10  -0.000e+00  y = ((x₀ * -0.00090613) + 1.5175) * (x₁ * x₀)\n",
      "10          9.190e+07  -0.000e+00  y = ((x₀ * inv(x₁ + -982.95)) + x₁) * x₀\n",
      "11          7.959e+06  5.334e-01  y = (x₁ + inv(-0.054013 + sin(x₀))) * x₀\n",
      "13          3.895e+06  1.510e-01  y = x₁ * (inv(1.6555 + (-31.261 * sin(x₀))) + x₀)\n",
      "15          3.665e+06  -0.000e+00  y = (x₀ + (-0.82677 + inv((sin(x₀) * -20.657) + 1.0936)))...\n",
      "                                        * x₁\n",
      "17          3.655e+06  -0.000e+00  y = (x₀ + (inv((sin(x₀) * -20.657) + 1.0936) + -0.95681))...\n",
      "                                        * (x₁ + 0.23308)\n",
      "18          3.197e+06  -0.000e+00  y = (x₀ + (inv((x₁ * (sin(x₁) * inv(x₀))) + -1.0183) * -0...\n",
      "                                       .82948)) * x₁\n",
      "20          2.857e+06  -0.000e+00  y = x₁ * (x₀ + ((inv((sin(x₁) * (x₁ * inv(x₀))) + -1.0183...\n",
      "                                       ) * -0.82886) + -0.96585))\n",
      "22          2.854e+06  -0.000e+00  y = (((x₀ + -0.65273) + (inv((x₁ * (sin(x₁) * inv(x₀))) +...\n",
      "                                        -1.0183) * -0.82886)) * x₁) + -160.49\n",
      "23          2.854e+06  -0.000e+00  y = x₁ * ((inv(((inv(x₀) * sin(x₁)) * x₁) + -1.0183) * -0...\n",
      "                                       .82886) + (x₀ + (inv(x₀) + -0.9719)))\n",
      "25          2.790e+06  -0.000e+00  y = (x₁ + sin(x₀)) * ((inv(((inv(x₀) * x₁) * sin(x₁)) + -...\n",
      "                                       1.0183) * -0.82886) + (-0.88726 + x₀))\n",
      "27          2.780e+06  -0.000e+00  y = (((inv(((x₁ * inv(x₀)) * sin(x₁)) + -1.0183) * -0.828...\n",
      "                                       86) + x₀) + -1.0183) * (x₁ + (sin(x₀) + 0.24672))\n",
      "29          2.779e+06  -0.000e+00  y = (((x₁ + sin(x₀)) + 0.25208) * (((inv(((x₁ * sin(x₁)) ...\n",
      "                                       * inv(x₀)) + -1.0183) * -0.82886) + -1.0183) + x₀)) + -17...\n",
      "                                       .132\n",
      "30          2.759e+06  -0.000e+00  y = (x₁ + sin(x₀)) * ((x₀ + ((inv(((x₁ * inv(x₀)) * sin(x...\n",
      "                                       ₁)) + -1.0183) * -0.82886) + inv(x₀ * 0.036934))) + -1.07...\n",
      "                                       66)\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.640e+06\n",
      "Progress: 30366 / 31000 total iterations (97.955%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           4.527e+10  0.000e+00  y = 2.2987e+05\n",
      "3           1.676e+10  -0.000e+00  y = x₁ * x₀\n",
      "5           1.429e+10  -0.000e+00  y = (x₁ * 0.85344) * x₀\n",
      "7           1.361e+10  -0.000e+00  y = ((x₀ * 0.77067) * x₁) + 38313\n",
      "9           1.129e+10  -0.000e+00  y = ((x₀ * -0.00090613) + 1.5175) * (x₁ * x₀)\n",
      "10          9.190e+07  -0.000e+00  y = ((x₀ * inv(x₁ + -982.95)) + x₁) * x₀\n",
      "11          7.957e+06  5.334e-01  y = x₀ * (inv(sin(x₀) + -0.054013) + x₁)\n",
      "13          3.895e+06  1.509e-01  y = x₁ * (inv(1.6555 + (-31.261 * sin(x₀))) + x₀)\n",
      "15          3.665e+06  -0.000e+00  y = (x₀ + (-0.82677 + inv((sin(x₀) * -20.657) + 1.0936)))...\n",
      "                                        * x₁\n",
      "17          3.655e+06  -0.000e+00  y = ((inv(1.0936 + (-20.657 * sin(x₀))) + -0.96823) + x₀)...\n",
      "                                        * (0.23308 + x₁)\n",
      "18          3.197e+06  -0.000e+00  y = (x₀ + (inv((x₁ * (sin(x₁) * inv(x₀))) + -1.0183) * -0...\n",
      "                                       .82948)) * x₁\n",
      "20          2.857e+06  -0.000e+00  y = x₁ * (x₀ + ((inv((sin(x₁) * (x₁ * inv(x₀))) + -1.0183...\n",
      "                                       ) * -0.82886) + -0.96585))\n",
      "22          2.852e+06  -0.000e+00  y = (((x₀ + -0.82886) + (inv(((x₁ * sin(x₁)) * inv(x₀)) +...\n",
      "                                        -1.0183) * -0.82886)) * x₁) + -113.38\n",
      "25          2.790e+06  -0.000e+00  y = (x₁ + sin(x₀)) * ((inv(((inv(x₀) * x₁) * sin(x₁)) + -...\n",
      "                                       1.0183) * -0.82886) + (-0.88726 + x₀))\n",
      "27          2.780e+06  -0.000e+00  y = ((x₀ + -1.0204) + (-0.82886 * inv((x₁ * (sin(x₁) * in...\n",
      "                                       v(x₀))) + -1.0183))) * ((x₁ + 0.23138) + sin(x₀))\n",
      "29          2.776e+06  -0.000e+00  y = ((((inv(((sin(x₁) * inv(x₀)) * x₁) + -1.0183) * -0.82...\n",
      "                                       886) + -0.98854) + x₀) * ((sin(x₀) + 0.31478) + x₁)) + -8...\n",
      "                                       4.007\n",
      "30          2.730e+06  -0.000e+00  y = (sin(x₀) + x₁) * ((-0.91582 + ((sin(x₀) + inv(((sin(x...\n",
      "                                       ₁) * x₁) * inv(x₀)) + -1.0183)) * -0.82886)) + x₀)\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           4.527e+10  0.000e+00  y = 2.2987e+05\n",
      "3           1.676e+10  -0.000e+00  y = x₁ * x₀\n",
      "5           1.429e+10  -0.000e+00  y = (x₁ * 0.85344) * x₀\n",
      "7           1.361e+10  -0.000e+00  y = ((x₀ * 0.77067) * x₁) + 38313\n",
      "9           1.129e+10  -0.000e+00  y = ((x₀ * -0.00090613) + 1.5175) * (x₁ * x₀)\n",
      "10          9.190e+07  -0.000e+00  y = ((x₀ * inv(x₁ + -982.95)) + x₁) * x₀\n",
      "11          7.957e+06  5.334e-01  y = x₀ * (inv(sin(x₀) + -0.054013) + x₁)\n",
      "13          3.895e+06  1.509e-01  y = x₁ * (inv(1.6555 + (-31.261 * sin(x₀))) + x₀)\n",
      "15          3.665e+06  -0.000e+00  y = (x₀ + (-0.82677 + inv((sin(x₀) * -20.657) + 1.0936)))...\n",
      "                                        * x₁\n",
      "17          3.655e+06  -0.000e+00  y = (x₁ + 0.24277) * ((inv((sin(x₀) * -20.657) + 1.0936) ...\n",
      "                                       + -0.96823) + x₀)\n",
      "18          3.197e+06  -0.000e+00  y = (x₀ + (inv((x₁ * (sin(x₁) * inv(x₀))) + -1.0183) * -0...\n",
      "                                       .82948)) * x₁\n",
      "20          2.857e+06  -0.000e+00  y = x₁ * (x₀ + ((inv((sin(x₁) * (x₁ * inv(x₀))) + -1.0183...\n",
      "                                       ) * -0.82886) + -0.96585))\n",
      "22          2.852e+06  -0.000e+00  y = (((x₀ + -0.82886) + (inv(((x₁ * sin(x₁)) * inv(x₀)) +...\n",
      "                                        -1.0183) * -0.82886)) * x₁) + -113.38\n",
      "25          2.790e+06  -0.000e+00  y = (x₁ + sin(x₀)) * ((inv(((inv(x₀) * x₁) * sin(x₁)) + -...\n",
      "                                       1.0183) * -0.82886) + (-0.88726 + x₀))\n",
      "27          2.780e+06  -0.000e+00  y = ((x₀ + -1.0204) + (-0.82886 * inv((x₁ * (sin(x₁) * in...\n",
      "                                       v(x₀))) + -1.0183))) * ((x₁ + 0.23138) + sin(x₀))\n",
      "29          2.776e+06  -0.000e+00  y = ((sin(x₀) + (0.31478 + x₁)) * (x₀ + ((inv(((x₁ * sin(...\n",
      "                                       x₁)) * inv(x₀)) + -1.0183) * -0.82886) + -0.97458))) + -8...\n",
      "                                       4.007\n",
      "30          2.730e+06  -0.000e+00  y = (sin(x₀) + x₁) * ((-0.91582 + ((sin(x₀) + inv(((sin(x...\n",
      "                                       ₁) * x₁) * inv(x₀)) + -1.0183)) * -0.82886)) + x₀)\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "💡Best equation for output 0 found to be x1 * (inv(1.6555322 + (-31.26067 * sin(x0))) + x0).\n",
      "❤️ SR on block_17429733968 complete.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ Info: Final population:\n",
      "[ Info: Results saved to:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                          229870.23   \n",
       " \t1         0.496819                                            x1 * x0   \n",
       " \t2         0.079785                              (x1 * 0.8534426) * x0   \n",
       " \t3         0.024371               ((x0 * 0.77066994) * x1) + 38313.215   \n",
       " \t4         0.093507    ((x0 * -0.00090613356) + 1.5175121) * (x1 * x0)   \n",
       " \t5         4.810631             ((x0 * inv(x1 + -982.9535)) + x1) * x0   \n",
       " \t6         2.446648            x0 * (inv(sin(x0) + -0.054012753) + x1)   \n",
       " \t7   >>>>  0.357153  x1 * (inv(1.6555322 + (-31.26067 * sin(x0))) +...   \n",
       " \t8         0.030416  (x0 + (-0.82677305 + inv((sin(x0) * -20.656754...   \n",
       " \t9         0.001398  (x1 + 0.24277031) * ((inv((sin(x0) * -20.65675...   \n",
       " \t10        0.133837  (x0 + (inv((x1 * (sin(x1) * inv(x0))) + -1.018...   \n",
       " \t11        0.056327  x1 * (x0 + ((inv((sin(x1) * (x1 * inv(x0))) + ...   \n",
       " \t12        0.000825  (((x0 + -0.82885695) + (inv(((x1 * sin(x1)) * ...   \n",
       " \t13        0.007344  (x1 + sin(x0)) * ((inv(((inv(x0) * x1) * sin(x...   \n",
       " \t14        0.001708  ((x0 + -1.020432) + (-0.82885695 * inv((x1 * (...   \n",
       " \t15        0.000746  ((sin(x0) + (0.31477514 + x1)) * (x0 + ((inv((...   \n",
       " \t16        0.016927  (sin(x0) + x1) * ((-0.91581607 + ((sin(x0) + i...   \n",
       " \t\n",
       " \t            loss  complexity  \n",
       " \t0   4.526662e+10           1  \n",
       " \t1   1.675895e+10           3  \n",
       " \t2   1.428719e+10           5  \n",
       " \t3   1.360751e+10           7  \n",
       " \t4   1.128650e+10           9  \n",
       " \t5   9.190282e+07          10  \n",
       " \t6   7.957250e+06          11  \n",
       " \t7   3.895329e+06          13  \n",
       " \t8   3.665433e+06          15  \n",
       " \t9   3.655202e+06          17  \n",
       " \t10  3.197324e+06          18  \n",
       " \t11  2.856680e+06          20  \n",
       " \t12  2.851968e+06          22  \n",
       " \t13  2.789822e+06          25  \n",
       " \t14  2.780307e+06          27  \n",
       " \t15  2.776162e+06          29  \n",
       " \t16  2.729567e+06          30  \n",
       " ]}"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  - SR_output/block_17429733968/dim0_1764339129/hall_of_fame.csv\n"
     ]
    }
   ],
   "source": [
    "#Perform SR on our model\n",
    "symbolic_model_multiplication.distill(X, sr_params=sr_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fd2b8003",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Standard symbolic expressions for output dimension 0:\n",
      "    complexity          loss  \\\n",
      "0            1  4.526662e+10   \n",
      "1            3  1.675895e+10   \n",
      "2            5  1.428719e+10   \n",
      "3            7  1.360751e+10   \n",
      "4            9  1.128650e+10   \n",
      "5           10  9.190282e+07   \n",
      "6           11  7.957250e+06   \n",
      "7           13  3.895329e+06   \n",
      "8           15  3.665433e+06   \n",
      "9           17  3.655202e+06   \n",
      "10          18  3.197324e+06   \n",
      "11          20  2.856680e+06   \n",
      "12          22  2.851968e+06   \n",
      "13          25  2.789822e+06   \n",
      "14          27  2.780307e+06   \n",
      "15          29  2.776162e+06   \n",
      "16          30  2.729567e+06   \n",
      "\n",
      "                                             equation     score  \\\n",
      "0                                           229870.23  0.000000   \n",
      "1                                             x1 * x0  0.496819   \n",
      "2                               (x1 * 0.8534426) * x0  0.079785   \n",
      "3                ((x0 * 0.77066994) * x1) + 38313.215  0.024371   \n",
      "4     ((x0 * -0.00090613356) + 1.5175121) * (x1 * x0)  0.093507   \n",
      "5              ((x0 * inv(x1 + -982.9535)) + x1) * x0  4.810631   \n",
      "6             x0 * (inv(sin(x0) + -0.054012753) + x1)  2.446648   \n",
      "7   x1 * (inv(1.6555322 + (-31.26067 * sin(x0))) +...  0.357153   \n",
      "8   (x0 + (-0.82677305 + inv((sin(x0) * -20.656754...  0.030416   \n",
      "9   (x1 + 0.24277031) * ((inv((sin(x0) * -20.65675...  0.001398   \n",
      "10  (x0 + (inv((x1 * (sin(x1) * inv(x0))) + -1.018...  0.133837   \n",
      "11  x1 * (x0 + ((inv((sin(x1) * (x1 * inv(x0))) + ...  0.056327   \n",
      "12  (((x0 + -0.82885695) + (inv(((x1 * sin(x1)) * ...  0.000825   \n",
      "13  (x1 + sin(x0)) * ((inv(((inv(x0) * x1) * sin(x...  0.007344   \n",
      "14  ((x0 + -1.020432) + (-0.82885695 * inv((x1 * (...  0.001708   \n",
      "15  ((sin(x0) + (0.31477514 + x1)) * (x0 + ((inv((...  0.000746   \n",
      "16  (sin(x0) + x1) * ((-0.91581607 + ((sin(x0) + i...  0.016927   \n",
      "\n",
      "                                         sympy_format  \\\n",
      "0                                    229870.230000000   \n",
      "1                                               x0*x1   \n",
      "2                                     x1*0.8534426*x0   \n",
      "3                        x0*0.77066994*x1 + 38313.215   \n",
      "4             x0*x1*(1.5175121 + x0*(-0.00090613356))   \n",
      "5                        x0*(x0/(x1 - 982.9535) + x1)   \n",
      "6                 x0*(x1 + 1/(sin(x0) - 0.054012753))   \n",
      "7          x1*(x0 + 1/(1.6555322 - 31.26067*sin(x0)))   \n",
      "8   x1*(x0 - 0.82677305 + 1/(1.093595 + sin(x0)*(-...   \n",
      "9   (x1 + 0.24277031)*(x0 - 0.9682323 + 1/(1.09359...   \n",
      "10    x1*(x0 - 0.829478/(-1.0183077 + x1*sin(x1)/x0))   \n",
      "11  x1*(x0 - 0.9658466 - 0.82885695/(-1.0183077 + ...   \n",
      "12  x1*(x0 - 0.82885695 - 0.82885695/(-1.0183077 +...   \n",
      "13  (x1 + sin(x0))*(x0 - 0.88726354 - 0.82885695/(...   \n",
      "14  (x0 - 1.020432 - 0.82885695/(-1.0183077 + x1*s...   \n",
      "15  (x0 - 0.97458076 - 0.82885695/(-1.0183077 + x1...   \n",
      "16  (x1 + sin(x0))*(x0 + (sin(x0) + 1/(-1.0183077 ...   \n",
      "\n",
      "                                        lambda_format  \n",
      "0                   PySRFunction(X=>229870.230000000)  \n",
      "1                              PySRFunction(X=>x0*x1)  \n",
      "2                    PySRFunction(X=>x1*0.8534426*x0)  \n",
      "3       PySRFunction(X=>x0*0.77066994*x1 + 38313.215)  \n",
      "4   PySRFunction(X=>x0*x1*(1.5175121 + x0*(-0.0009...  \n",
      "5       PySRFunction(X=>x0*(x0/(x1 - 982.9535) + x1))  \n",
      "6   PySRFunction(X=>x0*(x1 + 1/(sin(x0) - 0.054012...  \n",
      "7   PySRFunction(X=>x1*(x0 + 1/(1.6555322 - 31.260...  \n",
      "8   PySRFunction(X=>x1*(x0 - 0.82677305 + 1/(1.093...  \n",
      "9   PySRFunction(X=>(x1 + 0.24277031)*(x0 - 0.9682...  \n",
      "10  PySRFunction(X=>x1*(x0 - 0.829478/(-1.0183077 ...  \n",
      "11  PySRFunction(X=>x1*(x0 - 0.9658466 - 0.8288569...  \n",
      "12  PySRFunction(X=>x1*(x0 - 0.82885695 - 0.828856...  \n",
      "13  PySRFunction(X=>(x1 + sin(x0))*(x0 - 0.8872635...  \n",
      "14  PySRFunction(X=>(x0 - 1.020432 - 0.82885695/(-...  \n",
      "15  PySRFunction(X=>(x0 - 0.97458076 - 0.82885695/...  \n",
      "16  PySRFunction(X=>(x1 + sin(x0))*(x0 + (sin(x0) ...  \n",
      "🏆 Best: x1 * (inv(1.6555322 + (-31.26067 * sin(x0))) + x0) (loss: 3.895329e+06)\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_multiplication.show_symbolic_expression()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "792a67fe",
   "metadata": {},
   "source": [
    "## Counting"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "111d0579",
   "metadata": {},
   "source": [
    "What does the LLM return when counting the number of 1s in a string of 1s and 0s?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "531cdfe2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4.0"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "extract_boxed_number(llm_call(\"Return only the numeric answer in the format $boxed$. How many 1s are there in the string 000101\", max_tokens= 250))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "144e5c28",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_number_string_01(N = 100, len_sequence = 4):\n",
    "    return np.random.randint(0, 2, size=(N, len_sequence))\n",
    "\n",
    "X_counts_01 = random_number_string_01(N = 25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "516bdbcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def llm_counting(X):\n",
    "    outputs = []\n",
    "    # X is of shape (N,10)\n",
    "    for n in range(X.shape[0]):\n",
    "        sequence = ''.join(map(str, X[n,:]))\n",
    "        # print(sequence)\n",
    "        output = llm_call(f\"Return only the numeric answer in the format $boxed$. How many 1s are there in the string {sequence}\", max_tokens=250)\n",
    "        # print(f\"Return only the numeric answer in the format $boxed$. How many 1s are there in the string {sequence}\")\n",
    "\n",
    "        try:\n",
    "            output = extract_boxed_number(output)\n",
    "        except ValueError:\n",
    "            print(\"No boxed number found. Trying again with more tokens.\")\n",
    "            output = llm_call(f\"Return only the numeric answer in the format $boxed$. How many 1s are there in the string {sequence}\", max_tokens=500)\n",
    "            output = extract_boxed_number(output)\n",
    "        outputs.append(output)\n",
    "    return np.array(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "60a45816",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No name specified for this block. Label is block_14826965584.\n"
     ]
    }
   ],
   "source": [
    "# Initialise our model\n",
    "symbolic_model_counting = SymbolicModel(llm_counting)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "78ffff43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n",
      "[ Info: Started!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Expressions evaluated per second: 1.970e+06\n",
      "Progress: 11800 / 31000 total iterations (38.065%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.434e+00  0.000e+00  y = 2.92\n",
      "3           1.206e+00  8.627e-02  y = x₂ + 2.44\n",
      "5           1.078e+00  5.608e-02  y = (x₂ + 2.04) + x₁\n",
      "7           1.038e+00  1.924e-02  y = (x₀ * (x₁ + -1.0961)) + 3.25\n",
      "8           7.888e-01  2.743e-01  y = (x₂ * inv(x₃ + 0.43874)) + 2.2695\n",
      "9           7.258e-01  8.321e-02  y = (((x₃ * -2.0286) + 2.1385) * x₂) + 2.4615\n",
      "10          7.044e-01  2.995e-02  y = (inv(x₃ + 0.43035) * x₂) + (x₁ + 1.8595)\n",
      "11          6.458e-01  8.684e-02  y = (x₂ * ((x₃ * (x₁ + -2.1716)) + 2.1384)) + 2.4616\n",
      "12          4.649e-01  3.287e-01  y = (x₂ * inv(x₀ + (x₃ + 0.32121))) + (x₁ + 1.9131)\n",
      "13          3.969e-01  1.581e-01  y = ((((x₀ + -1.7462) + x₃) * (x₂ * -1.6667)) + x₁) + 1.92...\n",
      "                                      31\n",
      "15          2.307e-01  2.713e-01  y = ((((x₂ * -2.6489) + 0.9822) * (x₃ + (x₀ + -1.555))) + ...\n",
      "                                      x₁) + 2.2416\n",
      "17          1.548e-01  1.994e-01  y = (x₁ + ((x₂ + -0.43704) * ((x₁ + -3.0282) * (x₀ + (x₃ +...\n",
      "                                       -1.5562))))) + 2.2358\n",
      "19          6.625e-02  4.244e-01  y = (x₁ + ((((x₂ + -0.52102) * ((x₀ + -1.536) + x₃)) + -0....\n",
      "                                      49287) * (x₁ + -1.4819))) * 2.751\n",
      "21          1.929e-02  6.170e-01  y = (x₂ + 2.8941) * (((x₁ + -1.216) * (((x₀ + (x₃ + -1.354...\n",
      "                                      2)) * (x₂ + -0.57892)) + 0.43981)) + 1.1238)\n",
      "23          1.923e-02  1.640e-03  y = ((x₂ + 2.8592) * (((((x₂ + -0.5796) * (x₀ + (x₃ + -1.3...\n",
      "                                      686))) + 0.44353) * (x₁ + -1.2275)) + 1.08)) + 0.20065\n",
      "26          1.147e-02  1.720e-01  y = (x₂ + ((((((x₁ * exp(x₀)) + -2.4806) * ((x₀ + x₃) + -1...\n",
      "                                      .2217)) * (x₂ + -0.47523)) + x₁) * 1.5505)) + 1.538\n",
      "28          1.050e-02  4.443e-02  y = ((((((exp(x₀) * x₁) + -2.4802) * ((x₀ + -1.2517) + x₃)...\n",
      "                                      ) * (x₂ + -0.47585)) + x₁) * 1.5519) + ((x₂ * 0.88771) + 1...\n",
      "                                      .5881)\n",
      "29          2.930e-03  1.276e+00  y = (x₁ + ((((x₂ * 3.2764) + (x₁ + -1.636)) * ((x₀ * -0.96...\n",
      "                                      809) + (((x₃ + -0.27179) * ((x₀ * x₁) + -0.99686)) + 1.201...\n",
      "                                      ))) + 1.6211)) * 1.241\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.040e+06\n",
      "Progress: 23360 / 31000 total iterations (75.355%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.434e+00  0.000e+00  y = 2.92\n",
      "3           1.206e+00  8.627e-02  y = x₂ + 2.44\n",
      "5           1.078e+00  5.608e-02  y = (x₂ + 2.04) + x₁\n",
      "7           1.038e+00  1.924e-02  y = (x₀ * (x₁ + -1.0961)) + 3.25\n",
      "8           7.888e-01  2.743e-01  y = (x₂ * inv(x₃ + 0.43874)) + 2.2695\n",
      "9           7.258e-01  8.321e-02  y = (((x₃ * -2.0286) + 2.1385) * x₂) + 2.4615\n",
      "10          7.044e-01  2.995e-02  y = (inv(x₃ + 0.43035) * x₂) + (x₁ + 1.8595)\n",
      "11          6.458e-01  8.684e-02  y = (x₂ * ((x₃ * (x₁ + -2.1716)) + 2.1384)) + 2.4616\n",
      "12          4.649e-01  3.287e-01  y = (x₂ * inv(x₀ + (x₃ + 0.32121))) + (x₁ + 1.9131)\n",
      "13          3.969e-01  1.581e-01  y = ((((x₀ + -1.7462) + x₃) * (x₂ * -1.6667)) + x₁) + 1.92...\n",
      "                                      31\n",
      "15          2.307e-01  2.713e-01  y = ((((x₂ * -2.6489) + 0.9822) * (x₃ + (x₀ + -1.555))) + ...\n",
      "                                      x₁) + 2.2416\n",
      "17          1.098e-01  3.713e-01  y = ((x₁ + -3.2857) * (((x₂ + -0.4518) * ((x₃ + x₀) + -1.5...\n",
      "                                      69)) + 1.4834)) + 6.9132\n",
      "19          6.503e-02  2.619e-01  y = (((x₁ + -1.5731) * (((x₂ + -0.51391) * (x₀ + (x₃ + -1....\n",
      "                                      5436))) + 0.56624)) + 1.6908) * 2.5161\n",
      "21          1.929e-02  6.077e-01  y = (x₂ + 2.8942) * (((((x₃ + (x₀ + -1.3542)) * (x₂ + -0.5...\n",
      "                                      7892)) + 0.43979) * (x₁ + -1.216)) + 1.1237)\n",
      "23          1.921e-02  1.995e-03  y = ((((x₁ + -1.2293) * (((x₀ + (x₃ + -1.3652)) * (x₂ + -0...\n",
      "                                      .57943)) + 0.44524)) + 1.0949) * (x₂ + 2.8501)) + 0.16584\n",
      "25          3.278e-03  8.841e-01  y = (x₂ + 1.5019) + ((((x₀ + (x₃ + -1.2447)) * ((x₂ + -0.4...\n",
      "                                      9825) * ((x₁ * (x₀ + 1.5671)) + -2.5734))) + x₁) * 1.5524)\n",
      "27          3.059e-03  3.458e-02  y = (((((((x₀ + 1.5737) * x₁) + -2.5829) * ((x₂ + -0.4989)...\n",
      "                                       * ((x₀ + x₃) + -1.2597))) + x₁) * 1.5481) + 1.5288) + (x₂...\n",
      "                                       * 0.94361)\n",
      "29          2.198e-03  1.652e-01  y = (x₂ + ((((((x₂ + -0.48702) * ((x₁ * (x₀ + 1.5624)) + -...\n",
      "                                      2.5606)) * (x₀ + (x₃ + -1.2392))) + x₁) * 1.5589) + 1.447)...\n",
      "                                      ) + (x₃ * 0.081675)\n",
      "30          2.198e-03  1.455e-04  y = ((x₂ + (((x₀ + -1.2474) + x₃) * (((sin(x₁) * (x₃ + (x₀...\n",
      "                                       + 2.9041))) + -4.0016) * (x₂ + -0.50006)))) + (x₁ * 1.523...\n",
      "                                      )) + 1.5008\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           1.434e+00  0.000e+00  y = 2.92\n",
      "3           1.206e+00  8.627e-02  y = x₂ + 2.44\n",
      "5           1.078e+00  5.608e-02  y = (x₂ + 2.04) + x₁\n",
      "7           1.038e+00  1.924e-02  y = (x₀ * (x₁ + -1.0961)) + 3.25\n",
      "8           7.888e-01  2.743e-01  y = (x₂ * inv(x₃ + 0.43874)) + 2.2695\n",
      "9           7.258e-01  8.321e-02  y = (((x₃ * -2.0286) + 2.1385) * x₂) + 2.4615\n",
      "10          7.044e-01  2.995e-02  y = (inv(x₃ + 0.43035) * x₂) + (x₁ + 1.8595)\n",
      "11          6.458e-01  8.684e-02  y = (x₂ * ((x₃ * (x₁ + -2.1716)) + 2.1384)) + 2.4616\n",
      "12          4.649e-01  3.287e-01  y = (x₂ * inv(x₀ + (x₃ + 0.32121))) + (x₁ + 1.9131)\n",
      "13          3.969e-01  1.581e-01  y = ((((x₀ + -1.7462) + x₃) * (x₂ * -1.6667)) + x₁) + 1.92...\n",
      "                                      31\n",
      "15          2.307e-01  2.713e-01  y = ((((x₂ * -2.6489) + 0.9822) * (x₃ + (x₀ + -1.555))) + ...\n",
      "                                      x₁) + 2.2416\n",
      "17          1.098e-01  3.713e-01  y = ((x₁ + -3.2857) * (((x₂ + -0.4518) * ((x₃ + x₀) + -1.5...\n",
      "                                      69)) + 1.4834)) + 6.9132\n",
      "19          6.503e-02  2.619e-01  y = (((x₁ + -1.5731) * (((x₂ + -0.51391) * (x₀ + (x₃ + -1....\n",
      "                                      5436))) + 0.56624)) + 1.6908) * 2.5161\n",
      "21          1.929e-02  6.077e-01  y = (x₂ + 2.8942) * (((((x₃ + (x₀ + -1.3542)) * (x₂ + -0.5...\n",
      "                                      7892)) + 0.43979) * (x₁ + -1.216)) + 1.1237)\n",
      "23          1.921e-02  1.995e-03  y = ((((x₁ + -1.2293) * (((x₀ + (x₃ + -1.3652)) * (x₂ + -0...\n",
      "                                      .57943)) + 0.44524)) + 1.0949) * (x₂ + 2.8501)) + 0.16584\n",
      "25          3.278e-03  8.841e-01  y = (x₂ + 1.5019) + ((((x₀ + (x₃ + -1.2447)) * ((x₂ + -0.4...\n",
      "                                      9825) * ((x₁ * (x₀ + 1.5671)) + -2.5734))) + x₁) * 1.5524)\n",
      "27          2.855e-03  6.903e-02  y = (((x₁ * 1.5302) + 1.5019) + x₂) + ((((((x₃ + 2.2809) +...\n",
      "                                       x₀) * x₁) + -3.9928) * ((x₃ + x₀) + -1.2457)) * (x₂ + -0....\n",
      "                                      49792))\n",
      "29          4.516e-04  9.221e-01  y = (((x₀ * -0.10471) + 1.6673) * (x₁ + (((((x₁ * (x₀ + 1....\n",
      "                                      4338)) + -2.5013) * (x₂ + -0.52095)) * ((x₀ + x₃) + -1.231...\n",
      "                                      9)) + 0.92096))) + x₂\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "💡Best equation for output 0 found to be (((x0 * -0.10471333) + 1.6673021) * (x1 + (((((x1 * (x0 + 1.4338295)) + -2.501253) * (x2 + -0.52095324)) * ((x0 + x3) + -1.2319472)) + 0.92096275))) + x2.\n",
      "❤️ SR on block_14826965584 complete.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ Info: Final population:\n",
      "[ Info: Results saved to:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                          2.9200118   \n",
       " \t1         0.086274                                      x2 + 2.439992   \n",
       " \t2         0.056081                              (x2 + 2.0399914) + x1   \n",
       " \t3         0.019240                (x0 * (x1 + -1.096135)) + 3.2499862   \n",
       " \t4         0.274270            (x2 * inv(x3 + 0.43874383)) + 2.2695189   \n",
       " \t5         0.083207   (((x3 * -2.028587) + 2.138481) * x2) + 2.4615328   \n",
       " \t6         0.029947      (inv(x3 + 0.43034637) * x2) + (x1 + 1.859502)   \n",
       " \t7         0.086838  (x2 * ((x3 * (x1 + -2.171553)) + 2.1383862)) +...   \n",
       " \t8         0.328665  (x2 * inv(x0 + (x3 + 0.32121232))) + (x1 + 1.9...   \n",
       " \t9         0.158085  ((((x0 + -1.7461581) + x3) * (x2 * -1.6666652)...   \n",
       " \t10        0.271281  ((((x2 * -2.648921) + 0.9822007) * (x3 + (x0 +...   \n",
       " \t11        0.371268  ((x1 + -3.2856712) * (((x2 + -0.4517958) * ((x...   \n",
       " \t12        0.261910  (((x1 + -1.5730897) * (((x2 + -0.5139112) * (x...   \n",
       " \t13        0.607665  (x2 + 2.8942268) * (((((x3 + (x0 + -1.3542218)...   \n",
       " \t14        0.001995  ((((x1 + -1.2293217) * (((x0 + (x3 + -1.365233...   \n",
       " \t15        0.884148  (x2 + 1.5019486) + ((((x0 + (x3 + -1.2446526))...   \n",
       " \t16        0.069028  (((x1 * 1.5302026) + 1.5018678) + x2) + ((((((...   \n",
       " \t17  >>>>  0.922076  (((x0 * -0.10471333) + 1.6673021) * (x1 + ((((...   \n",
       " \t\n",
       " \t        loss  complexity  \n",
       " \t0   1.433600           1  \n",
       " \t1   1.206400           3  \n",
       " \t2   1.078400           5  \n",
       " \t3   1.037692           7  \n",
       " \t4   0.788777           8  \n",
       " \t5   0.725802           9  \n",
       " \t6   0.704389          10  \n",
       " \t7   0.645802          11  \n",
       " \t8   0.464902          12  \n",
       " \t9   0.396923          13  \n",
       " \t10  0.230714          15  \n",
       " \t11  0.109798          17  \n",
       " \t12  0.065028          19  \n",
       " \t13  0.019288          21  \n",
       " \t14  0.019211          23  \n",
       " \t15  0.003278          25  \n",
       " \t16  0.002855          27  \n",
       " \t17  0.000452          29  \n",
       " ]}"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  - SR_output/block_14826965584/dim0_1764340041/hall_of_fame.csv\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_counting.distill(X_counts_01, sr_params=sr_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "167b6e3b",
   "metadata": {},
   "source": [
    "The LLM is really terrible at counting! The equations it learns are not remotely what you would expect ($x_0+x_1+...+x_N$)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "6989015c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Standard symbolic expressions for output dimension 0:\n",
      "    complexity      loss                                           equation  \\\n",
      "0            1  1.433600                                          2.9200118   \n",
      "1            3  1.206400                                      x2 + 2.439992   \n",
      "2            5  1.078400                              (x2 + 2.0399914) + x1   \n",
      "3            7  1.037692                (x0 * (x1 + -1.096135)) + 3.2499862   \n",
      "4            8  0.788777            (x2 * inv(x3 + 0.43874383)) + 2.2695189   \n",
      "5            9  0.725802   (((x3 * -2.028587) + 2.138481) * x2) + 2.4615328   \n",
      "6           10  0.704389      (inv(x3 + 0.43034637) * x2) + (x1 + 1.859502)   \n",
      "7           11  0.645802  (x2 * ((x3 * (x1 + -2.171553)) + 2.1383862)) +...   \n",
      "8           12  0.464902  (x2 * inv(x0 + (x3 + 0.32121232))) + (x1 + 1.9...   \n",
      "9           13  0.396923  ((((x0 + -1.7461581) + x3) * (x2 * -1.6666652)...   \n",
      "10          15  0.230714  ((((x2 * -2.648921) + 0.9822007) * (x3 + (x0 +...   \n",
      "11          17  0.109798  ((x1 + -3.2856712) * (((x2 + -0.4517958) * ((x...   \n",
      "12          19  0.065028  (((x1 + -1.5730897) * (((x2 + -0.5139112) * (x...   \n",
      "13          21  0.019288  (x2 + 2.8942268) * (((((x3 + (x0 + -1.3542218)...   \n",
      "14          23  0.019211  ((((x1 + -1.2293217) * (((x0 + (x3 + -1.365233...   \n",
      "15          25  0.003278  (x2 + 1.5019486) + ((((x0 + (x3 + -1.2446526))...   \n",
      "16          27  0.002855  (((x1 * 1.5302026) + 1.5018678) + x2) + ((((((...   \n",
      "17          29  0.000452  (((x0 * -0.10471333) + 1.6673021) * (x1 + ((((...   \n",
      "\n",
      "       score                                       sympy_format  \\\n",
      "0   0.000000                                   2.92001180000000   \n",
      "1   0.086274                                      x2 + 2.439992   \n",
      "2   0.056081                                x1 + x2 + 2.0399914   \n",
      "3   0.019240                     x0*(x1 - 1.096135) + 3.2499862   \n",
      "4   0.274270                   x2/(x3 + 0.43874383) + 2.2695189   \n",
      "5   0.083207         x2*(2.138481 + x3*(-2.028587)) + 2.4615328   \n",
      "6   0.029947               x1 + x2/(x3 + 0.43034637) + 1.859502   \n",
      "7   0.086838    x2*(x3*(x1 - 2.171553) + 2.1383862) + 2.4616337   \n",
      "8   0.328665          x1 + x2/(x0 + x3 + 0.32121232) + 1.913123   \n",
      "9   0.158085  x1 + (x0 + x3 - 1.7461581)*x2*(-1.6666652) + 1...   \n",
      "10  0.271281  x1 + (0.9822007 + x2*(-2.648921))*(x0 + x3 - 1...   \n",
      "11  0.371268  (x1 - 3.2856712)*((x2 - 0.4517958)*(x0 + x3 - ...   \n",
      "12  0.261910  ((x1 - 1.5730897)*((x2 - 0.5139112)*(x0 + x3 -...   \n",
      "13  0.607665  (x2 + 2.8942268)*((x1 - 1.2159615)*((x2 - 0.57...   \n",
      "14  0.001995  (x2 + 2.8500886)*((x1 - 1.2293217)*((x2 - 0.57...   \n",
      "15  0.884148  x2 + (x1 + (x2 - 0.4982464)*(x1*(x0 + 1.567087...   \n",
      "16  0.069028  x1*1.5302026 + x2 + (x2 - 0.4979197)*(x1*(x0 +...   \n",
      "17  0.922076  x2 + (1.6673021 + x0*(-0.10471333))*(x1 + (x2 ...   \n",
      "\n",
      "                                        lambda_format  \n",
      "0                   PySRFunction(X=>2.92001180000000)  \n",
      "1                      PySRFunction(X=>x2 + 2.439992)  \n",
      "2                PySRFunction(X=>x1 + x2 + 2.0399914)  \n",
      "3     PySRFunction(X=>x0*(x1 - 1.096135) + 3.2499862)  \n",
      "4   PySRFunction(X=>x2/(x3 + 0.43874383) + 2.2695189)  \n",
      "5   PySRFunction(X=>x2*(2.138481 + x3*(-2.028587))...  \n",
      "6   PySRFunction(X=>x1 + x2/(x3 + 0.43034637) + 1....  \n",
      "7   PySRFunction(X=>x2*(x3*(x1 - 2.171553) + 2.138...  \n",
      "8   PySRFunction(X=>x1 + x2/(x0 + x3 + 0.32121232)...  \n",
      "9   PySRFunction(X=>x1 + (x0 + x3 - 1.7461581)*x2*...  \n",
      "10  PySRFunction(X=>x1 + (0.9822007 + x2*(-2.64892...  \n",
      "11  PySRFunction(X=>(x1 - 3.2856712)*((x2 - 0.4517...  \n",
      "12  PySRFunction(X=>((x1 - 1.5730897)*((x2 - 0.513...  \n",
      "13  PySRFunction(X=>(x2 + 2.8942268)*((x1 - 1.2159...  \n",
      "14  PySRFunction(X=>(x2 + 2.8500886)*((x1 - 1.2293...  \n",
      "15  PySRFunction(X=>x2 + (x1 + (x2 - 0.4982464)*(x...  \n",
      "16  PySRFunction(X=>x1*1.5302026 + x2 + (x2 - 0.49...  \n",
      "17  PySRFunction(X=>x2 + (1.6673021 + x0*(-0.10471...  \n",
      "🏆 Best: (((x0 * -0.10471333) + 1.6673021) * (x1 + (((((x1 * (x0 + 1.4338295)) + -2.501253) * (x2 + -0.52095324)) * ((x0 + x3) + -1.2319472)) + 0.92096275))) + x2 (loss: 4.515823e-04)\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_counting.show_symbolic_expression()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "604635e6",
   "metadata": {},
   "source": [
    "## Temperature conversion"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "915198e4",
   "metadata": {},
   "source": [
    "Let's see how the LLM calculates Celsius to Fahrenheit. We would expect $y = \\frac{9}{5}x + 32$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "5102cfd9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"To convert Celsius to Fahrenheit, multiply the Celsius temperature by 9/5 and add 32. Here's the formula: $F = \\\\frac{9}{5}C + 32$ where $C$ is the temperature in Celsius. Plug in the value of $C$ and solve for $F$. $F = \\\\frac{9}{5}(30) + 32$ $F = \\\\frac{270}{5} + 32$ $F = 54 + 32$ $F = 86$ Therefore, 30 degrees Celsius is equal to 86 degrees Fahrenheit.\""
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "llm_call(\"Return only the numeric answer in the format $boxed$. What is 30 degrees Celsius in Fahrenheit?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53d01801",
   "metadata": {},
   "source": [
    "First, let's try with temperatures that are within a regular range (ie. between -20 and 200C)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "1d2d0e6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def llm_C_to_F(X):\n",
    "    outputs = []\n",
    "    # X is of shape (N,1)\n",
    "    for n in range(X.shape[0]):\n",
    "        temp_C = X[n,0]\n",
    "        output = llm_call(f\"Return only the numeric answer in the format $boxed$. What is {int(temp_C)} degreees Celsius in Fahrenheit?\")\n",
    "        try:\n",
    "            output = extract_boxed_number(output)\n",
    "        except ValueError:\n",
    "            print(\"No boxed number found. Trying again with more tokens.\")\n",
    "            output = llm_call(f\"Return only the numeric answer in the format $boxed$. What is {int(temp_C)} degreees Celsius in Fahrenheit?\", max_tokens= 500)\n",
    "            output = extract_boxed_number(output)\n",
    "        outputs.append(output)\n",
    "    return np.array(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "18fe0c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "def random_numbers(N = 100, minimum = 0, maximum = 999):\n",
    "    return np.random.randint(minimum, maximum, size=(N, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "59107042",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_temps = random_numbers (N = 50, minimum=-20, maximum=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "1e95456c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No name specified for this block. Label is block_17429614352.\n"
     ]
    }
   ],
   "source": [
    "# Initialise our model\n",
    "symbolic_model_C_to_F = SymbolicModel(llm_C_to_F)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "7ef7ea43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🛠️ Running SR on output dimension 0 of 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/liz/PhD/SymTorch_project/symtorch_venv/lib/python3.11/site-packages/pysr/sr.py:2811: UserWarning: Note: it looks like you are running in Jupyter. The progress bar will be turned off.\n",
      "  warnings.warn(\n",
      "[ Info: Started!\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Expressions evaluated per second: 2.010e+06\n",
      "Progress: 10198 / 31000 total iterations (32.897%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           5.150e+04  0.000e+00  y = 213.54\n",
      "3           3.533e+04  1.856e-01  y = x₀ * 2.2271\n",
      "5           3.522e+04  -0.000e+00  y = (x₀ * 2.0842) + 19.069\n",
      "7           3.515e+04  -0.000e+00  y = (x₀ * (x₀ * 0.010869)) + 78.179\n",
      "8           6.494e+03  1.686e+00  y = x₀ * (inv(x₀ + -168.86) + 1.9003)\n",
      "10          6.086e+03  3.212e-02  y = (x₀ + 16.528) * (inv(x₀ + -168.85) + 1.7258)\n",
      "12          5.954e+03  1.076e-02  y = x₀ + ((x₀ * (inv(x₀ + -168.86) + 0.57332)) + 42.782)\n",
      "14          5.948e+03  2.451e-04  y = (x₀ * (inv(x₀ + -168.86) + -0.394)) + ((x₀ + x₀) + 39....\n",
      "                                      498)\n",
      "16          5.948e+03  -0.000e+00  y = ((inv(x₀ + -168.86) + -0.394) * (x₀ + 0.63519)) + ((x...\n",
      "                                       ₀ + x₀) + 39.34)\n",
      "17          3.432e+03  5.495e-01  y = (((inv(x₀ + -168.86) * x₀) + x₀) + x₀) + inv((x₀ + -52...\n",
      "                                      .375) * -0.0070229)\n",
      "19          3.221e+03  3.149e-02  y = (0.86328 * x₀) + ((inv(-168.86 + x₀) * x₀) + (inv((x₀ ...\n",
      "                                      + -52.375) * -0.0070229) + x₀))\n",
      "21          3.168e+03  8.308e-03  y = (((inv(x₀ + -168.86) * x₀) + ((x₀ * 0.86328) + 5.6507)...\n",
      "                                      ) + x₀) + inv((x₀ + -52.375) * -0.0070229)\n",
      "23          3.085e+03  1.313e-02  y = (((inv(x₀ + -168.86) * x₀) + x₀) + 21.752) + ((inv(((x...\n",
      "                                      ₀ * 0.85589) + -44.893) * -0.005057) + x₀) * 0.71591)\n",
      "25          1.657e+03  3.107e-01  y = (((inv(x₀ + -168.86) * x₀) + (inv(0.1886 + sin(x₀)) + ...\n",
      "                                      x₀)) + x₀) + inv((x₀ + -52.375) * -0.0070229)\n",
      "27          1.635e+03  6.551e-03  y = inv(-0.0070229 * (x₀ + -52.375)) + ((((x₀ + x₀) + inv(...\n",
      "                                      sin(x₀) + 0.1886)) + (x₀ * inv(x₀ + -168.86))) * 0.98213)\n",
      "29          1.582e+03  1.636e-02  y = inv((x₀ + -52.375) * -0.0070229) + (inv(sin(x₀) + 0.18...\n",
      "                                      86) + ((((x₀ + x₀) + 7.3005) + (inv(x₀ + -168.86) * x₀)) *...\n",
      "                                       0.97359))\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "\n",
      "Expressions evaluated per second: 2.230e+06\n",
      "Progress: 21595 / 31000 total iterations (69.661%)\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           5.150e+04  0.000e+00  y = 213.54\n",
      "3           3.533e+04  1.856e-01  y = x₀ * 2.2271\n",
      "5           3.522e+04  -0.000e+00  y = (x₀ * 2.0842) + 19.069\n",
      "7           3.515e+04  -0.000e+00  y = (x₀ * (x₀ * 0.010869)) + 78.179\n",
      "8           6.492e+03  1.686e+00  y = x₀ * (inv(x₀ + -168.86) + 1.8887)\n",
      "10          6.086e+03  3.202e-02  y = (x₀ + 16.528) * (inv(x₀ + -168.85) + 1.7258)\n",
      "12          5.954e+03  1.076e-02  y = x₀ + ((x₀ * (inv(x₀ + -168.86) + 0.57332)) + 42.782)\n",
      "14          5.943e+03  6.285e-04  y = ((inv(x₀ + -168.86) + -0.42483) * x₀) + ((x₀ + x₀) + 4...\n",
      "                                      2.305)\n",
      "15          5.668e+03  4.688e-02  y = x₀ + (((inv(x₀ + -168.86) + inv(x₀ + -51.107)) + 0.863...\n",
      "                                      26) * x₀)\n",
      "17          3.212e+03  2.838e-01  y = x₀ + ((x₀ * (inv(x₀ + -168.86) + 0.86326)) + inv((x₀ +...\n",
      "                                       -52.376) * -0.0066555))\n",
      "19          3.175e+03  5.623e-03  y = (3.3255 + x₀) + ((x₀ * (inv(x₀ + -168.86) + 0.86326)) ...\n",
      "                                      + inv((x₀ + -52.376) * -0.0066555))\n",
      "21          3.070e+03  1.664e-02  y = (((inv(x₀ + -168.86) + 0.74751) * x₀) + 4.3337) + (inv...\n",
      "                                      ((x₀ + -52.376) * -0.0070607) + (x₀ + 14.136))\n",
      "25          1.651e+03  1.550e-01  y = (x₀ * inv(x₀ + -168.86)) + (inv(sin(x₀) + 0.1886) + ((...\n",
      "                                      inv((x₀ + -52.376) * -0.0067068) + x₀) + x₀))\n",
      "27          1.630e+03  6.350e-03  y = (((x₀ + inv((x₀ + -52.376) * -0.0065694)) + ((x₀ * inv...\n",
      "                                      (x₀ + -168.86)) + x₀)) + inv(sin(x₀) + 0.1886)) * 0.9864\n",
      "29          1.525e+03  3.316e-02  y = x₀ + (inv((x₀ + -52.375) * -0.0070229) + (((inv(sin(x₀...\n",
      "                                      ) + 0.1886) + 14.111) + (x₀ * inv(x₀ + -168.86))) + (x₀ * ...\n",
      "                                      0.87967)))\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "════════════════════════════════════════════════════════════════════════════════════════════════════\n",
      "Press 'q' and then <enter> to stop execution early.\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "Complexity  Loss       Score      Equation\n",
      "1           5.150e+04  0.000e+00  y = 213.54\n",
      "3           3.533e+04  1.856e-01  y = x₀ * 2.2271\n",
      "5           3.522e+04  -0.000e+00  y = (x₀ * 2.0842) + 19.069\n",
      "7           3.515e+04  -0.000e+00  y = ((x₀ * 0.010869) * x₀) + 78.177\n",
      "8           6.492e+03  1.686e+00  y = x₀ * (inv(x₀ + -168.86) + 1.8887)\n",
      "10          5.962e+03  4.230e-02  y = ((inv(x₀ + -168.86) + 1.633) * x₀) + 37.564\n",
      "12          5.954e+03  4.838e-04  y = x₀ + ((x₀ * (inv(x₀ + -168.86) + 0.57332)) + 42.782)\n",
      "13          3.278e+03  5.963e-01  y = x₀ * (inv(x₀ + -168.86) + (inv(x₀ + -51.868) + 1.8501)...\n",
      "                                      )\n",
      "15          3.043e+03  3.709e-02  y = (((inv(x₀ + -168.86) + inv(x₀ + -51.863)) + 1.6465) * ...\n",
      "                                      x₀) + 27.001\n",
      "17          3.042e+03  7.007e-05  y = (((inv(x₀ + -51.864) + 1.6465) + inv(x₀ + -168.87)) * ...\n",
      "                                      (x₀ + -0.60253)) + 28.113\n",
      "18          1.870e+03  4.865e-01  y = x₀ * (((inv(-129.84 + x₀) + inv(x₀ + -51.867)) + 1.816...\n",
      "                                      8) + inv(x₀ + -168.86))\n",
      "20          1.416e+03  1.388e-01  y = ((((inv(x₀ + -168.86) + 1.6777) + inv(x₀ + -129.65)) +...\n",
      "                                       inv(x₀ + -51.867)) * x₀) + 32.111\n",
      "22          1.416e+03  -0.000e+00  y = (x₀ * ((((inv(x₀ + -129.65) + 1.6465) + inv(x₀ + -51....\n",
      "                                       867)) + 0.031154) + inv(x₀ + -168.86))) + 32.099\n",
      "25          1.285e+03  3.249e-02  y = (x₀ * (inv(x₀ + -129.65) + ((inv(x₀ + -165.54) + 1.646...\n",
      "                                      5) + (inv(x₀ + -51.867) + inv(x₀ + -168.86))))) + 32.099\n",
      "27          1.154e+03  5.337e-02  y = (x₀ * (inv((x₀ + -167.16) + x₀) + ((inv(x₀ + -129.65) ...\n",
      "                                      + 1.6465) + (inv(x₀ + -168.86) + inv(x₀ + -51.867))))) + 3...\n",
      "                                      2.099\n",
      "29          1.149e+03  2.301e-03  y = (((((inv(x₀ + -168.86) + inv(x₀ + -51.867)) + inv((x₀ ...\n",
      "                                      + (-168.86 + x₀)) + 1.8319)) + 1.6465) + inv(x₀ + -129.65)...\n",
      "                                      ) * x₀) + 32.111\n",
      "───────────────────────────────────────────────────────────────────────────────────────────────────\n",
      "💡Best equation for output 0 found to be ((((inv(x0 + -168.86176) + 1.6776766) + inv(x0 + -129.65385)) + inv(x0 + -51.86743)) * x0) + 32.11095.\n",
      "❤️ SR on block_17429614352 complete.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ Info: Final population:\n",
      "[ Info: Results saved to:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: PySRRegressor.equations_ = [\n",
       " \t    pick     score                                           equation  \\\n",
       " \t0         0.000000                                           213.5406   \n",
       " \t1         0.188442                                     x0 * 2.2270966   \n",
       " \t2         0.001546                       (x0 * 2.0842118) + 19.069176   \n",
       " \t3         0.001028                 ((x0 * 0.01086909) * x0) + 78.1772   \n",
       " \t4         1.688924            x0 * (inv(x0 + -168.86307) + 1.8887489)   \n",
       " \t5         0.042564  ((inv(x0 + -168.86176) + 1.6329502) * x0) + 37...   \n",
       " \t6         0.000728  x0 + ((x0 * (inv(x0 + -168.86124) + 0.5733208)...   \n",
       " \t7         0.596715  x0 * (inv(x0 + -168.86176) + (inv(x0 + -51.868...   \n",
       " \t8         0.037221  (((inv(x0 + -168.86482) + inv(x0 + -51.86347))...   \n",
       " \t9         0.000192  (((inv(x0 + -51.864178) + 1.6464747) + inv(x0 ...   \n",
       " \t10        0.486669  x0 * (((inv(-129.83676 + x0) + inv(x0 + -51.86...   \n",
       " \t11  >>>>  0.138889  ((((inv(x0 + -168.86176) + 1.6776766) + inv(x0...   \n",
       " \t12        0.000006  (x0 * ((((inv(x0 + -129.65385) + 1.6465319) + ...   \n",
       " \t13        0.032532  (x0 * (inv(x0 + -129.65385) + ((inv(x0 + -165....   \n",
       " \t14        0.053434  (x0 * (inv((x0 + -167.16382) + x0) + ((inv(x0 ...   \n",
       " \t15        0.002362  (((((inv(x0 + -168.86176) + inv(x0 + -51.86743...   \n",
       " \t\n",
       " \t          loss  complexity  \n",
       " \t0   51499.5800           1  \n",
       " \t1   35328.5040           3  \n",
       " \t2   35219.4500           5  \n",
       " \t3   35147.1200           7  \n",
       " \t4    6492.3110           8  \n",
       " \t5    5962.5000          10  \n",
       " \t6    5953.8223          12  \n",
       " \t7    3278.2800          13  \n",
       " \t8    3043.1028          15  \n",
       " \t9    3041.9333          17  \n",
       " \t10   1869.7866          18  \n",
       " \t11   1416.2981          20  \n",
       " \t12   1416.2822          22  \n",
       " \t13   1284.5906          25  \n",
       " \t14   1154.3903          27  \n",
       " \t15   1148.9487          29  \n",
       " ]}"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  - SR_output/block_17429614352/dim0_1764340979/hall_of_fame.csv\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_C_to_F.distill(X_temps, sr_params=sr_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "7589d521",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "➡️ Dimension 0 - Complexity 5:\n",
      "   (x0 * 2.0842118) + 19.069176 (loss: 3.521945e+04)\n"
     ]
    }
   ],
   "source": [
    "symbolic_model_C_to_F.show_symbolic_expression(complexity=5)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base_venv",
   "language": "python",
   "name": "venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
