# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datetime import datetime, timedelta, timezone
from airflow import DAG
import papermill as pm
from airflow.operators.python_operator import PythonOperator



default_args = {
    'owner': '{{owner}}',
    'start_date': '{{start_date}}',
    'retries': '{{retry_count}}',
    'retry_delay': timedelta(minutes=int('{{retry_delay}}')), 
    'email': {{email | safe}},
    'email_on_failure': {{email_failure}},     
    'email_on_retry': {{email_delay}},      
    'email_on_success': {{email_success}}
}

def write_output_to_file(run_id, **kwargs):
    output_file_path = f"{{output_notebook}}{run_id}.ipynb"
    print(output_file_path)
    kwargs['ti'].xcom_push(key='output_file_path', value=output_file_path)
    return output_file_path

def convert_parameters(param_str):
    try:
        param_dict = {}
        for item in param_str.split(','):
            key, value = item.split(':')
            param_dict[key.strip()] = value.strip() 
        return param_dict
    except Exception as e:
        raise ValueError(f"Invalid parameters: {param_str} - Error: {e}")

time_zone = '{{time_zone}}'
input_notebook = '{{input_notebook}}'
output_notebook = {% raw %}"{{ ti.xcom_pull(task_ids='generate_output_file') }}"{% endraw %}


dag = DAG(
    '{{name}}', 
    default_args=default_args,
    description='{{name}}',
    tags =['scheduler_jupyter_plugin'],
    schedule_interval='{{schedule_interval}}',
    catchup = False
)

def run_notebook_task(**kwargs):
    parameters = '''
    {{parameters}}
    '''
    if parameters.strip():
        parameters_dict = convert_parameters(parameters)
    else:
        parameters_dict = {}
    output_notebook = kwargs['ti'].xcom_pull(task_ids='generate_output_file')
    pm.execute_notebook(input_notebook, output_notebook, kernel_name = "python3", parameters=parameters_dict)

 
write_output_task = PythonOperator(
    task_id='generate_output_file',
    python_callable=write_output_to_file,
    provide_context=True,  
    op_kwargs={'run_id': {% raw %}'{{run_id}}'{% endraw %}},  
    dag=dag
)

execute_notebook_task = PythonOperator(
    task_id='execute_notebook',
    python_callable=run_notebook_task,
    provide_context=True,
    dag=dag
)

write_output_task >> execute_notebook_task