from sqlbuild.loaders import loader


@loader(
    write_strategy="append",
    cursor_column="load_seq",
    columns=[
        {"name": "order_id", "type": "INTEGER"},
        {"name": "customer_id", "type": "INTEGER"},
        {"name": "waffle_type", "type": "VARCHAR"},
        {"name": "quantity", "type": "INTEGER"},
        {"name": "load_seq", "type": "INTEGER"},
    ],
)
def fetch_order_events(ctx):
    if ctx.current_cursor_value is None:
        next_seq = 1
    else:
        next_seq = ctx.current_cursor_value + 1
    first_order = (next_seq - 1) * 2 + 1
    return [
        {
            "order_id": first_order,
            "customer_id": 1 if next_seq == 1 else 3,
            "waffle_type": "classic",
            "quantity": next_seq,
            "load_seq": next_seq,
        },
        {
            "order_id": first_order + 1,
            "customer_id": 2,
            "waffle_type": "blueberry",
            "quantity": next_seq + 1,
            "load_seq": next_seq,
        },
    ]


@loader(
    write_strategy="merge",
    unique_key="customer_id",
    cursor_column="load_seq",
    columns=[
        {"name": "customer_id", "type": "INTEGER"},
        {"name": "plan_name", "type": "VARCHAR"},
        {"name": "load_seq", "type": "INTEGER"},
    ],
)
def fetch_customers(ctx):
    if ctx.current_cursor_value is None:
        return [
            {"customer_id": 1, "plan_name": "basic", "load_seq": 1},
            {"customer_id": 2, "plan_name": "plus", "load_seq": 1},
        ]
    return [
        {"customer_id": 1, "plan_name": "pro", "load_seq": 2},
        {"customer_id": 3, "plan_name": "enterprise", "load_seq": 2},
    ]


@loader(
    write_strategy="delete_insert",
    cursor_column="load_seq",
    columns=[
        {"name": "waffle_type", "type": "VARCHAR"},
        {"name": "price_cents", "type": "INTEGER"},
        {"name": "load_seq", "type": "INTEGER"},
    ],
)
def fetch_prices(ctx):
    classic_price = 600 if ctx.current_cursor_value is None else 650
    return [
        {"waffle_type": "classic", "price_cents": classic_price, "load_seq": 1},
        {"waffle_type": "blueberry", "price_cents": 750, "load_seq": 1},
    ]


@loader(depends_on=[fetch_order_events, fetch_prices])
def load_raw_orders(ctx):
    events = ctx.loader(fetch_order_events)
    prices = ctx.loader(fetch_prices)
    cursor = ctx.query(
        f"SELECT e.order_id, e.customer_id, e.waffle_type, e.quantity, "
        f"p.price_cents, e.load_seq FROM {events.target} e "
        f"JOIN {prices.target} p ON e.waffle_type = p.waffle_type "
        f"ORDER BY e.order_id"
    )
    return [
        {
            "order_id": row[0],
            "customer_id": row[1],
            "waffle_type": row[2],
            "quantity": row[3],
            "price_cents": row[4],
            "load_seq": row[5],
        }
        for row in cursor.fetchall()
    ]


@loader(depends_on=[fetch_customers])
def load_raw_customers(ctx):
    customers = ctx.loader(fetch_customers)
    cursor = ctx.query(
        f"SELECT customer_id, plan_name, load_seq FROM {customers.target} ORDER BY customer_id"
    )
    return [
        {"customer_id": row[0], "plan_name": row[1], "load_seq": row[2]}
        for row in cursor.fetchall()
    ]
