#!/usr/bin/env python3
"""Count or split Illumina dual-index strings from Undetermined FASTQs.

This is intended for large BCL Convert Undetermined/Unclassified R1 FASTQs.
It streams gzip data from local paths, s3:// URIs, or presigned http(s) URLs,
extracts the index token from each FASTQ header, and aggregates the top index
pairs across lanes. With --split-fastqs it streams paired R1/R2 FASTQs and
writes one combined FASTQ pair per selected index pair across all lanes.
"""

from __future__ import annotations

import argparse
import csv
import os
import re
import shutil
import shlex
import subprocess
import sys
import tempfile
from collections import Counter, defaultdict
from pathlib import Path


FILTER_SOURCE = r'''
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>

enum filter_mode_t { MODE_ALL = 0, MODE_UNCALLED = 1, MODE_CALLED = 2 };

static int contains_n(const char *s, size_t n) {
    for (size_t i = 0; i < n; ++i) {
        if (s[i] == 'N') return 1;
    }
    return 0;
}

static void process_header(const char *buf, size_t n, enum filter_mode_t mode) {
    if (n && buf[n - 1] == '\n') n--;
    if (n && buf[n - 1] == '\r') n--;

    const char *end = buf + n;
    const char *space = memchr(buf, ' ', n);
    if (!space) return;

    const char *p = end;
    while (p > space && *(p - 1) != ':') p--;
    if (p <= space || p >= end) return;

    size_t idx_n = (size_t)(end - p);
    int has_n = contains_n(p, idx_n);
    if (mode == MODE_UNCALLED && !has_n) return;
    if (mode == MODE_CALLED && has_n) return;

    fwrite(p, 1, idx_n, stdout);
    fputc('\n', stdout);
}

static enum filter_mode_t parse_mode(const char *s) {
    if (strcmp(s, "all") == 0) return MODE_ALL;
    if (strcmp(s, "uncalled") == 0) return MODE_UNCALLED;
    if (strcmp(s, "called") == 0) return MODE_CALLED;
    fprintf(stderr, "unknown mode: %s\n", s);
    exit(2);
}

int main(int argc, char **argv) {
    if (argc != 3) {
        fprintf(stderr, "usage: %s MODE HEADER_COUNT_FILE\n", argv[0]);
        return 2;
    }
    enum filter_mode_t mode = parse_mode(argv[1]);
    const char *count_path = argv[2];

    char *line = NULL;
    size_t cap = 0;
    unsigned long long headers = 0;
    unsigned line_mod = 0;
    ssize_t len;

    while ((len = getline(&line, &cap, stdin)) != -1) {
        if (line_mod == 0) {
            headers++;
            process_header(line, (size_t)len, mode);
        }
        line_mod = (line_mod + 1) & 3u;
    }
    free(line);

    FILE *fh = fopen(count_path, "w");
    if (!fh) {
        perror("fopen header count");
        return 3;
    }
    fprintf(fh, "%llu\n", headers);
    fclose(fh);
    return ferror(stdin) ? 4 : 0;
}
'''

SPLITTER_SOURCE = r'''
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <zlib.h>

enum filter_mode_t { MODE_ALL = 0, MODE_UNCALLED = 1, MODE_CALLED = 2 };

struct entry_t {
    char *tag;
    unsigned long long count;
    gzFile r1;
    gzFile r2;
};

struct table_t {
    struct entry_t *entries;
    size_t n_entries;
    size_t max_entries;
    size_t hash_size;
    size_t *hash_to_entry;
};

static unsigned long long fnv1a(const char *s) {
    unsigned long long h = 1469598103934665603ULL;
    while (*s) {
        h ^= (unsigned char)*s++;
        h *= 1099511628211ULL;
    }
    return h;
}

static size_t next_power_two(size_t n) {
    size_t p = 1;
    while (p < n) p <<= 1;
    return p;
}

static int contains_n(const char *s, size_t n) {
    for (size_t i = 0; i < n; ++i) {
        if (s[i] == 'N') return 1;
    }
    return 0;
}

static enum filter_mode_t parse_mode(const char *s) {
    if (strcmp(s, "all") == 0) return MODE_ALL;
    if (strcmp(s, "uncalled") == 0) return MODE_UNCALLED;
    if (strcmp(s, "called") == 0) return MODE_CALLED;
    fprintf(stderr, "unknown mode: %s\n", s);
    exit(2);
}

static void die_errno(const char *what) {
    fprintf(stderr, "%s: %s\n", what, strerror(errno));
    exit(2);
}

static void init_table(struct table_t *table, size_t max_entries) {
    memset(table, 0, sizeof(*table));
    table->max_entries = max_entries;
    table->hash_size = next_power_two(max_entries * 4 + 16);
    table->entries = calloc(max_entries, sizeof(struct entry_t));
    table->hash_to_entry = calloc(table->hash_size, sizeof(size_t));
    if (!table->entries || !table->hash_to_entry) die_errno("calloc");
}

static int lookup_entry(struct table_t *table, const char *tag) {
    unsigned long long h = fnv1a(tag);
    size_t mask = table->hash_size - 1;
    size_t pos = (size_t)h & mask;
    for (;;) {
        size_t entry_plus_one = table->hash_to_entry[pos];
        if (entry_plus_one == 0) return -1;
        size_t idx = entry_plus_one - 1;
        if (strcmp(table->entries[idx].tag, tag) == 0) return (int)idx;
        pos = (pos + 1) & mask;
    }
}

static int add_entry(struct table_t *table, const char *tag) {
    if (table->n_entries >= table->max_entries) {
        fprintf(
            stderr,
            "too many unique tag pairs; saw at least %zu, max is %zu\n",
            table->n_entries + 1,
            table->max_entries
        );
        exit(3);
    }
    int existing = lookup_entry(table, tag);
    if (existing >= 0) return existing;

    size_t idx = table->n_entries++;
    table->entries[idx].tag = strdup(tag);
    if (!table->entries[idx].tag) die_errno("strdup");

    unsigned long long h = fnv1a(tag);
    size_t mask = table->hash_size - 1;
    size_t pos = (size_t)h & mask;
    while (table->hash_to_entry[pos] != 0) pos = (pos + 1) & mask;
    table->hash_to_entry[pos] = idx + 1;
    return (int)idx;
}

static void sanitize_tag(const char *tag, char *out, size_t out_size) {
    size_t j = 0;
    for (size_t i = 0; tag[i] && j + 1 < out_size; ++i) {
        char c = tag[i];
        if (c == '+') {
            if (j + 2 >= out_size) break;
            out[j++] = '_';
            out[j++] = '_';
        } else if (
            (c >= 'A' && c <= 'Z') ||
            (c >= 'a' && c <= 'z') ||
            (c >= '0' && c <= '9') ||
            c == '-' ||
            c == '_'
        ) {
            out[j++] = c;
        } else {
            out[j++] = '_';
        }
    }
    out[j] = '\0';
}

static void ensure_open(struct entry_t *entry, const char *out_dir) {
    if (entry->r1 && entry->r2) return;
    char safe[512];
    char path1[4096];
    char path2[4096];
    sanitize_tag(entry->tag, safe, sizeof(safe));
    snprintf(path1, sizeof(path1), "%s/%s_R1.fastq.gz", out_dir, safe);
    snprintf(path2, sizeof(path2), "%s/%s_R2.fastq.gz", out_dir, safe);
    entry->r1 = gzopen(path1, "ab");
    if (!entry->r1) {
        fprintf(stderr, "unable to open %s\n", path1);
        exit(4);
    }
    entry->r2 = gzopen(path2, "ab");
    if (!entry->r2) {
        fprintf(stderr, "unable to open %s\n", path2);
        exit(4);
    }
}

static void gzwrite_checked(gzFile fh, const char *buf, size_t n) {
    if (gzwrite(fh, buf, (unsigned)n) != (int)n) {
        fprintf(stderr, "gzwrite failed\n");
        exit(5);
    }
}

static int read_record(FILE *fh, char **lines, size_t *caps) {
    ssize_t n0 = getline(&lines[0], &caps[0], fh);
    if (n0 == -1) return 0;
    for (int i = 1; i < 4; ++i) {
        if (getline(&lines[i], &caps[i], fh) == -1) {
            fprintf(stderr, "truncated FASTQ record\n");
            exit(6);
        }
    }
    return 1;
}

static int extract_tag(char *header, char *tag, size_t tag_size, enum filter_mode_t mode) {
    size_t n = strlen(header);
    while (n && (header[n - 1] == '\n' || header[n - 1] == '\r')) n--;
    char *space = memchr(header, ' ', n);
    if (!space) return 0;
    char *end = header + n;
    char *p = end;
    while (p > space && *(p - 1) != ':') p--;
    if (p <= space || p >= end) return 0;
    size_t tag_n = (size_t)(end - p);
    if (tag_n + 1 > tag_size) {
        fprintf(stderr, "index token is too long: %zu bytes\n", tag_n);
        exit(7);
    }
    int has_n = contains_n(p, tag_n);
    if (mode == MODE_UNCALLED && !has_n) return 0;
    if (mode == MODE_CALLED && has_n) return 0;
    memcpy(tag, p, tag_n);
    tag[tag_n] = '\0';
    return 1;
}

static void write_record(gzFile out, char **lines) {
    for (int i = 0; i < 4; ++i) {
        gzwrite_checked(out, lines[i], strlen(lines[i]));
    }
}

static void load_allowlist(struct table_t *table, const char *path) {
    if (strcmp(path, "-") == 0) return;
    FILE *fh = fopen(path, "r");
    if (!fh) die_errno("fopen allowlist");
    char *line = NULL;
    size_t cap = 0;
    while (getline(&line, &cap, fh) != -1) {
        size_t n = strlen(line);
        while (n && (line[n - 1] == '\n' || line[n - 1] == '\r')) line[--n] = '\0';
        if (n == 0 || line[0] == '#') continue;
        add_entry(table, line);
    }
    free(line);
    fclose(fh);
}

int main(int argc, char **argv) {
    if (argc != 8) {
        fprintf(
            stderr,
            "usage: %s MODE ALLOWLIST_OR_- MAX_TAGS OUT_DIR SUMMARY_PATH R1_FIFO R2_FIFO\n",
            argv[0]
        );
        return 2;
    }
    enum filter_mode_t mode = parse_mode(argv[1]);
    const char *allowlist = argv[2];
    size_t max_tags = (size_t)strtoull(argv[3], NULL, 10);
    const char *out_dir = argv[4];
    const char *summary_path = argv[5];
    const char *r1_path = argv[6];
    const char *r2_path = argv[7];
    int use_allowlist = strcmp(allowlist, "-") != 0;

    struct table_t table;
    init_table(&table, max_tags);
    load_allowlist(&table, allowlist);

    FILE *r1 = fopen(r1_path, "r");
    if (!r1) die_errno("fopen R1");
    FILE *r2 = fopen(r2_path, "r");
    if (!r2) die_errno("fopen R2");

    char *r1_lines[4] = {0};
    char *r2_lines[4] = {0};
    size_t r1_caps[4] = {0};
    size_t r2_caps[4] = {0};
    char tag[512];
    char r2_tag[512];
    unsigned long long scanned = 0;
    unsigned long long matched = 0;

    for (;;) {
        int have_r1 = read_record(r1, r1_lines, r1_caps);
        int have_r2 = read_record(r2, r2_lines, r2_caps);
        if (!have_r1 && !have_r2) break;
        if (have_r1 != have_r2) {
            fprintf(stderr, "R1/R2 FASTQ record count mismatch\n");
            exit(8);
        }
        scanned++;
        if (!extract_tag(r1_lines[0], tag, sizeof(tag), mode)) continue;
        if (!extract_tag(r2_lines[0], r2_tag, sizeof(r2_tag), MODE_ALL)) {
            fprintf(stderr, "R2 FASTQ header is missing an index token\n");
            exit(9);
        }
        if (strcmp(tag, r2_tag) != 0) {
            fprintf(stderr, "R1/R2 index token mismatch: %s != %s\n", tag, r2_tag);
            exit(10);
        }
        int entry_idx = lookup_entry(&table, tag);
        if (entry_idx < 0) {
            if (use_allowlist) continue;
            entry_idx = add_entry(&table, tag);
        }
        struct entry_t *entry = &table.entries[entry_idx];
        ensure_open(entry, out_dir);
        write_record(entry->r1, r1_lines);
        write_record(entry->r2, r2_lines);
        entry->count++;
        matched++;
    }

    fclose(r1);
    fclose(r2);
    for (int i = 0; i < 4; ++i) {
        free(r1_lines[i]);
        free(r2_lines[i]);
    }
    for (size_t i = 0; i < table.n_entries; ++i) {
        if (table.entries[i].r1) gzclose(table.entries[i].r1);
        if (table.entries[i].r2) gzclose(table.entries[i].r2);
    }

    FILE *summary = fopen(summary_path, "w");
    if (!summary) die_errno("fopen summary");
    fprintf(summary, "#scanned\t%llu\n", scanned);
    fprintf(summary, "#matched\t%llu\n", matched);
    for (size_t i = 0; i < table.n_entries; ++i) {
        if (table.entries[i].count > 0) {
            fprintf(summary, "%s\t%llu\n", table.entries[i].tag, table.entries[i].count);
        }
    }
    fclose(summary);

    for (size_t i = 0; i < table.n_entries; ++i) free(table.entries[i].tag);
    free(table.entries);
    free(table.hash_to_entry);
    return 0;
}
'''


LANE_RE = re.compile(r"(?:^|[_./-])L0*([0-9]+)(?:[_./-]|$)")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description=(
            "Extract, rank, or split Illumina index strings from "
            "Undetermined/Unclassified FASTQs."
        )
    )
    parser.add_argument(
        "inputs",
        nargs="+",
        help="Input .fastq.gz path, s3:// URI, or presigned http(s) URL.",
    )
    parser.add_argument(
        "-o",
        "--output",
        default="-",
        help="Output TSV path. Default: stdout.",
    )
    parser.add_argument(
        "--mode",
        choices=("all", "uncalled", "called"),
        default="uncalled",
        help="Which index tokens to count. 'uncalled' means index or index2 contains N.",
    )
    parser.add_argument(
        "--top",
        type=int,
        default=96,
        help="Number of ranked rows to emit.",
    )
    parser.add_argument(
        "--total-reads",
        type=int,
        default=None,
        help=(
            "Denominator for pct_of_all_reads. If omitted, uses the number of "
            "FASTQ records scanned across inputs."
        ),
    )
    parser.add_argument(
        "--tmp-dir",
        default=None,
        help="Temporary work directory for sort files and per-input counts.",
    )
    parser.add_argument(
        "--sort-memory",
        default="8G",
        help="Memory hint passed to sort -S.",
    )
    parser.add_argument(
        "--keep-tmp",
        action="store_true",
        help="Do not remove the temporary work directory.",
    )
    parser.add_argument(
        "--quiet",
        action="store_true",
        help="Suppress progress messages on stderr.",
    )
    parser.add_argument(
        "--split-fastqs",
        action="store_true",
        help="Also split paired FASTQs by index pair instead of only counting indexes.",
    )
    parser.add_argument(
        "--read2-inputs",
        nargs="+",
        default=None,
        help=(
            "R2 .fastq.gz paths, s3:// URIs, or presigned URLs matching the positional "
            "R1 inputs. Required with --split-fastqs."
        ),
    )
    parser.add_argument(
        "--fastq-out-dir",
        default=None,
        help="Directory where split FASTQ.gz files are written. Required with --split-fastqs.",
    )
    parser.add_argument(
        "--tag-pairs-tsv",
        default=None,
        help=(
            "Optional allowlist TSV with index/index2 columns, or one index+index2 per line. "
            "Only matching tag pairs are emitted."
        ),
    )
    parser.add_argument(
        "--split-all-tags",
        action="store_true",
        help=(
            "Permit splitting every tag pair that passes --mode. Without this or "
            "--tag-pairs-tsv, --split-fastqs fails to avoid creating millions of files."
        ),
    )
    parser.add_argument(
        "--max-split-tags",
        type=int,
        default=10000,
        help="Maximum unique tag pairs allowed when --split-all-tags is used.",
    )
    parser.add_argument(
        "--overwrite-fastqs",
        action="store_true",
        help="Allow removing existing *_R1.fastq.gz and *_R2.fastq.gz files in --fastq-out-dir.",
    )
    return parser.parse_args()


def log(message: str, *, quiet: bool) -> None:
    if not quiet:
        print(message, file=sys.stderr, flush=True)


def require_tool(name: str) -> str:
    path = shutil.which(name)
    if not path:
        raise SystemExit(f"required command not found on PATH: {name}")
    return path


def compile_filter(work_dir: Path, quiet: bool) -> Path:
    gcc = require_tool("gcc")
    source = work_dir / "ilmn_index_filter.c"
    binary = work_dir / "ilmn_index_filter"
    source.write_text(FILTER_SOURCE, encoding="utf-8")
    cmd = [gcc, "-O3", "-Wall", "-o", str(binary), str(source)]
    log("compiling FASTQ header filter", quiet=quiet)
    subprocess.run(cmd, check=True)
    return binary


def compile_splitter(work_dir: Path, quiet: bool) -> Path:
    gcc = require_tool("gcc")
    source = work_dir / "ilmn_index_splitter.c"
    binary = work_dir / "ilmn_index_splitter"
    source.write_text(SPLITTER_SOURCE, encoding="utf-8")
    cmd = [gcc, "-O3", "-Wall", "-o", str(binary), str(source), "-lz"]
    log("compiling paired FASTQ splitter", quiet=quiet)
    subprocess.run(cmd, check=True)
    return binary


def lane_for_input(input_ref: str, ordinal: int) -> str:
    match = LANE_RE.search(input_ref)
    if match:
        return str(int(match.group(1)))
    return str(ordinal)


def producer_command(input_ref: str) -> list[str]:
    if input_ref.startswith("s3://"):
        aws = require_tool("aws")
        return [aws, "s3", "cp", input_ref, "-"]
    if input_ref.startswith("http://") or input_ref.startswith("https://"):
        curl = require_tool("curl")
        return [
            curl,
            "-L",
            "--fail",
            "--retry",
            "20",
            "--retry-delay",
            "10",
            "--speed-time",
            "120",
            "--speed-limit",
            "1024",
            input_ref,
        ]
    cat = require_tool("cat")
    return [cat, input_ref]


def decompressor_command() -> list[str]:
    pigz = shutil.which("pigz")
    if pigz:
        return [pigz, "-dc"]
    gzip = require_tool("gzip")
    return [gzip, "-dc"]


def shell_pipeline_to_fifo(input_ref: str, fifo: Path) -> str:
    producer = producer_command(input_ref)
    decompressor = decompressor_command()
    return f"{shlex.join(producer)} | {shlex.join(decompressor)} > {shlex.quote(str(fifo))}"


def run_pipeline(
    input_ref: str,
    *,
    lane: str,
    mode: str,
    filter_binary: Path,
    work_dir: Path,
    sort_memory: str,
    quiet: bool,
) -> tuple[Path, int]:
    lane_dir = work_dir / f"lane_{lane}"
    lane_dir.mkdir(parents=True, exist_ok=True)
    header_count_path = lane_dir / "headers.txt"
    count_path = lane_dir / "index_counts.txt"

    producer = producer_command(input_ref)
    decompressor = decompressor_command()
    filter_cmd = [str(filter_binary), mode, str(header_count_path)]
    sort_cmd = ["sort", "-T", str(lane_dir), "-S", sort_memory]
    uniq_cmd = ["uniq", "-c"]

    log(f"scanning lane {lane}: {input_ref}", quiet=quiet)
    with count_path.open("wb") as out_fh:
        p1 = subprocess.Popen(producer, stdout=subprocess.PIPE)
        assert p1.stdout is not None
        p2 = subprocess.Popen(decompressor, stdin=p1.stdout, stdout=subprocess.PIPE)
        p1.stdout.close()
        assert p2.stdout is not None
        p3 = subprocess.Popen(filter_cmd, stdin=p2.stdout, stdout=subprocess.PIPE)
        p2.stdout.close()
        assert p3.stdout is not None
        p4 = subprocess.Popen(sort_cmd, stdin=p3.stdout, stdout=subprocess.PIPE)
        p3.stdout.close()
        assert p4.stdout is not None
        p5 = subprocess.Popen(uniq_cmd, stdin=p4.stdout, stdout=out_fh)
        p4.stdout.close()

        statuses = [
            ("producer", p1.wait()),
            ("decompressor", p2.wait()),
            ("filter", p3.wait()),
            ("sort", p4.wait()),
            ("uniq", p5.wait()),
        ]

    failed = [(name, rc) for name, rc in statuses if rc != 0]
    if failed:
        detail = ", ".join(f"{name}={rc}" for name, rc in failed)
        raise RuntimeError(f"pipeline failed for lane {lane}: {detail}")

    try:
        headers = int(header_count_path.read_text(encoding="utf-8").strip())
    except FileNotFoundError as exc:
        raise RuntimeError(f"header count file missing for lane {lane}") from exc
    log(f"lane {lane}: scanned {headers} FASTQ records", quiet=quiet)
    return count_path, headers


def load_counts(paths_by_lane: list[tuple[str, Path]]) -> tuple[Counter[str], dict[str, set[str]]]:
    counts: Counter[str] = Counter()
    lanes_by_index: dict[str, set[str]] = defaultdict(set)
    for lane, path in paths_by_lane:
        with path.open("rt", encoding="utf-8") as fh:
            for line in fh:
                line = line.strip()
                if not line:
                    continue
                count_s, index = line.split(maxsplit=1)
                count = int(count_s)
                counts[index] += count
                lanes_by_index[index].add(lane)
    return counts, lanes_by_index


def write_tsv(
    output: str,
    *,
    counts: Counter[str],
    lanes_by_index: dict[str, set[str]],
    denominator: int,
    top: int,
) -> None:
    rows = sorted(counts.items(), key=lambda item: (-item[1], item[0]))[:top]
    if output == "-":
        fh = sys.stdout
        close = False
    else:
        fh = open(output, "w", newline="", encoding="utf-8")
        close = True
    try:
        writer = csv.writer(fh, delimiter="\t")
        writer.writerow(["rank", "index", "index2", "count", "pct_of_all_reads", "lanes_detected"])
        for rank, (index_pair, count) in enumerate(rows, 1):
            if "+" in index_pair:
                index, index2 = index_pair.split("+", 1)
            else:
                index, index2 = index_pair, ""
            pct = (count / denominator * 100.0) if denominator else 0.0
            writer.writerow(
                [
                    rank,
                    index,
                    index2,
                    count,
                    f"{pct:.8f}",
                    len(lanes_by_index[index_pair]),
                ]
            )
    finally:
        if close:
            fh.close()


def parse_tag_pairs(path: Path) -> list[str]:
    raw_lines = [
        line.rstrip("\n\r")
        for line in path.read_text(encoding="utf-8").splitlines()
        if line.strip() and not line.lstrip().startswith("#")
    ]
    if not raw_lines:
        raise ValueError(f"tag-pair allowlist is empty: {path}")

    sample = "\n".join(raw_lines[:5])
    delimiter = "\t" if "\t" in sample else ","
    first = raw_lines[0].split(delimiter)
    lowered = [field.strip().lower() for field in first]
    tags: list[str] = []
    if "index" in lowered and "index2" in lowered:
        reader = csv.DictReader(raw_lines, delimiter=delimiter)
        for row in reader:
            index = (row.get("index") or "").strip()
            index2 = (row.get("index2") or "").strip()
            if index and index2:
                tags.append(f"{index}+{index2}")
    else:
        for line in raw_lines:
            fields = [field.strip() for field in line.split(delimiter)]
            if len(fields) >= 2 and "+" not in fields[0]:
                tags.append(f"{fields[0]}+{fields[1]}")
            else:
                tags.append(fields[0])

    unique_tags = list(dict.fromkeys(tags))
    if not unique_tags:
        raise ValueError(f"no usable tag pairs found in {path}")
    return unique_tags


def write_allowlist(args: argparse.Namespace, work_dir: Path) -> tuple[Path | None, int]:
    if args.tag_pairs_tsv:
        tags = parse_tag_pairs(Path(args.tag_pairs_tsv))
        allowlist = work_dir / "tag_pairs.allowlist.txt"
        allowlist.write_text("\n".join(tags) + "\n", encoding="utf-8")
        return allowlist, len(tags)
    if args.split_all_tags:
        return None, args.max_split_tags
    raise SystemExit(
        "--split-fastqs requires --tag-pairs-tsv, or --split-all-tags if you "
        "really want one FASTQ pair per observed tag pair."
    )


def prepare_fastq_out_dir(path: str, *, overwrite: bool) -> Path:
    out_dir = Path(path)
    out_dir.mkdir(parents=True, exist_ok=True)
    existing_fastqs = sorted(out_dir.glob("*_R[12].fastq.gz"))
    if existing_fastqs and not overwrite:
        examples = ", ".join(str(p) for p in existing_fastqs[:3])
        raise SystemExit(
            f"{out_dir} already contains split FASTQs; use --overwrite-fastqs to replace "
            f"them. Examples: {examples}"
        )
    if overwrite:
        for pattern in ("*_R1.fastq.gz", "*_R2.fastq.gz"):
            for path in out_dir.glob(pattern):
                path.unlink()
    return out_dir


def run_split_pair(
    read1_ref: str,
    read2_ref: str,
    *,
    lane: str,
    mode: str,
    splitter_binary: Path,
    allowlist: Path | None,
    max_tags: int,
    fastq_out_dir: Path,
    work_dir: Path,
    quiet: bool,
) -> Path:
    lane_dir = work_dir / f"split_lane_{lane}"
    lane_dir.mkdir(parents=True, exist_ok=True)
    r1_fifo = lane_dir / "r1.fastq"
    r2_fifo = lane_dir / "r2.fastq"
    summary = lane_dir / "split_summary.tsv"
    for fifo in (r1_fifo, r2_fifo):
        if fifo.exists():
            fifo.unlink()
        os.mkfifo(fifo)

    allowlist_arg = str(allowlist) if allowlist else "-"
    splitter_cmd = [
        str(splitter_binary),
        mode,
        allowlist_arg,
        str(max_tags),
        str(fastq_out_dir),
        str(summary),
        str(r1_fifo),
        str(r2_fifo),
    ]
    script = "\n".join(
        [
            "set -euo pipefail",
            f"({shell_pipeline_to_fifo(read1_ref, r1_fifo)}) &",
            "r1_pid=$!",
            f"({shell_pipeline_to_fifo(read2_ref, r2_fifo)}) &",
            "r2_pid=$!",
            "split_rc=0",
            f"{shlex.join(splitter_cmd)} || split_rc=$?",
            "wait $r1_pid",
            "wait $r2_pid",
            "exit $split_rc",
        ]
    )
    log(f"splitting lane {lane}: {read1_ref} / {read2_ref}", quiet=quiet)
    subprocess.run(["bash", "-c", script], check=True)
    return summary


def load_split_summaries(paths_by_lane: list[tuple[str, Path]]) -> tuple[Counter[str], dict[str, set[str]], int, int]:
    counts: Counter[str] = Counter()
    lanes_by_index: dict[str, set[str]] = defaultdict(set)
    scanned = 0
    matched = 0
    for lane, path in paths_by_lane:
        with path.open("rt", encoding="utf-8") as fh:
            for line in fh:
                line = line.rstrip("\n")
                if not line:
                    continue
                key, value = line.split("\t", 1)
                if key == "#scanned":
                    scanned += int(value)
                elif key == "#matched":
                    matched += int(value)
                else:
                    count = int(value)
                    counts[key] += count
                    lanes_by_index[key].add(lane)
    return counts, lanes_by_index, scanned, matched


def fastq_paths_for_tag(out_dir: Path, tag: str) -> tuple[Path, Path]:
    safe = []
    for char in tag:
        if char == "+":
            safe.append("__")
        elif char.isalnum() or char in {"-", "_"}:
            safe.append(char)
        else:
            safe.append("_")
    stem = "".join(safe)
    return out_dir / f"{stem}_R1.fastq.gz", out_dir / f"{stem}_R2.fastq.gz"


def write_split_tsv(
    output: str,
    *,
    counts: Counter[str],
    lanes_by_index: dict[str, set[str]],
    denominator: int,
    fastq_out_dir: Path,
) -> None:
    rows = sorted(counts.items(), key=lambda item: (-item[1], item[0]))
    if output == "-":
        fh = sys.stdout
        close = False
    else:
        fh = open(output, "w", newline="", encoding="utf-8")
        close = True
    try:
        writer = csv.writer(fh, delimiter="\t")
        writer.writerow(
            [
                "rank",
                "index",
                "index2",
                "count",
                "pct_of_all_reads",
                "lanes_detected",
                "read1_fastq",
                "read2_fastq",
            ]
        )
        for rank, (tag, count) in enumerate(rows, 1):
            index, index2 = tag.split("+", 1) if "+" in tag else (tag, "")
            read1, read2 = fastq_paths_for_tag(fastq_out_dir, tag)
            pct = (count / denominator * 100.0) if denominator else 0.0
            writer.writerow(
                [
                    rank,
                    index,
                    index2,
                    count,
                    f"{pct:.8f}",
                    len(lanes_by_index[tag]),
                    read1,
                    read2,
                ]
            )
    finally:
        if close:
            fh.close()


def run_split_fastqs(args: argparse.Namespace, work_dir: Path) -> None:
    if not args.read2_inputs:
        raise SystemExit("--read2-inputs is required with --split-fastqs")
    if len(args.read2_inputs) != len(args.inputs):
        raise SystemExit("--read2-inputs must have the same number of entries as positional R1 inputs")
    if not args.fastq_out_dir:
        raise SystemExit("--fastq-out-dir is required with --split-fastqs")
    if args.max_split_tags < 1:
        raise SystemExit("--max-split-tags must be positive")

    fastq_out_dir = prepare_fastq_out_dir(args.fastq_out_dir, overwrite=args.overwrite_fastqs)
    allowlist, max_tags = write_allowlist(args, work_dir)
    splitter_binary = compile_splitter(work_dir, args.quiet)
    summaries: list[tuple[str, Path]] = []
    for ordinal, (read1_ref, read2_ref) in enumerate(zip(args.inputs, args.read2_inputs), 1):
        lane = lane_for_input(read1_ref, ordinal)
        summary = run_split_pair(
            read1_ref,
            read2_ref,
            lane=lane,
            mode=args.mode,
            splitter_binary=splitter_binary,
            allowlist=allowlist,
            max_tags=max_tags,
            fastq_out_dir=fastq_out_dir,
            work_dir=work_dir,
            quiet=args.quiet,
        )
        summaries.append((lane, summary))

    counts, lanes_by_index, scanned, matched = load_split_summaries(summaries)
    denominator = args.total_reads if args.total_reads is not None else scanned
    write_split_tsv(
        args.output,
        counts=counts,
        lanes_by_index=lanes_by_index,
        denominator=denominator,
        fastq_out_dir=fastq_out_dir,
    )
    log(
        f"split {matched} read pairs across {len(counts)} tag pairs from {scanned} scanned records",
        quiet=args.quiet,
    )
    if args.output != "-":
        log(f"wrote {args.output}", quiet=args.quiet)


def main() -> int:
    args = parse_args()
    if args.top < 1:
        raise SystemExit("--top must be positive")

    base_tmp = Path(args.tmp_dir) if args.tmp_dir else Path(tempfile.mkdtemp(prefix="ilmn-indexes-"))
    base_tmp.mkdir(parents=True, exist_ok=True)
    remove_tmp = not args.keep_tmp and args.tmp_dir is None

    try:
        if args.split_fastqs:
            run_split_fastqs(args, base_tmp)
            return 0

        filter_binary = compile_filter(base_tmp, args.quiet)
        count_paths: list[tuple[str, Path]] = []
        scanned_records = 0
        for ordinal, input_ref in enumerate(args.inputs, 1):
            lane = lane_for_input(input_ref, ordinal)
            count_path, headers = run_pipeline(
                input_ref,
                lane=lane,
                mode=args.mode,
                filter_binary=filter_binary,
                work_dir=base_tmp,
                sort_memory=args.sort_memory,
                quiet=args.quiet,
            )
            count_paths.append((lane, count_path))
            scanned_records += headers

        counts, lanes_by_index = load_counts(count_paths)
        denominator = args.total_reads if args.total_reads is not None else scanned_records
        write_tsv(
            args.output,
            counts=counts,
            lanes_by_index=lanes_by_index,
            denominator=denominator,
            top=args.top,
        )
        log(
            f"counted {sum(counts.values())} {args.mode} index observations "
            f"across {len(counts)} unique index pairs",
            quiet=args.quiet,
        )
        if args.output != "-":
            log(f"wrote {args.output}", quiet=args.quiet)
    finally:
        if remove_tmp:
            shutil.rmtree(base_tmp, ignore_errors=True)
        elif not args.quiet:
            print(f"temporary files: {base_tmp}", file=sys.stderr)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
