# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime, timedelta, timezone
from airflow import DAG
from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator
from airflow.operators.python_operator import PythonOperator
from google.cloud import dataproc_v1
from google.api_core.client_options import ClientOptions

default_args = {
    'owner': '{{owner}}',
    'start_date': '{{start_date}}',
    'retries': '{{retry_count}}',
    'retry_delay': timedelta(minutes=int('{{retry_delay}}')), 
    'email': {{email | safe}},
    'email_on_failure': {{email_failure}},     
    'email_on_retry': {{email_delay}},     
    'email_on_success': {{email_success}}
}

def write_output_to_file(run_id, **kwargs):
    output_file_path = f"{{output_notebook}}{run_id}.ipynb"
    print(output_file_path)
    kwargs['ti'].xcom_push(key='output_file_path', value=output_file_path)
    return output_file_path
    
time_zone = '{{time_zone}}'
stop_cluster_check = '{{stop_cluster}}'
input_notebook = '{{input_notebook}}'
output_notebook = {% raw %}"{{ ti.xcom_pull(task_ids='generate_output_file') }}"{% endraw %}
notebook_args = [input_notebook, output_notebook]
parameters = '''
{{parameters}}
'''
notebook_args= [input_notebook, output_notebook] 
# Check if parameters is not empty or contains only whitespace
if parameters.strip():  
     notebook_args.extend(["--parameters", parameters])


def get_client_cert():
    # code to load client certificate and private key.
    return client_cert_bytes, client_private_key_bytes
 

def get_cluster_state_start_if_not_running():

    options = ClientOptions(api_endpoint="{{gcpRegion}}-dataproc.googleapis.com:443",
    client_cert_source=get_client_cert)

    # Create a client
    client = dataproc_v1.ClusterControllerClient(client_options=options)

    # Initialize request argument(s)
    request = dataproc_v1.GetClusterRequest(
        project_id='{{gcpProjectId}}',
        region='{{gcpRegion}}',
        cluster_name='{{cluster_name}}',
    )

    # Make the request
    response = client.get_cluster(request=request)    
   
    # Handle the response
    print(f"State is {response.status.state}")
    if response.status.state in (6, 7):
        print("Cluster is in stopped/stopping state. Starting the cluster")
        request1 = dataproc_v1.StartClusterRequest(
            project_id='{{gcpProjectId}}',
            region='{{gcpRegion}}',
            cluster_name='{{cluster_name}}',
        )
        operation = client.start_cluster(request=request1)
        print("Waiting for operation to complete...")
        response = operation.result()
        if response.status.state in (2, 5):
            print("Cluster is started succesfully")    
    elif response.status.state in (2, 5):
       print("Cluster is already running")
    else:
        print("Cluster is unavailable")
        raise Exception("Cluster is unavailable")

 
def stop_the_cluster():
    if '{{stop_cluster}}' == 'True':
        options = ClientOptions(api_endpoint="{{gcpRegion}}-dataproc.googleapis.com:443",
            client_cert_source=get_client_cert)
    
        # Create a client
        client = dataproc_v1.ClusterControllerClient(client_options=options)
    
        # Initialize request argument(s)
        request = dataproc_v1.StopClusterRequest(
            project_id='{{gcpProjectId}}',
            region='{{gcpRegion}}',
            cluster_name='{{cluster_name}}',
        )
    
        # Make the request
        operation = client.stop_cluster(request=request)
        print("Waiting for operation to complete...")
        response = operation.result()
        if response.status.state in (6, 7):
            print("Cluster is stopped succesfully")
    
        # Handle the response
        print(response)

dag = DAG(
    '{{name}}', 
    default_args=default_args,
    description='{{name}}',
    tags =['scheduler_jupyter_plugin'],
    schedule_interval='{{schedule_interval}}',
    catchup= False
)


start_cluster = PythonOperator(
    task_id='start_cluster',
    python_callable=get_cluster_state_start_if_not_running,
    retries= 2,
    provide_context=True,
    dag=dag)

write_output_task = PythonOperator(
    task_id='generate_output_file',
    python_callable=write_output_to_file,
    provide_context=True,  
    op_kwargs={'run_id': {% raw %}'{{run_id}}'{% endraw %}},  
    dag=dag
)

submit_pyspark_job = DataprocSubmitJobOperator(
    task_id='submit_pyspark_job',
    project_id='{{gcpProjectId}}',  # This parameter can be overridden by the connection
    region='{{gcpRegion}}',  # This parameter can be overridden by the connection 
    job={
        'reference': {'project_id': '{{gcpProjectId}}'},
        'placement': {'cluster_name': '{{cluster_name}}'},
        'labels': {'client': 'scheduler-jupyter-plugin'},
        'pyspark_job': {
            'main_python_file_uri': '{{inputFilePath}}',
            'args' : notebook_args
        },
    },
    {% if multi_tenant_service_account %}
    impersonation_chain=['{{multi_tenant_service_account}}'],
    {% endif %}
    gcp_conn_id='google_cloud_default',  # Reference to the GCP connection
    dag=dag,
)

stop_cluster = PythonOperator(
        task_id='stop_cluster',
        python_callable=stop_the_cluster,
        provide_context=True,
        dag=dag)
    
start_cluster >> write_output_task >> submit_pyspark_job >> stop_cluster 

