#!/usr/bin/env python3
"""
SPY Optimizer — Historical Surge Detector
Parcourt les klines 1m depuis le 01/01/2026 pour détecter rétroactivement les surges
et simuler le résultat de chaque trade. Génère un dataset d'entraînement beaucoup plus gros.

Usage:
  python build_historical_dataset.py                  # Génère le dataset complet
  python build_historical_dataset.py --quick           # 20 symboles seulement (test)
  python build_historical_dataset.py --symbols 50      # Top 50 par volume
"""
import argparse
import json
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import pandas as pd

from feature_engineering import compute_features_at_timestamp, FEATURE_COLUMNS
from config import STABLECOINS

# ─── Paths ───
PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
KLINES_DIR = DATA_DIR / "klines_1m"
OUTPUT_FILE = DATA_DIR / "historical_training_dataset.parquet"

# ─── Surge Detection Constants (from market_spy.py) ───
FLASH_MIN_CHANGE = 1.0       # % en 1 scan (~1 min en klines)
BREAKOUT_MIN_CHANGE_2 = 1.1  # % en 2 scans
BREAKOUT_MIN_CHANGE_1 = 0.5  # % dernière scan
MOMENTUM_CHANGE_20 = 4.0     # % en 20 min
MOMENTUM_CHANGE_40 = 6.0     # % en 40 min
MOMENTUM_MIN_LAST = 0.3      # % dernière minute

# ─── Trade Simulation Constants ───
TAKE_PROFIT_PCT = 3.0    # TP %
STOP_LOSS_PCT = -2.0     # SL %
TRAILING_ACTIVATE = 1.5  # Trailing s'active à +1.5%
TRAILING_DROP = 0.8      # Drop % depuis le max pour trailing exit
MAX_HOLD_MINUTES = 30    # Max hold time
COOLDOWN_MINUTES = 10    # Cooldown par symbole entre deux détections


def detect_surges_in_klines(df: pd.DataFrame, symbol: str) -> list[dict]:
    """
    Détecte les surges dans une série de klines 1m.
    Retourne une liste de surges détectées avec timestamp et type.
    """
    if len(df) < 60:
        return []

    close = df["close"].values
    open_time = df["open_time"].values
    volume = df["quote_volume"].values

    surges = []
    last_surge_ts = 0

    for i in range(45, len(df) - MAX_HOLD_MINUTES - 5):
        ts = int(open_time[i])

        # Cooldown check
        if ts - last_surge_ts < COOLDOWN_MINUTES * 60 * 1000:
            continue

        price = close[i]
        if price <= 0:
            continue

        # Change calculations
        change_1 = (price / close[i - 1] - 1) * 100 if close[i - 1] > 0 else 0
        change_2 = (price / close[i - 2] - 1) * 100 if close[i - 2] > 0 else 0
        change_5 = (price / close[i - 5] - 1) * 100 if close[i - 5] > 0 else 0
        change_20 = (price / close[i - 20] - 1) * 100 if i >= 20 and close[i - 20] > 0 else 0
        change_40 = (price / close[i - 40] - 1) * 100 if i >= 40 and close[i - 40] > 0 else 0

        # Volume ratio (current vs avg 30m)
        avg_vol_30 = np.mean(volume[max(0, i - 30):i]) if i >= 30 else np.mean(volume[:i])
        vol_ratio = volume[i] / avg_vol_30 if avg_vol_30 > 0 else 1

        surge_type = None
        surge_strength = 0

        # FLASH_SURGE: spike ≥1% en 1 min + volume
        if change_1 >= FLASH_MIN_CHANGE and vol_ratio >= 2.0:
            surge_type = "FLASH_SURGE"
            surge_strength = change_1

        # BREAKOUT_SURGE: ≥1.1% en 2 min, still accelerating
        elif change_2 >= BREAKOUT_MIN_CHANGE_2 and change_1 >= BREAKOUT_MIN_CHANGE_1 and vol_ratio >= 1.5:
            surge_type = "BREAKOUT_SURGE"
            surge_strength = change_2

        # MOMENTUM_SURGE: gradual rise over 20-40 min
        elif i >= 20 and change_1 >= MOMENTUM_MIN_LAST:
            if change_20 >= MOMENTUM_CHANGE_20 or (i >= 40 and change_40 >= MOMENTUM_CHANGE_40):
                # Monotonicity check: at least 3/4 quartiles rising
                segment = close[i - 20:i + 1]
                q_size = len(segment) // 4
                if q_size >= 2:
                    rising = sum(
                        1 for q in range(4)
                        if segment[min((q + 1) * q_size, len(segment) - 1)] > segment[q * q_size]
                    )
                    if rising >= 3:
                        surge_type = "MOMENTUM_SURGE"
                        surge_strength = change_20

        if surge_type:
            surges.append({
                "symbol": symbol,
                "timestamp_ms": ts,
                "surge_type": surge_type,
                "surge_strength": surge_strength,
                "vol_ratio": vol_ratio,
                "index": i,
            })
            last_surge_ts = ts

    return surges


def simulate_trade(df: pd.DataFrame, entry_idx: int, entry_price: float) -> dict:
    """
    Simule un trade après détection de surge.
    Utilise un trailing stop + TP/SL comme le bot réel.
    """
    close = df["close"].values
    high = df["high"].values
    low = df["low"].values

    max_pnl = 0
    trailing_active = False
    exit_reason = "MAX_HOLD"
    exit_price = entry_price
    hold_minutes = MAX_HOLD_MINUTES

    for j in range(1, MAX_HOLD_MINUTES + 1):
        idx = entry_idx + j
        if idx >= len(df):
            exit_price = close[min(idx - 1, len(df) - 1)]
            hold_minutes = j
            exit_reason = "END_OF_DATA"
            break

        current_high = high[idx]
        current_low = low[idx]
        current_close = close[idx]

        pnl_high = (current_high / entry_price - 1) * 100
        pnl_low = (current_low / entry_price - 1) * 100
        pnl_close = (current_close / entry_price - 1) * 100

        max_pnl = max(max_pnl, pnl_high)

        # Stop loss hit
        if pnl_low <= STOP_LOSS_PCT:
            exit_price = entry_price * (1 + STOP_LOSS_PCT / 100)
            exit_reason = "STOP_LOSS"
            hold_minutes = j
            break

        # Take profit hit
        if pnl_high >= TAKE_PROFIT_PCT:
            exit_price = entry_price * (1 + TAKE_PROFIT_PCT / 100)
            exit_reason = "TAKE_PROFIT"
            hold_minutes = j
            break

        # Trailing stop
        if pnl_high >= TRAILING_ACTIVATE:
            trailing_active = True

        if trailing_active:
            drop_from_max = max_pnl - pnl_close
            if drop_from_max >= TRAILING_DROP:
                exit_price = current_close
                exit_reason = "TRAILING"
                hold_minutes = j
                break

        if j == MAX_HOLD_MINUTES:
            exit_price = current_close
            exit_reason = "MAX_HOLD"
            hold_minutes = j

    pnl_pct = (exit_price / entry_price - 1) * 100
    pnl_usdt = pnl_pct  # Normalized, not actual USDT

    return {
        "exit_price": float(exit_price),
        "pnl_pct": float(pnl_pct),
        "pnl_usdt": float(pnl_usdt),
        "max_pnl": float(max_pnl),
        "hold_minutes": hold_minutes,
        "exit_reason": exit_reason,
    }


def build_historical_dataset(max_symbols: int = 0, quick: bool = False) -> pd.DataFrame:
    """
    Construit un dataset massif en détectant les surges dans tout l'historique klines.
    """
    parquet_files = sorted(KLINES_DIR.glob("*.parquet"))
    if not parquet_files:
        print("  ❌ No klines data found!")
        return pd.DataFrame()

    # Filter: only USDC+USDT, skip stablecoins
    valid_files = []
    for f in parquet_files:
        symbol = f.stem
        if symbol in STABLECOINS:
            continue
        if not (symbol.endswith("USDC") or symbol.endswith("USDT")):
            continue
        valid_files.append(f)

    if max_symbols > 0:
        valid_files = valid_files[:max_symbols]
    if quick:
        valid_files = valid_files[:20]

    print(f"  📂 Processing {len(valid_files)} symbols...")

    all_rows = []
    total_surges = 0
    total_profitable = 0

    t0 = time.time()

    for file_idx, parquet_file in enumerate(valid_files):
        symbol = parquet_file.stem
        df = pd.read_parquet(parquet_file)

        if len(df) < 120:
            continue

        # Detect surges
        surges = detect_surges_in_klines(df, symbol)

        for surge in surges:
            idx = surge["index"]
            entry_price = df["close"].values[idx]
            ts_ms = surge["timestamp_ms"]

            # Simulate trade
            trade_result = simulate_trade(df, idx, entry_price)

            # Compute features at surge moment
            features = compute_features_at_timestamp(df, ts_ms, lookback_minutes=120)
            if features is None:
                continue

            # Add metadata
            features["symbol"] = symbol
            features["surge_type"] = surge["surge_type"]
            features["surge_strength"] = surge["surge_strength"]
            entry_dt = pd.Timestamp(ts_ms, unit="ms", tz="UTC")
            features["entry_time"] = entry_dt.isoformat()

            # Add targets
            features["target_profitable"] = 1 if trade_result["pnl_usdt"] > 0 else 0
            features["target_pnl_pct"] = trade_result["pnl_pct"]
            features["target_max_pnl"] = trade_result["max_pnl"]
            features["target_exit_reason"] = trade_result["exit_reason"]

            all_rows.append(features)
            total_surges += 1
            if trade_result["pnl_usdt"] > 0:
                total_profitable += 1

        if (file_idx + 1) % 20 == 0 or file_idx == len(valid_files) - 1:
            elapsed = time.time() - t0
            rate = (file_idx + 1) / elapsed
            eta = (len(valid_files) - file_idx - 1) / rate if rate > 0 else 0
            print(f"    [{file_idx + 1}/{len(valid_files)}] {symbol:20s} | "
                  f"Surges: {total_surges:5d} | "
                  f"Profitable: {total_profitable}/{total_surges} "
                  f"({total_profitable / max(total_surges, 1) * 100:.1f}%) | "
                  f"ETA: {eta:.0f}s")

    elapsed = time.time() - t0

    if not all_rows:
        print("  ❌ No surges detected!")
        return pd.DataFrame()

    dataset = pd.DataFrame(all_rows)

    # Stats
    print(f"\n{'═' * 60}")
    print(f"  📊 Historical Dataset Built")
    print(f"{'═' * 60}")
    print(f"  Symbols:     {len(valid_files)}")
    print(f"  Surges:      {total_surges}")
    print(f"  Profitable:  {total_profitable} ({total_profitable / total_surges * 100:.1f}%)")
    print(f"  Time:        {elapsed:.0f}s")

    # By surge type
    for st in ["FLASH_SURGE", "BREAKOUT_SURGE", "MOMENTUM_SURGE"]:
        subset = dataset[dataset["surge_type"] == st]
        if len(subset) > 0:
            wr = subset["target_profitable"].mean() * 100
            avg_pnl = subset["target_pnl_pct"].mean()
            print(f"  {st:20s}: {len(subset):5d} surges | WR: {wr:.1f}% | Avg PnL: {avg_pnl:+.2f}%")

    # By exit reason
    print(f"\n  Exit reasons:")
    for reason in dataset["target_exit_reason"].value_counts().items():
        print(f"    {reason[0]:15s}: {reason[1]:5d}")

    # Date range
    dates = sorted(dataset["entry_time"].values)
    print(f"\n  Period: {str(dates[0])[:10]} → {str(dates[-1])[:10]}")

    return dataset


def main():
    parser = argparse.ArgumentParser(description="Build historical training dataset")
    parser.add_argument("--quick", action="store_true", help="Quick test with 20 symbols")
    parser.add_argument("--symbols", type=int, default=0, help="Max number of symbols (0=all)")
    args = parser.parse_args()

    print(f"\n{'═' * 60}")
    print(f"  🔍 Historical Surge Detection & Trade Simulation")
    print(f"{'═' * 60}")

    dataset = build_historical_dataset(max_symbols=args.symbols, quick=args.quick)

    if len(dataset) == 0:
        return

    # Save
    OUTPUT_FILE.parent.mkdir(parents=True, exist_ok=True)
    dataset.to_parquet(OUTPUT_FILE, index=False)
    print(f"\n  💾 Saved: {OUTPUT_FILE} ({len(dataset)} samples)")
    print(f"     Size: {OUTPUT_FILE.stat().st_size / 1e6:.1f} MB")

    # Also save a combined dataset (historical + real trades)
    real_dataset_path = DATA_DIR / "training_dataset.parquet"
    if real_dataset_path.exists():
        real = pd.read_parquet(real_dataset_path)
        # Mark source
        dataset["source"] = "historical"
        real["source"] = "real"
        combined = pd.concat([dataset, real], ignore_index=True)
        combined = combined.sort_values("entry_time").reset_index(drop=True)
        combined_path = DATA_DIR / "combined_training_dataset.parquet"
        combined.to_parquet(combined_path, index=False)
        print(f"  💾 Combined: {combined_path} ({len(combined)} samples)")
        print(f"     Historical: {len(dataset)} | Real: {len(real)}")


if __name__ == "__main__":
    main()
