# Stats
- 21 files
- 5032 (5.0K) lines
- 167243 (167K) chars
- 66962 (67K) `gpt2` tokens

# File Tree

```
pattern-lens                     
├── .github                      
│   └── workflows                
│    └── checks.yml              [   85L  1,688C    844T]
├── data                         
│   ├── pile_5.jsonl             [    5L 22,124C  6,858T]
├── pattern_lens                 
│   ├── frontend                 
│   │   ├── app.js               [  950L 29,307C 13,185T]
│   │   ├── index.template.html  [  204L 10,838C  5,785T]
│   │   ├── style.css            [  645L 10,988C  4,536T]
│   │   └── util.js              [   58L  1,521C    573T]
│   ├── __init__.py              [   13L    180C     83T]
│   ├── activations.py           [  657L 18,397C  7,077T]
│   ├── attn_figure_funcs.py     [  118L  3,127C  1,159T]
│   ├── consts.py                [   37L  1,093C    354T]
│   ├── figure_util.py           [  515L 16,788C  6,213T]
│   ├── figures.py               [  472L 14,402C  5,289T]
│   ├── indexes.py               [  201L  6,103C  2,281T]
│   ├── load_activations.py      [  167L  4,270C  1,579T]
│   ├── prompts.py               [   82L  1,721C    743T]
│   ├── py.typed                 [    0L      0C      0T]
│   └── server.py                [   59L  1,372C    493T]
├── README.md                    [  106L  3,862C  1,295T]
├── demo.ipynb                   [  165L  4,787C  2,343T]
├── makefile                     [1,681L 51,167C 19,727T]
├── pyproject.toml               [  225L  7,181C  3,758T]
```

# File Contents

``````{ path=".github/workflows/checks.yml"  }
name: Checks

on:
  workflow_dispatch:
  pull_request:
    branches:
      - '*'
  push:
    branches:
      - main

jobs:
  dep-check:
    name: Check dependencies
    runs-on: ubuntu-latest
    strategy:
      matrix:
        versions:
          - python: "3.11"
    steps:
      - uses: actions/checkout@v4
        with:
          fetch-depth: 0 # whole history for making version
      
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.versions.python }}

      - name: set up uv
        run: curl -LsSf https://astral.sh/uv/install.sh | sh

      - name: print python version
        run: python --version

      - name: check deps
        run: make dep-check
      

  lint:
    name: Lint
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      
      - uses: actions/setup-python@v5
        with:
          python-version: '3.12'

      - name: install
        run: |
          curl -LsSf https://astral.sh/uv/install.sh | sh
          make setup

      - name: format-check
        run: make format-check

  test:
    name: Test
    runs-on: ubuntu-latest
    strategy:
      matrix:
        versions:
          - python: "3.11"
          - python: "3.12"
          - python: "3.13"
    steps:
      - uses: actions/checkout@v4
      - uses: actions/setup-python@v5
        with:
          python-version: ${{ matrix.versions.python }}

      - name: install
        run: |
          curl -LsSf https://astral.sh/uv/install.sh | sh
          make setup

      - name: Tests
        run: make test

      - name: typing
        run: make typing
    
      - name: run demo (no server)
        run: make demo-docs
     
``````{ end_of_file=".github/workflows/checks.yml" }

``````{ path="data/pile_5.jsonl"  }
{"text": "It is done, and submitted. You can play \u201cSurvival of the Tastiest\u201d on Android, and on the web. Playing on the web works, but you have to simulate multi-touch for table moving and that can be a bit confusing.\n\nThere\u2019s a lot I\u2019d like to talk about. I\u2019ll go through every topic, insted of making the typical what went right/wrong list.\n\nConcept\n\nWorking over the theme was probably one of the hardest tasks I had to face.\n\nOriginally, I had an idea of what kind of game I wanted to develop, gameplay wise \u2013 something with lots of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident I could fit any theme around it.\n\nIn the end, the problem with a theme like \u201cEvolution\u201d in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game?\n\nIn a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it\u2019s not evolution anymore \u2013 it\u2019s the equivalent of intelligent design, the fable invented by creationists to combat the very idea of evolution. Being agnostic and a Pastafarian, that\u2019s not something that rubbed me the right way.\n\nHence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn\u2019t want to create an \u201cintelligent design\u201d simulator and wrongly call it evolution.\n\nThis is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I\u2019d say the only real solution was through the use of artificial selection, somehow. So far, I haven\u2019t seen any entry using this at its core gameplay.\n\nAlas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out.\n\nMy initial idea was to create something where humanity tried to evolve to a next level but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn\u2019t think of compelling (read: serious) mechanics for that.\n\nBorgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg?\n\nThe third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it.\n\nConversations with my inspiring co-worker Roushey (who also created the \u201cMechanical Underdogs\u201d signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist \u2013 by evolving from a normal dinner table.\n\nSo the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your \u201cbase\u201d. There are 5 other guests at the table, each with their own plate.\n\nYour plate can spawn little pieces of pasta. You do so by \u201cordering\u201d them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying costs, which are debited from your credits (you start with a number of credits).\n\nOnce spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps).\n\nYour pasta doesn\u2019t like other people\u2019s pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill.\n\nOnce a pasta is in the vicinity of a plate, it starts conquering it for its team. It takes around 10 seconds for a plate to be conquered; less if more pasta from the same team are around. If pasta from other team are around, though, they get locked down in their attempt, unable to conquer the plate, until one of them die (think Battlefield\u2019s standard \u201cConquest\u201d mode).\n\nYou get points every second for every plate you own.\n\nOver time, the concept also evolved to use an Italian bistro as its main scenario.\n\nCarlos, Carlos\u2019 Bistro\u2019s founder and owner\n\nSetup\n\nNo major changes were made from my work setup. I used FDT and Starling creating an Adobe AIR (ActionScript) project, all tools or frameworks I already had some knowledge with.\n\nOne big change for me was that I livestreamed my work through a twitch.tv account. This was a new thing for me. As recommended by Roushey, I used a program called XSplit and I got to say, it is pretty amazing. It made the livestream pretty effortless and the features are awesome, even for the free version. It was great to have some of my friends watch me, and then interact with them and random people through chat. It was also good knowing that I was also recording a local version of the files, so I could make a timelapse video later.\n\nKnowing the video was being recorded also made me a lot more self-conscious about my computer use, as if someone was watching over my shoulder. It made me realize that sometimes I spend too much time in seemingly inane tasks (I ended up wasting the longest time just to get some text alignment the way I wanted \u2013 it\u2019ll probably drive someone crazy if they watch it) and that I do way too many typos where writing code. I pretty much spend half of the time writing a line and the other half fixing the crazy characters in it.\n\nMy own stream was probably boring to watch since I was coding for the most time. But livestreaming is one of the cool things to do as a spectator too. It was great seeing other people working \u2013 I had a few tabs opened on my second monitor all the time. It\u2019s actually a bit sad, because if I could, I could have spent the whole weekend just watching other people working! But I had to do my own work, so I\u2019d only do it once in a while, when resting for a bit.\n\nDesign\n\nAlthough I wanted some simple, low-fi, high-contrast kind of design, I ended up going with somewhat realistic (vector) art. I think it worked very well, fitting the mood of the game, but I also went overboard.\n\nFor example: to know the state of a plate (who owns it, who\u2019s conquering it and how much time they have left before conquering it, which pasta units are in the queue, etc), you have to look at the plate\u2019s bill.\n\nThe problem I realized when doing some tests is that people never look at the bill! They think it\u2019s some kind of prop, so they never actually read its details.\n\nPlus, if you\u2019re zoomed out too much, you can\u2019t actually read it, so it\u2019s hard to know what\u2019s going on with the game until you zoom in to the area of a specific plate.\n\nOne other solution that didn\u2019t turn out to be as perfect as I thought was how to indicate who a plate base belongs to. In the game, that\u2019s indicated by the plate\u2019s decoration \u2013 its color denotes the team owner. But it\u2019s something that fits so well into the design that people never realized it, until they were told about it.\n\nIn the end, the idea of going with a full physical metaphor is one that should be done with care. Things that are very important risk becoming background noise, unless the player knows its importance.\n\nOriginally, I wanted to avoid any kind of heads-up display in my game. In the end, I ended up adding it at the bottom to indicate your credits and bases owned, as well as the hideous out-of-place-and-still-not-obvious \u201cCall Waiter\u201d button. But in hindsight, I should have gone with a simple HUD from the start, especially one that indicated each team\u2019s colors and general state of the game without the need for zooming in and out.\n\nDevelopment\n\nDevelopment went fast. But not fast enough.\n\nEven though I worked around 32+ hours for this Ludum Dare, the biggest problem I had to face in the end was overscoping. I had too much planned, and couldn\u2019t get it all done.\n\nContent-wise, I had several kinds of pasta planned (Wikipedia is just amazing in that regard), split into several different groups, from small Pastina to huge Pasta al forno. But because of time constraints, I ended up scratching most of them, and ended up with 5 different types of very small pasta \u2013 barely something to start when talking about the evolution of Pasta.\n\nPastas used in the game. Unfortunately, the macs where never used\n\nWhich is one of the saddest things about the project, really. It had the framework and the features to allow an endless number of elements in there, but I just didn\u2019t have time to draw the rest of the assets needed (something I loved to do, by the way).\n\nOther non-obvious features had to be dropped, too. For example, when ordering some pasta, you were supposed to select what kind of sauce you\u2019d like with your pasta, each with different attributes. Bolognese, for example, is very strong, but inaccurate; Pesto is very accurate and has great range, but it\u2019s weaker; and my favorite, Vodka, would triggers 10% loss of speed on the pasta hit by it.\n\nThe code for that is mostly in there. But in the end, I didn\u2019t have time to implement the sauce selection interface; all pasta ended up using bolognese sauce.\n\nTo-do list: lots of things were not done\n\nActual programming also took a toll in the development time. Having been programming for a while, I like to believe I got to a point where I know how to make things right, but at the expense of forgetting how to do things wrong in a seemingly good way. What I mean is that I had to take a lot of shortcuts in my code to save time (e.g. a lot of singletons references for cross-communication rather than events or observers, all-encompassing check loops, not fast enough) that left a very sour taste in my mouth. While I know I used to do those a few years ago and survive, I almost cannot accept the state my code is in right now.\n\nAt the same time, I do know it was the right thing to do given the timeframe.\n\nOne small thing that had some impact was using a somewhat new platform for me. That\u2019s Starling, the accelerated graphics framework I used in Flash. I had tested it before and I knew how to use it well \u2013 the API is very similar to Flash itself. However, there were some small details that had some impact during development, making me feel somewhat uneasy the whole time I was writing the game. It was, again, the right thing to do, but I should have used Starling more deeply before (which is the conundrum: I used it for Ludum Dare just so I could learn more about it).\n\nArgument and user experience\n\nOne final aspect of the game that I learned is that making the game obvious for your players goes a long way into making it fun. If you have to spend the longest time explaining things, your game is doing something wrong.\n\nAnd that\u2019s exactly the problem Survival of the Tastiest ultimately faced. It\u2019s very hard for people to understand what\u2019s going on with the game, why, and how. I did have some introductory text at the beginning, but that was a last-minute thing. More importantly, I should have had a better interface or simplified the whole concept so it would be easier for people to understand.\n\nThat doesn\u2019t mean the game itself should be simple. It just means that the experience and interface should be approachable and understandable.\n\nConclusion\n\nI\u2019m extremely happy with what I\u2019ve done and, especially given that this was my first Ludum Dare. However, I feel like I\u2019ve learned a lot of what not to do.\n\nThe biggest problem is overscoping. Like Eric Decker said, the biggest lesson we can learn with this is probably with scoping \u2013 deciding what to do beforehand in a way you can complete it without having to rush and do something half-assed.\n\nI\u2019m sure I will do more Ludum Dares in the future. But if there are any lessons I can take of it, they are to make it simple, to use frameworks and platforms you already have some absolute experience with (otherwise you\u2019ll spend too much time trying to solve easy questions), and to scope for a game that you can complete in one day only (that way, you can actually take two days and make it cool).\n\nThis entry was posted\non Monday, August 27th, 2012 at 10:54 am and is filed under LD #24.\nYou can follow any responses to this entry through the RSS 2.0 feed.\nYou can skip to the end and leave a response. Pinging is currently not allowed.\n\n3 Responses to \u201c\u201cSurvival of the Tastiest\u201d Post-mortem\u201d\n\ndarn it , knowing that I missed your livestream makes me a sad panda ;( but more to the point, the game is \u2026 well for a startup its original to say the least ;D it has some really neat ideas and more importantly its designed arround touch screens whitch by the looks of the submission is something rare ;o or that could be just me and my short memory -_-! awesum game, love et <3", "meta": {"pile_set_name": "Pile-CC"}}
{"text": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\r\n<segment>\r\n    <name>PD1</name>\r\n    <description>Patient Additional Demographic</description>\r\n    <elements>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.1</name>\r\n            <description>Living Dependency</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.2</name>\r\n            <description>Living Arrangement</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.3</name>\r\n            <description>Patient Primary Facility</description>\r\n            <datatype>XON</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.4</name>\r\n            <description>Patient Primary Care Provider Name &amp; ID No.</description>\r\n            <datatype>XCN</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.5</name>\r\n            <description>Student Indicator</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.6</name>\r\n            <description>Handicap</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.7</name>\r\n            <description>Living Will Code</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.8</name>\r\n            <description>Organ Donor Code</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.9</name>\r\n            <description>Separate Bill</description>\r\n            <datatype>ID</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.10</name>\r\n            <description>Duplicate Patient</description>\r\n            <datatype>CX</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.11</name>\r\n            <description>Publicity Code</description>\r\n            <datatype>CE</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.12</name>\r\n            <description>Protection Indicator</description>\r\n            <datatype>ID</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.13</name>\r\n            <description>Protection Indicator Effective Date</description>\r\n            <datatype>DT</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.14</name>\r\n            <description>Place of Worship</description>\r\n            <datatype>XON</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.15</name>\r\n            <description>Advance Directive Code</description>\r\n            <datatype>CE</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.16</name>\r\n            <description>Immunization Registry Status</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.17</name>\r\n            <description>Immunization Registry Status Effective Date</description>\r\n            <datatype>DT</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.18</name>\r\n            <description>Publicity Code Effective Date</description>\r\n            <datatype>DT</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.19</name>\r\n            <description>Military Branch</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.20</name>\r\n            <description>Military Rank/Grade</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n        <field minOccurs=\"0\" maxOccurs=\"0\">\r\n            <name>PD1.21</name>\r\n            <description>Military Status</description>\r\n            <datatype>IS</datatype>\r\n        </field>\r\n    </elements>\r\n</segment>\r\n", "meta": {"pile_set_name": "Github"}}
{"text": "Article content\n\nHuman behavior has a tremendous impact on investing \u2014 more so than most realize \u2014 and one of our biggest weaknesses is the tendency to constantly compare and contrast ourselves to others.\n\n[np_storybar title=\u201dFollow Financial Post\u201d link=\u201d\u201d]\n\nWe apologize, but this video has failed to load.\n\ntap here to see other videos from our team. Try refreshing your browser, or Three signs bubbles are brewing again in the market \u2014 and one of them has wheels Back to video\n\n\u2022 Twitter\n\n\u2022 Facebook\n\n[/np_storybar]\n\nFor example, a 1995 study by the Harvard School of Public Health indicated that people will forgo a stronger income scenario in favour of a weaker one as long as it meant earning more than their neighbours.\n\nUnfortunately, many in the investment world are keenly aware of this and will structure their marketing efforts accordingly. As a result, you have a compounding of momentum or trends in the market as investors buy at or near market tops for fear of not doing as well as or better than others.\n\nFor the same reason, investors piled into technology stocks in 2000 with only the promise of earnings in some distant future, and into housing-related investments in 2007 that were backstopped by very low incomes.", "meta": {"pile_set_name": "OpenWebText2"}}
{"text": "Topic: reinvent midnight madness\n\nAmazon announced a new service at the AWS re:Invent Midnight Madness event. Amazon Sumerian is a solution that aims to make it easier for developers to build virtual reality, augmented reality, and 3D applications. It features a user friendly editor, which can be used to drag and drop 3D objects and characters into scenes. Amazon \u2026 continue reading", "meta": {"pile_set_name": "Pile-CC"}}
{"text": "About Grand Slam Fishing Charters\n\nAs a family owned business we know how important it is that your trip becomes the best memory of your vacation, we are proud of our islands, our waters and our crew and we are desperate show you the best possible time during your stay. We can not guarantee fish every time but we can guarantee you a great time! The biggest perk of our job is seeing so many of our customers become close friends\u201d\n\nA Great Way To Make New Friends!\n\nOur dockside parties are a great way to make new friends! Everyone is welcome!\n\nAndrea runs the whole operation, from discussing your initial needs by phone or email through to ensuring you have sufficient potato chips. Andrea has worked as concierge for many International resorts and fully understands the high expectations of international visitors.\n\n\u201cLife\u2019s A Game But Fishing Is Serious!\u201d\n\nUnlike many tour operators, our crew are highly valued and have been with us since day 1. Each have their own personalities and sense of humour and understand the importance of making your day perfect, for us the saying is true, \u201cLifes a game but fishing is serious!\u201d\n\nTRIP ADVISOR\n\nPlan Your Trip!\n\nAJ and Earl were excellent. My son and I did a half day deep sea trip and though the fish weren\u2019t too cooperative, they did everything to try to get something to bite. Very knowledgeable about the waters and my son was able to land a nice barracuda. The next day my wife, daughter, son [\u2026]\n\nWhen we arrived the crew made us feel right at home. They made us feel comfortable and answered all questions. The crew worked hard all day to put us on fish. We were successful in landing a nice size Wahoo even though the weather did not cooperate the entire day was enjoyable. I highly recommend [\u2026]", "meta": {"pile_set_name": "Pile-CC"}}
``````{ end_of_file="data/pile_5.jsonl" }

``````{ path="pattern_lens/frontend/app.js"  }
const app = Vue.createApp({

	// ########     ###    ########    ###
	// ##     ##   ## ##      ##      ## ##
	// ##     ##  ##   ##     ##     ##   ##
	// ##     ## ##     ##    ##    ##     ##
	// ##     ## #########    ##    #########
	// ##     ## ##     ##    ##    ##     ##
	// ########  ##     ##    ##    ##     ##

	data() {
		return {
			isDarkMode: false,
			prompts: {
				all: {},        // hash -> prompt mapping
				selected: [],   // selected from table
				grid: {
					api: null,
					isReady: false
				},
			},
			loading: false,
			images: {
				visible: [],
				expected: 0,
				requested: false,
				upToDate: false,
				perRow: 4,
			},
			models: {
				configs: {},    // model -> config mapping
				grid: {
					api: null,
				},
			},
			filters: {
				available: {    // all available options
					models: [],
					functions: [],
					layers: [],
					heads: [],
				},
				selected: {     // currently selected options
					models: [],
					functions: [],
					layers: [],
					heads: [],
				},
			},
			head_selections_str: {}, // model -> selection string mapping
			visualization: {
				colorBy: '',
				sortBy: '',
				sortOrder: 'asc',
				colorMap: {},
			},
		};
	},

	methods: {

		// ##     ## ########    ###    ########   ######
		// ##     ## ##         ## ##   ##     ## ##    ##
		// ##     ## ##        ##   ##  ##     ## ##
		// ######### ######   ##     ## ##     ##  ######
		// ##     ## ##       ######### ##     ##       ##
		// ##     ## ##       ##     ## ##     ## ##    ##
		// ##     ## ######## ##     ## ########   ######

		// Parse head selection string and return a 2D array of booleans
		parseHeadString(str, maxLayer, maxHead) {
			try {
				const result = Array(maxLayer).fill().map(() => Array(maxHead).fill(false));
				if (!str || str.trim() === '') return result;

				const selections = str.replaceAll("x", "*").split(',').map(s => s.trim());

				for (const selection of selections) {
					const match = selection.match(/^L(\d+|\d+-\d+|\*)(H\d+|H\*|Hx)?$/);
					if (!match) return null;

					const layerPart = match[1];
					let headPart = match[2];

					// If the user typed only "L8" (no head specification), default to H*
					if (!headPart) {
						headPart = 'H*';
					}

					let layers = [];
					if (layerPart === '*') {
						layers = Array.from({ length: maxLayer }, (_, i) => i);
					} else if (layerPart.includes('-')) {
						const [start, end] = layerPart.split('-').map(Number);
						if (start > end || end >= maxLayer) return null;
						layers = Array.from({ length: end - start + 1 }, (_, i) => start + i);
					} else {
						const layer = Number(layerPart);
						if (layer >= maxLayer) return null;
						layers = [layer];
					}

					const headStr = headPart.substring(1);
					if (headStr === '*' || headStr === 'x') {
						for (const layer of layers) {
							result[layer].fill(true);
						}
					} else {
						const head = Number(headStr);
						if (head >= maxHead) return null;
						for (const layer of layers) {
							result[layer][head] = true;
						}
					}
				}

				return result;
			} catch (e) {
				console.error('Error parsing head string:', e);
				return null;
			}
		},

		isHeadSelected(model, layer, head) {
			// First check if we have parsed selections for this model
			if (!this.head_selections_arr[model]) {
				console.warn(`No parsed head selections found for model: ${model}`);
				return false;
			}

			try {
				// Verify layer and head are within bounds
				const parsedSelections = this.head_selections_arr[model];
				if (!Array.isArray(parsedSelections) ||
					!Array.isArray(parsedSelections[layer]) ||
					typeof parsedSelections[layer][head] === 'undefined') {
					console.warn(
						`Invalid layer/head combination for ${model}: L${layer}H${head}`,
						`Max bounds: L${parsedSelections.length - 1}H${parsedSelections[0]?.length - 1}`
					);
					return false;
				}

				return parsedSelections[layer][head];
			} catch (e) {
				console.error('Error checking head selection:', e);
				console.log('Model:', model, 'Layer:', layer, 'Head:', head);
				return false;
			}
		},

		isValidHeadSelection(model) {
			return this.head_selections_arr[model] !== null;
		},
		// ##     ## ########  ##
		// ##     ## ##     ## ##
		// ##     ## ##     ## ##
		// ##     ## ########  ##
		// ##     ## ##   ##   ##
		// ##     ## ##    ##  ##
		//  #######  ##     ## ########

		// Modified URL handling
		updateURL() {
			const params = new URLSearchParams();

			if (this.filters.selected.functions.length > 0) {
				params.set('functions', this.filters.selected.functions.join('~'));
			}

			if (this.prompts.selected.length > 0) {
				params.set('prompts', this.prompts.selected.join('~'));
			}

			if (this.filters.selected.models.length > 0) {
				params.set('models', this.filters.selected.models.join('~'));
			}

			if (this.filters.selected.models.length > 0) {
				for (const model of Object.keys(this.head_selections_str)) {
					params.set(
						`${URL_HEAD_PREFIX}${model}`,
						this.head_selections_str[model].replaceAll("*", "x").replaceAll(" ", "").split(',').join('~')
					);
				}
			}

			const newURL = `${window.location.pathname}?${params.toString()}`;
			history.replaceState(null, '', newURL);
		},

		readURL() {
			const params = new URLSearchParams(window.location.search);

			this.filters.selected.functions = params.get('functions')?.split('~') || [];

			this.prompts.selected = params.get('prompts')?.split('~') || [];

			this.filters.selected.models = params.get('models')?.split('~') || [];

			try {
				this.head_selections_str = {};
				for (const [key, value] of params) {
					if (key.startsWith(URL_HEAD_PREFIX)) {
						const model = key.substring(URL_HEAD_PREFIX.length);
						this.head_selections_str[model] = value.split('~').join(', ');
					}
				}
			} catch (e) {
				console.error('Error parsing head selections from URL:', e);
			}
		},
		selectPromptsFromURL() {
			if (!this.isGridReady || this.prompts.selected.length === 0) return;

			const promptSet = new Set(this.prompts.selected);
			this.prompts.grid.api.forEachNode((node) => {
				if (promptSet.has(node.data.hash)) {
					node.setSelected(true);
				}
			});
		},
		getImageUrl(image) {
			return this.getFilterUrl('all', [image.model], [image.promptHash], [image.layer], [image.head], [image.function]);
		},

		getSinglePropertyFilterUrl(type, value) {
			const params = new URLSearchParams(window.location.search);
			params.set(type, value); // This preserves other params while updating just this one
			return `${window.location.pathname}?${params.toString()}`;
		},

		getFilterUrl(type, ...values) {
			const params = new URLSearchParams(window.location.search);

			if (type === 'all') {
				params.set('models', values[0].join('~'));
				params.set('prompts', values[1].join('~'));
				params.set('layers', values[2].join('~'));
				params.set('heads', values[3].join('~'));
				params.set('functions', values[4].join('~'));
			} else {
				params.set(type, values.flat().join('~'));
			}

			return `${window.location.pathname}?${params.toString()}`;
		},

		// ##     ## ######## ##       ########  ######## ########
		// ##     ## ##       ##       ##     ## ##       ##     ##
		// ##     ## ##       ##       ##     ## ##       ##     ##
		// ######### ######   ##       ########  ######   ########
		// ##     ## ##       ##       ##        ##       ##   ##
		// ##     ## ##       ##       ##        ##       ##    ##
		// ##     ## ######## ######## ##        ######## ##     ##

		toggleDarkMode() {
			console.log('Toggling dark mode');  // Add this debug line
			this.isDarkMode = !this.isDarkMode;
			localStorage.setItem('darkMode', this.isDarkMode);
			// Force a DOM update
			this.$nextTick(() => {
				document.documentElement.classList.toggle('dark-mode', this.isDarkMode);
			});
		},
		clearAllSelections() {
			// Clear prompts selection
			if (this.prompts.grid.api) {
				this.prompts.grid.api.deselectAll();
			}

			// Clear models selection
			if (this.models.grid.api) {
				this.models.grid.api.deselectAll();
			}

			// Clear function selections
			this.filters.selected.functions = [];

			// Reset head selections
			this.head_selections_str = {};

			// Update URL to reflect cleared state
			this.updateURL();
		},
		isIndeterminate(category) {
			const items = this.filters.available[category];
			const selectedItems = this.filters.selected[category];
			return selectedItems.length > 0 && selectedItems.length < items.length;
		},
		isChecked(category) {
			const items = this.filters.available[category];
			const selectedItems = this.filters.selected[category];
			return selectedItems.length === items.length && items.length > 0;
		},
		toggleSelectAll(category, event) {
			const checked = event.target.checked;
			this.filters.selected[category] = checked ? [...this.filters.available[category]] : [];
		},
		async loadData() {
			try {
				await this.loadModels();
				await Promise.all([
					this.loadAllPrompts(),
					this.loadFunctions()
				]);

				this.updateLayersAndHeads();
			} catch (error) {
				console.error('Error loading data:', error);
			}
		},
		async loadModels() {
			this.loading = true;
			console.log('Loading models...');
			const models = await fileOps.fetchJsonL(`${DATA_DIR}/models.jsonl`);
			this.models.configs = {};
			for (const model of models) {
				this.models.configs[model["model_name"]] = model;
			}
			this.filters.available.models = Object.keys(this.models.configs);
			console.log('Models:', this.filters.available.models);
			this.loading = false;

			// After loading models, initialize head selections
			this.filters.selected.models.forEach(model => {
				if (!this.head_selections_str[model]) {
					this.head_selections_str[model] = 'L*H*';
				}
			});
		},
		async loadFunctions() {
			const functions = await fileOps.fetchJsonL(`${DATA_DIR}/figures.jsonl`);
			console.log('Functions:', functions);
			this.filters.available.functions = functions.reduce(
				(acc, item) => {
					acc[item.name] = item;
					return acc;
				},
				{},
			);
			console.log('this.filters.available.functions:', this.filters.available.functions);
		},
		onFirstDataRendered(params) {
			this.selectPromptsFromURL();
		},
		// Handle selection change in ag-Grid
		onSelectionChanged() {
			const selectedNodes = this.prompts.grid.api.getSelectedRows();
			this.prompts.selected = selectedNodes.map(node => node.hash);
			this.updateURL();
		},
		// Update layers and heads based on selected models
		updateLayersAndHeads() {
			// get all layer and head counts
			let mdl_n_layers = [];
			let mdl_n_heads = [];
			for (const model of this.filters.selected.models) {
				const config = this.models.configs[model];
				if (config) {
					mdl_n_layers.push(config.n_layers);
					mdl_n_heads.push(config.n_heads);
				}
			}
			// get the max layer and head counts, generate lists
			this.filters.available.layers = [];
			this.filters.available.heads = [];

			for (let i = 0; i < _.max(mdl_n_layers); i++) {
				this.filters.available.layers.push(i.toString());
			}
			for (let i = 0; i < _.max(mdl_n_heads); i++) {
				this.filters.available.heads.push(i.toString());
			}
		},

		// ##     ##  #######  ########  ######## ##        ######
		// ###   ### ##     ## ##     ## ##       ##       ##    ##
		// #### #### ##     ## ##     ## ##       ##       ##
		// ## ### ## ##     ## ##     ## ######   ##        ######
		// ##     ## ##     ## ##     ## ##       ##             ##
		// ##     ## ##     ## ##     ## ##       ##       ##    ##
		// ##     ##  #######  ########  ######## ########  ######
		getHeadSelectionCount(model) {
			const parsed = this.head_selections_arr[model];
			if (!parsed) return 0;
			return parsed.reduce((acc, layer) =>
				acc + layer.reduce((sum, isSelected) => sum + (isSelected ? 1 : 0), 0), 0);
		},
		getTotalHeads(model) {
			const config = this.models.configs[model];
			return config ? config.n_layers * config.n_heads : 0;
		},
		setupModelTable() {
			const columnDefs = [
				{
					headerName: 'Model',
					field: 'model_name',
					sort: 'asc',
					width: 150
				},
				{
					headerName: 'd_model',
					field: 'd_model',
					width: 90,
					filter: 'agNumberColumnFilter'
				},
				{
					headerName: 'n_layers',
					field: 'n_layers',
					width: 90,
					filter: 'agNumberColumnFilter'
				},
				{
					headerName: 'n_heads',
					field: 'n_heads',
					width: 90,
					filter: 'agNumberColumnFilter'
				},
				{
					headerName: 'Selected',
					valueGetter: (params) => {
						return `${this.getHeadSelectionCount(params.data.model_name)} / ${this.getTotalHeads(params.data.model_name)}`;
					},
					width: 100
				},
				{
					headerName: 'Head Grid',
					field: 'head_grid',
					width: 150,
					cellRenderer: (params) => {
						const model = params.data.model_name;
						const div = document.createElement('div');
						div.className = 'head-grid';
						div.setAttribute('data-model', model); // Add data attribute for updates

						const n_heads = params.data.n_heads;
						const n_layers = params.data.n_layers;

						for (let h = 0; h < n_heads; h++) {
							const layerDiv = document.createElement('div');
							layerDiv.className = 'headsGrid-col';

							for (let l = 0; l < n_layers; l++) {
								const cell = document.createElement('div');
								cell.className = `headsGrid-cell ${this.isHeadSelected(model, l, h) ? 'headsGrid-cell-selected' : 'headsGrid-cell-empty'}`;
								cell.setAttribute('data-layer', l);
								cell.setAttribute('data-head', h);
								layerDiv.appendChild(cell);
							}

							div.appendChild(layerDiv);
						}

						return div;
					}
				},
				{
					headerName: 'Head Selection',
					field: 'head_selection',
					editable: true,
					width: 200,
					cellEditor: 'agTextCellEditor',
					cellEditorParams: {
						maxLength: 50
					},
					valueSetter: params => {
						const newValue = params.newValue;
						const model = params.data.model_name;

						// Update the head selection in Vue's data
						params.context.componentParent.head_selections_str[model] = newValue;

						// Update the cell class for validation styling
						const isValid = params.context.componentParent.isValidHeadSelection(model);
						const cell = params.api.getCellRendererInstances({
							rowNodes: [params.node],
							columns: [params.column]
						})[0];

						if (cell) {
							const element = cell.getGui();
							if (isValid) {
								element.classList.remove('invalid-selection');
							} else {
								element.classList.add('invalid-selection');
							}
						}

						// Force refresh of the head grid cell
						const gridCol = params.api.getColumnDef('head_grid');
						if (gridCol) {
							params.api.refreshCells({
								rowNodes: [params.node],
								columns: ['head_grid'],
								force: true
							});
						}

						return true;
					},
					valueGetter: params => {
						return params.context.componentParent.head_selections_str[params.data.model_name] || 'L*H*';
					},
					cellClass: params => {
						const isValid = params.context.componentParent.isValidHeadSelection(params.data.model_name);
						return isValid ? '' : 'invalid-selection';
					}
				},
			];

			const modelGrid_options = {
				columnDefs: columnDefs,
				rowData: Object.values(this.models.configs),
				selection: {
					headerCheckbox: true,
					selectAll: 'filtered',
					checkboxes: true,
					mode: 'multiRow',
					enableClickSelection: true,
				},
				defaultColDef: {
					sortable: true,
					filter: true,
					resizable: true,
					floatingFilter: true,
					suppressKeyboardEvent: params => {
						// Allow all keyboard events in edit mode
						if (params.editing) {
							return false;
						}
						// Prevent default grid behavior for typing when not in edit mode
						if (params.event.key.length === 1 && !params.event.ctrlKey && !params.event.metaKey) {
							return false;
						}
						return true;
					},
				},
				context: {
					componentParent: this
				},
				onSelectionChanged: (event) => {
					const selectedRows = event.api.getSelectedRows();
					this.filters.selected.models = selectedRows.map(row => row.model_name);
				},
				onGridReady: (params) => {
					this.models.grid.api = params.api;
					// Select models from URL
					if (this.filters.selected.models.length > 0) {
						params.api.forEachNode(node => {
							if (this.filters.selected.models.includes(node.data.model_name)) {
								node.setSelected(true);
							}
						});
					}
				},
			};

			const modelGrid_div = document.querySelector('#modelGrid');
			this.models.grid.api = agGrid.createGrid(modelGrid_div, modelGrid_options);
		},
		refreshHeadGrids() {
			if (this.models.grid.api) {
				this.models.grid.api.refreshCells({
					columns: ['head_grid'],
					force: true
				});
			}
		},
		// ########  ########   #######  ##     ## ########  ########
		// ##     ## ##     ## ##     ## ###   ### ##     ##    ##   
		// ##     ## ##     ## ##     ## #### #### ##     ##    ##   
		// ########  ########  ##     ## ## ### ## ########     ##   
		// ##        ##   ##   ##     ## ##     ## ##           ##    
		// ##        ##    ##  ##     ## ##     ## ##           ##    
		// ##        ##     ##  #######  ##     ## ##           ##    

		async loadAllPrompts() {
			this.loading = true;
			console.log('Loading prompts...');
			this.prompts.all = {};

			for (const model of this.filters.available.models) {
				try {
					const modelPrompts = await fileOps.fetchJsonL(`${DATA_DIR}/${model}/prompts.jsonl`);
					for (const prompt of modelPrompts) {
						if (prompt.hash in this.prompts.all) {
							this.prompts.all[prompt.hash].models.push(model);
						} else {
							this.prompts.all[prompt.hash] = { ...prompt, models: [model] };
						}
					}
				} catch (error) {
					console.error(`Error loading prompts for model ${model}:`, error);
				}
			}
			console.log('loaded number of prompts:', Object.keys(this.prompts.all).length);
			this.loading = false;
		},
		// Initialize the ag-Grid table
		setupPromptTable() {
			const columnDefs = [
				{
					headerName: 'Prompt Text',
					field: 'text',
					sortable: true,
					filter: true,
					flex: 2,
					cellRenderer: (params) => {
						const eGui = document.createElement('div');
						// Replace tabs and newlines with spaces for display
						eGui.innerText = params.value.replace(/\s+/g, ' ');
						eGui.classList.add('prompt-text-cell');
						eGui.addEventListener('click', () => {
							navigator.clipboard.writeText(params.value);
						});

						eGui.addEventListener('contextmenu', (event) => {
							event.preventDefault();
							const newWindow = window.open();
							newWindow.document.write(`<pre>${params.value}</pre>`);
							newWindow.document.close();
							newWindow.document.title = `Prompt '${params.data.hash}'`;
						});

						return eGui;
					},
				},
				{
					headerName: 'Models', field: 'models', sortable: true, filter: true, width: 150,
					valueFormatter: (params) => params.value.join(', '),
				},
				{ headerName: 'Hash', field: 'hash', sortable: true, filter: true, width: 100 },
				{ headerName: 'Tokens', field: 'n_tokens', sortable: true, filter: 'agNumberColumnFilter', width: 80 },
				{ headerName: 'Dataset', field: 'meta.pile_set_name', sortable: true, filter: true, width: 150 },
			];

			// Grid options
			const promptGrid_options = {
				columnDefs: columnDefs,
				rowData: Object.values(this.prompts.all),
				pagination: true,
				enableCellTextSelection: true,
				paginationPageSize: 20,
				paginationPageSizeSelector: [5, 10, 20, 50, 100, 500],
				selection: {
					headerCheckbox: true,
					selectAll: 'filtered',
					checkboxes: true,
					mode: 'multiRow',
					enableClickSelection: true,
				},

				defaultColDef: {
					sortable: true,
					filter: true,
					resizable: true,
					floatingFilter: true
				},
				onSelectionChanged: this.onSelectionChanged.bind(this),
				onFirstDataRendered: this.onFirstDataRendered.bind(this),
				onGridReady: (params) => {
					this.prompts.grid.api = params.api;
					this.isGridReady = true;
					this.selectPromptsFromURL();
				},
			};

			const promptGrid_div = document.querySelector('#promptGrid');
			this.prompts.grid.api = agGrid.createGrid(promptGrid_div, promptGrid_options);
		},

		// ########  ####  ######  ########  ##          ###    ##    ##
		// ##     ##  ##  ##    ## ##     ## ##         ## ##    ##  ##
		// ##     ##  ##  ##       ##     ## ##        ##   ##    ####
		// ##     ##  ##   ######  ########  ##       ##     ##    ##
		// ##     ##  ##        ## ##        ##       #########    ##
		// ##     ##  ##  ##    ## ##        ##       ##     ##    ##
		// ########  ####  ######  ##        ######## ##     ##    ##

		// Display images based on selected criteria
		async displayImages() {
			this.loading = true;
			this.images.requested = true;
			this.images.visible = [];

			// Calculate total images based on parsed head selections
			let totalImages = 0;
			for (const model of this.filters.selected.models) {
				totalImages += this.getHeadSelectionCount(model) * this.prompts.selected.length * this.filters.selected.functions.length;
			}
			this.images.expected = totalImages;

			// Load images based on parsed head selections
			for (const model of this.filters.selected.models) {
				const config = this.models.configs[model];
				const rawString = this.head_selections_str[model] || 'L*H*';
				const parsedHeads = this.parseHeadString(rawString, config.n_layers, config.n_heads);
				if (!parsedHeads) {
					console.warn(`Invalid head selection for ${model}: "${rawString}"`);
					continue;
				}

				// Iterate over all layers and heads
				for (let layer = 0; layer < config.n_layers; layer++) {
					for (let head = 0; head < config.n_heads; head++) {
						if (!parsedHeads[layer][head]) {
							continue;
						}
						// Now for each selected prompt and function:
						for (const promptHash of this.prompts.selected) {
							for (
								const func_name of
								this.filters.selected.functions
							) {
								let func = this.filters.available.functions[func_name];
								if (!func) {
									console.warn(`Function not found ${func_name}`, typeof func_name, JSON.stringify(func_name), func_name, this.filters.available.functions);
								}
								const basePath = `${DATA_DIR}/${model}/prompts/${promptHash}/L${layer}/H${head}`;

								// get the figure format from metadata
								let figure_format = func.figure_save_fmt;
								if (!figure_format) {
									// as a fallback, look for all valid formats
									figure_format = await fileOps.figureExists(`${basePath}/${func_name}`);
									console.log('could not find figure format for func name', func_name, 'found', figure_format);
								}

								if (figure_format) {
									// Create figure entry
									const figure_meta = {
										name: `${model} - Prompt ${promptHash} - L${layer}H${head} - ${func_name}`,
										model: model,
										promptHash: promptHash,
										layer: layer,
										head: head,
										function: func_name,
										figure_format: figure_format,
									};

									if (figure_format === 'svgz') {
										const svgText = await fileOps.fetchAndDecompressSvgz(`${basePath}/${func_name}.svgz`);
										if (svgText) {
											this.images.visible.push({
												content: svgText,
												...figure_meta,
											});
										}
									} else {
										const imglink = `<img src="${basePath}/${func_name}.${figure_format}" alt="${figure_meta.name}">`;
										this.images.visible.push({
											content: imglink,
											...figure_meta,
										});
									}
								}
							}
						}
					}
				}
			}

			this.images.upToDate = true;
			this.loading = false;
		},
		openMetadata(func) {
			const newWindow = window.open('', '_blank');
			let content = `<div style="font-family: sans-serif; line-height:1.4;">`;
			if (func.doc) {
				content += `<p><strong>Description:</strong> ${func.doc}</p>`;
			}
			if (func.figure_save_fmt) {
				content += `<p><strong>Format:</strong> ${func.figure_save_fmt}</p>`;
			}
			if (func.source) {
				content += `<p><strong>Source:</strong> ${func.source}</p>`;
			}
			content += `</div>`;
			newWindow.document.write(content);
			newWindow.document.close();
			newWindow.document.title = `Metadata for ${func.name}`;
		},

		regenerateColors() {
			if (!this.visualization.colorBy) return;
			
			// Get unique values for the selected property
			const uniqueValues = [...new Set(this.images.visible.map(img => img[this.visualization.colorBy]))];
			
			// Generate new random colors
			this.visualization.colorMap = {};
			uniqueValues.forEach(value => {
				this.visualization.colorMap[value] = colorUtils.getRandomColor();
			});
		},
		
		
		getBorderColor(image) {
			if (!this.visualization.colorBy || !image) return 'transparent';
			const value = image[this.visualization.colorBy];
			return this.visualization.colorMap[value] || 'transparent';
		},
	},



	//  ######   #######  ##     ## ########  ##     ## ######## ######## ########
	// ##    ## ##     ## ###   ### ##     ## ##     ##    ##    ##       ##     ##
	// ##       ##     ## #### #### ##     ## ##     ##    ##    ##       ##     ##
	// ##       ##     ## ## ### ## ########  ##     ##    ##    ######   ##     ##
	// ##       ##     ## ##     ## ##        ##     ##    ##    ##       ##     ##
	// ##    ## ##     ## ##     ## ##        ##     ##    ##    ##       ##     ##
	//  ######   #######  ##     ## ##         #######     ##    ######## ########

	computed: {
		uniqueDatasets() {
			return [
				...new Set(
					Object.values(this.prompts.all).map(prompt => prompt.meta.pile_set_name).filter(Boolean)
				)
			];
		},
		head_selections_arr() {
			// model -> boolean[][] mapping for efficient lookup
			let parsed = {};

			for (const model in this.head_selections_str) {
				const config = this.models.configs[model];
				if (!config) {
					console.warn(`No config found for model: ${model}`);
					parsed[model] = null;
					continue;
				}

				const parsedHeads = this.parseHeadString(
					this.head_selections_str[model] || 'L*H*',
					config.n_layers,
					config.n_heads
				);

				if (!parsedHeads) {
					console.warn(
						`Invalid head selection for ${model}: "${this.head_selections_str[model]}"`
					);
				}

				parsed[model] = parsedHeads;
			}

			return parsed;
		},
		sortedImages() {
			if (!this.visualization.sortBy) return this.images.visible;
			
			return [...this.images.visible].sort((a, b) => {
				const valueA = a[this.visualization.sortBy];
				const valueB = b[this.visualization.sortBy];
				
				// Handle numeric values for layer and head
				if (['layer', 'head'].includes(this.visualization.sortBy)) {
					const numA = Number(valueA);
					const numB = Number(valueB);
					return this.visualization.sortOrder === 'asc' 
						? numA - numB 
						: numB - numA;
				}
				
				// Handle string values
				const comparison = String(valueA).localeCompare(String(valueB));
				return this.visualization.sortOrder === 'asc' ? comparison : -comparison;
			});
		},
	},


	// ##      ##    ###    ########  ######  ##     ##
	// ##  ##  ##   ## ##      ##    ##    ## ##     ##
	// ##  ##  ##  ##   ##     ##    ##       ##     ##
	// ##  ##  ## ##     ##    ##    ##       #########
	// ##  ##  ## #########    ##    ##       ##     ##
	// ##  ##  ## ##     ##    ##    ##    ## ##     ##
	//  ###  ###  ##     ##    ##     ######  ##     ##

	// Watch for changes in selected models to load prompts and update layers and heads
	watch: {
		'filters.selected': {
			deep: true,
			handler() {
				this.images.upToDate = false;
				this.updateURL();
			}
		},
		'prompts.selected': {
			handler() {
				this.images.upToDate = false;
			}
		},
		'head_selections_str': {
			deep: true,
			handler(newValue) {
				Object.keys(newValue).forEach(model => {
					if (!this.models.configs[model]) {
						console.warn(`Attempting to update head selections for unknown model: ${model}`);
						return;
					}
				});
				this.images.upToDate = false;
				this.updateURL();
				this.refreshHeadGrids();
			}
		},
		'filters.selected.models': {
			deep: true,
			handler(newModels) {
				// Initialize head selections for new models
				newModels.forEach(model => {
					if (!this.head_selections_str[model]) {
						this.head_selections_str[model] = 'L*H*';
					}
				});
				this.updateURL();
			}
		},
		'visualization.colorBy': {
			handler(newValue) {
				if (newValue) {
					this.regenerateColors();
				}
			}
    	},
	},

	// Lifecycle hook when component is mounted
	async mounted() {
		console.log('Mounting app:', this);
		const savedDarkMode = localStorage.getItem('darkMode');
		if (savedDarkMode !== null) {
			this.isDarkMode = savedDarkMode === 'true';
		}
		if (this.isDarkMode) {
			document.documentElement.classList.add('dark-mode');
		}
		this.readURL(); // Read filters from URL first
		await this.loadData(); // Load models, prompts, and functions
		this.setupModelTable(); // Initialize the model grid
		this.setupPromptTable(); // Initialize the prompts grid
		console.log('Mounted app:', this);
	}
});

``````{ end_of_file="pattern_lens/frontend/app.js" }

``````{ path="pattern_lens/frontend/index.template.html"  }
<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <!-- <meta name="viewport" content="width=device-width, initial-scale=1.0"> -->
    <title>Attention Pattern Analysis</title>

    <link rel="icon" type="image/svg+xml"
        href='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100"><rect width="100" height="100" fill="%23000" stroke="%23333" stroke-width="1"/><path d="M0,0 L100,100 L0,100 Z" fill="rgba(0,255,255,0.2)"/><path d="M0,0 L15,15 L0,15 Z" fill="rgba(255,255,255,0.3)"/><path d="M15,15 L35,35 L15,35 Z" fill="rgba(255,255,255,0.3)"/><path d="M35,35 L60,60 L35,60 Z" fill="rgba(255,255,255,0.3)"/><path d="M60,60 L100,100 L60,100 Z" fill="rgba(255,255,255,0.3)"/></svg>'>

    <script src="https://cdnjs.cloudflare.com/ajax/libs/vue/3.2.31/vue.global.min.js"></script>
    <!-- Include lodash library for utility functions -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/lodash.js/4.17.21/lodash.min.js"></script>
    <!-- Include pako library for decompressing SVGZ files -->
    <script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.0.4/pako.min.js"></script>
    <!-- For decompressing SVGZ files -->
    <!-- Include ag-Grid library for prompts table -->
    <!-- <script src="https://cdnjs.cloudflare.com/ajax/libs/ag-grid/32.1.0/ag-grid-community.min.js"></script> -->
    <script src="https://cdn.jsdelivr.net/npm/ag-grid-community@32.2.0/dist/ag-grid-community.min.js"></script>

    <style src="style.css"></style>
</head>

<body>
    <!-- Root element for Vue app -->
    <div id="app" class="container" :class="{ 'dark-mode': isDarkMode }">
        <div class="header-container">
            <h1 class="header-title">Attention Pattern Analysis</h1>
            <a href="https://github.com/mivanit/pattern-lens/">built with pattern-lens v$$PATTERN_LENS_VERSION$$</a>
            <div class="header-controls">
                <button class="btn btn-header dark-mode-button" @click="toggleDarkMode">
                    <span>Dark Mode</span>
                    <div class="dark-mode-toggle">
                        <div class="dark-mode-icon" style="left: 4px">☀️</div>
                        <div class="dark-mode-icon" style="right: 4px">🌙</div>
                    </div>
                </button>
                <button class="btn btn-header" @click="clearAllSelections">
                    🗑️ Clear All Selections
                </button>
            </div>
        </div>

        <div class="main-selection-content">
            <!-- Top section with functions and models side by side -->
            <div class="top-filters">
                <!-- Functions Filter -->
                <div class="functions-filter">
                    <div class="filter-label">
                        <input type="checkbox" id="select-all-functions"
                            :indeterminate.prop="isIndeterminate('functions')" :checked="isChecked('functions')"
                            @change="toggleSelectAll('functions', $event)">
                        <label for="select-all-functions">Functions:</label>
                        <span class="counter">
                            {{ filters.selected.functions.length }} / {{ Object.keys(filters.available.functions).length
                            }}
                        </span>
                    </div>
                    <div class="checkbox-list">
                        <div v-for="(func, name) in filters.available.functions" :key="name" class="checkbox-item">
                            <input type="checkbox" :id="'func-' + name" :value="name"
                                v-model="filters.selected.functions">
                            <label :for="'func-' + name">
                                <span class="function-name">{{ name }}</span>
                                <span class="function-info">ℹ️
                                    <div class="function-tooltip">
                                        <div v-if="func.figure_save_fmt"><strong>Format:</strong> {{
                                            func.figure_save_fmt }}</div>
                                        <div v-if="func.source"><strong>Source:</strong> {{ func.source }}</div>
                                        <div v-if="func.doc"> {{ func.doc }} </div>
                                    </div>
                                </span>
                            </label>
                        </div>
                    </div>
                </div>

                <!-- Model Selection -->
                <div id="modelGrid" class="ag-theme-alpine" style="height: 300px; width: 100%;"></div>
            </div>

            <!-- Prompts Table (full width) -->
            <div class="prompt-table">
                <div class="prompts-info">
                    <div class="prompt-counter">
                        Selected Prompts: {{ prompts.selected.length }} / {{ Object.keys(prompts.all).length }}
                    </div>
                    <div class="dataset-list-container">
                        <div class="dataset-list">
                            Hover here to see unique datasets
                            <div class="dataset-list-content">
                                <ul>
                                    <li v-for="dataset in uniqueDatasets" :key="dataset">{{ dataset }}</li>
                                </ul>
                            </div>
                        </div>
                    </div>
                </div>
                <div id="promptGrid" class="ag-theme-alpine"></div>
            </div>
        </div>

        <!-- image display button and size controls -->
        <div class="image-controls-container">
            <div class="visualization-controls">
                <div class="color-controls">
                    <label for="colorBy">Color borders by:</label>
                    <select id="colorBy" v-model="visualization.colorBy">
                        <option value="">None</option>
                        <option value="model">Model</option>
                        <option value="function">Function</option>
                        <option value="layer">Layer</option>
                        <option value="head">Head</option>
                        <option value="promptHash">Prompt</option>
                    </select>
                    <button class="btn btn-secondary" @click="regenerateColors" :disabled="!visualization.colorBy"
                        title="Generate new random colors">
                        🔄 Regenerate Colors
                    </button>
                </div>
                <div class="sort-controls">
                    <label for="sortBy">Sort by:</label>
                    <select id="sortBy" v-model="visualization.sortBy">
                        <option value="">Default Order</option>
                        <option value="model">Model</option>
                        <option value="function">Function</option>
                        <option value="layer">Layer</option>
                        <option value="head">Head</option>
                        <option value="promptHash">Prompt</option>
                    </select>
                    <select v-model="visualization.sortOrder">
                        <option value="asc">Ascending</option>
                        <option value="desc">Descending</option>
                    </select>
                </div>
            </div>
            <div class="image-controls">
                <div class="image-controls-display">
                    <button class="btn" :class="{ 'btn-primary': !images.upToDate, 'btn-secondary': images.upToDate }"
                        @click="displayImages">
                        {{ images.upToDate ? 'Images Up to Date' : 'Display Images' }}
                    </button>
                    <div class="progress-wrapper">
                        <span class="progress-status" v-if="images.expected > 0"> {{ images.visible.length || 'N/A' }} /
                            {{ images.expected }} images</span>
                        <div class="progress-bar" v-if="loading || images.visible.length > 0">
                            <div class="progress-bar-fill" :class="{ 'loading': loading, 'complete': !loading }"
                                :style="{ width: `${(images.visible.length / images.expected) * 100}%` }">
                            </div>
                        </div>
                    </div>
                </div>
                <div class="image-controls-size" v-if="images.visible.length > 0">
                    <label for="resizeSlider">Images per row:</label>
                    <input type="range" id="resizeSlider" class="resize-slider" v-model.number="images.perRow" min="1"
                        max="16" step="1">
                    <input type="number" class="resize-input" v-model.number="images.perRow" min="1" max="64">
                </div>
            </div>
        </div>

        <!-- images are loading -->
        <div v-if="loading" class="loading">Loading...</div>

        <!-- actual images display -->
        <!-- actual images display -->
        <div v-else-if="images.visible.length > 0" class="images" :style="{ 'grid-template-columns': `repeat(${images.perRow}, 1fr)` }">
            <div
                v-for="image in sortedImages"
                class="image-container"
                :style="{ borderColor: getBorderColor(image) }"
            >
                <p v-if="images.perRow <= 4" class="image-info">
                    <a :href="getSinglePropertyFilterUrl('models', image.model)">{{ image.model }}</a> -
                    <a :href="getSinglePropertyFilterUrl('functions', image.function)">{{ image.function }}</a> -
                    <a :href="getSinglePropertyFilterUrl('layers', image.layer)">L{{ image.layer }}</a> -
                    <a :href="getSinglePropertyFilterUrl('heads', image.head)">H{{ image.head }}</a> -
                    <a :href="getSinglePropertyFilterUrl('prompts', image.promptHash)">{{ image.promptHash }}</a>
                </p>
                <a :href="getImageUrl(image)" class="img-container" v-html="image.content"
                :title="images.perRow > 4 ? image.name : ''">
                </a>
            </div>
        </div>


        <!-- no images found -->
        <div v-else-if="images.requested" class="error">No images found for the selected criteria.</div>

    </div>

    <script src="util.js"></script>
    <script src="app.js"></script>
    <script>
        const DATA_DIR = '.';
        const FIGURE_FORMATS = ['svg', 'svgz', 'png'];
        const URL_HEAD_PREFIX = 'heads-';
        // Mount the Vue app to the DOM element with id="app"
        app.mount('#app');
    </script>
</body>

</html>
``````{ end_of_file="pattern_lens/frontend/index.template.html" }

``````{ path="pattern_lens/frontend/style.css"  }
/* CSS Variables */
:root {
	/* Colors */
	--primary: #007bff;
	--primary-hover: #0056b3;
	--secondary: #6c757d;
	--secondary-hover: #545b62;
	--success: #28a745;
	--border: #ccc;
	--text-muted: #666;
	--bg-light: #f0f0f0;
	--bg-white: #fff;
	--shadow: rgba(0, 0, 0, 0.1);
	--text-color: #000;

	/* Dark mode colors */
	--dark-bg: #1a1a1a;
	--dark-text: #ffffff;
	--dark-border: #444;
	--dark-bg-light: #2d2d2d;
	--dark-shadow: rgba(0, 0, 0, 0.3);

	/* Spacing */
	--space-xs: 3px;
	--space-sm: 5px;
	--space-md: 10px;
	--space-lg: 20px;

	/* Layout */
	--border-radius: 4px;
	--container-max-width: 1200px;
	--checkbox-size: 12px;
}

/* Base Styles */
body {
	font-family: Arial, sans-serif;
	line-height: 1.4;
	margin: 0;
	padding: var(--space-md);
}

.container {
	max-width: var(--container-max-width);
	margin: 0 auto;
}

/* Header Styles */
.header-container {
	display: flex;
	justify-content: space-between;
	align-items: center;
	margin-bottom: 1rem;
}

.header-title {
	margin: 0;
}

.header-controls {
	display: flex;
	gap: 1rem;
	align-items: center;
}

/* Layout Components */
.main-selection-content {
	display: flex;
	flex-direction: column;
	border: 2px solid var(--border);
	height: 800px;
	min-height: 400px;
	resize: vertical;
	overflow: hidden;
}

.top-filters {
	display: flex;
	gap: var(--space-md);
	height: 350px;
	border-bottom: 1px solid var(--border);
	min-height: 100px;
	max-height: 80vh;
	padding: var(--space-md);
	resize: vertical;
	position: relative;
	overflow: auto;
}

/* Functions Filter */
.functions-filter {
	width: 200px;
	min-width: 100px;
	max-width: 500px;
	display: flex;
	flex-direction: column;
	border: 1px solid var(--border);
	padding: var(--space-md);
	border-radius: var(--border-radius);
	flex-shrink: 0;
}

/* Filter Components */
.filter-item {
	margin-bottom: var(--space-sm);
	border: 1px solid var(--border);
	padding: var(--space-sm);
	border-radius: var(--border-radius);
}

.filter-label {
	display: flex;
	align-items: center;
	justify-content: space-between;
	margin-bottom: var(--space-xs);
}

/* Checkbox Lists */
.checkbox-list {
	border: 1px solid var(--border);
	padding: var(--space-xs);
	flex: 1;
	overflow-y: auto;
	overflow-x: visible;
}

.checkbox-item {
	position: relative;
	display: flex;
	align-items: center;
	margin-bottom: 1px;
	line-height: 1;
	width: 100%;
}

.checkbox-item label {
	display: flex;
	align-items: center;
	justify-content: space-between;
	width: 100%;
	margin-left: 4px;
}

.function-name {
	flex-grow: 1;
	margin-right: 8px;
}

input[type="checkbox"] {
	margin: 0 0.2em 0 0;
	width: var(--checkbox-size);
	height: var(--checkbox-size);
	vertical-align: middle;
}

/* Head Grid */
.head-grid {
	display: flex;
	gap: 1px;
	margin: 0 8px;
	height: 100%;
	align-items: center;
}

.headsGrid-col {
	display: flex;
	flex-direction: column;
	gap: 1px;
	height: 100%;
	justify-content: center;
}

.headsGrid-cell {
	width: 5px;
	height: 5px;
	margin: 0.5px;
	transition: background-color 0.2s ease;
}

.headsGrid-cell-selected {
	background-color: #2a1fee;
}

.headsGrid-cell-empty {
	background-color: #ac9a9a;
}

/* Model Grid */
#modelGrid {
	flex: 1;
	min-width: 200px;
	overflow: auto;
}

/* Prompt Table */
.prompt-table {
	flex: 1;
	min-height: 200px;
	display: flex;
	flex-direction: column;
	overflow: hidden;
	position: relative;
}

.prompts-info {
	border: 1px solid var(--border);
	padding: var(--space-sm);
	border-radius: var(--border-radius);
}

.prompt-counter {
	display: flex;
	align-items: center;
	justify-content: space-between;
}

.prompt-text-cell {
	cursor: pointer;
}

/* ag-Grid Customization */
.ag-theme-alpine {
	height: calc(100% - 3em) !important;
	width: 100% !important;
}

.ag-cell-edit-input {
	height: 100% !important;
	line-height: normal !important;
	padding: 0 8px !important;
}

.ag-cell:not(.invalid-selection) {
	background-color: transparent !important;
}

.ag-cell.invalid-selection {
	background-color: #ffeaea !important;
}

/* Dataset List */
.dataset-list-container {
	position: absolute;
	right: var(--space-md);
	top: 0.5em;
}

.dataset-list {
	position: relative;
	cursor: pointer;
	border: 1px solid var(--border);
	padding: 1px;
	border-radius: var(--border-radius);
	background-color: #f9f9f9;
}

.dataset-list-content {
	display: none;
	position: absolute;
	right: 0;
	top: 100%;
	background-color: var(--bg-white);
	border: 1px solid var(--border);
	padding: var(--space-xs) var(--space-lg) var(--space-xs) var(--space-xs);
	font-family: monospace;
	box-shadow: 0 4px 8px var(--shadow);
	z-index: 1000;
}

.dataset-list:hover .dataset-list-content {
	display: block;
}

/* Image Controls and Display */
.image-controls-container {
	margin: var(--space-lg) 0;
}

.image-controls {
	display: flex;
	align-items: center;
	justify-content: space-between;
	padding: var(--space-md);
	background-color: var(--bg-light);
	border-radius: 8px;
	box-shadow: 0 2px 4px var(--shadow);
}

.image-controls-display,
.image-controls-size {
	display: flex;
	align-items: center;
	width: 50%;
}

.image-controls-size {
	justify-content: flex-end;
}

.resize-slider {
	width: 250px;
	margin: 0 var(--space-md);
}

.resize-input {
	width: 75px;
	padding: 2px var(--space-sm);
}

/* Image Grid */
.images {
	display: grid;
	gap: var(--space-sm);
	margin-top: var(--space-md);
}

.image-container {
	text-align: center;
}

.image-info {
	font-size: 0.8em;
	margin: 2em 0 -1em;
}

.img-container svg,
.img-container img {
	width: 100%;
	height: 100%;
	object-fit: contain;
	image-rendering: pixelated;
	-ms-interpolation-mode: nearest-neighbor;
}

/* Buttons */
.btn {
	margin: 5px;
	padding: 8px 16px;
	font-size: 14px;
	font-weight: bold;
	border: none;
	border-radius: var(--border-radius);
	cursor: pointer;
	transition: background-color 0.3s ease;
}

.btn-primary {
	background-color: var(--primary);
	color: white;
}

.btn-primary:hover {
	background-color: var(--primary-hover);
}

.btn-secondary,
.btn-header,
.btn-dark-mode {
	background-color: var(--secondary);
	color: white;
}

.btn-secondary:hover,
.btn-header:hover,
.btn-dark-mode:hover {
	background-color: var(--secondary-hover);
}

/* Progress Bar */
.progress-bar {
	height: 12px;
	width: 200px;
	background: #ddd;
	border-radius: 6px;
	overflow: hidden;
}

.progress-bar-fill {
	height: 100%;
	transition: width 0.3s ease;
}

.progress-bar-fill.loading {
	background-color: var(--primary);
}

.progress-bar-fill.complete {
	background-color: var(--success);
}

.progress-wrapper {
	padding-left: 1rem;
}

/* Function Info Tooltip */
.function-info {
	position: relative;
	cursor: help;
	display: flex;
	align-items: center;
	margin-left: auto;
}

.function-tooltip {
	display: none;
	position: fixed;
	background-color: #eee;
	border: 1px solid #ccc;
	padding: 8px;
	width: 250px;
	z-index: 9999999;
	border-radius: 4px;
	box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
}

.function-info:hover .function-tooltip {
	display: block;
}

/* Dark Mode Styles */
.dark-mode {
	background-color: var(--dark-bg);
	color: var(--dark-text);
}

.dark-mode .container {
	background-color: var(--dark-bg);
}

.dark-mode .functions-filter,
.dark-mode .filter-item,
.dark-mode .checkbox-list {
	background-color: var(--dark-bg-light);
	border-color: var(--dark-border);
}

.dark-mode .ag-theme-alpine {
	--ag-background-color: var(--dark-bg-light);
	--ag-header-background-color: var(--dark-bg);
	--ag-odd-row-background-color: var(--dark-bg);
	--ag-header-foreground-color: var(--dark-text);
	--ag-foreground-color: var(--dark-text);
	--ag-border-color: var(--dark-border);
}

.dark-mode .top-controls {
	background-color: var(--dark-bg-light);
}

.dark-mode .dataset-list,
.dark-mode .dataset-list-content {
	background-color: var(--dark-bg-light);
	border-color: var(--dark-border);
	color: var(--dark-text);
}

.dark-mode .dataset-list-content ul {
	margin: 0;
	padding: 0.5em 1em;
	list-style-type: none;
}

.dark-mode .dataset-list-content li {
	color: var(--dark-text);
	padding: 0.2em 0;
}

.dark-mode .image-controls-container {
	background-color: transparent;
}

.dark-mode .image-controls {
	background-color: var(--dark-bg-light);
	border-color: var(--dark-border);
	box-shadow: 0 2px 4px var(--dark-shadow);
}

.dark-mode .resize-slider,
.dark-mode .resize-input {
	background-color: var(--dark-bg);
	border-color: var(--dark-border);
}

.dark-mode .resize-input {
	color: var(--dark-text);
}

.dark-mode .image-controls label {
	color: var(--dark-text);
}

.dark-mode .progress-bar {
	background-color: var(--dark-bg);
	border: 1px solid var(--dark-border);
}

.dark-mode .progress-status {
	color: var(--dark-text);
}

.dark-mode .function-tooltip {
	background-color: var(--dark-bg-light);
	border-color: var(--dark-border);
	color: var(--dark-text);
}

/* Utility Classes */
.loading,
.error {
	text-align: center;
	padding: var(--space-md);
}

.counter {
	font-size: 0.8em;
	color: var(--text-muted);
	margin-left: auto;
}

/* Dark Mode Toggle */
.dark-mode-toggle {
	position: relative;
	width: 60px;
	height: 30px;
	border-radius: 15px;
	background-color: #e2e8f0;
	cursor: pointer;
	transition: background-color 0.3s ease;
	border: none;
	padding: 0;
	overflow: hidden;
}

.dark-mode-toggle::before {
	content: "";
	position: absolute;
	top: 3px;
	left: 3px;
	width: 24px;
	height: 24px;
	border-radius: 50%;
	background-color: white;
	transition: transform 0.3s ease;
	z-index: 1;
}

.dark-mode .dark-mode-toggle {
	background-color: #4a5568;
}

.dark-mode .dark-mode-toggle::before {
	transform: translateX(30px);
}

.dark-mode-icon {
	position: absolute;
	top: 50%;
	transform: translateY(-50%);
	font-size: 14px;
	pointer-events: none;
	line-height: 1;
	display: flex;
	align-items: center;
	justify-content: center;
	width: 24px;
	height: 24px;
}

.sun-icon {
	left: 8px;
	opacity: 1;
}

.moon-icon {
	right: 8px;
	opacity: 1;
}

.dark-mode-button {
	display: flex;
	align-items: center;
	gap: 8px;
	cursor: pointer;
	color: inherit;
}

/* Visualization Controls Styles */
.visualization-controls {
    display: flex;
    align-items: center;
    gap: 20px;
    margin-top: 10px;
    padding: 5px 0;
    border-top: 1px solid var(--border);
}

.color-controls, .sort-controls {
    display: flex;
    align-items: center;
    gap: 8px;
}

.image-container {
    position: relative;
    border: 3px solid transparent;
    border-radius: 6px;
    padding: 4px;
    transition: border-color 0.2s ease;
}

.dark-mode .visualization-controls {
    border-color: var(--dark-border);
}

.color-legend {
    display: flex;
    flex-wrap: wrap;
    gap: 10px;
    margin-top: 10px;
    padding: 10px;
    background-color: var(--bg-light);
    border-radius: 6px;
}

.legend-item {
    display: flex;
    align-items: center;
    gap: 5px;
    font-size: 0.85em;
}

.legend-color {
    width: 16px;
    height: 16px;
    border-radius: 3px;
    border: 3px solid #000;
}

.dark-mode .color-legend {
    background-color: var(--dark-bg-light);
}

.dark-mode .legend-color {
    border-color: #fff;
}
``````{ end_of_file="pattern_lens/frontend/style.css" }

``````{ path="pattern_lens/frontend/util.js"  }
const fileOps = {
	async getDirectoryContents(path) {
		const response = await fetch(`${path}/index.txt`);
		const text = await response.text();
		return text.trim().split('\n');
	},
	async fileExists(path) {
		const response = await fetch(path, { method: 'HEAD' });
		return response.ok;
	},
	async fetchJson(path) {
		const response = await fetch(path);
		return response.json();
	},
	async fetchJsonL(path) {
		const response = await fetch(path);
		const text = await response.text();
		// allow for the last line being incomplete
		const text_split = text.trim().split('\n');
		let output = text_split.slice(0, -1).map(JSON.parse);
		try {
			output.push(JSON.parse(text_split[text_split.length - 1]));
		} catch (error) {
			console.error('Error parsing last line of JSONL:', error);
		}
		return output;
	},
	async fetchAndDecompressSvgz(path) {
		// returns null if file does not exist
		const response = await fetch(path);
		if (!response.ok) {
			return null;
		} else {
			const arrayBuffer = await response.arrayBuffer();
			const uint8Array = new Uint8Array(arrayBuffer);
			return pako.inflate(uint8Array, { to: 'string' });
		}
	},
	async figureExists(path) {
		for (const format of FIGURE_FORMATS) {
			fig_path = `${path}.${format}`;
			if (await this.fileExists(fig_path)) {
				return format;
			}
		}
		return null;
	}
};


const colorUtils = {
	getRandomColor() {
		// Generate vibrant colors with good contrast
		const hue = Math.floor(Math.random() * 360);
		return `hsl(${hue}, 70%, 60%)`;
	},
};


``````{ end_of_file="pattern_lens/frontend/util.js" }

``````{ path="pattern_lens/__init__.py"  }
""".. include:: ../README.md"""

__all__ = [
	"activations",
	"attn_figure_funcs",
	"consts",
	"figure_util",
	"figures",
	"indexes",
	"load_activations",
	"prompts",
	"server",
]

``````{ end_of_file="pattern_lens/__init__.py" }

``````{ path="pattern_lens/activations.py"  }
"""computing and saving activations given a model and prompts

# Usage:

from the command line:

```bash
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
```

from a script:

```python
from pattern_lens.activations import activations_main
activations_main(
	model_name="gpt2",
	save_path="demo/"
	prompts_path="data/pile_1k.jsonl",
)
```

"""

import argparse
import functools
import json
import re
from collections.abc import Callable
from dataclasses import asdict
from pathlib import Path
from typing import Literal, overload

import numpy as np
import torch
import tqdm
from jaxtyping import Float
from muutils.json_serialize import json_serialize
from muutils.misc.numerical import shorten_numerical_to_str

# custom utils
from muutils.spinner import SpinnerContext
from transformer_lens import (  # type: ignore[import-untyped]
	ActivationCache,
	HookedTransformer,
	HookedTransformerConfig,
)

# pattern_lens
from pattern_lens.consts import (
	ATTN_PATTERN_REGEX,
	DATA_DIR,
	DIVIDER_S1,
	DIVIDER_S2,
	SPINNER_KWARGS,
	ActivationCacheNp,
	ReturnCache,
)
from pattern_lens.indexes import (
	generate_models_jsonl,
	generate_prompts_jsonl,
	write_html_index,
)
from pattern_lens.load_activations import (
	ActivationsMissingError,
	augment_prompt_with_hash,
	load_activations,
)
from pattern_lens.prompts import load_text_data


# return nothing, but `stack_heads` still affects how we save the activations
@overload
def compute_activations(
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: Literal[None] = None,
	stack_heads: bool = False,
) -> tuple[Path, None]: ...
# return stacked heads in numpy or torch form
@overload
def compute_activations(
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: Literal["torch"] = "torch",
	stack_heads: Literal[True] = True,
) -> tuple[Path, Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]]: ...
@overload
def compute_activations(
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: Literal["numpy"] = "numpy",
	stack_heads: Literal[True] = True,
) -> tuple[Path, Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]]: ...
# return dicts in numpy or torch form
@overload
def compute_activations(
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: Literal["numpy"] = "numpy",
	stack_heads: Literal[False] = False,
) -> tuple[Path, ActivationCacheNp]: ...
@overload
def compute_activations(
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: Literal["torch"] = "torch",
	stack_heads: Literal[False] = False,
) -> tuple[Path, ActivationCache]: ...
# actual function body
def compute_activations(  # noqa: PLR0915
	prompt: dict,
	model: HookedTransformer | None = None,
	save_path: Path = Path(DATA_DIR),
	names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
	return_cache: ReturnCache = "torch",
	stack_heads: bool = False,
) -> tuple[
	Path,
	ActivationCacheNp
	| ActivationCache
	| Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
	| Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"]
	| None,
]:
	"""get activations for a given model and prompt, possibly from a cache

	if from a cache, prompt_meta must be passed and contain the prompt hash

	# Parameters:
	- `prompt : dict | None`
		(defaults to `None`)
	- `model : HookedTransformer`
	- `save_path : Path`
		(defaults to `Path(DATA_DIR)`)
	- `names_filter : Callable[[str], bool]|re.Pattern`
		a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
		(defaults to `ATTN_PATTERN_REGEX`)
	- `return_cache : Literal[None, "numpy", "torch"]`
		will return `None` as the second element if `None`, otherwise will return the cache in the specified tensor format. `stack_heads` still affects whether it will be a dict (False) or a single tensor (True)
		(defaults to `None`)
	- `stack_heads : bool`
		whether the heads should be stacked in the output. this causes a number of changes:
	- `npy` file with a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor saved for each prompt instead of `npz` file with dict by layer
	- `cache` will be a single `(n_layers, n_heads, n_ctx, n_ctx)` tensor instead of a dict by layer if `return_cache` is `True`
		will assert that everything in the activation cache is only attention patterns, and is all of the attention patterns. raises an exception if not.

	# Returns:
	```
	tuple[
		Path,
		Union[
			None,
			ActivationCacheNp, ActivationCache,
			Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"], Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"],
		]
	]
	```
	"""
	# check inputs
	assert model is not None, "model must be passed"
	assert "text" in prompt, "prompt must contain 'text' key"
	prompt_str: str = prompt["text"]

	# compute or get prompt metadata
	prompt_tokenized: list[str] = prompt.get(
		"tokens",
		model.tokenizer.tokenize(prompt_str),
	)
	prompt.update(
		dict(
			n_tokens=len(prompt_tokenized),
			tokens=prompt_tokenized,
		),
	)

	# save metadata
	prompt_dir: Path = save_path / model.cfg.model_name / "prompts" / prompt["hash"]
	prompt_dir.mkdir(parents=True, exist_ok=True)
	with open(prompt_dir / "prompt.json", "w") as f:
		json.dump(prompt, f)

	# set up names filter
	names_filter_fn: Callable[[str], bool]
	if isinstance(names_filter, re.Pattern):
		names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
	else:
		names_filter_fn = names_filter

	# compute activations
	cache_torch: ActivationCache
	with torch.no_grad():
		model.eval()
		# TODO: batching?
		_, cache_torch = model.run_with_cache(
			prompt_str,
			names_filter=names_filter_fn,
			return_type=None,
		)

	activations_path: Path
	# saving and returning
	if stack_heads:
		n_layers: int = model.cfg.n_layers
		key_pattern: str = "blocks.{i}.attn.hook_pattern"
		# NOTE: this only works for stacking heads at the moment
		# activations_specifier: str = key_pattern.format(i=f'0-{n_layers}')
		activations_specifier: str = key_pattern.format(i="-")
		activations_path = prompt_dir / f"activations-{activations_specifier}.npy"

		# check the keys are only attention heads
		head_keys: list[str] = [key_pattern.format(i=i) for i in range(n_layers)]
		cache_torch_keys_set: set[str] = set(cache_torch.keys())
		assert cache_torch_keys_set == set(head_keys), (
			f"unexpected keys!\n{set(head_keys).symmetric_difference(cache_torch_keys_set) = }\n{cache_torch_keys_set} != {set(head_keys)}"
		)

		# stack heads
		patterns_stacked: Float[torch.Tensor, "n_layers n_heads n_ctx n_ctx"] = (
			torch.stack([cache_torch[k] for k in head_keys], dim=1)
		)
		# check shape
		pattern_shape_no_ctx: tuple[int, ...] = tuple(patterns_stacked.shape[:3])
		assert pattern_shape_no_ctx == (1, n_layers, model.cfg.n_heads), (
			f"unexpected shape: {patterns_stacked.shape[:3] = } ({pattern_shape_no_ctx = }), expected {(1, n_layers, model.cfg.n_heads) = }"
		)

		patterns_stacked_np: Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"] = (
			patterns_stacked.cpu().numpy()
		)

		# save
		np.save(activations_path, patterns_stacked_np)

		# return
		match return_cache:
			case "numpy":
				return activations_path, patterns_stacked_np
			case "torch":
				return activations_path, patterns_stacked
			case None:
				return activations_path, None
			case _:
				msg = f"invalid return_cache: {return_cache = }"
				raise ValueError(msg)
	else:
		activations_path = prompt_dir / "activations.npz"

		# save
		cache_np: ActivationCacheNp = {
			k: v.detach().cpu().numpy() for k, v in cache_torch.items()
		}

		np.savez_compressed(
			activations_path,
			**cache_np,
		)

		# return
		match return_cache:
			case "numpy":
				return activations_path, cache_np
			case "torch":
				return activations_path, cache_torch
			case None:
				return activations_path, None
			case _:
				msg = f"invalid return_cache: {return_cache = }"
				raise ValueError(msg)


@overload
def get_activations(
	prompt: dict,
	model: HookedTransformer | str,
	save_path: Path = Path(DATA_DIR),
	allow_disk_cache: bool = True,
	return_cache: Literal[None] = None,
) -> tuple[Path, None]: ...
@overload
def get_activations(
	prompt: dict,
	model: HookedTransformer | str,
	save_path: Path = Path(DATA_DIR),
	allow_disk_cache: bool = True,
	return_cache: Literal["torch"] = "torch",
) -> tuple[Path, ActivationCache]: ...
@overload
def get_activations(
	prompt: dict,
	model: HookedTransformer | str,
	save_path: Path = Path(DATA_DIR),
	allow_disk_cache: bool = True,
	return_cache: Literal["numpy"] = "numpy",
) -> tuple[Path, ActivationCacheNp]: ...
def get_activations(
	prompt: dict,
	model: HookedTransformer | str,
	save_path: Path = Path(DATA_DIR),
	allow_disk_cache: bool = True,
	return_cache: ReturnCache = "numpy",
) -> tuple[Path, ActivationCacheNp | ActivationCache | None]:
	"""given a prompt and a model, save or load activations

	# Parameters:
	- `prompt : dict`
		expected to contain the 'text' key
	- `model : HookedTransformer | str`
		either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
	- `save_path : Path`
		path to save the activations to (and load from)
		(defaults to `Path(DATA_DIR)`)
	- `allow_disk_cache : bool`
		whether to allow loading from disk cache
		(defaults to `True`)
	- `return_cache : Literal[None, "numpy", "torch"]`
		whether to return the cache, and in what format
		(defaults to `"numpy"`)

	# Returns:
	- `tuple[Path, ActivationCacheNp | ActivationCache | None]`
		the path to the activations and the cache if `return_cache is not None`

	"""
	# add hash to prompt
	augment_prompt_with_hash(prompt)

	# get the model
	model_name: str = (
		model.cfg.model_name if isinstance(model, HookedTransformer) else model
	)

	# from cache
	if allow_disk_cache:
		try:
			path, cache = load_activations(
				model_name=model_name,
				prompt=prompt,
				save_path=save_path,
			)
			if return_cache:
				return path, cache
			else:
				# TODO: this basically does nothing, since we load the activations and then immediately get rid of them.
				# maybe refactor this so that load_activations can take a parameter to simply assert that the cache exists?
				# this will let us avoid loading it, which slows things down
				return path, None
		except ActivationsMissingError:
			pass

	# compute them
	if isinstance(model, str):
		model = HookedTransformer.from_pretrained(model_name)

	return compute_activations(
		prompt=prompt,
		model=model,
		save_path=save_path,
		return_cache=return_cache,
	)


DEFAULT_DEVICE: torch.device = torch.device(
	"cuda" if torch.cuda.is_available() else "cpu",
)


def activations_main(
	model_name: str,
	save_path: str,
	prompts_path: str,
	raw_prompts: bool,
	min_chars: int,
	max_chars: int,
	force: bool,
	n_samples: int,
	no_index_html: bool,
	shuffle: bool = False,
	stacked_heads: bool = False,
	device: str | torch.device = DEFAULT_DEVICE,
) -> None:
	"""main function for computing activations

	# Parameters:
	- `model_name : str`
		name of a model to load with `HookedTransformer.from_pretrained`
	- `save_path : str`
		path to save the activations to
	- `prompts_path : str`
		path to the prompts file
	- `raw_prompts : bool`
		whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
	- `min_chars : int`
		minimum number of characters for a prompt
	- `max_chars : int`
		maximum number of characters for a prompt
	- `force : bool`
		whether to overwrite existing files
	- `n_samples : int`
		maximum number of samples to process
	- `no_index_html : bool`
		whether to write an index.html file
	- `shuffle : bool`
		whether to shuffle the prompts
		(defaults to `False`)
	- `stacked_heads : bool`
		whether	to stack the heads in the output tensor. will save as `.npy` instead of `.npz` if `True`
		(defaults to `False`)
	- `device : str | torch.device`
		the device to use. if a string, will be passed to `torch.device`
	"""
	# figure out the device to use
	device_: torch.device
	if isinstance(device, torch.device):
		device_ = device
	elif isinstance(device, str):
		device_ = torch.device(device)
	else:
		msg = f"invalid device: {device}"
		raise TypeError(msg)

	print(f"using device: {device_}")

	with SpinnerContext(message="loading model", **SPINNER_KWARGS):
		model: HookedTransformer = HookedTransformer.from_pretrained(
			model_name,
			device=device_,
		)
		model.model_name = model_name
		model.cfg.model_name = model_name
		n_params: int = sum(p.numel() for p in model.parameters())
	print(
		f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters",
	)
	print(f"\tmodel devices: { {p.device for p in model.parameters()} }")

	save_path_p: Path = Path(save_path)
	save_path_p.mkdir(parents=True, exist_ok=True)
	model_path: Path = save_path_p / model_name
	with SpinnerContext(
		message=f"saving model info to {model_path.as_posix()}",
		**SPINNER_KWARGS,
	):
		model_cfg: HookedTransformerConfig
		model_cfg = model.cfg
		model_path.mkdir(parents=True, exist_ok=True)
		with open(model_path / "model_cfg.json", "w") as f:
			json.dump(json_serialize(asdict(model_cfg)), f)

	# load prompts
	with SpinnerContext(
		message=f"loading prompts from {prompts_path = }",
		**SPINNER_KWARGS,
	):
		prompts: list[dict]
		if raw_prompts:
			prompts = load_text_data(
				Path(prompts_path),
				min_chars=min_chars,
				max_chars=max_chars,
				shuffle=shuffle,
			)
		else:
			with open(model_path / "prompts.jsonl", "r") as f:
				prompts = [json.loads(line) for line in f.readlines()]
		# truncate to n_samples
		prompts = prompts[:n_samples]

	print(f"{len(prompts)} prompts loaded")

	# write index.html
	with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
		if not no_index_html:
			write_html_index(save_path_p)

	# TODO: not implemented yet
	if stacked_heads:
		raise NotImplementedError("stacked_heads not implemented yet")

	# get activations
	list(
		tqdm.tqdm(
			map(
				functools.partial(
					get_activations,
					model=model,
					save_path=save_path_p,
					allow_disk_cache=not force,
					return_cache=None,
					# stacked_heads=stacked_heads,
				),
				prompts,
			),
			total=len(prompts),
			desc="Computing activations",
			unit="prompt",
		),
	)

	with SpinnerContext(
		message="updating jsonl metadata for models and prompts",
		**SPINNER_KWARGS,
	):
		generate_models_jsonl(save_path_p)
		generate_prompts_jsonl(save_path_p / model_name)


def main() -> None:
	"generate attention pattern activations for a model and prompts"
	print(DIVIDER_S1)
	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
		arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
		# input and output
		arg_parser.add_argument(
			"--model",
			"-m",
			type=str,
			required=True,
			help="The model name(s) to use. comma separated with no whitespace if multiple",
		)

		arg_parser.add_argument(
			"--prompts",
			"-p",
			type=str,
			required=False,
			help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
			default=None,
		)

		arg_parser.add_argument(
			"--save-path",
			"-s",
			type=str,
			required=False,
			help="The path to save the attention patterns",
			default=DATA_DIR,
		)

		# min and max prompt lengths
		arg_parser.add_argument(
			"--min-chars",
			type=int,
			required=False,
			help="The minimum number of characters for a prompt",
			default=100,
		)
		arg_parser.add_argument(
			"--max-chars",
			type=int,
			required=False,
			help="The maximum number of characters for a prompt",
			default=1000,
		)

		# number of samples
		arg_parser.add_argument(
			"--n-samples",
			"-n",
			type=int,
			required=False,
			help="The max number of samples to process, do all in the file if None",
			default=None,
		)

		# force overwrite
		arg_parser.add_argument(
			"--force",
			"-f",
			action="store_true",
			help="If passed, will overwrite existing files",
		)

		# no index html
		arg_parser.add_argument(
			"--no-index-html",
			action="store_true",
			help="If passed, will not write an index.html file for the model",
		)

		# raw prompts
		arg_parser.add_argument(
			"--raw-prompts",
			"-r",
			action="store_true",
			help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
		)

		# shuffle
		arg_parser.add_argument(
			"--shuffle",
			action="store_true",
			help="If passed, will shuffle the prompts",
		)

		# stack heads
		arg_parser.add_argument(
			"--stacked-heads",
			action="store_true",
			help="If passed, will stack the heads in the output tensor",
		)

		# device
		arg_parser.add_argument(
			"--device",
			type=str,
			required=False,
			help="The device to use for the model",
			default="cuda" if torch.cuda.is_available() else "cpu",
		)

		args: argparse.Namespace = arg_parser.parse_args()

	print(f"args parsed: {args}")

	models: list[str]
	if "," in args.model:
		models = args.model.split(",")
	else:
		models = [args.model]

	n_models: int = len(models)
	for idx, model in enumerate(models):
		print(DIVIDER_S2)
		print(f"processing model {idx + 1} / {n_models}: {model}")
		print(DIVIDER_S2)

		activations_main(
			model_name=model,
			save_path=args.save_path,
			prompts_path=args.prompts,
			raw_prompts=args.raw_prompts,
			min_chars=args.min_chars,
			max_chars=args.max_chars,
			force=args.force,
			n_samples=args.n_samples,
			no_index_html=args.no_index_html,
			shuffle=args.shuffle,
			stacked_heads=args.stacked_heads,
			device=args.device,
		)
		del model

	print(DIVIDER_S1)


if __name__ == "__main__":
	main()

``````{ end_of_file="pattern_lens/activations.py" }

``````{ path="pattern_lens/attn_figure_funcs.py"  }
"""default figure functions

- If you are making a PR, add your new figure function here.
- if you are using this as a library, then you can see examples here


note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`

"""

import itertools
from collections.abc import Callable, Sequence

from pattern_lens.consts import AttentionMatrix
from pattern_lens.figure_util import (
	AttentionMatrixFigureFunc,
	Matrix2D,
	save_matrix_wrapper,
)

_FIGURE_NAMES_KEY: str = "_figure_names"

ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()


def get_all_figure_names() -> list[str]:
	"""get all figure names"""
	return list(
		itertools.chain.from_iterable(
			getattr(
				func,
				_FIGURE_NAMES_KEY,
				[func.__name__],
			)
			for func in ATTENTION_MATRIX_FIGURE_FUNCS
		),
	)


def register_attn_figure_func(
	func: AttentionMatrixFigureFunc,
) -> AttentionMatrixFigureFunc:
	"""decorator for registering attention matrix figure function

	if you want to add a new figure function, you should use this decorator

	# Parameters:
	- `func : AttentionMatrixFigureFunc`
		your function, which should take an attention matrix and path

	# Returns:
	- `AttentionMatrixFigureFunc`
		your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`

	# Usage:
	```python
	@register_attn_figure_func
	def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
		fig, ax = plt.subplots(figsize=(10, 10))
		ax.matshow(attn_matrix, cmap="viridis")
		ax.set_title("My New Figure Function")
		ax.axis("off")
		plt.savefig(path / "my_new_figure_func", format="svgz")
		plt.close(fig)
	```

	"""
	setattr(func, _FIGURE_NAMES_KEY, (func.__name__,))
	global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
	ATTENTION_MATRIX_FIGURE_FUNCS.append(func)

	return func


def register_attn_figure_multifunc(
	names: Sequence[str],
) -> Callable[[AttentionMatrixFigureFunc], AttentionMatrixFigureFunc]:
	"decorator which registers a function as a multi-figure function"

	def decorator(func: AttentionMatrixFigureFunc) -> AttentionMatrixFigureFunc:
		setattr(
			func,
			_FIGURE_NAMES_KEY,
			tuple([f"{func.__name__}.{name}" for name in names]),
		)
		global ATTENTION_MATRIX_FIGURE_FUNCS  # noqa: PLW0602
		ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
		return func

	return decorator


@register_attn_figure_func
@save_matrix_wrapper(fmt="png")
def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
	"raw attention matrix"
	return attn_matrix


# some more examples:

# @register_attn_figure_func
# @matplotlib_figure_saver
# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
#     ax.matshow(attn_matrix, cmap="viridis")
#     ax.set_title("Raw Attention Pattern")
#     ax.axis("off")

# @register_attn_figure_func
# @save_matrix_wrapper(fmt="svg")
# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
#     return attn_matrix

# @register_attn_figure_func
# @save_matrix_wrapper(fmt="svgz")
# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
#     return attn_matrix

``````{ end_of_file="pattern_lens/attn_figure_funcs.py" }

``````{ path="pattern_lens/consts.py"  }
"""implements some constants and types"""

import re
from typing import Literal

import numpy as np
import torch
from jaxtyping import Float

AttentionMatrix = Float[np.ndarray, "n_ctx n_ctx"]
"type alias for attention matrix"

ActivationCacheNp = dict[str, np.ndarray]
"type alias for a cache of activations, like a transformer_lens.ActivationCache"

ActivationCacheTorch = dict[str, torch.Tensor]
"type alias for a cache of activations, like a transformer_lens.ActivationCache but without the extras. useful for when loading from an npz file"

DATA_DIR: str = "attn_data"
"default directory for attention data"

ATTN_PATTERN_REGEX: re.Pattern = re.compile(r"blocks\.(\d+)\.attn\.hook_pattern")
"regex for finding attention patterns in model state dicts"

SPINNER_KWARGS: dict = dict(
	config=dict(success="✔️ "),
)
"default kwargs for `muutils.spinner.Spinner`"

DIVIDER_S1: str = "=" * 70
"divider string for separating sections"

DIVIDER_S2: str = "-" * 50
"divider string for separating subsections"

ReturnCache = Literal[None, "numpy", "torch"]
"return type for a cache of activations"

``````{ end_of_file="pattern_lens/consts.py" }

``````{ path="pattern_lens/figure_util.py"  }
"""implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
"""

import base64
import functools
import gzip
import io
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Literal, overload

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from jaxtyping import Float, UInt8
from matplotlib.colors import Colormap
from PIL import Image

from pattern_lens.consts import AttentionMatrix

AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
"Type alias for a function that, given an attention matrix, saves one or more figures"

Matrix2D = Float[np.ndarray, "n m"]
"Type alias for a 2D matrix (plottable)"

Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
"Type alias for a 2D matrix with 3 channels (RGB)"

AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
"Type alias for a function that, given an attention matrix, returns a 2D matrix"

MATPLOTLIB_FIGURE_FMT: str = "svgz"
"format for saving matplotlib figures"

MatrixSaveFormat = Literal["png", "svg", "svgz"]
"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"

MATRIX_SAVE_NORMALIZE: bool = False
"default for whether to normalize the matrix to range [0, 1]"

MATRIX_SAVE_CMAP: str = "viridis"
"default colormap for saving matrices"

MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
"default format for saving matrices"

MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
"template for saving an `n` by `m` matrix as an svg/svgz"


# TYPING: mypy hates it when we dont pass func=None or None as the first arg
@overload  # without keyword arguments, returns decorated function
def matplotlib_figure_saver(
	func: Callable[[AttentionMatrix, plt.Axes], None],
) -> AttentionMatrixFigureFunc: ...
@overload  # with keyword arguments, returns decorator
def matplotlib_figure_saver(
	func: None = None,
	fmt: str = MATPLOTLIB_FIGURE_FMT,
) -> Callable[
	[Callable[[AttentionMatrix, plt.Axes], None], str],
	AttentionMatrixFigureFunc,
]: ...
def matplotlib_figure_saver(
	func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
	fmt: str = MATPLOTLIB_FIGURE_FMT,
) -> (
	AttentionMatrixFigureFunc
	| Callable[
		[Callable[[AttentionMatrix, plt.Axes], None], str],
		AttentionMatrixFigureFunc,
	]
):
	"""decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure

	# Parameters:
	- `func : Callable[[AttentionMatrix, plt.Axes], None]`
		your function, which should take an attention matrix and predefined `ax` object
	- `fmt : str`
		format for saving matplotlib figures
		(defaults to `MATPLOTLIB_FIGURE_FMT`)

	# Returns:
	- `AttentionMatrixFigureFunc`
		your function, after we wrap it to save a figure

	# Usage:
	```python
	@register_attn_figure_func
	@matplotlib_figure_saver
	def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
		ax.matshow(attn_matrix, cmap="viridis")
		ax.set_title("Raw Attention Pattern")
		ax.axis("off")
	```

	"""

	def decorator(
		func: Callable[[AttentionMatrix, plt.Axes], None],
		fmt: str = fmt,
	) -> AttentionMatrixFigureFunc:
		@functools.wraps(func)
		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"

			fig, ax = plt.subplots(figsize=(10, 10))
			func(attn_matrix, ax)
			plt.tight_layout()
			plt.savefig(fig_path)
			plt.close(fig)

		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]

		return wrapped

	if callable(func):
		# Handle no-arguments case
		return decorator(func)
	else:
		# Handle arguments case
		return decorator


def matplotlib_multifigure_saver(
	names: Sequence[str],
	fmt: str = MATPLOTLIB_FIGURE_FMT,
) -> Callable[
	# decorator takes in function
	# which takes a matrix and a dictionary of axes corresponding to the names
	[Callable[[AttentionMatrix, dict[str, plt.Axes]], None]],
	# returns the decorated function
	AttentionMatrixFigureFunc,
]:
	"""decorate a function such that it saves multiple figures, one for each name in `names`

	# Parameters:
	- `names : Sequence[str]`
		the names of the figures to save
	- `fmt : str`
		format for saving matplotlib figures
		(defaults to `MATPLOTLIB_FIGURE_FMT`)

	# Returns:
	- `Callable[[Callable[[AttentionMatrix, dict[str, plt.Axes]], None], AttentionMatrixFigureFunc]`
		the decorator, which will then be applied to the function
		we expect the decorated function to take an attention pattern, and a dict of axes corresponding to the names

	"""

	def decorator(
		func: Callable[[AttentionMatrix, dict[str, plt.Axes]], None],
	) -> AttentionMatrixFigureFunc:
		func_name: str = func.__name__

		@functools.wraps(func)
		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
			# set up axes and corresponding figures
			axes_dict: dict[str, plt.Axes] = {}
			figs_dict: dict[str, plt.Figure] = {}

			# Create all figures and axes
			for name in names:
				fig, ax = plt.subplots(figsize=(10, 10))
				axes_dict[name] = ax
				figs_dict[name] = fig

			try:
				# Run the function to make plots
				func(attn_matrix, axes_dict)

				# Save each figure
				for name, fig_ in figs_dict.items():
					fig_path: Path = save_dir / f"{func_name}.{name}.{fmt}"
					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "tight_layout"  [union-attr]
					fig_.tight_layout()  # type: ignore[union-attr]
					# TYPING: error: Item "SubFigure" of "Figure | SubFigure" has no attribute "savefig"  [union-attr]
					fig_.savefig(fig_path)  # type: ignore[union-attr]
			finally:
				# Always clean up figures, even if an error occurred
				for fig in figs_dict.values():
					# TYPING: error: Argument 1 to "close" has incompatible type "Figure | SubFigure"; expected "int | str | Figure | Literal['all'] | None"  [arg-type]
					plt.close(fig)  # type: ignore[arg-type]

		# it doesn't normally have this attribute, but we're adding it
		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]

		return wrapped

	return decorator


def matrix_to_image_preprocess(
	matrix: Matrix2D,
	normalize: bool = False,
	cmap: str | Colormap = "viridis",
	diverging_colormap: bool = False,
	normalize_min: float | None = None,
) -> Matrix2Drgb:
	"""preprocess a 2D matrix into a plottable heatmap image

	# Parameters:
	- `matrix : Matrix2D`
		input matrix
	- `normalize : bool`
		whether to normalize the matrix to range [0, 1]
		(defaults to `MATRIX_SAVE_NORMALIZE`)
	- `cmap : str|Colormap`
		the colormap to use for the matrix
		(defaults to `MATRIX_SAVE_CMAP`)
	- `diverging_colormap : bool`
		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
		(defaults to False)
	- `normalize_min : float|None`
		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?).
		if `None`, then the minimum value of the matrix is used.
		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`.
		(defaults to `None`)

	# Returns:
	- `Matrix2Drgb`
	"""
	# check dims (2 is not that magic of a value here, hence noqa)
	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"  # noqa: PLR2004

	# check matrix is not empty
	assert matrix.size > 0, "Matrix cannot be empty"

	if normalize_min is not None:
		assert not diverging_colormap, (
			"normalize_min cannot be used with diverging_colormap=True"
		)
		assert normalize, "normalize_min cannot be used with normalize=False"

	# Normalize the matrix to range [0, 1]
	normalized_matrix: Matrix2D
	if normalize:
		if diverging_colormap:
			# For diverging colormaps, we want to center around 0
			max_abs: float = max(abs(matrix.max()), abs(matrix.min()))
			normalized_matrix = (matrix / (2 * max_abs)) + 0.5
		else:
			max_val: float = matrix.max()
			min_val: float
			if normalize_min is not None:
				min_val = normalize_min
				assert min_val < max_val, "normalize_min must be less than matrix max"
				assert min_val >= matrix.min(), (
					"normalize_min must less than matrix min"
				)
			else:
				min_val = matrix.min()

			normalized_matrix = (matrix - min_val) / (max_val - min_val)
	else:
		if diverging_colormap:
			assert matrix.min() >= -1 and matrix.max() <= 1, (  # noqa: PT018
				"For diverging colormaps without normalization, matrix values must be in range [-1, 1]"
			)
			normalized_matrix = matrix
		else:
			assert matrix.min() >= 0 and matrix.max() <= 1, (  # noqa: PT018
				"Matrix values must be in range [0, 1], or normalize must be True"
			)
			normalized_matrix = matrix

	# get the colormap
	cmap_: Colormap
	if isinstance(cmap, str):
		cmap_ = mpl.colormaps[cmap]
	elif isinstance(cmap, Colormap):
		cmap_ = cmap
	else:
		msg = f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
		raise TypeError(
			msg,
		)

	# Apply the colormap
	rgb_matrix: Float[np.ndarray, "n m channels=3"] = (
		cmap_(normalized_matrix)[:, :, :3] * 255
	).astype(np.uint8)  # Drop alpha channel

	assert rgb_matrix.shape == (
		matrix.shape[0],
		matrix.shape[1],
		3,
	), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"

	return rgb_matrix


@overload
def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
@overload
def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
def matrix2drgb_to_png_bytes(
	matrix: Matrix2Drgb,
	buffer: io.BytesIO | None = None,
) -> bytes | None:
	"""Convert a `Matrix2Drgb` to valid PNG bytes via PIL

	- if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
	- if `buffer` is not provided, it will return the PNG bytes

	# Parameters:
	- `matrix : Matrix2Drgb`
	- `buffer : io.BytesIO | None`
		(defaults to `None`, in which case it will return the PNG bytes)

	# Returns:
	- `bytes|None`
		`bytes` if `buffer` is `None`, otherwise `None`
	"""
	pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
	if buffer is None:
		buffer = io.BytesIO()
		pil_img.save(buffer, format="PNG")
		buffer.seek(0)
		return buffer.read()
	else:
		pil_img.save(buffer, format="PNG")
		return None


def matrix_as_svg(
	matrix: Matrix2D,
	normalize: bool = MATRIX_SAVE_NORMALIZE,
	cmap: str | Colormap = MATRIX_SAVE_CMAP,
	diverging_colormap: bool = False,
	normalize_min: float | None = None,
) -> str:
	"""quickly convert a 2D matrix to an SVG image, without matplotlib

	# Parameters:
	- `matrix : Float[np.ndarray, 'n m']`
		a 2D matrix to convert to an SVG image
	- `normalize : bool`
		whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
		(defaults to `False`)
	- `cmap : str`
		the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
		(defaults to `"viridis"`)
	- `diverging_colormap : bool`
		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
		(defaults to False)
	- `normalize_min : float|None`
		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
		if `None`, then the minimum value of the matrix is used
		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
		(defaults to `None`)


	# Returns:
	- `str`
		the SVG content for the matrix
	"""
	# Get the dimensions of the matrix
	assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.shape = }"  # noqa: PLR2004
	m, n = matrix.shape

	# Preprocess the matrix into an RGB image
	matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
		matrix,
		normalize=normalize,
		cmap=cmap,
		diverging_colormap=diverging_colormap,
		normalize_min=normalize_min,
	)

	# Convert the RGB image to PNG bytes
	image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)

	# Encode the PNG bytes as base64
	png_base64: str = base64.b64encode(image_data).decode("utf-8")

	# Generate the SVG content
	svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)

	return svg_content


@overload  # with keyword arguments, returns decorator
def save_matrix_wrapper(
	func: None = None,
	*args: tuple[()],
	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
	normalize: bool = MATRIX_SAVE_NORMALIZE,
	cmap: str | Colormap = MATRIX_SAVE_CMAP,
	diverging_colormap: bool = False,
	normalize_min: float | None = None,
) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
@overload  # without keyword arguments, returns decorated function
def save_matrix_wrapper(
	func: AttentionMatrixToMatrixFunc,
	*args: tuple[()],
	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
	normalize: bool = MATRIX_SAVE_NORMALIZE,
	cmap: str | Colormap = MATRIX_SAVE_CMAP,
	diverging_colormap: bool = False,
	normalize_min: float | None = None,
) -> AttentionMatrixFigureFunc: ...
def save_matrix_wrapper(
	func: AttentionMatrixToMatrixFunc | None = None,
	*args,
	fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
	normalize: bool = MATRIX_SAVE_NORMALIZE,
	cmap: str | Colormap = MATRIX_SAVE_CMAP,
	diverging_colormap: bool = False,
	normalize_min: float | None = None,
) -> (
	AttentionMatrixFigureFunc
	| Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
):
	"""Decorator for functions that process an attention matrix and save it as an SVGZ image.

	Can handle both argumentless usage and with arguments.

	# Parameters:

	- `func : AttentionMatrixToMatrixFunc|None`
		Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
	- `fmt : MatrixSaveFormat, keyword-only`
		The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
	- `normalize : bool, keyword-only`
		Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
	- `cmap : str, keyword-only`
		The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
	- `diverging_colormap : bool`
		if True and using a diverging colormap, ensures 0 values map to the center of the colormap
		(defaults to False)
	- `normalize_min : float|None`
		if a float, then for `normalize=True` and `diverging_colormap=False`, the minimum value to normalize to (generally set this to zero?)
		if `None`, then the minimum value of the matrix is used
		if `diverging_colormap=True` OR `normalize=False`, this **must** be `None`
		(defaults to `None`)

	# Returns:

	`AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`

	- `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
	- `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)

	# Usage:

	```python
	@save_matrix_wrapper
	def identity_matrix(matrix):
		return matrix

	@save_matrix_wrapper(normalize=True, fmt="png")
	def scale_matrix(matrix):
		return matrix * 2

	@save_matrix_wrapper(normalize=True, cmap="plasma")
	def scale_matrix(matrix):
		return matrix * 2
	```

	"""
	assert len(args) == 0, "This decorator only supports keyword arguments"

	assert (
		fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
	), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]

	def decorator(
		func: Callable[[AttentionMatrix], Matrix2D],
	) -> AttentionMatrixFigureFunc:
		@functools.wraps(func)
		def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
			fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
			processed_matrix: Matrix2D = func(attn_matrix)

			if fmt == "png":
				processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
					processed_matrix,
					normalize=normalize,
					cmap=cmap,
					diverging_colormap=diverging_colormap,
					normalize_min=normalize_min,
				)
				image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
				fig_path.write_bytes(image_data)

			else:
				svg_content: str = matrix_as_svg(
					processed_matrix,
					normalize=normalize,
					cmap=cmap,
					diverging_colormap=diverging_colormap,
					normalize_min=normalize_min,
				)

				if fmt == "svgz":
					with gzip.open(fig_path, "wt") as f:
						f.write(svg_content)

				else:
					fig_path.write_text(svg_content, encoding="utf-8")

		wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]

		return wrapped

	if callable(func):
		# Handle no-arguments case
		return decorator(func)
	else:
		# Handle arguments case
		return decorator

``````{ end_of_file="pattern_lens/figure_util.py" }

``````{ path="pattern_lens/figures.py"  }
"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`"""

import argparse
import fnmatch
import functools
import itertools
import json
import multiprocessing
import re
import warnings
from collections import defaultdict
from pathlib import Path

import numpy as np
from jaxtyping import Float

# custom utils
from muutils.json_serialize import json_serialize
from muutils.parallel import run_maybe_parallel
from muutils.spinner import SpinnerContext

# pattern_lens
from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
from pattern_lens.consts import (
	DATA_DIR,
	DIVIDER_S1,
	DIVIDER_S2,
	SPINNER_KWARGS,
	ActivationCacheNp,
	AttentionMatrix,
)
from pattern_lens.figure_util import AttentionMatrixFigureFunc
from pattern_lens.indexes import (
	generate_functions_jsonl,
	generate_models_jsonl,
	generate_prompts_jsonl,
)
from pattern_lens.load_activations import load_activations


class HTConfigMock:
	"""Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json

	can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes:
	- `n_layers: int`
	- `n_heads: int`
	- `model_name: str`

	we do this to avoid having to import `torch` and `transformer_lens`, since this would have to be done for each process in the parallelization and probably slows things down significantly
	"""

	def __init__(self, **kwargs: dict[str, str | int]) -> None:
		"will pass all kwargs to `__dict__`"
		self.n_layers: int
		self.n_heads: int
		self.model_name: str
		self.__dict__.update(kwargs)

	def serialize(self) -> dict:
		"""serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`"""
		# its fine, we know its a dict
		return json_serialize(self.__dict__)  # type: ignore[return-value]

	@classmethod
	def load(cls, data: dict) -> "HTConfigMock":
		"try to load a config from a dict, using the `__init__` method"
		return cls(**data)


def process_single_head(
	layer_idx: int,
	head_idx: int,
	attn_pattern: AttentionMatrix,
	save_dir: Path,
	figure_funcs: list[AttentionMatrixFigureFunc],
	force_overwrite: bool = False,
) -> dict[str, bool | Exception]:
	"""process a single head's attention pattern, running all the functions in `figure_funcs` on the attention pattern

	> [gotcha:] if `force_overwrite` is `False`, and we used a multi-figure function,
	> it will skip all figures for that function if any are already saved
	> and it assumes a format of `{func_name}.{figure_name}.{fmt}` for the saved figures

	# Parameters:
	- `layer_idx : int`
	- `head_idx : int`
	- `attn_pattern : AttentionMatrix`
		attention pattern for the head
	- `save_dir : Path`
		directory to save the figures to
	- `force_overwrite : bool`
		whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure
		(defaults to `False`)

	# Returns:
	- `dict[str, bool | Exception]`
		a dictionary of the status of each function, with the function name as the key and the status as the value
	"""
	funcs_status: dict[str, bool | Exception] = dict()

	for func in figure_funcs:
		func_name: str = func.__name__
		fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*"))

		if not force_overwrite and len(fig_path) > 0:
			funcs_status[func_name] = True
			continue

		try:
			func(attn_pattern, save_dir)
			funcs_status[func_name] = True

		# bling catch any exception
		except Exception as e:  # noqa: BLE001
			error_file = save_dir / f"{func.__name__}.error.txt"
			error_file.write_text(str(e))
			warnings.warn(
				f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {e!s}",
				stacklevel=2,
			)
			funcs_status[func_name] = e

	return funcs_status


def compute_and_save_figures(
	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
	activations_path: Path,
	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"],
	figure_funcs: list[AttentionMatrixFigureFunc],
	save_path: Path = Path(DATA_DIR),
	force_overwrite: bool = False,
	track_results: bool = False,
) -> None:
	"""compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`

	# Parameters:
	- `model_cfg : HookedTransformerConfig|HTConfigMock`
		configuration of the model, used for loading the activations
	- `cache : ActivationCacheNp | Float[np.ndarray, &quot;n_layers n_heads n_ctx n_ctx&quot;]`
		activation cache containing actual patterns for the prompt we are processing
	- `figure_funcs : list[AttentionMatrixFigureFunc]`
		list of functions to run
	- `save_path : Path`
		directory to save the figures to
		(defaults to `Path(DATA_DIR)`)
	- `force_overwrite : bool`
		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
		(defaults to `False`)
	- `track_results : bool`
		whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO
		(defaults to `False`)
	"""
	prompt_dir: Path = activations_path.parent

	if track_results:
		results: defaultdict[
			str,  # func name
			dict[
				tuple[int, int],  # layer, head
				bool | Exception,  # success or exception
			],
		] = defaultdict(dict)

	for layer_idx, head_idx in itertools.product(
		range(model_cfg.n_layers),
		range(model_cfg.n_heads),
	):
		attn_pattern: AttentionMatrix
		if isinstance(cache, dict):
			attn_pattern = cache[f"blocks.{layer_idx}.attn.hook_pattern"][0, head_idx]
		elif isinstance(cache, np.ndarray):
			attn_pattern = cache[layer_idx, head_idx]
		else:
			msg = (
				f"cache must be a dict or np.ndarray, not {type(cache) = }\n{cache = }"
			)
			raise TypeError(
				msg,
			)

		save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}"
		save_dir.mkdir(parents=True, exist_ok=True)
		head_res: dict[str, bool | Exception] = process_single_head(
			layer_idx=layer_idx,
			head_idx=head_idx,
			attn_pattern=attn_pattern,
			save_dir=save_dir,
			force_overwrite=force_overwrite,
			figure_funcs=figure_funcs,
		)

		if track_results:
			for func_name, status in head_res.items():
				results[func_name][(layer_idx, head_idx)] = status

	# TODO: do something with results

	generate_prompts_jsonl(save_path / model_cfg.model_name)


def process_prompt(
	prompt: dict,
	model_cfg: "HookedTransformerConfig|HTConfigMock",  # type: ignore[name-defined] # noqa: F821
	save_path: Path,
	figure_funcs: list[AttentionMatrixFigureFunc],
	force_overwrite: bool = False,
) -> None:
	"""process a single prompt, loading the activations and computing and saving the figures

	basically just calls `load_activations` and then `compute_and_save_figures`

	# Parameters:
	- `prompt : dict`
		prompt to process, should be a dict with the following keys:
		- `"text"`: the prompt string
		- `"hash"`: the hash of the prompt
	- `model_cfg : HookedTransformerConfig|HTConfigMock`
		configuration of the model, used for figuring out where to save
	- `save_path : Path`
		directory to save the figures to
	- `figure_funcs : list[AttentionMatrixFigureFunc]`
		list of functions to run
	- `force_overwrite : bool`
		(defaults to `False`)
	"""
	# load the activations
	activations_path: Path
	cache: ActivationCacheNp | Float[np.ndarray, "n_layers n_heads n_ctx n_ctx"]
	activations_path, cache = load_activations(
		model_name=model_cfg.model_name,
		prompt=prompt,
		save_path=save_path,
		return_fmt="numpy",
	)

	# compute and save the figures
	compute_and_save_figures(
		model_cfg=model_cfg,
		activations_path=activations_path,
		cache=cache,
		figure_funcs=figure_funcs,
		save_path=save_path,
		force_overwrite=force_overwrite,
	)


def select_attn_figure_funcs(
	figure_funcs_select: set[str] | str | None = None,
) -> list[AttentionMatrixFigureFunc]:
	"""given a selector, figure out which functions from `ATTENTION_MATRIX_FIGURE_FUNCS` to use

	- if arg is `None`, will use all functions
	- if a string, will use the function names which match the string (glob/fnmatch syntax)
	- if a set, will use functions whose names are in the set

	"""
	# figure out which functions to use
	figure_funcs: list[AttentionMatrixFigureFunc]
	if figure_funcs_select is None:
		# all if nothing specified
		figure_funcs = ATTENTION_MATRIX_FIGURE_FUNCS
	elif isinstance(figure_funcs_select, str):
		# if a string, assume a glob pattern
		pattern: re.Pattern = re.compile(fnmatch.translate(figure_funcs_select))
		figure_funcs = [
			func
			for func in ATTENTION_MATRIX_FIGURE_FUNCS
			if pattern.match(func.__name__)
		]
	elif isinstance(figure_funcs_select, set):
		# if a set, assume a set of function names
		figure_funcs = [
			func
			for func in ATTENTION_MATRIX_FIGURE_FUNCS
			if func.__name__ in figure_funcs_select
		]
	else:
		err_msg: str = (
			f"figure_funcs_select must be None, str, or set, not {type(figure_funcs_select) = }"
			f"\n{figure_funcs_select = }"
		)
		raise TypeError(err_msg)
	return figure_funcs


def figures_main(
	model_name: str,
	save_path: str,
	n_samples: int,
	force: bool,
	figure_funcs_select: set[str] | str | None = None,
	parallel: bool | int = True,
) -> None:
	"""main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS`

	# Parameters:
	- `model_name : str`
		model name to use, used for loading the model config, prompts, activations, and saving the figures
	- `save_path : str`
		base path to look in
	- `n_samples : int`
		max number of samples to process
	- `force : bool`
		force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure
	- `figure_funcs_select : set[str]|str|None`
		figure functions to use. if `None`, will use all functions. if a string, will use the function names which match the string. if a set, will use the function names in the set
		(defaults to `None`)
	- `parallel : bool | int`
		whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores
		(defaults to `True`)
	"""
	with SpinnerContext(message="setting up paths", **SPINNER_KWARGS):
		# save model info or check if it exists
		save_path_p: Path = Path(save_path)
		model_path: Path = save_path_p / model_name
		with open(model_path / "model_cfg.json", "r") as f:
			model_cfg = HTConfigMock.load(json.load(f))

	with SpinnerContext(message="loading prompts", **SPINNER_KWARGS):
		# load prompts
		with open(model_path / "prompts.jsonl", "r") as f:
			prompts: list[dict] = [json.loads(line) for line in f.readlines()]
		# truncate to n_samples
		prompts = prompts[:n_samples]

	print(f"{len(prompts)} prompts loaded")

	figure_funcs: list[AttentionMatrixFigureFunc] = select_attn_figure_funcs(
		figure_funcs_select=figure_funcs_select,
	)
	print(f"{len(figure_funcs)} figure functions loaded")
	print("\t" + ", ".join([func.__name__ for func in figure_funcs]))

	chunksize: int = int(
		max(
			1,
			len(prompts) // (5 * multiprocessing.cpu_count()),
		),
	)
	print(f"chunksize: {chunksize}")

	list(
		run_maybe_parallel(
			func=functools.partial(
				process_prompt,
				model_cfg=model_cfg,
				save_path=save_path_p,
				figure_funcs=figure_funcs,
				force_overwrite=force,
			),
			iterable=prompts,
			parallel=parallel,
			chunksize=chunksize,
			pbar="tqdm",
			pbar_kwargs=dict(
				desc="Making figures",
				unit="prompt",
			),
		),
	)

	with SpinnerContext(
		message="updating jsonl metadata for models and functions",
		**SPINNER_KWARGS,
	):
		generate_models_jsonl(save_path_p)
		generate_functions_jsonl(save_path_p)


def _parse_args() -> tuple[
	argparse.Namespace,
	list[str],  # models
	set[str] | str | None,  # figure_funcs_select
]:
	arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
	# input and output
	arg_parser.add_argument(
		"--model",
		"-m",
		type=str,
		required=True,
		help="The model name(s) to use. comma separated with no whitespace if multiple",
	)
	arg_parser.add_argument(
		"--save-path",
		"-s",
		type=str,
		required=False,
		help="The path to save the attention patterns",
		default=DATA_DIR,
	)
	# number of samples
	arg_parser.add_argument(
		"--n-samples",
		"-n",
		type=int,
		required=False,
		help="The max number of samples to process, do all in the file if None",
		default=None,
	)
	# force overwrite of existing figures
	arg_parser.add_argument(
		"--force",
		"-f",
		type=bool,
		required=False,
		help="Force overwrite of existing figures",
		default=False,
	)
	# figure functions
	arg_parser.add_argument(
		"--figure-funcs",
		type=str,
		required=False,
		help="The figure functions to use. if 'None' (default), will use all functions. if a string, will use the function names which match the string. if a comma-separated list of strings, will use the function names in the set",
		default=None,
	)

	args: argparse.Namespace = arg_parser.parse_args()

	# figure out models
	models: list[str]
	if "," in args.model:
		models = args.model.split(",")
	else:
		models = [args.model]

	# figure out figures
	figure_funcs_select: set[str] | str | None
	if (args.figure_funcs is None) or (args.figure_funcs.lower().strip() == "none"):
		figure_funcs_select = None
	elif "," in args.figure_funcs:
		figure_funcs_select = {x.strip() for x in args.figure_funcs.split(",")}
	else:
		figure_funcs_select = args.figure_funcs.strip()

	return args, models, figure_funcs_select


def main() -> None:
	"generates figures from the activations using the functions decorated with `register_attn_figure_func`"
	# parse args
	print(DIVIDER_S1)
	args: argparse.Namespace
	models: list[str]
	figure_funcs_select: set[str] | str | None
	with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
		args, models, figure_funcs_select = _parse_args()
	print(f"\targs parsed: '{args}'")
	print(f"\tmodels: '{models}'")
	print(f"\tfigure_funcs_select: '{figure_funcs_select}'")

	# compute for each model
	n_models: int = len(models)
	for idx, model in enumerate(models):
		print(DIVIDER_S2)
		print(f"processing model {idx + 1} / {n_models}: {model}")
		print(DIVIDER_S2)
		figures_main(
			model_name=model,
			save_path=args.save_path,
			n_samples=args.n_samples,
			force=args.force,
			figure_funcs_select=figure_funcs_select,
		)

	print(DIVIDER_S1)


if __name__ == "__main__":
	main()

``````{ end_of_file="pattern_lens/figures.py" }

``````{ path="pattern_lens/indexes.py"  }
"""writes indexes to the model directory for the frontend to use or for record keeping"""

import importlib.metadata
import importlib.resources
import inspect
import itertools
import json
from collections.abc import Callable
from pathlib import Path
from typing import Literal

import pattern_lens
from pattern_lens.attn_figure_funcs import (
	_FIGURE_NAMES_KEY,
	ATTENTION_MATRIX_FIGURE_FUNCS,
)


def generate_prompts_jsonl(model_dir: Path) -> None:
	"""creates a `prompts.jsonl` file with all the prompts in the model directory

	looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
	"""
	prompts: list[dict] = list()
	for prompt_dir in (model_dir / "prompts").iterdir():
		prompt_file: Path = prompt_dir / "prompt.json"
		if prompt_file.exists():
			with open(prompt_file, "r") as f:
				prompt_data: dict = json.load(f)
				prompts.append(prompt_data)

	with open(model_dir / "prompts.jsonl", "w") as f:
		for prompt in prompts:
			f.write(json.dumps(prompt))
			f.write("\n")


def generate_models_jsonl(path: Path) -> None:
	"""creates a `models.jsonl` file with all the models"""
	models: list[dict] = list()
	for model_dir in (path).iterdir():
		model_cfg_path: Path = model_dir / "model_cfg.json"
		if model_cfg_path.exists():
			with open(model_cfg_path, "r") as f:
				model_cfg: dict = json.load(f)
				models.append(model_cfg)

	with open(path / "models.jsonl", "w") as f:
		for model in models:
			f.write(json.dumps(model))
			f.write("\n")


def get_func_metadata(func: Callable) -> list[dict[str, str | None]]:
	"""get metadata for a function

	# Parameters:
	- `func : Callable` which has a `_FIGURE_NAMES_KEY` (by default `_figure_names`) attribute

	# Returns:

	`list[dict[str, str | None]]`
	each dictionary is for a function, containing:

	- `name : str` : the name of the figure
	- `func_name : str`
		the name of the function. if not a multi-figure function, this is identical to `name`
		if it is a multi-figure function, then `name` is `{func_name}.{figure_name}`
	- `doc : str` : the docstring of the function
	- `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
	- `source : str | None` : the source file of the function
	- `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read

	"""
	source_file: str | None = inspect.getsourcefile(func)
	output: dict[str, str | None] = dict(
		func_name=func.__name__,
		doc=func.__doc__,
		figure_save_fmt=getattr(func, "figure_save_fmt", None),
		source=Path(source_file).as_posix() if source_file else None,
	)

	try:
		output["code"] = inspect.getsource(func)
	except OSError:
		output["code"] = None

	fig_names: list[str] | None = getattr(func, _FIGURE_NAMES_KEY, None)
	if fig_names:
		return [
			{
				"name": func_name,
				**output,
			}
			for func_name in fig_names
		]
	else:
		return [
			{
				"name": func.__name__,
				**output,
			},
		]


def generate_functions_jsonl(path: Path) -> None:
	"unions all functions from `figures.jsonl` and `ATTENTION_MATRIX_FIGURE_FUNCS` into the file"
	figures_file: Path = path / "figures.jsonl"
	existing_figures: dict[str, dict] = dict()

	if figures_file.exists():
		with open(figures_file, "r") as f:
			for line in f:
				func_data: dict = json.loads(line)
				existing_figures[func_data["name"]] = func_data

	# Add any new functions from ALL_FUNCTIONS
	new_functions_lst: list[dict] = list(
		itertools.chain.from_iterable(
			get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
		),
	)
	new_functions: dict[str, dict] = {func["name"]: func for func in new_functions_lst}

	all_functions: list[dict] = list(
		{
			**existing_figures,
			**new_functions,
		}.values(),
	)

	with open(figures_file, "w") as f:
		for func_meta in sorted(all_functions, key=lambda x: x["name"]):
			json.dump(func_meta, f)
			f.write("\n")


def inline_assets(
	html: str,
	assets: list[tuple[Literal["script", "style"], str]],
	base_path: Path,
) -> str:
	"""Inline specified local CSS/JS files into an HTML document.

	Each entry in `assets` should be a tuple like `("script", "app.js")` or `("style", "style.css")`.

	# Parameters:
	- `html : str`
		input HTML content.
	- `assets : list[tuple[Literal["script", "style"], str]]`
		List of (tag_type, filename) tuples to inline.

	# Returns:
	`str` : Modified HTML content with inlined assets.
	"""
	for tag_type, filename in assets:
		if tag_type not in ("style", "script"):
			err_msg: str = f"Unsupported tag type: {tag_type}"
			raise ValueError(err_msg)

		# Dynamically create the pattern for the given tag and filename
		pattern: str = rf'<{tag_type} src="{filename}"></{tag_type}>'
		# assert it's in the text exactly once
		assert html.count(pattern) == 1, (
			f"Pattern {pattern} should be in the html exactly once, found {html.count(pattern) = }"
		)
		# read the content and create the replacement
		content: str = (base_path / filename).read_text()
		replacement: str = f"<{tag_type}>\n{content}\n</{tag_type}>"
		# perform the replacement
		html = html.replace(pattern, replacement)

	return html


def write_html_index(path: Path) -> None:
	"""writes an index.html file to the path"""
	# TYPING: error: Argument 1 to "Path" has incompatible type "Traversable"; expected "str | PathLike[str]"  [arg-type]
	frontend_resources_path: Path = Path(
		importlib.resources.files(pattern_lens).joinpath("frontend"),  # type: ignore[arg-type]
	)
	html_index: str = (frontend_resources_path / "index.template.html").read_text(
		encoding="utf-8",
	)
	# inline assets
	html_index = inline_assets(
		html_index,
		[
			("style", "style.css"),
			("script", "util.js"),
			("script", "app.js"),
		],
		base_path=frontend_resources_path,
	)

	# add version
	pattern_lens_version: str = importlib.metadata.version("pattern-lens")
	html_index = html_index.replace("$$PATTERN_LENS_VERSION$$", pattern_lens_version)
	# write the index.html file
	with open(path / "index.html", "w", encoding="utf-8") as f:
		f.write(html_index)

``````{ end_of_file="pattern_lens/indexes.py" }

``````{ path="pattern_lens/load_activations.py"  }
"loading activations from .npz on disk. implements some custom Exception classes"

import base64
import hashlib
import json
from pathlib import Path
from typing import Literal, overload

import numpy as np

from pattern_lens.consts import ReturnCache


class GetActivationsError(ValueError):
	"""base class for errors in getting activations"""

	pass


class ActivationsMissingError(GetActivationsError, FileNotFoundError):
	"""error for missing activations -- can't find the activations file"""

	pass


class ActivationsMismatchError(GetActivationsError):
	"""error for mismatched activations -- the prompt text or hash do not match

	raised by `compare_prompt_to_loaded`
	"""

	pass


class InvalidPromptError(GetActivationsError):
	"""error for invalid prompt -- the prompt does not have fields "hash" or "text"

	raised by `augment_prompt_with_hash`
	"""

	pass


def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
	"""compare a prompt to a loaded prompt, raise an error if they do not match

	# Parameters:
	- `prompt : dict`
	- `prompt_loaded : dict`

	# Returns:
	- `None`

	# Raises:
	- `ActivationsMismatchError` : if the prompt text or hash do not match
	"""
	for key in ("text", "hash"):
		if prompt[key] != prompt_loaded[key]:
			msg = f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
			raise ActivationsMismatchError(
				msg,
			)


def augment_prompt_with_hash(prompt: dict) -> dict:
	"""if a prompt does not have a hash, add one

	not having a "text" field is allowed, but only if "hash" is present

	# Parameters:
	- `prompt : dict`

	# Returns:
	- `dict`

	# Modifies:
	the input `prompt` dictionary, if it does not have a `"hash"` key
	"""
	if "hash" not in prompt:
		if "text" not in prompt:
			msg = f"Prompt does not have 'text' field or 'hash' field: {prompt}"
			raise InvalidPromptError(
				msg,
			)
		prompt_str: str = prompt["text"]
		prompt_hash: str = (
			# we don't need this to be a secure hash
			base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())  # noqa: S324
			.decode()
			.rstrip("=")
		)
		prompt.update(hash=prompt_hash)
	return prompt


@overload
def load_activations(
	model_name: str,
	prompt: dict,
	save_path: Path,
	return_fmt: Literal["torch"] = "torch",
) -> "tuple[Path, dict[str, torch.Tensor]]":  # type: ignore[name-defined] # noqa: F821
	...
@overload
def load_activations(
	model_name: str,
	prompt: dict,
	save_path: Path,
	return_fmt: Literal["numpy"] = "numpy",
) -> "tuple[Path, dict[str, np.ndarray]]": ...
def load_activations(
	model_name: str,
	prompt: dict,
	save_path: Path,
	return_fmt: ReturnCache = "torch",
) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
	"""load activations for a prompt and model, from an npz file

	# Parameters:
	- `model_name : str`
	- `prompt : dict`
	- `save_path : Path`
	- `return_fmt : Literal["torch", "numpy"]`
		(defaults to `"torch"`)

	# Returns:
	- `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
		the path to the activations file and the activations as a dictionary
		of numpy arrays or torch tensors, depending on `return_fmt`

	# Raises:
	- `ActivationsMissingError` : if the activations file is missing
	- `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
	"""
	if return_fmt not in ("torch", "numpy"):
		msg = f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
		raise ValueError(
			msg,
		)
	if return_fmt == "torch":
		import torch

	augment_prompt_with_hash(prompt)

	prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
	prompt_file: Path = prompt_dir / "prompt.json"
	if not prompt_file.exists():
		msg = f"Prompt file {prompt_file} does not exist"
		raise ActivationsMissingError(msg)
	with open(prompt_dir / "prompt.json", "r") as f:
		prompt_loaded: dict = json.load(f)
		compare_prompt_to_loaded(prompt, prompt_loaded)

	activations_path: Path = prompt_dir / "activations.npz"

	cache: dict

	with np.load(activations_path) as npz_data:
		if return_fmt == "numpy":
			cache = dict(npz_data.items())
		elif return_fmt == "torch":
			cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}

	return activations_path, cache


# def load_activations_stacked()

``````{ end_of_file="pattern_lens/load_activations.py" }

``````{ path="pattern_lens/prompts.py"  }
"implements `load_text_data` for loading prompts"

import json
import random
from pathlib import Path


def load_text_data(
	fname: Path,
	min_chars: int | None = None,
	max_chars: int | None = None,
	shuffle: bool = False,
) -> list[dict]:
	"""given `fname`, the path to a jsonl file, split prompts up into more reasonable sizes

	# Parameters:
	- `fname : Path`
		jsonl file with prompts. Expects a list of dicts with a "text" key
	- `min_chars : int | None`
		(defaults to `None`)
	- `max_chars : int | None`
		(defaults to `None`)
	- `shuffle : bool`
		(defaults to `False`)

	# Returns:
	- `list[dict]`
		processed list of prompts. Each prompt has a "text" key w/ a string value and some metadata.
		this is not guaranteed to be the same length as the input list!
	"""
	# read raw data
	with open(fname, "r") as f:
		data_raw: list[dict] = [json.loads(d) for d in f.readlines()]

	# add fname metadata
	for d in data_raw:
		d["source_fname"] = fname.as_posix()

	# trim too-short samples
	if min_chars is not None:
		data_raw = list(
			filter(
				lambda x: len(x["text"]) >= min_chars,
				data_raw,
			),
		)

	# split up too-long samples
	if max_chars is not None:
		data_new: list[dict] = []
		for d in data_raw:
			d_text: str = d["text"]
			while len(d_text) > max_chars:
				data_new.append(
					{
						**d,
						"text": d_text[:max_chars],
					},
				)
				d_text = d_text[max_chars:]
			data_new.append(
				{
					**d,
					"text": d_text,
				},
			)
		data_raw = data_new

	# trim too-short samples again
	if min_chars is not None:
		data_raw = list(
			filter(
				lambda x: len(x["text"]) >= min_chars,
				data_raw,
			),
		)

	# shuffle
	if shuffle:
		random.shuffle(data_raw)

	return data_raw

``````{ end_of_file="pattern_lens/prompts.py" }

``````{ path="pattern_lens/py.typed"  }

``````{ end_of_file="pattern_lens/py.typed" }

``````{ path="pattern_lens/server.py"  }
"""cli for starting the server to show the web ui.

can also run with --rewrite-index to update the index.html file.
this is useful for working on the ui.
"""

import argparse
import http.server
import os
import socketserver
import sys
from pathlib import Path

from pattern_lens.indexes import write_html_index


def main(path: str | None = None, port: int = 8000) -> None:
	"move to the given path and start the server"
	if path is not None:
		os.chdir(path)
	try:
		with socketserver.TCPServer(
			("", port),
			http.server.SimpleHTTPRequestHandler,
		) as httpd:
			print(f"Serving at http://localhost:{port}")
			httpd.serve_forever()
	except KeyboardInterrupt:
		print("Server stopped")
		sys.exit(0)


if __name__ == "__main__":
	arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
	arg_parser.add_argument(
		"--path",
		type=str,
		required=False,
		help="The path to serve, defaults to the current directory",
		default=None,
	)
	arg_parser.add_argument(
		"--port",
		type=int,
		required=False,
		help="The port to serve on, defaults to 8000",
		default=8000,
	)
	arg_parser.add_argument(
		"--rewrite-index",
		action="store_true",
		help="Whether to write the latest index.html file",
	)
	args: argparse.Namespace = arg_parser.parse_args()

	if args.rewrite_index:
		write_html_index(path=Path(args.path))

	main(path=args.path, port=args.port)

``````{ end_of_file="pattern_lens/server.py" }

``````{ path="README.md"  }
[![PyPI](https://img.shields.io/pypi/v/pattern-lens)](https://pypi.org/project/pattern-lens/)
![PyPI - Downloads](https://img.shields.io/pypi/dm/pattern-lens)
[![docs](https://img.shields.io/badge/docs-latest-blue)](https://miv.name/pattern-lens)
[![Checks](https://github.com/mivanit/pattern-lens/actions/workflows/checks.yml/badge.svg)](https://github.com/mivanit/pattern-lens/actions/workflows/checks.yml)

[![Coverage](docs/coverage/coverage.svg)](docs/coverage/html/)
![GitHub commits](https://img.shields.io/github/commit-activity/t/mivanit/pattern-lens)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/mivanit/pattern-lens)
![GitHub closed pull requests](https://img.shields.io/github/issues-pr-closed/mivanit/pattern-lens)
![code size, bytes](https://img.shields.io/github/languages/code-size/mivanit/pattern-lens)

# pattern-lens

visualization of LLM attention patterns and things computed about them

`pattern-lens` makes it easy to:

- Generate visualizations of attention patterns, or figures computed from attention patterns, from models supported by [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens)
- Compare generated figures across models, layers, and heads in an [interactive web interface](https://miv.name/pattern-lens/demo/)

# Installation

```bash
pip install pattern-lens
```


# Usage

The pipeline is as follows:

- Generate attention patterns using `pattern_lens.activations.acitvations_main()`, saving them in `npz` files
- Generate visualizations using `pattern_lens.figures.figures_main()` -- read the `npz` files, pass each attention pattern to each visualization function, and save the resulting figures
- Serve the web interface using `pattern_lens.server` -- web interface reads metadata in json/jsonl files, then lets the user select figures to show


## Basic CLI

Generate attention patterns and default visualizations:

```bash
# generate activations
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
# create visualizations
python -m pattern_lens.figures --model gpt2 --save-path attn_data
```

serve the web UI:

```bash
python -m pattern_lens.server --path attn_data
```


## Web UI

View a demo of the web UI at [miv.name/pattern-lens/demo](https://miv.name/pattern-lens/demo/).

## Custom Figures

Add custom visualization functions by decorating them with `@register_attn_figure_func`. You should still generate the activations first:
```
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
```

and then write+run a script/notebook that looks something like this:

```python
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

# these functions simplify writing a function which saves a figure
from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
# decorator to register your function, such that it will be run by `figures_main`
from pattern_lens.attn_figure_funcs import register_attn_figure_func
# runs the actual figure generation pipeline
from pattern_lens.figures import figures_main

# define your own functions
# this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


# run the figures pipelne
# run the pipeline
figures_main(
	model_name="pythia-14m",
	save_path=Path("docs/demo/"),
	n_samples=5,
	force=False,
)
```

see `demo.ipynb` for a full example
``````{ end_of_file="README.md" }

``````{ path="demo.ipynb" processed_with="ipynb_to_md" }
```python
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import svd

from pattern_lens.attn_figure_funcs import register_attn_figure_func
from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
from pattern_lens.figures import figures_main
```

```python
# define and register your own functions
# don't take these too seriously, they're just examples


# using matplotlib_figure_saver -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
	# Perform SVD
	U, s, Vh = svd(attn_matrix)

	# Plot singular values
	ax.plot(s, "o-")
	ax.set_yscale("log")
	ax.set_xlabel("Singular Value Index")
	ax.set_ylabel("Singular Value")
	ax.set_title("Singular Value Spectrum of Attention Matrix")


# manually creating and saving a figure
@register_attn_figure_func
def attention_flow(attn_matrix: np.ndarray, path: Path) -> None:
	"""Visualize attention as flows between tokens.

	Creates a simplified Sankey-style diagram where line thickness and color
	intensity represent attention strength.
	"""
	fig, ax = plt.subplots(figsize=(6, 6))
	n_tokens: int = attn_matrix.shape[0]

	# Create positions for tokens on left and right
	left_pos: np.ndarray = np.arange(n_tokens)
	right_pos: np.ndarray = np.arange(n_tokens)

	# Plot flows
	for i in range(n_tokens):
		for j in range(n_tokens):
			weight = attn_matrix[i, j]
			if weight > 0.05:  # Only plot stronger connections
				ax.plot(
					[0, 1],
					[left_pos[i], right_pos[j]],
					alpha=weight,
					linewidth=weight * 5,
					color="blue",
				)

	ax.set_xlim(-0.1, 1.1)
	ax.set_ylim(-1, n_tokens)
	ax.axis("off")
	ax.set_title("Attention Flow Between Positions")

	# be sure to save the figure as `function_name.format` in the given location
	fig.savefig(path / "attention_flow.svgz", format="svgz")


@register_attn_figure_func
@save_matrix_wrapper(fmt="svgz")
def gram_matrix(attn_matrix: np.ndarray) -> np.ndarray:
	return attn_matrix @ attn_matrix.T
```

```python
# run the pipeline
figures_main(
	model_name="pythia-14m",
	save_path=Path("docs/demo/"),
	n_samples=5,
	force=False,
)
```


``````{ end_of_file="demo.ipynb" }

``````{ path="makefile" processed_with="makefile_recipes" }
# first/default target is help
.PHONY: default
default: help
	...

# this recipe is weird. we need it because:
# - a one liner for getting the version with toml is unwieldy, and using regex is fragile
# - using $$SCRIPT_GET_VERSION within $(shell ...) doesn't work because of escaping issues
# - trying to write to the file inside the `gen-version-info` recipe doesn't work, 
# 	shell eval happens before our `python -c ...` gets run and `cat` doesn't see the new file
.PHONY: write-proj-version
write-proj-version:
	...

# gets version info from $(PYPROJECT), last version from $(LAST_VERSION_FILE), and python version
# uses just `python` for everything except getting the python version. no echo here, because this is "private"
.PHONY: gen-version-info
gen-version-info: write-proj-version
	...

# getting commit log since the tag specified in $(LAST_VERSION_FILE)
# will write to $(COMMIT_LOG_FILE)
# when publishing, the contents of $(COMMIT_LOG_FILE) will be used as the tag description (but can be edited during the process)
# no echo here, because this is "private"
.PHONY: gen-commit-log
gen-commit-log: gen-version-info
	...

# force the version info to be read, printing it out
# also force the commit log to be generated, and cat it out
.PHONY: version
version: gen-commit-log
	@echo "Current version is $(PROJ_VERSION), last auto-uploaded version is $(LAST_VERSION)"
	...

.PHONY: setup
setup: dep-check
	@echo "install and update via uv"
	...

.PHONY: dep-check-torch
dep-check-torch:
	@echo "see if torch is installed, and which CUDA version and devices it sees"
	...

.PHONY: dep
dep:
	@echo "Exporting dependencies as per $(PYPROJECT) section 'tool.uv-exports.exports'"
	...

.PHONY: dep-check
dep-check:
	@echo "Checking that exported requirements are up to date"
	...

.PHONY: dep-clean
dep-clean:
	@echo "clean up lock files, .venv, and requirements files"
	...

# runs ruff
.PHONY: format
format:
	@echo "format the source code"
	...

# runs ruff to check if the code is formatted correctly
.PHONY: format-check
format-check:
	@echo "check if the source code is formatted correctly"
	...

# runs type checks with mypy
.PHONY: typing
typing: clean
	@echo "running type checks"
	...

# generates a report of the mypy output
.PHONY: typing-report
typing-report:
	@echo "generate a report of the type check output -- errors per file"
	...

.PHONY: test
test: clean
	@echo "running tests"
	...

.PHONY: check
check: clean format-check test typing
	@echo "run format checks, tests, and typing checks"
	...

# generates a whole tree of documentation in html format.
# see `$(MAKE_DOCS_SCRIPT_PATH)` and the templates in `$(DOCS_RESOURCES_DIR)/templates/html/` for more info
.PHONY: docs-html
docs-html:
	@echo "generate html docs"
	...

# instead of a whole website, generates a single markdown file with all docs using the templates in `$(DOCS_RESOURCES_DIR)/templates/markdown/`.
# this is useful if you want to have a copy that you can grep/search, but those docs are much messier.
# docs-combined will use pandoc to convert them to other formats.
.PHONY: docs-md
docs-md:
	@echo "generate combined (single-file) docs in markdown"
	...

# after running docs-md, this will convert the combined markdown file to other formats:
# gfm (github-flavored markdown), plain text, and html
# requires pandoc in path, pointed to by $(PANDOC)
# pdf output would be nice but requires other deps
.PHONY: docs-combined
docs-combined: docs-md
	@echo "generate combined (single-file) docs in markdown and convert to other formats"
	...

# generates coverage reports as html and text with `pytest-cov`, and a badge with `coverage-badge`
# if `.coverage` is not found, will run tests first
# also removes the `.gitignore` file that `coverage html` creates, since we count that as part of the docs
.PHONY: cov
cov:
	@echo "generate coverage reports"
	...

# runs the coverage report, then the docs, then the combined docs
.PHONY: docs
docs: cov docs-html docs-combined todo lmcat
	@echo "generate all documentation and coverage reports"
	...

# removed all generated documentation files, but leaves everything in `$DOCS_RESOURCES_DIR`
# and leaves things defined in `pyproject.toml:tool.makefile.docs.no_clean`
# (templates, svg, css, make_docs.py script)
# distinct from `make clean`
.PHONY: docs-clean
docs-clean:
	@echo "remove generated docs except resources"
	...

.PHONY: todo
todo:
	@echo "get all TODO's from the code"
	...

.PHONY: lmcat-tree
lmcat-tree:
	@echo "show in console the lmcat tree view"
	...

.PHONY: lmcat
lmcat:
	@echo "write the lmcat full output to pyproject.toml:[tool.lmcat.output]"
	...

# verifies that the current branch is $(PUBLISH_BRANCH) and that git is clean
# used before publishing
.PHONY: verify-git
verify-git: 
	@echo "checking git status"
	...

.PHONY: build
build: 
	@echo "build the package"
	...

# gets the commit log, checks everything, builds, and then publishes with twine
# will ask the user to confirm the new version number (and this allows for editing the tag info)
# will also print the contents of $(PYPI_TOKEN_FILE) to the console for the user to copy and paste in when prompted by twine
.PHONY: publish
publish: gen-commit-log check build verify-git version gen-version-info
	@echo "run all checks, build, and then publish"
	...

# cleans up temp files from formatter, type checking, tests, coverage
# removes all built files
# removes $(TESTS_TEMP_DIR) to remove temporary test files
# recursively removes all `__pycache__` directories and `*.pyc` or `*.pyo` files
# distinct from `make docs-clean`, which only removes generated documentation files
.PHONY: clean
clean:
	@echo "clean up temporary files"
	...

.PHONY: clean-all
clean-all: clean docs-clean dep-clean
	@echo "clean up all temporary files, dep files, venv, and generated docs"
	...

.PHONY: info
info: gen-version-info
	@echo "# makefile variables"
	...

.PHONY: info-long
info-long: info
	@echo "# other variables"
	...

# immediately print out the help targets, and then local variables (but those take a bit longer)
.PHONY: help
help: help-targets info
	@echo -n ""
	...

.PHONY: demo-clean
demo-clean:
	...

.PHONY: demo-activations
demo-activations:
	...

.PHONY: demo-figures
demo-figures:
	...

.PHONY: demo-server
demo-server:
	...

.PHONY: demo
demo: demo-clean demo-activations demo-figures demo-server
	@echo "generate demo"
	...

.PHONY: demo-docs
demo-docs: demo-clean demo-activations demo-figures
	@echo "generate demo for docs (no server)"
	...

.PHONY: summary
summary:
	@echo "write docs/summary.md using lmcat"
	...

``````{ end_of_file="makefile" }

``````{ path="pyproject.toml"  }
[project]
    name = "pattern_lens"
    version = "0.4.0"
    description = ""
    readme = "README.md"
    requires-python = ">=3.11"
    dependencies = [
        # standard
        "numpy>=1.26.1,<2.0.0",
        "torch>=2.5.1",
        "jaxtyping>=0.2.33",
        "tqdm>=4.66.5",
        "pandas>=2.2.2",
        "scipy>=1.14.1",
        # "scikit-learn>=1.3",
        "matplotlib>=3.8.0",
        "pillow>=11.0.0",
        # jupyter
        "ipykernel>=6.29.5",
        "ipywidgets>=8.1.5",
        # typing
        "beartype>=0.14.1",
        # custom utils
        "muutils>=0.6.19",
        "zanj>=0.3.1",
        # TL
        "transformer-lens>=2.10.0",
        # this TL dep not listed? is this in an extra?
        "typeguard>=4.4.1",
    ]

[dependency-groups]
    dev = [
        # lmcat
        "lmcat>=0.2.0; python_version >= '3.11'",
        # test
        "pytest>=8.2.2",
        # coverage
        "pytest-cov>=4.1.0",
        "coverage-badge>=1.1.0",
        # type checking
        "mypy>=1.0.1",
        "types-tqdm",
        # docs
        'pdoc>=14.6.0',
        "nbconvert>=7.16.4",
        # tomli since no tomlib in python < 3.11
        "tomli>=2.1.0; python_version < '3.11'",
        # lint
        "ruff>=0.4.8",
    ]

[tool.uv]
    package = true

[project.urls]
    Homepage = "https://miv.name/pattern-lens"
    Documentation = "https://miv.name/pattern-lens"
    Repository = "https://github.com/mivanit/pattern-lens"
    Issues = "https://github.com/mivanit/pattern-lens/issues"

[build-system]
    requires = ["hatchling"]
    build-backend = "hatchling.build"

# tools
[tool]
    [tool.hatch.build.targets.wheel]
        packages = ["pattern_lens"]

    # ruff config
    [tool.ruff]
        exclude = ["__pycache__"]

        [tool.ruff.lint]
            ignore = [
                "F722", # doesn't like jaxtyping
                "W191", # we like tabs
                "D400", # missing-trailing-period
                "D415", # missing-terminal-punctuation
                "E501", # line-too-long
                "S101", # assert is fine
                "D403", # first-word-uncapitalized
                "D206", # docstring-tab-indentation
                "ERA001", # commented-out-code
                "T201", # print is fine
                "C408", # calling dict() is fine
                "UP015", # we like specifying the mode even if it's the default
                "D300", # we like docstrings
                # boolean positional arguments are fine
                "FBT001", 
                "FBT002",
                "PTH123", # opening files is fine
                "RET505", # else return is fine
                "FIX002", # `make todo` will give us the TODO comments
                "PIE790", # be explicit about when we pass
                "EM101", # fine to have string literal exceptions
                "FURB129", # .readlines() is fine
                "SIM108", # ternary operators can be hard to read, choose on a case-by-case basis
                "PLR5501", # nested if else is fine, for readability
                "D203", # docstring right after the class
                "D213", # docstring on first line
                "NPY002", # legacy numpy generator is fine
                "D401", # don't care about imperative mood
                # todos:
                "TD002", # don't care about author
                "TD003", # `make todo` will give us a table where we can create issues
                "PLR0913", # sometimes you have to have a lot of args
            ]
            select = ["ALL"]
            # select = ["ICN001"]

            [tool.ruff.lint.per-file-ignores]
                "tests/*" = [
                    # don't need docstrings in test functions or modules
                    "D100",
                    "D102",
                    "D103", 
                    "D107",
                    # don't need __init__ either
                    "INP001",
                    "ANN204",
                    # don't need type annotations in test functions
                    "ANN001",
                    "ANN201", 
                    "ANN202",
                    "TRY003", # long exception messages in tests are fine
                    "PLR2004", # magic values fine in tests
                ]
                "docs/*" = ["ALL"] # not our problem
                "**/*.ipynb" = [
                    "D103", # don't need docstrings
                    "PLR2004", # magic variables are fine
                    "N806", # uppercase vars are fine
                ]

        [tool.ruff.format]
            indent-style = "tab"
            skip-magic-trailing-comma = false

    [tool.pytest.ini_options]
        adopts = "--jaxtyping-packages=pattern_lens,beartype.beartype"
        filterwarnings = [
            "ignore: PEP 484 type hint*:beartype.roar._roarwarn.BeartypeDecorHintPep585DeprecationWarning",
        ]

    [tool.mypy]
        check_untyped_defs = true

    # `make lmcat` depends on the lmcat and can be configured here
    [tool.lmcat]
        output = "docs/other/lmcat.txt" # changing this might mean it wont be accessible from the docs
        ignore_patterns = [
            "!docs/resources/make_docs.py",
            "docs/**",
            ".venv/**",
            ".git/**",
            ".meta/**",
            "data/pile_demo.jsonl",
            "tests/**",
            "uv.lock",
            "LICENSE",
        ]
        [tool.lmcat.glob_process]
            "[mM]akefile" = "makefile_recipes"
            "*.ipynb" = "ipynb_to_md"

# for configuring this tool (makefile, make_docs.py)
# ============================================================
[tool.makefile]

# documentation configuration, for `make docs` and `make docs-clean`
[tool.makefile.docs]
    output_dir = "docs"
    no_clean = [
        ".nojekyll",
        "demo",
    ]
    markdown_headings_increment = 2
    warnings_ignore = [
        ".*No docstring.*",
        ".*Private member.*",
    ]
    [tool.makefile.docs.notebooks]
        enabled = false
        source_path = "notebooks"
        output_path_relative = "notebooks"
        # [tool.makefile.docs.notebooks.descriptions]
        #     "example" = "Example notebook showing basic usage"
        #     "advanced" = "Advanced usage patterns and techniques"
        
        

# Custom export configurations
# affects `make dep` and related commands
[tool.makefile.uv-exports]
	args = [
		"--no-hashes"
	]
	exports = [
		# # all groups and extras
		{ name = "all", filename="requirements.txt", groups = true, extras=true },
		# # all groups and extras, a different way
		{ name = "all", groups = true, options = ["--all-extras"] },
	]

# configures `make todo`
[tool.makefile.inline-todo]
	search_dir = "."
	out_file_base = "docs/other/todo-inline.md"
	context_lines = 2
	extensions = ["py", "md"]
	tags = ["CRIT", "TODO", "FIXME", "HACK", "BUG", "DOC", "TYPING"]
	exclude = [
		"docs/**",
		".venv/**",
		"scripts/get_todos.py",
	]
	branch = "main"
    [tool.makefile.inline-todo.tag_label_map]
        "BUG" = "bug"
        "TODO" = "enhancement"
		"DOC" = "documentation"

# ============================================================
``````{ end_of_file="pyproject.toml" }