#!python

"""
Parallel sync optimizes throughput by utilizing available CPUs/threads,
particularly in environments with high network latency.

Scenario & Challenge:
When your PostgreSQL database, Elasticsearch/OpenSearch, and PGSync servers
are on different networks, request/response latency becomes significant.
The main bottleneck is the database query roundtrip time. Server-side cursors
help by fetching batches of records, but the delay between cursor fetches
still limits overall synchronization speed.

Solution:
Perform an initial fast parallel sync to populate Elasticsearch/OpenSearch
in one pass. After this, regular pgsync can continue running as a daemon.

Technical Approach:
This implementation uses PostgreSQL's tuple identifier system column "ctid"
(type "tid"), which identifies the page and row number of each record.
The ctid is a tuple like (1, 5), indicating the disk page and row position.

The sync process works as follows:
1. Retrieve all ctid values from the root table upfront
2. Divide records into work units based on BLOCK_SIZE (number of root node
   records per worker)
3. Distribute work units across available CPUs/threads
4. Workers execute queries filtered by page and row numbers in parallel
5. Each worker performs bulk inserts to Elasticsearch/OpenSearch

This parallel approach significantly improves synchronization performance,
especially in high-latency network environments.
"""

import asyncio
import multiprocessing
import os
import re
import sys
import typing as t
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from dataclasses import dataclass
from queue import Queue
from threading import Thread

import click
import sqlalchemy as sa

from pgsync.settings import (
    BLOCK_SIZE,
    CHECKPOINT_PATH,
    S3_SCHEMA_URL,
    SCHEMA,
    SCHEMA_URL,
)
from pgsync.sync import Sync
from pgsync.utils import (
    config_loader,
    MutuallyExclusiveOption,
    show_settings,
    timeit,
    validate_config,
)


def save_ctid(page: int, row: int, filename: str) -> None:
    """
    Save the checkpoint for a given page and row in a file with the given name.

    Args:
        page (int): The page number to save.
        row (int): The row number to save.
        filename (str): The name of the file to save the checkpoint in.
    """
    filepath: str = os.path.join(CHECKPOINT_PATH, f".{filename}.ctid")
    with open(filepath, "w+") as fp:
        fp.write(f"{page},{row}\n")


def read_ctid(filename: str) -> t.Tuple[t.Optional[int], t.Optional[int]]:
    """
    Reads the checkpoint file for the given name and returns the page and row numbers.

    Args:
        filename (str): The name of the checkpoint file.

    Returns:
        tuple: A tuple containing the page and row numbers. If the checkpoint file does not exist, returns (None, None).
    """
    filepath: str = os.path.join(CHECKPOINT_PATH, f".{filename}.ctid")
    if os.path.exists(filepath):
        with open(filepath, "r") as fp:
            pairs: str = fp.read().split()[0].split(",")
            page: int = int(pairs[0])
            row: int = int(pairs[1])
            return page, row
    return None, None


def logical_slot_changes(
    doc: dict, verbose: bool = False, validate: bool = False
) -> None:
    # now sync up to txmax to capture everything we may have missed
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.logical_slot_changes(txmin=txmin, txmax=txmax)
    sync.checkpoint = txmax or sync.txid_current


@dataclass
class Task:
    doc: dict
    verbose: bool = False
    validate: bool = False

    def process(self, task: dict) -> None:
        sync: Sync = Sync(
            self.doc, verbose=self.verbose, validate=self.validate
        )
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        sync.search_client.bulk(
            sync.index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
        sys.stdout.write(f"Process pid: {os.getpid()} complete.\n")
        sys.stdout.flush()


@timeit
def fetch_tasks(
    doc: dict,
    block_size: t.Optional[int] = None,
) -> t.Generator:
    block_size = block_size or BLOCK_SIZE
    pages: dict = {}
    sync: Sync = Sync(doc)
    page: t.Optional[int] = None
    row: t.Optional[int] = None
    filename: str = re.sub(
        "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
    )
    page, row = read_ctid(filename)
    statement: sa.sql.Select = sa.select(
        *[
            sa.literal_column("1").label("x"),
            sa.literal_column("1").label("y"),
            sa.column("ctid"),
        ]
    ).select_from(sync.tree.root.model)

    # filter by Page
    if page:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    1,
                ),
                sa.Integer,
            )
            > page
        )

    # filter by Row
    if row:
        statement = statement.where(
            sa.cast(
                sa.func.SPLIT_PART(
                    sa.func.REPLACE(
                        sa.func.REPLACE(
                            sa.cast(sa.column("ctid"), sa.Text),
                            "(",
                            "",
                        ),
                        ")",
                        "",
                    ),
                    ",",
                    2,
                ),
                sa.Integer,
            )
            > row
        )

    i: int = 1
    for _, _, ctid in sync.fetchmany(statement):
        value: list = ctid[0].split(",")
        page: int = int(value[0].replace("(", ""))
        row: int = int(value[1].replace(")", ""))
        pages.setdefault(page, [])
        pages[page].append(row)
        if i % block_size == 0:
            yield pages
            pages = {}
        i += 1
    yield pages


@timeit
def synchronous(
    tasks: t.Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Synchronous\n")
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    index: str = sync.index
    for task in tasks:
        sync.search_client.bulk(
            index,
            sync.sync(ctid=task, txmin=txmin, txmax=txmax),
        )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded(
    tasks: t.Generator,
    doc: dict,
    nthreads: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multithreaded\n")

    def worker(sync: Sync, queue: Queue) -> None:
        txmin: int = sync.checkpoint
        txmax: int = sync.txid_current
        while True:
            task: dict = queue.get()
            sync.search_client.bulk(
                sync.index,
                sync.sync(ctid=task, txmin=txmin, txmax=txmax),
            )
            queue.task_done()

    nthreads: int = nthreads or 1
    queue: Queue = Queue()
    sync: Sync = Sync(doc, verbose=verbose, validate=validate)

    for _ in range(nthreads):
        thread: Thread = Thread(
            target=worker,
            args=(
                sync,
                queue,
            ),
        )
        thread.daemon = True
        thread.start()
    for task in tasks:
        queue.put(task)

    queue.join()  # block until all tasks are done
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess(
    tasks: t.Generator,
    doc: dict,
    ncpus: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multiprocess\n")
    task: Task = Task(doc, verbose=verbose, validate=validate)
    with ProcessPoolExecutor(max_workers=ncpus) as executor:
        try:
            list(executor.map(task.process, tasks))
        except Exception as e:
            sys.stdout.write(f"Exception: {e}\n")
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multithreaded_async(
    tasks: t.Generator,
    doc: dict,
    nthreads: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-threaded async\n")
    executor: ThreadPoolExecutor = ThreadPoolExecutor(max_workers=nthreads)
    event_loop = asyncio.get_event_loop()
    event_loop.run_until_complete(
        run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
    )
    logical_slot_changes(doc, verbose=verbose, validate=validate)


@timeit
def multiprocess_async(
    tasks: t.Generator,
    doc: dict,
    ncpus: t.Optional[int] = None,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sys.stdout.write("Multi-process async\n")
    executor: ProcessPoolExecutor = ProcessPoolExecutor(max_workers=ncpus)
    event_loop = asyncio.get_event_loop()
    try:
        event_loop.run_until_complete(
            run_tasks(executor, tasks, doc, verbose=verbose, validate=validate)
        )
    except KeyboardInterrupt:
        pass
    logical_slot_changes(doc, verbose=verbose, validate=validate)


async def run_tasks(
    executor: t.Union[ThreadPoolExecutor, ProcessPoolExecutor],
    tasks: t.Generator,
    doc: dict,
    verbose: bool = False,
    validate: bool = False,
) -> None:
    sync: t.Optional[Sync] = None
    if isinstance(executor, ThreadPoolExecutor):
        # threads can share a common Sync object
        sync = Sync(doc, verbose=verbose, validate=validate)
    event_loop = asyncio.get_event_loop()
    completed, _ = await asyncio.wait(
        [
            event_loop.run_in_executor(
                executor, run_task, task, sync, doc, verbose, validate
            )
            for task in tasks
        ]
    )
    results: list = [task.result() for task in completed]
    print("results: {!r}".format(results))
    print("exiting")


def run_task(
    task: dict,
    sync: t.Optional[Sync] = None,
    doc: t.Optional[dict] = None,
    verbose: bool = False,
    validate: bool = False,
) -> int:
    if sync is None:
        sync: Sync = Sync(doc, verbose=verbose, validate=validate)
    txmin: int = sync.checkpoint
    txmax: int = sync.txid_current
    sync.search_client.bulk(
        sync.index,
        sync.sync(ctid=task, txmin=txmin, txmax=txmax),
    )
    if len(task) > 0:
        page: int = max(task.keys())
        row: int = max(task[page])
        filename: str = re.sub(
            "[^0-9a-zA-Z_]+", "", f"{sync.database.lower()}_{sync.index}"
        )
        save_ctid(page, row, filename)

    return 1


@click.command()
@click.option(
    "--config",
    "-c",
    help="Schema config",
    type=click.Path(exists=True),
    default=SCHEMA,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["s3_schema_url", "schema_url"],
)
@click.option(
    "--schema_url",
    help="URL for schema config",
    type=click.STRING,
    default=SCHEMA_URL,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["config", "s3_schema_url"],
)
@click.option(
    "--s3_schema_url",
    help="S3 URL for schema config",
    type=click.STRING,
    default=S3_SCHEMA_URL,
    show_default=True,
    cls=MutuallyExclusiveOption,
    mutually_exclusive=["config", "schema_url"],
)
@click.option(
    "--verbose",
    "-v",
    is_flag=True,
    default=False,
    help="Turn on verbosity",
)
@click.option(
    "--nprocs",
    "-n",
    help="Number of threads/process",
    type=int,
    default=multiprocessing.cpu_count() * 2,
)
@click.option(
    "--mode",
    "-m",
    help="Sync mode",
    type=click.Choice(
        [
            "synchronous",
            "multithreaded",
            "multiprocess",
            "multithreaded_async",
            "multiprocess_async",
        ],
        case_sensitive=False,
    ),
    default="multiprocess_async",
)
def main(
    config: str,
    schema_url: str,
    s3_schema_url: str,
    nprocs: int,
    mode: str,
    verbose: bool,
) -> None:
    """
    TODO:
    - Track progress across cpus/threads
    - Handle KeyboardInterrupt Exception
    """

    validate_config(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    )

    show_settings(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    )

    for doc in config_loader(
        config=config, schema_url=schema_url, s3_schema_url=s3_schema_url
    ):
        tasks: t.Generator = fetch_tasks(doc)
        if mode == "synchronous":
            synchronous(tasks, doc, verbose=verbose)
        elif mode == "multithreaded":
            multithreaded(tasks, doc, nthreads=nprocs, verbose=verbose)
        elif mode == "multiprocess":
            multiprocess(tasks, doc, ncpus=nprocs, verbose=verbose)
        elif mode == "multithreaded_async":
            multithreaded_async(tasks, doc, nthreads=nprocs, verbose=verbose)
        elif mode == "multiprocess_async":
            multiprocess_async(tasks, doc, ncpus=nprocs, verbose=verbose)


if __name__ == "__main__":
    main()
