#!/usr/bin/env python

"""
Thin wrapper around the "OpenTofu" command line interface (CLI) for use
with LocalStack.

The "tflocal" CLI allows you to easily interact with your local services
without having to specify the local endpoints in the "provider" section of
your TF config.
"""

import os
import sys
import glob
import subprocess

from urllib.parse import urlparse

PARENT_FOLDER = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
if os.path.isdir(os.path.join(PARENT_FOLDER, ".venv")):
    sys.path.insert(0, PARENT_FOLDER)

from localstack_client import config  # noqa: E402
import hcl2  # noqa: E402

DEFAULT_REGION = "us-east-1"
DEFAULT_ACCESS_KEY = "test"
AWS_ENDPOINT_URL = os.environ.get("AWS_ENDPOINT_URL")
CUSTOMIZE_ACCESS_KEY = str(os.environ.get("CUSTOMIZE_ACCESS_KEY")).strip().lower() in ["1", "true"]
LOCALHOST_HOSTNAME = "localhost.localstack.cloud"
S3_HOSTNAME = os.environ.get("S3_HOSTNAME") or f"s3.{LOCALHOST_HOSTNAME}"
USE_EXEC = str(os.environ.get("USE_EXEC")).strip().lower() in ["1", "true"]
TF_CMD = os.environ.get("TF_CMD") or "tofu"
LS_PROVIDERS_FILE = os.environ.get("LS_PROVIDERS_FILE") or "localstack_providers_override.tf"
LOCALSTACK_HOSTNAME = urlparse(AWS_ENDPOINT_URL).hostname or os.environ.get("LOCALSTACK_HOSTNAME") or "localhost"
EDGE_PORT = int(urlparse(AWS_ENDPOINT_URL).port or os.environ.get("EDGE_PORT") or 4566)
TF_PROVIDER_CONFIG = """
provider "aws" {
  access_key                  = "<access_key>"
  secret_key                  = "test"
  skip_credentials_validation = true
  skip_metadata_api_check     = true
  <configs>
 endpoints {
<endpoints>
 }
}
"""
TF_S3_BACKEND_CONFIG = """
terraform {
  backend "s3" {
    region         = "<region>"
    bucket         = "<bucket>"
    key            = "<key>"
    dynamodb_table = "<dynamodb_table>"

    access_key        = "test"
    secret_key        = "test"
    endpoint          = "<s3_endpoint>"
    iam_endpoint      = "<iam_endpoint>"
    sts_endpoint      = "<sts_endpoint>"
    dynamodb_endpoint = "<dynamodb_endpoint>"
    skip_credentials_validation = true
    skip_metadata_api_check     = true
  }
}
"""
PROCESS = None


# ---
# CONFIG GENERATION UTILS
# ---

def create_provider_config_file(provider_aliases=None):
    provider_aliases = provider_aliases or []

    # maps services to be replaced with alternative names
    # skip services which do not have equivalent endpoint overrides
    # see https://registry.terraform.io/providers/hashicorp/aws/latest/docs/guides/custom-service-endpoints
    service_replaces = {
        "apigatewaymanagementapi": "",
        "appconfigdata": "",
        "ce": "costexplorer",
        "dynamodbstreams": "",
        "edge": "",
        "emrserverless": "",
        "iotdata": "",
        "ioteventsdata": "",
        "iotjobsdata": "",
        "iotwireless": "",
        "logs": "cloudwatchlogs",
        "mediastoredata": "",
        "qldbsession": "",
        "rdsdata": "",
        "sagemakerruntime": "",
        "support": "",
        "timestream": "",
        "timestreamquery": "",
    }
    # service names to be excluded (not yet available in TF)
    service_excludes = ["meteringmarketplace"]

    # create list of service names
    services = list(config.get_service_ports())
    services = [srvc for srvc in services if srvc not in service_excludes]
    services = [s.replace("-", "") for s in services]
    for old, new in service_replaces.items():
        try:
            services.remove(old)
            if new:
                services.append(new)
        except ValueError:
            pass
    services = sorted(services)

    # add default (non-aliased) provider, if not defined yet
    default_provider = [p for p in provider_aliases if not p.get("alias")]
    if not default_provider:
        provider_aliases.append({"region": get_region()})

    # create provider configs
    provider_configs = []
    for provider in provider_aliases:
        provider_config = TF_PROVIDER_CONFIG.replace(
            "<access_key>",
            get_access_key(provider) if CUSTOMIZE_ACCESS_KEY else DEFAULT_ACCESS_KEY
        )
        endpoints = "\n".join([f'{s} = "{get_service_endpoint(s)}"' for s in services])
        provider_config = provider_config.replace("<endpoints>", endpoints)
        additional_configs = []
        if use_s3_path_style():
            additional_configs += [" s3_use_path_style = true"]
        alias = provider.get("alias")
        if alias:
            if isinstance(alias, list):
                alias = alias[0]
            additional_configs += [f' alias = "{alias}"']
        region = provider.get("region") or get_region()
        if isinstance(region, list):
            region = region[0]
        additional_configs += [f' region = "{region}"']
        provider_config = provider_config.replace("<configs>", "\n".join(additional_configs))
        provider_configs.append(provider_config)

    # construct final config file content
    tf_config = "\n".join(provider_configs)

    # create s3 backend config
    tf_config += generate_s3_backend_config()

    # write temporary config file
    providers_file = get_providers_file_path()
    if os.path.exists(providers_file):
        msg = f"Providers override file {providers_file} already exists - please delete it first"
        raise Exception(msg)
    with open(providers_file, mode="w") as fp:
        fp.write(tf_config)
    return providers_file


def get_providers_file_path() -> str:
    """Determine the path under which the providers override file should be stored"""
    chdir = [arg for arg in sys.argv if arg.startswith("-chdir=")]
    base_dir = "."
    if chdir:
        base_dir = chdir[0].removeprefix("-chdir=")
    return os.path.join(base_dir, LS_PROVIDERS_FILE)


def determine_provider_aliases() -> list:
    """Return a list of providers (and aliases) configured in the *.tf files (if any)"""
    result = []
    tf_files = parse_tf_files()
    for _file, obj in tf_files.items():
        try:
            providers = ensure_list(obj.get("provider", []))
            aws_providers = [prov["aws"] for prov in providers if prov.get("aws")]
            result.extend(aws_providers)
        except Exception as e:
            print(f"Warning: Unable to extract providers from {_file}:", e)
    return result


def generate_s3_backend_config() -> str:
    """Generate an S3 `backend {..}` block with local endpoints, if configured"""
    backend_config = None
    tf_files = parse_tf_files()
    for obj in tf_files.values():
        tf_configs = ensure_list(obj.get("terraform", []))
        for tf_config in tf_configs:
            backend_config = ensure_list(tf_config.get("backend"))
            if backend_config:
                backend_config = backend_config[0]
                break
    backend_config = backend_config and backend_config.get("s3")
    if not backend_config:
        return ""

    configs = {
        # note: default values, updated by `backend_config` further below...
        "bucket": "tf-test-state",
        "key": "terraform.tfstate",
        "dynamodb_table": "tf-test-state",
        "region": get_region(),
        "s3_endpoint": get_service_endpoint("s3"),
        "iam_endpoint": get_service_endpoint("iam"),
        "sts_endpoint": get_service_endpoint("sts"),
        "dynamodb_endpoint": get_service_endpoint("dynamodb"),
    }
    configs.update(backend_config)
    get_or_create_bucket(configs["bucket"])
    get_or_create_ddb_table(configs["dynamodb_table"], region=configs["region"])
    result = TF_S3_BACKEND_CONFIG
    for key, value in configs.items():
        value = str(value).lower() if isinstance(value, bool) else str(value)
        result = result.replace(f"<{key}>", value)
    return result


# ---
# AWS CLIENT UTILS
# ---

def use_s3_path_style() -> bool:
    """
    Whether to use S3 path addressing (depending on the configured S3 endpoint)
    If the endpoint starts with the `s3.` prefix, LocalStack will recognize virtual host addressing. If the endpoint
    does not start with it, use path style. This also allows overriding the endpoint to always use path style in case of
    inter container communications in Docker.
    """
    try:
        host = urlparse(get_service_endpoint("s3")).hostname
    except ValueError:
        host = ""

    return not host.startswith("s3.")


def get_region() -> str:
    region = str(os.environ.get("AWS_DEFAULT_REGION") or "").strip()
    if region:
        return region
    try:
        # If boto3 is installed, try to get the region from local credentials.
        # Note that boto3 is currently not included in the dependencies, to
        # keep the library lightweight.
        import boto3
        region = boto3.session.Session().region_name
    except Exception:
        pass
    # fall back to default region
    return region or DEFAULT_REGION


def get_access_key(provider: dict) -> str:
    access_key = str(os.environ.get("AWS_ACCESS_KEY_ID") or provider.get("access_key", "")).strip()
    if access_key and access_key != DEFAULT_ACCESS_KEY:
        # Change live access key to mocked one
        return deactivate_access_key(access_key)
    try:
        # If boto3 is installed, try to get the access_key from local credentials.
        # Note that boto3 is currently not included in the dependencies, to
        # keep the library lightweight.
        import boto3
        access_key = boto3.session.Session().get_credentials().access_key
    except Exception:
        pass
    # fall back to default region
    return deactivate_access_key(access_key or DEFAULT_ACCESS_KEY)


def deactivate_access_key(access_key: str) -> str:
    """Safe guarding user from accidental live credential usage by deactivating access key IDs.
        See more: https://docs.localstack.cloud/references/credentials/"""
    return "L" + access_key[1:] if access_key[0] == "A" else access_key


def get_service_endpoint(service: str) -> str:
    """Get the service endpoint URL for the given service name"""
    # allow configuring a custom endpoint via the environment
    env_name = f"{service.replace('-', '_').upper().strip()}_ENDPOINT"
    env_endpoint = os.environ.get(env_name, "").strip()
    if env_endpoint:
        if "://" not in env_endpoint:
            env_endpoint = f"http://{env_endpoint}"
        return env_endpoint

    # some services need specific hostnames
    hostname = LOCALSTACK_HOSTNAME
    if service == "s3":
        hostname = S3_HOSTNAME
    elif service == "mwaa":
        hostname = f"mwaa.{LOCALHOST_HOSTNAME}"

    return f"http://{hostname}:{EDGE_PORT}"


def connect_to_service(service: str, region: str = None):
    import boto3
    region = region or get_region()
    return boto3.client(
        service, endpoint_url=get_service_endpoint(service), region_name=region,
        aws_access_key_id="test", aws_secret_access_key="test",
    )


def get_or_create_bucket(bucket_name: str):
    """Get or create a bucket in the current region."""
    s3_client = connect_to_service("s3")
    try:
        return s3_client.head_bucket(Bucket=bucket_name)
    except Exception:
        region = s3_client.meta.region_name
        kwargs = {}
        if region != "us-east-1":
            kwargs = {"CreateBucketConfiguration": {"LocationConstraint": region}}
        return s3_client.create_bucket(Bucket=bucket_name, **kwargs)


def get_or_create_ddb_table(table_name: str, region: str = None):
    """Get or create a DynamoDB table with the given name."""
    ddb_client = connect_to_service("dynamodb", region=region)
    try:
        return ddb_client.describe_table(TableName=table_name)
    except Exception:
        return ddb_client.create_table(
            TableName=table_name, BillingMode="PAY_PER_REQUEST",
            KeySchema=[{"AttributeName": "LockID", "KeyType": "HASH"}],
            AttributeDefinitions=[{"AttributeName": "LockID", "AttributeType": "S"}]
        )


# ---
# TF UTILS
# ---

def parse_tf_files() -> dict:
    """Parse the local *.tf files and return a dict of <filename> -> <resource_dict>"""
    result = {}
    for _file in glob.glob("*.tf"):
        try:
            with open(_file, "r") as fp:
                result[_file] = hcl2.load(fp)
        except Exception as e:
            print(f'Unable to parse "{_file}" as HCL file: {e}')
    return result


def run_tf_exec(cmd, env):
    """Run terraform using os.exec - can be useful as it does not require any I/O
        handling for stdin/out/err. Does *not* allow us to perform any cleanup logic."""
    os.execvpe(cmd[0], cmd, env=env)


def run_tf_subprocess(cmd, env):
    """Run terraform in a subprocess - useful to perform cleanup logic at the end."""
    global PROCESS

    # register signal handlers
    import signal
    signal.signal(signal.SIGINT, signal_handler)

    PROCESS = subprocess.Popen(
        cmd, stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stdout)
    PROCESS.communicate()
    sys.exit(PROCESS.returncode)


# ---
# UTIL FUNCTIONS
# ---

def signal_handler(sig, frame):
    PROCESS.send_signal(sig)


def ensure_list(obj) -> list:
    return obj if isinstance(obj, list) else [obj]


def to_bytes(obj) -> bytes:
    return obj.encode("UTF-8") if isinstance(obj, str) else obj


def to_str(obj) -> bytes:
    return obj.decode("UTF-8") if isinstance(obj, bytes) else obj


# ---
# MAIN ENTRYPOINT
# ---

def main():
    env = dict(os.environ)
    cmd = [TF_CMD] + sys.argv[1:]

    # create TF provider config file
    providers = determine_provider_aliases()
    config_file = create_provider_config_file(providers)

    # call terraform command
    try:
        if USE_EXEC:
            run_tf_exec(cmd, env)
        else:
            run_tf_subprocess(cmd, env)
    finally:
        os.remove(config_file)


if __name__ == "__main__":
    main()
