#!/usr/bin/env python3
#
# Compute incremental P&L from fills
#

import argparse
import json
import sys
from typing import List, Dict, Any
from dataclasses import dataclass
from collections import defaultdict

import pandas as pd

# Import your existing classes
from bulk_api.common.comparisons import *
from bulk_api.common.inventory import Inventory, Position, Pnl

def parse_fills(json_file: str) -> List[Dict[str, Any]]:
    """Parse JSON file containing fills"""
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)
        return [item['fills'] for item in data]
    except Exception as e:
        print(f"Error reading JSON file: {e}")
        sys.exit(1)


def calculate_incremental_pnl(fills: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Calculate incremental and final P&L from fills"""
    inventory = Inventory()

    # Process fills in order
    results = []

    for i, fill in enumerate(fills):
        # Determine side (Buy/Sell)
        side = 1 if fill['isBuy'] else -1
        amount = fill['amount']
        price = fill['price']
        symbol = fill['symbol']
        timestamp_ns = fill['timestamp']

        # Convert nanosecond timestamp to pandas timestamp
        timestamp = pd.Timestamp(timestamp_ns, unit='ns')

        # Add trade to inventory
        inventory.traded(symbol, side, amount, price)

        # Calculate P&L at this point
        prices = {symbol: price}  # Use current price for unrealized P&L
        pnl = inventory.pv(prices)
        total_pnl = pnl.realized + pnl.unrealized

        # Store incremental P&L
        results.append({
            'fill_index': i,
            'timestamp': timestamp,
            'symbol': symbol,
            'amount': amount,
            'price': price,
            'side': 'Buy' if fill['isBuy'] else 'Sell',
            'realized_pnl': pnl.realized,
            'unrealized_pnl': pnl.unrealized,
            'total_pnl': total_pnl,
            'position': {
                'quantity': inventory.quantity_for(symbol),
                'vwap': inventory.position_for(symbol).vwap
            }
        })

    return results


def print_results(results: List[Dict[str, Any]], show_all: bool = True):
    """Print incremental P&L results"""
    print("Incremental P&L Calculation:")
    print("=" * 125)

    if show_all:
        print(f"{'Time':<27} {'Symbol':<10} {'Side':<5} {'Amount':<12} {'Price':<10} {'Realized':>12} {'Unrealized':>12} {'Total':>12} {'Position':>13}")
        print("-" * 125)

        for result in results:
            timestamp_str = result['timestamp'].strftime('%Y-%m-%d %H:%M:%S.%f')
            position_quantity = result['position']['quantity']
            print(f"{timestamp_str:<27} {result['symbol']:<10} {result['side']:<5} {result['amount']:<12.8f} {result['price']:<10} "
                  f"{result['realized_pnl']:>12.4f} {result['unrealized_pnl']:>12.4f} {result['total_pnl']:>12.2f} {position_quantity:>13.8f}")
    else:
        # Show only final results
        final_result = results[-1]
        timestamp_str = final_result['timestamp'].strftime('%Y-%m-%d %H:%M:%S.%f')
        print(f"Final P&L: {final_result['total_pnl']:.6f}")
        print(f"Time: {timestamp_str}")
        print(f"Final Position: {final_result['position']['quantity']:.8f} @ {final_result['position']['vwap']:.2f}")


def main():
    parser = argparse.ArgumentParser(description='Calculate incremental and final P&L from fills')
    parser.add_argument('json_file', help='Path to JSON file containing fills')
    parser.add_argument('--final-only', action='store_true', help='Show only final P&L')
    parser.add_argument('--summary', action='store_true', help='Show position summary')

    args = parser.parse_args()

    # Parse fills
    fills = parse_fills(args.json_file)

    if not fills:
        print("No fills found in JSON file")
        sys.exit(1)

    print(f"Processing {len(fills)} fills...")

    # Calculate P&L
    results = calculate_incremental_pnl(fills)

    # Print results
    print_results(results, not args.final_only)

    if args.summary:
        print("\nPosition Summary:")
        print("=" * 40)
        inventory = Inventory()
        for fill in fills:
            side = 1 if fill['isBuy'] else -1
            inventory.traded(fill['symbol'], side, fill['amount'], fill['price'])

        positions = inventory.summary()
        for pos in positions:
            print(f"{pos['id']}: {pos['amount']:.8f} @ {pos['price']:.2f} (PV: {pos['pv']:.6f})")

if __name__ == "__main__":
    main()
