# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from datetime import datetime, timedelta, timezone
import uuid
from airflow import DAG
from airflow.providers.google.cloud.operators.dataproc import DataprocCreateBatchOperator
from airflow.operators.python_operator import PythonOperator
import os


default_args = {
    'owner': '{{owner}}',
    'start_date': '{{start_date}}',
    'retries': '{{retry_count}}',
    'retry_delay': timedelta(minutes=int('{{retry_delay}}')), 
    'email': {{email | safe}},
    'email_on_failure': {{email_failure}},     
    'email_on_retry': {{email_delay}},      
    'email_on_success': {{email_success}}
}

def write_output_to_file(run_id, **kwargs):
    output_file_path = f"{{output_notebook}}{run_id}.ipynb"
    print(output_file_path)
    kwargs['ti'].xcom_push(key='output_file_path', value=output_file_path)
    return output_file_path


time_zone = '{{time_zone}}'
serverless_name = '{{serverless_name}}'
input_notebook = '{{input_notebook}}'
output_notebook = {% raw %}"{{ ti.xcom_pull(task_ids='generate_output_file') }}"{% endraw %}
notebook_args = [input_notebook, output_notebook]
parameters = '''
{{parameters}}
'''
notebook_args= [input_notebook, output_notebook] 
# Check if parameters is not empty or contains only whitespace
if parameters.strip(): 
    notebook_args.extend(["--parameters", parameters])


dag = DAG(
    '{{name}}', 
    default_args=default_args,
    description='{{name}}',
    tags =['scheduler_jupyter_plugin'],
    schedule_interval='{{schedule_interval}}',
    catchup = False
)
 
write_output_task = PythonOperator(
    task_id='generate_output_file',
    python_callable=write_output_to_file,
    provide_context=True,  
    op_kwargs={'run_id': {% raw %}'{{run_id}}'{% endraw %}},  
    dag=dag
)

create_batch = DataprocCreateBatchOperator(
        task_id="batch_create",
        project_id = '{{gcpProjectId}}',
        region = '{{gcpRegion}}',
        batch={
            "pyspark_batch": {
                "main_python_file_uri": '{{inputFilePath}}',
                "args": notebook_args                
            },
            "environment_config": {
                "peripherals_config": {
                    {% if metastore_service %}
                    "metastore_service": '{{metastore_service}}',
                    {% endif %}
                    "spark_history_server_config": {
                        "dataproc_cluster": '{{phs_path}}',
                    },
                },
            },
            
            "runtime_config": {
                {% if custom_container %}
                "container_image": '{{custom_container}}',
                {% endif %}
                {% if version %}
                "version":'{{version}}'
                {% endif %}
            },
            
        },
        batch_id=str(uuid.uuid4()),
        dag = dag,
    )
write_output_task >> create_batch