"""
🧪 SPY BACKTESTER - Simule différents paramètres sur les trades récents
============================================================================
Télécharge les klines 1m réels de Binance pour chaque trade historique
et simule la stratégie de sortie avec différents jeux de paramètres.
"""

import json
import os
import time
import requests
import numpy as np
from datetime import datetime, timedelta
from itertools import product

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
HISTORY_FILE = os.path.join(SCRIPT_DIR, "espion_history.json")
KLINES_CACHE = os.path.join(SCRIPT_DIR, "spy_backtest_cache.json")

# ═══════════════════════════════════════════════════════════════════════
# PARAM GRID - Toutes les combinaisons à tester
# ═══════════════════════════════════════════════════════════════════════

PARAM_GRID = {
    # Trailing stop de base (PnL < 2%)
    'trail_base': [0.5, 0.8, 1.0, 1.2, 1.5],
    # Trailing élargi (PnL 2-5%)
    'trail_wide': [1.5, 2.0, 2.5, 3.0],
    # Trailing large (PnL 5-10%)
    'trail_large': [2.5, 3.0, 4.0],
    # Hard stop loss
    'hard_sl': [1.0, 1.5, 2.0, 2.5],
    # Trailing activation (PnL min pour activer le trailing)
    'trail_activation': [0.5, 0.8, 1.0],
    # Momentum exit: nombre de baisses consécutives
    'momentum_candles': [3, 4, 5],
    # Max hold (minutes)
    'max_hold': [30, 60, 120, 240],
}

# Paramètres ACTUELS pour comparaison
CURRENT_PARAMS = {
    'trail_base': 1.0,
    'trail_wide': 2.5,
    'trail_large': 3.0,
    'hard_sl': 1.5,
    'trail_activation': 1.0,
    'momentum_candles': 4,
    'max_hold': 120,
}

# Paramètres ANCIENS (avant nos modifications) pour comparaison
OLD_PARAMS = {
    'trail_base': 0.5,
    'trail_wide': 1.5,
    'trail_large': 2.5,
    'hard_sl': 1.5,
    'trail_activation': 0.8,
    'momentum_candles': 3,
    'max_hold': 45,
}


def fetch_klines_for_trade(symbol, entry_time_str, lookforward_minutes=300):
    """Télécharge les klines 1min depuis l'entrée du trade + X minutes après."""
    try:
        entry_dt = datetime.fromisoformat(entry_time_str)
        start_ms = int(entry_dt.timestamp() * 1000) - 60000  # 1 min avant
        end_ms = start_ms + (lookforward_minutes * 60 * 1000)
        
        url = "https://api.binance.com/api/v3/klines"
        all_klines = []
        current_start = start_ms
        
        while current_start < end_ms:
            params = {
                "symbol": symbol,
                "interval": "1m",
                "startTime": current_start,
                "endTime": min(current_start + 500 * 60000, end_ms),
                "limit": 500
            }
            r = requests.get(url, params=params, timeout=10)
            if r.status_code != 200:
                break
            klines = r.json()
            if not klines:
                break
            all_klines.extend(klines)
            current_start = klines[-1][0] + 60000
            time.sleep(0.15)  # Rate limit
        
        return all_klines
    except Exception as e:
        print(f"  ❌ Erreur fetch {symbol}: {e}")
        return []


def load_or_fetch_klines(trades):
    """Charge les klines depuis le cache ou les télécharge."""
    cache = {}
    if os.path.exists(KLINES_CACHE):
        try:
            with open(KLINES_CACHE, 'r') as f:
                cache = json.load(f)
        except:
            cache = {}
    
    updated = False
    for trade in trades:
        key = f"{trade['symbol']}_{trade['entry_time']}"
        if key not in cache:
            print(f"  📡 Téléchargement klines 1m {trade['symbol']} ({trade['entry_time'][:16]})...")
            klines = fetch_klines_for_trade(trade['symbol'], trade['entry_time'])
            if klines:
                cache[key] = klines
                updated = True
                print(f"     ✅ {len(klines)} klines récupérées")
            else:
                print(f"     ⚠️ Pas de klines")
    
    if updated:
        with open(KLINES_CACHE, 'w') as f:
            json.dump(cache, f)
        print(f"  💾 Cache sauvegardé ({len(cache)} trades)")
    
    return cache


def simulate_exit(klines, entry_price, entry_time_str, params):
    """
    Simule la stratégie de sortie sur des klines 1m réels.
    Retourne: (pnl_pct, max_pnl, hold_minutes, exit_reason)
    """
    entry_dt = datetime.fromisoformat(entry_time_str)
    entry_ms = int(entry_dt.timestamp() * 1000)
    
    max_pnl = 0.0
    trailing_activated = False
    consecutive_drops = 0
    prev_close = entry_price
    
    for kline in klines:
        kline_time_ms = kline[0]
        if kline_time_ms < entry_ms:
            continue
        
        close = float(kline[4])
        high = float(kline[2])
        low = float(kline[3])
        
        elapsed_min = (kline_time_ms - entry_ms) / 60000
        
        # PnL actuel (sur close et sur high/low de la bougie)
        pnl_close = ((close - entry_price) / entry_price) * 100
        pnl_high = ((high - entry_price) / entry_price) * 100
        pnl_low = ((low - entry_price) / entry_price) * 100
        
        # Mettre à jour max_pnl avec le high de la bougie
        max_pnl = max(max_pnl, pnl_high)
        
        # ── RULE 1: HARD STOP LOSS ──
        if pnl_low <= -params['hard_sl']:
            actual_exit = max(-params['hard_sl'], pnl_close)
            return actual_exit, max_pnl, elapsed_min, "HARD_SL"
        
        # ── RULE 2: TRAILING STOP (progressif) ──
        if max_pnl >= params['trail_activation']:
            trailing_activated = True
        
        if trailing_activated and max_pnl > 0:
            # Déterminer le trailing dynamique selon le PnL
            if max_pnl >= 15.0:
                trail_pct = 5.0
            elif max_pnl >= 10.0:
                trail_pct = params.get('trail_ultra', 4.0)
            elif max_pnl >= 5.0:
                trail_pct = params['trail_large']
            elif max_pnl >= 2.0:
                trail_pct = params['trail_wide']
            else:
                trail_pct = params['trail_base']
            
            drop_from_max = max_pnl - pnl_close
            if drop_from_max >= trail_pct:
                exit_pnl = max_pnl - trail_pct
                return exit_pnl, max_pnl, elapsed_min, f"TRAILING"
        
        # ── RULE 3: MAX HOLD ──
        if elapsed_min >= params['max_hold']:
            # Ne vendre que si PnL < 3% (sinon laisser le trailing gérer)
            if pnl_close < 3.0:
                return pnl_close, max_pnl, elapsed_min, "MAX_HOLD"
        
        # ── RULE 4: MOMENTUM EXIT ──
        if close < prev_close:
            consecutive_drops += 1
        else:
            consecutive_drops = 0
        
        if consecutive_drops >= params['momentum_candles']:
            # Ne vendre que si PnL < 5%
            if pnl_close < 5.0:
                return pnl_close, max_pnl, elapsed_min, "MOMENTUM_EXIT"
        
        prev_close = close
    
    # Fin des données: vendre au dernier prix
    last_close = float(klines[-1][4]) if klines else entry_price
    final_pnl = ((last_close - entry_price) / entry_price) * 100
    return final_pnl, max_pnl, len(klines), "END_OF_DATA"


def run_backtest(trades, klines_cache, params, label=""):
    """Exécute le backtest pour un jeu de paramètres."""
    results = []
    
    for trade in trades:
        key = f"{trade['symbol']}_{trade['entry_time']}"
        klines = klines_cache.get(key, [])
        if not klines:
            continue
        
        pnl, max_pnl, hold, reason = simulate_exit(
            klines, trade['entry_price'], trade['entry_time'], params
        )
        results.append({
            'symbol': trade['symbol'],
            'entry_time': trade['entry_time'],
            'surge_strength': trade.get('surge_strength', 0),
            'pnl_pct': round(pnl, 2),
            'max_pnl': round(max_pnl, 2),
            'hold_min': round(hold, 1),
            'exit_reason': reason,
            'missed_gain': round(max_pnl - pnl, 2),
            'original_pnl': trade.get('pnl_pct', 0),
        })
    
    return results


def score_results(results):
    """Score composite pour évaluer un jeu de paramètres."""
    if not results:
        return {'total_pnl': 0, 'avg_pnl': 0, 'win_rate': 0, 'score': -999}
    
    pnls = [r['pnl_pct'] for r in results]
    total_pnl = sum(pnls)
    wins = sum(1 for p in pnls if p > 0)
    losses = len(pnls) - wins
    avg_pnl = np.mean(pnls)
    max_drawdown = min(pnls) if pnls else 0
    missed = np.mean([r['missed_gain'] for r in results])
    
    # Score composite:
    # + PnL total (principal)
    # + Win rate bonus
    # - Pénalité pour gains manqués
    # - Pénalité pour drawdown extrême
    score = (
        total_pnl * 2.0          # PnL total compte double
        + (wins / len(pnls)) * 5  # Win rate bonus
        - missed * 0.5            # Pénalité gains manqués
        + max_drawdown * 0.3      # Pénalité drawdown (négatif)
    )
    
    return {
        'total_pnl': round(total_pnl, 2),
        'avg_pnl': round(avg_pnl, 2),
        'win_rate': round(wins / len(pnls) * 100, 1),
        'wins': wins,
        'losses': losses,
        'max_dd': round(max_drawdown, 2),
        'avg_missed': round(missed, 2),
        'score': round(score, 2),
        'n_trades': len(results),
    }


def smart_grid_search(trades, klines_cache):
    """
    Recherche intelligente: au lieu de tester TOUTES les combinaisons,
    on teste les paramètres un par un autour des valeurs actuelles.
    """
    print("\n" + "═" * 70)
    print("🔬 PHASE 1: Test des paramètres individuels (ancrage = params actuels)")
    print("═" * 70)
    
    base = CURRENT_PARAMS.copy()
    
    # Tester chaque paramètre individuellement
    param_impact = {}
    for param_name, values in PARAM_GRID.items():
        print(f"\n📊 Test de '{param_name}' ({len(values)} valeurs)...")
        best_score = -999
        best_val = base[param_name]
        
        for val in values:
            test_params = base.copy()
            test_params[param_name] = val
            results = run_backtest(trades, klines_cache, test_params)
            stats = score_results(results)
            
            marker = " 🏆" if stats['score'] > best_score else ""
            print(f"   {param_name}={val:>5} → PnL: {stats['total_pnl']:+.2f}% | "
                  f"WR: {stats['win_rate']:>5.1f}% | Missed: {stats['avg_missed']:.2f}% | "
                  f"Score: {stats['score']:>6.2f}{marker}")
            
            if stats['score'] > best_score:
                best_score = stats['score']
                best_val = val
        
        param_impact[param_name] = {'best_val': best_val, 'best_score': best_score}
        print(f"   ✅ Meilleur: {param_name}={best_val}")
    
    # Construire les meilleurs paramètres
    optimized = {k: v['best_val'] for k, v in param_impact.items()}
    
    return optimized, param_impact


def targeted_grid_search(trades, klines_cache, base_params):
    """
    PHASE 2: Recherche combinatoire ciblée autour des meilleurs paramètres individuels.
    On ne fait varier que les 3 paramètres les plus impactants.
    """
    print("\n" + "═" * 70)
    print("🎯 PHASE 2: Recherche combinatoire ciblée (top 3 paramètres)")
    print("═" * 70)
    
    # Les 3 paramètres les plus importants pour la sortie
    key_params = {
        'trail_base': [max(0.3, base_params['trail_base'] - 0.3), base_params['trail_base'], base_params['trail_base'] + 0.3],
        'trail_wide': [max(1.0, base_params['trail_wide'] - 0.5), base_params['trail_wide'], base_params['trail_wide'] + 0.5],
        'hard_sl': [max(0.8, base_params['hard_sl'] - 0.5), base_params['hard_sl'], base_params['hard_sl'] + 0.5],
    }
    
    best_score = -999
    best_combo = None
    combos = list(product(key_params['trail_base'], key_params['trail_wide'], key_params['hard_sl']))
    
    print(f"   Testing {len(combos)} combinaisons...\n")
    
    for tb, tw, sl in combos:
        test_params = base_params.copy()
        test_params['trail_base'] = round(tb, 1)
        test_params['trail_wide'] = round(tw, 1)
        test_params['hard_sl'] = round(sl, 1)
        
        results = run_backtest(trades, klines_cache, test_params)
        stats = score_results(results)
        
        if stats['score'] > best_score:
            best_score = stats['score']
            best_combo = test_params.copy()
            best_stats = stats
            print(f"   🏆 NEW BEST: trail={tb}/{tw} sl={sl} → PnL:{stats['total_pnl']:+.2f}% WR:{stats['win_rate']:.0f}% Score:{stats['score']:.2f}")
    
    return best_combo, best_stats


def main():
    print("=" * 70)
    print("🧪 SPY BACKTESTER - Optimisation des paramètres d'exit")
    print("=" * 70)
    
    # 1. Charger l'historique
    with open(HISTORY_FILE, 'r', encoding='utf-8') as f:
        trades = json.load(f)
    
    print(f"\n📋 {len(trades)} trades dans l'historique")
    
    # 2. Télécharger / charger les klines
    print("\n📡 Chargement des données de prix 1m (Binance)...")
    klines_cache = load_or_fetch_klines(trades)
    print(f"   ✅ Klines disponibles pour {len(klines_cache)} trades")
    
    # 3. Benchmark: paramètres ANCIENS (avant modifs)
    print("\n" + "═" * 70)
    print("📦 BENCHMARK: Paramètres ANCIENS (avant modifications)")
    print("═" * 70)
    old_results = run_backtest(trades, klines_cache, OLD_PARAMS)
    old_stats = score_results(old_results)
    print(f"   PnL Total: {old_stats['total_pnl']:+.2f}% | WR: {old_stats['win_rate']:.1f}% ({old_stats['wins']}W/{old_stats['losses']}L)")
    print(f"   PnL Moyen: {old_stats['avg_pnl']:+.2f}% | Max DD: {old_stats['max_dd']:.2f}% | Missed: {old_stats['avg_missed']:.2f}%")
    print(f"   Score: {old_stats['score']:.2f}")
    
    # 4. Benchmark: paramètres ACTUELS
    print("\n" + "═" * 70)
    print("📦 BENCHMARK: Paramètres ACTUELS (après nos modifications)")
    print("═" * 70)
    current_results = run_backtest(trades, klines_cache, CURRENT_PARAMS)
    current_stats = score_results(current_results)
    print(f"   PnL Total: {current_stats['total_pnl']:+.2f}% | WR: {current_stats['win_rate']:.1f}% ({current_stats['wins']}W/{current_stats['losses']}L)")
    print(f"   PnL Moyen: {current_stats['avg_pnl']:+.2f}% | Max DD: {current_stats['max_dd']:.2f}% | Missed: {current_stats['avg_missed']:.2f}%")
    print(f"   Score: {current_stats['score']:.2f}")
    
    # 5. Recherche intelligente
    optimized_p1, impacts = smart_grid_search(trades, klines_cache)
    
    # 6. Recherche combinatoire ciblée
    best_params, best_stats = targeted_grid_search(trades, klines_cache, optimized_p1)
    
    # 7. Résultat final détaillé
    print("\n" + "=" * 70)
    print("🏆 RÉSULTAT FINAL - Comparaison des 3 configurations")
    print("=" * 70)
    
    final_results = run_backtest(trades, klines_cache, best_params)
    final_stats = score_results(final_results)
    
    headers = f"{'':>20} | {'ANCIENS':>12} | {'ACTUELS':>12} | {'OPTIMISÉS':>12}"
    sep = "-" * 65
    print(headers)
    print(sep)
    print(f"{'PnL Total':>20} | {old_stats['total_pnl']:>+11.2f}% | {current_stats['total_pnl']:>+11.2f}% | {final_stats['total_pnl']:>+11.2f}%")
    print(f"{'Win Rate':>20} | {old_stats['win_rate']:>11.1f}% | {current_stats['win_rate']:>11.1f}% | {final_stats['win_rate']:>11.1f}%")
    print(f"{'PnL Moyen':>20} | {old_stats['avg_pnl']:>+11.2f}% | {current_stats['avg_pnl']:>+11.2f}% | {final_stats['avg_pnl']:>+11.2f}%")
    print(f"{'Max Drawdown':>20} | {old_stats['max_dd']:>+11.2f}% | {current_stats['max_dd']:>+11.2f}% | {final_stats['max_dd']:>+11.2f}%")
    print(f"{'Gains Manqués Moy':>20} | {old_stats['avg_missed']:>11.2f}% | {current_stats['avg_missed']:>11.2f}% | {final_stats['avg_missed']:>11.2f}%")
    print(f"{'Score':>20} | {old_stats['score']:>11.2f} | {current_stats['score']:>11.2f} | {final_stats['score']:>11.2f}")
    
    print(f"\n📊 Paramètres OPTIMISÉS recommandés:")
    print(sep)
    for k, v in sorted(best_params.items()):
        current_v = CURRENT_PARAMS.get(k, '?')
        diff = ""
        if isinstance(v, (int, float)) and isinstance(current_v, (int, float)):
            if v != current_v:
                diff = f" ← changé (était {current_v})"
        print(f"   {k:>20} = {v}{diff}")
    
    # 8. Détail par trade: comparaison
    print(f"\n📋 Détail par trade (OPTIMISÉ):")
    print(sep)
    print(f"{'Symbole':>14} | {'Surge':>6} | {'PnL Real':>9} | {'PnL Sim':>8} | {'MaxPnL':>7} | {'Missed':>7} | {'Hold':>6} | Exit")
    print(sep)
    for r in final_results:
        orig = r['original_pnl']
        sim = r['pnl_pct']
        better = "✅" if sim > orig else "❌" if sim < orig else "➖"
        print(f"{r['symbol']:>14} | {r['surge_strength']:>+5.1f}% | {orig:>+8.2f}% | {sim:>+7.2f}% | {r['max_pnl']:>+6.2f}% | {r['missed_gain']:>6.2f}% | {r['hold_min']:>5.0f}m | {r['exit_reason']} {better}")
    
    # 9. Sauvegarder les résultats
    output = {
        'timestamp': datetime.now().isoformat(),
        'trades_tested': len(trades),
        'old_params': OLD_PARAMS,
        'old_stats': old_stats,
        'current_params': CURRENT_PARAMS,
        'current_stats': current_stats,
        'optimized_params': best_params,
        'optimized_stats': final_stats,
        'detail_results': final_results,
    }
    output_file = os.path.join(SCRIPT_DIR, "spy_backtest_results.json")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(output, f, indent=2, ensure_ascii=False)
    print(f"\n💾 Résultats sauvegardés dans spy_backtest_results.json")
    
    return best_params, final_stats


if __name__ == "__main__":
    best_params, stats = main()
