#!/usr/bin/env python3
"""
Python script for downloading data from the NASA CMR API

Requirements:
    > pip install tqdm

To run the script: `python download_data_{{short_name}}_{{version}}_{{date}}.py`
"""

import os
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

CMR_BASE_URL = "https://cmr.earthdata.nasa.gov"

SHORT_NAME = "{{short_name}}"
VERSION = "{{version}}"
FILTER_TEMPORAL = "{{filter_temporal}}"
FILTER_BBOX = "{{filter_bbox}}"
FILTER_SEARCH = "{{filter_search}}"
FILTER_CLOUD_COVER_MIN = "{{filter_cloud_cover_min}}"
FILTER_CLOUD_COVER_MAX = "{{filter_cloud_cover_max}}"
DOWNLOAD_DIR = f"./{SHORT_NAME}_{VERSION}"
MAX_WORKERS = 5  # Number of parallel downloads

os.makedirs(DOWNLOAD_DIR, exist_ok=True)


def query_cmr_granules(
    short_name: str,
    version: str,
    page_size: int = 2000,
    search_after: str | None = None,
    **extra_params,
):
    """
    Queries the CMR granules endpoint.

    Args:
        short_name: The short name of the collection
        version: The version of the collection
        page_size: The number of results per page (max is 2000)
        search_after: The pagination token for subsequent requests (optional, see https://cmr.earthdata.nasa.gov/search/site/docs/search/api.html#search-after for more details)
        **extra_params: Additional query parameters (e.g., temporal, bounding_box, etc.)

    Returns:
        tuple: (response object, list of granule items)
    """
    url = f"{CMR_BASE_URL}/search/granules.umm_json"

    params = {
        "short_name": short_name,
        "version": version,
        "page_size": page_size,
        **extra_params,  # Merge any additional parameters
    }

    headers = {"Accept": "application/json"}
    if search_after:
        headers["CMR-Search-After"] = search_after

    response = requests.get(url, params=params, headers=headers)

    try:
        response.raise_for_status()
    except requests.exceptions.HTTPError:
        print("Failed to fetch granules:", response.text)
        raise

    data = response.json()
    items = data.get("items", [])

    return response, items


def download_data_from_cmr(
    short_name: str, version: str, total_granules: int, page_size: int = 2000, **params
):
    """Fetches granules for a given collection from the CMR API, then downloads the data from the "GET DATA" URLs"""
    search_after_value: str | None = None
    all_download_urls = []
    granules_without_urls = 0

    # First pass: collect all download URLs
    print("Collecting download URLs...")
    with tqdm(total=total_granules, desc="Collecting URLs", unit="granule") as pbar:
        while True:
            # Use shared query function
            response, items = query_cmr_granules(
                short_name, version, page_size, search_after_value, **params
            )

            # Collect "GET DATA" URLs
            for item in items:
                pbar.update(1)  # Update progress for each granule processed
                download_urls = []
                for related_url in item.get("umm", {}).get("RelatedUrls", []):
                    if related_url.get("Type") == "GET DATA":
                        download_urls.append(related_url.get("URL"))
                
                if download_urls:
                    all_download_urls.extend(download_urls)
                else:
                    granules_without_urls += 1

            # Read the next search-after value from response headers
            search_after_value = response.headers.get("CMR-Search-After")

            # There is no next search-after value, we've reached the end
            if not search_after_value:
                break

    print(f"Found {len(all_download_urls)} files to download")
    if granules_without_urls > 0:
        print(f"⚠️ {granules_without_urls} granules have no download URLs")

    # Second pass: download all files in parallel
    downloaded_count = 0
    skipped_count = 0
    failed_count = 0
    
    with tqdm(total=len(all_download_urls), desc="Downloading files", unit="file") as pbar:
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            # Submit all download tasks at once
            future_to_url = {executor.submit(download_file, url): url for url in all_download_urls}
            
            # Process completed downloads as they finish
            for future in as_completed(future_to_url):
                status, filename = future.result()
                if status == "success":
                    downloaded_count += 1
                    pbar.set_description(f"Downloaded: {filename}")
                elif status == "skipped":
                    skipped_count += 1
                    pbar.set_description(f"Skipped: {filename}")
                else:  # failed
                    failed_count += 1
                    print(f"Failed: {filename}")
                
                pbar.update(1)

    print(f"Downloaded: {downloaded_count} files")
    if skipped_count > 0:
        print(f"Skipped: {skipped_count} files (already exist)")
    if failed_count > 0:
        print(f"Failed: {failed_count} files")


def download_file(url: str):
    """Download a single file and return (success, skipped, failed) status and filename"""
    local_filename = os.path.join(DOWNLOAD_DIR, os.path.basename(url))
    if os.path.exists(local_filename):
        return "skipped", os.path.basename(url)  # skip existing

    try:
        with requests.get(url, stream=True, timeout=120) as r:
            r.raise_for_status()
            with open(local_filename, "wb") as f:
                for chunk in r.iter_content(8192):
                    f.write(chunk)
        return "success", os.path.basename(url)
    except Exception as e:
        print(f"⚠️ Error downloading {os.path.basename(url)}: {e}")
        return "failed", os.path.basename(url)


def fetch_total_granules_count_from_cmr(short_name: str, version: str, **params):
    """Fetches the total number of granules for a given collection and filters from the CMR API"""
    response, _ = query_cmr_granules(short_name, version, page_size=1, **params)
    data = response.json()
    return data.get("hits", 0)


def main():
    """Main function to download data from the CMR API"""
    filter_params = {}

    if FILTER_TEMPORAL:
        filter_params["temporal"] = FILTER_TEMPORAL

    if FILTER_BBOX:
        filter_params["bounding_box"] = FILTER_BBOX

    if FILTER_SEARCH:
        filter_params["producer_granule_id[]"] = FILTER_SEARCH
        filter_params["options[producer_granule_id][pattern]"] = 'true'

    if FILTER_CLOUD_COVER_MIN and FILTER_CLOUD_COVER_MAX:
        filter_params["cloud_cover"] = f"{FILTER_CLOUD_COVER_MIN},{FILTER_CLOUD_COVER_MAX}"

    total_granules = fetch_total_granules_count_from_cmr(
        SHORT_NAME, VERSION, **filter_params
    )
    print(f"Total granules: {total_granules:,}")

    download_data_from_cmr(SHORT_NAME, VERSION, total_granules, **filter_params)

    print("✅ All downloads complete.")


if __name__ == "__main__":
    main()
