#!/usr/bin/env python3
"""
CRASH-TEST DES CORRECTIONS
===========================
Compare les performances AVANT vs APRÈS les 4 corrections :

1.  Open candle fix  — la bougie en cours n'est plus incluse dans l'analyse IA
2.  Hot signals      — bot réagit en ~3s au lieu de ~80s (fin de cycle complet)
3.  Workers 12→20    — cycle IA ~40% plus rapide (~80s → ~50s)
4.  cache_ttl        — validation pre-achat sans appel réseau redondant

Méthode :
- Lit les 50 trades exécutés de l'archive Mars 18-21
- Pour chaque trade : récupère les klines 5m réelles depuis Binance
- Simule 2 scénarios : AVANT (entry +80s) et APRÈS (entry +3s avec hot signal)
- Applique même SL/TP que le bot réel (1.5% SL, 2.5% TP, 180min MAX_HOLD)
- Rapport détaillé des gains/pertes et écarts de prix
"""

import json
import time
import os
import sys
import statistics
from datetime import datetime, timedelta
from collections import defaultdict

# ── Paramètres ──────────────────────────────────────────────────────────────
ARCHIVE = "trade_logs/archives_reset/dashboard_reset_20260321_105334"
SL_PCT  = 1.5   # Stop-loss %
TP_PCT  = 2.5   # Take-profit %
MAX_HOLD_MIN = 180
DELAY_OLD_S  = 80   # Délai moyen AVANT hot signals (fin de cycle AI)
DELAY_NEW_S  = 3    # Délai APRÈS hot signals (callback immédiat)

# ── Chargement des données réelles ─────────────────────────────────────────

def load_executed_signals(archive_dir: str):
    """Retourne les 50 signaux exécutés avec leurs métadonnées."""
    signals_file = os.path.join(archive_dir, "signals_log.jsonl")
    executed = []
    with open(signals_file, encoding="utf-8") as f:
        for line in f:
            try:
                s = json.loads(line.strip())
                if s.get("executed"):
                    executed.append(s)
            except Exception:
                pass
    return executed


def load_trade_details(archive_dir: str):
    """Retourne les trades réels (OPEN + CLOSE) indexés par symbole."""
    trades_file = os.path.join(archive_dir, "trades_log.jsonl")
    opens, closes = [], []
    with open(trades_file, encoding="utf-8") as f:
        for line in f:
            try:
                t = json.loads(line.strip())
                if t["type"] == "TRADE_OPEN":
                    opens.append(t)
                elif t["type"] == "TRADE_CLOSE":
                    closes.append(t)
            except Exception:
                pass
    return opens, closes


def get_klines_rest(symbol: str, interval: str, start_ms: int, limit: int = 150):
    """Appel REST Binance public pour récupérer les klines historiques."""
    try:
        import urllib.request
        import urllib.parse
        params = urllib.parse.urlencode({
            "symbol": symbol,
            "interval": interval,
            "startTime": start_ms,
            "limit": limit,
        })
        url = f"https://api.binance.com/api/v3/klines?{params}"
        req = urllib.request.Request(url, headers={"User-Agent": "crash-test/1.0"})
        with urllib.request.urlopen(req, timeout=10) as resp:
            return json.loads(resp.read())
    except Exception as e:
        return None


# ── Simulation d'un trade ────────────────────────────────────────────────────

def simulate_trade(klines, entry_ts_ms: int, sl_pct: float, tp_pct: float,
                   max_hold_min: int = 180, label: str = ""):
    """
    Simule un trade à partir de klines 5m.
    entry_ts_ms : timestamp Unix (ms) du moment d'entrée souhaité
    Retourne : (pnl_pct, exit_reason, duration_min, entry_price, exit_price)
    """
    if not klines:
        return None

    # Trouver la kline d'entrée (ouverture de la bougie contenant entry_ts_ms)
    entry_kline_idx = None
    for i, k in enumerate(klines):
        open_ts  = int(k[0])
        close_ts = int(k[6])
        if open_ts <= entry_ts_ms <= close_ts:
            entry_kline_idx = i
            break

    if entry_kline_idx is None:
        # Prendre la kline la plus proche (future)
        for i, k in enumerate(klines):
            if int(k[0]) >= entry_ts_ms:
                entry_kline_idx = i
                break

    if entry_kline_idx is None:
        return None

    entry_price = float(klines[entry_kline_idx][4])  # Close de la bougie d'entrée
    sl_price = entry_price * (1 - sl_pct / 100)
    tp_price = entry_price * (1 + tp_pct / 100)
    entry_time_ms = int(klines[entry_kline_idx][0])

    # Parcourir les bougies suivantes
    for k in klines[entry_kline_idx + 1:]:
        open_ts    = int(k[0])
        low_price  = float(k[3])
        high_price = float(k[2])
        close_price = float(k[4])
        duration_min = (open_ts - entry_time_ms) / 60000

        if duration_min >= max_hold_min:
            pnl_pct = (close_price - entry_price) / entry_price * 100
            return pnl_pct, f"MAX_HOLD_{max_hold_min}min", duration_min, entry_price, close_price

        if low_price <= sl_price:
            pnl_pct = (sl_price - entry_price) / entry_price * 100
            return pnl_pct, "STOP_LOSS", duration_min, entry_price, sl_price

        if high_price >= tp_price:
            pnl_pct = (tp_price - entry_price) / entry_price * 100
            return pnl_pct, "TAKE_PROFIT", duration_min, entry_price, tp_price

    # Fin des données
    last_close = float(klines[-1][4])
    pnl_pct = (last_close - entry_price) / entry_price * 100
    return pnl_pct, "FIN_DONNEES", (int(klines[-1][0]) - entry_time_ms) / 60000, entry_price, last_close


# ── Rapport ─────────────────────────────────────────────────────────────────

def print_stats(label: str, results: list):
    if not results:
        print(f"\n{label}: aucun résultat")
        return

    pnls = [r["pnl_pct"] for r in results if r["pnl_pct"] is not None]
    wins = [p for p in pnls if p > 0]
    losses = [p for p in pnls if p <= 0]

    print(f"\n{'='*60}")
    print(f"  {label}")
    print(f"{'='*60}")
    print(f"  Trades simulés  : {len(results)}")
    print(f"  Trades valides  : {len(pnls)}")
    print(f"  Wins            : {len(wins)} ({len(wins)/max(len(pnls),1)*100:.1f}%)")
    print(f"  Losses          : {len(losses)} ({len(losses)/max(len(pnls),1)*100:.1f}%)")
    print(f"  P&L moyen       : {sum(pnls)/max(len(pnls),1):.3f}%")
    print(f"  P&L total       : {sum(pnls):.3f}% (cumulé sur {len(pnls)} trades)")
    print(f"  P&L max win     : {max(pnls, default=0):.3f}%")
    print(f"  P&L max loss    : {min(pnls, default=0):.3f}%")
    if len(pnls) > 1:
        print(f"  P&L std dev     : {statistics.stdev(pnls):.3f}%")
    durations = [r.get("duration_min", 0) for r in results if r.get("duration_min")]
    if durations:
        print(f"  Durée moy       : {sum(durations)/len(durations):.1f} min")
    exits = defaultdict(int)
    for r in results:
        exits[r.get("exit_reason", "?")] += 1
    print(f"  Sorties         : {dict(exits)}")


# ── Main ─────────────────────────────────────────────────────────────────────

def main():
    print("=" * 60)
    print("  CRASH-TEST CORRECTIONS — Mars 18-21 2026")
    print("=" * 60)

    # 1. Charger les données
    print("\n📂 Chargement des données…")
    signals = load_executed_signals(ARCHIVE)
    opens, closes = load_trade_details(ARCHIVE)
    print(f"  Signaux exécutés : {len(signals)}")
    print(f"  Trades ouverts   : {len(opens)}")
    print(f"  Trades fermés    : {len(closes)}")

    results_old = []   # AVANT corrections
    results_new = []   # APRÈS corrections
    results_real = []  # Réalité (trades réels clôturés)
    errors       = []

    # Infos sur les vrais trades fermés
    closes_by_sym = defaultdict(list)
    for t in closes:
        closes_by_sym[t["symbol"]].append(t)

    print(f"\n🔄 Simulation sur {len(signals)} signaux (klines Binance en direct)…")
    print("  (Chaque symbole = 1 appel REST Binance)\n")

    processed = 0
    for sig in signals:
        symbol    = sig["symbol"]
        ts_str    = sig["timestamp"]   # "2026-03-18 18:06:40.162044"
        score     = sig.get("ai_score", 0)
        pattern   = sig.get("pattern", "?")

        # Parser timestamp → ms
        try:
            dt = datetime.strptime(ts_str[:19], "%Y-%m-%d %H:%M:%S")
            signal_ts_ms = int(dt.timestamp() * 1000)
        except Exception as e:
            errors.append(f"{symbol}: timestamp invalide ({e})")
            continue

        # Bornes pour les klines (signal - 5min, signal + 200min)
        start_ms = signal_ts_ms - 5 * 60 * 1000
        klines = get_klines_rest(symbol, "5m", start_ms, limit=50)

        if not klines:
            errors.append(f"{symbol}: pas de klines")
            time.sleep(0.3)
            continue

        # Vérifier la cohérence (au moins 1 bougie après le signal)
        if len(klines) < 3:
            errors.append(f"{symbol}: trop peu de klines ({len(klines)})")
            continue

        # AVANT correction : entry = signal + 80s
        ts_old = signal_ts_ms + DELAY_OLD_S * 1000
        res_old = simulate_trade(klines, ts_old, SL_PCT, TP_PCT, MAX_HOLD_MIN)

        # APRÈS correction : entry = signal + 3s (hot signal)
        ts_new = signal_ts_ms + DELAY_NEW_S * 1000
        res_new = simulate_trade(klines, ts_new, SL_PCT, TP_PCT, MAX_HOLD_MIN)

        if res_old:
            pnl, reason, dur, ep, xp = res_old
            results_old.append({
                "symbol": symbol, "score": score, "pattern": pattern,
                "pnl_pct": pnl, "exit_reason": reason, "duration_min": dur,
                "entry_price": ep, "exit_price": xp,
            })

        if res_new:
            pnl, reason, dur, ep, xp = res_new
            results_new.append({
                "symbol": symbol, "score": score, "pattern": pattern,
                "pnl_pct": pnl, "exit_reason": reason, "duration_min": dur,
                "entry_price": ep, "exit_price": xp,
            })

        processed += 1
        status = ""
        if res_old and res_new:
            diff = res_new[0] - res_old[0]
            status = f" Δ={diff:+.3f}%"
        print(f"  [{processed:2d}/{len(signals)}] {symbol:<18} score={score:<5}{status}")

        # Rate-limit Binance (max ~20 req/s sans clé)
        time.sleep(0.12)

    # Vrais résultats
    for t in closes:
        results_real.append({
            "symbol": t["symbol"],
            "pnl_pct": t.get("pnl_pct", 0),
            "exit_reason": t.get("reason", "?")[:40],
            "duration_min": t.get("duration_seconds", 0) / 60,
        })

    # ── Rapport ──────────────────────────────────────────────────────────────
    print("\n\n" + "=" * 60)
    print("  RÉSULTATS DU CRASH-TEST")
    print("=" * 60)

    print_stats("📌 RÉALITÉ BOT (trades réels clôturés)", results_real)
    print_stats(f"⏱  AVANT corrections (entry +{DELAY_OLD_S}s après signal)", results_old)
    print_stats(f"🚀 APRÈS corrections (hot signal, entry +{DELAY_NEW_S}s)", results_new)

    # ── Comparaison directe par trade ───────────────────────────────────────
    if results_old and results_new:
        print(f"\n{'='*60}")
        print("  IMPACT PRIX : Arrivée en avance de 77s")
        print(f"{'='*60}")
        paired = 0
        price_improvement = []
        for old, new in zip(results_old, results_new):
            if old["entry_price"] and new["entry_price"] and old["symbol"] == new["symbol"]:
                diff_entry = (new["entry_price"] - old["entry_price"]) / old["entry_price"] * 100
                price_improvement.append(diff_entry)
                paired += 1

        if price_improvement:
            avg_entry_diff = sum(price_improvement) / len(price_improvement)
            better_entry = sum(1 for d in price_improvement if d <= 0)
            print(f"  Trades comparés     : {paired}")
            print(f"  Entrée moins chère  : {better_entry}/{paired} ({better_entry/paired*100:.0f}%)")
            print(f"  Δ prix entrée moy   : {avg_entry_diff:+.4f}%")
            print(f"  (négatif = mieux, on entre avant la hausse)")

    # ── Erreurs ─────────────────────────────────────────────────────────────
    if errors:
        print(f"\n⚠️  {len(errors)} erreurs:")
        for e in errors[:10]:
            print(f"   • {e}")

    # ── Synthèse finale ─────────────────────────────────────────────────────
    pnls_old = [r["pnl_pct"] for r in results_old]
    pnls_new = [r["pnl_pct"] for r in results_new]
    pnls_real = [r["pnl_pct"] for r in results_real]

    print(f"\n{'='*60}")
    print("  SYNTHÈSE — IMPACT DES CORRECTIONS")
    print(f"{'='*60}")

    if pnls_real:
        print(f"  Réalité bot   : P&L total = {sum(pnls_real):.3f}%  |  WR = {sum(1 for p in pnls_real if p>0)/len(pnls_real)*100:.1f}%")
    if pnls_old:
        wr_old = sum(1 for p in pnls_old if p > 0) / len(pnls_old) * 100
        print(f"  AVANT fix     : P&L total = {sum(pnls_old):.3f}%  |  WR = {wr_old:.1f}%")
    if pnls_new:
        wr_new = sum(1 for p in pnls_new if p > 0) / len(pnls_new) * 100
        print(f"  APRÈS fix     : P&L total = {sum(pnls_new):.3f}%  |  WR = {wr_new:.1f}%")
    if pnls_old and pnls_new:
        delta_pnl = sum(pnls_new) - sum(pnls_old)
        delta_wr  = wr_new - wr_old
        print(f"\n  🎯 Gain hot signal : Δ P&L total = {delta_pnl:+.3f}%  |  Δ WR = {delta_wr:+.1f}pp")
        if delta_pnl > 0:
            print("  ✅ Les corrections améliorent la rentabilité")
        else:
            print("  ⚠️  Amélioration non visible sur ce marché baissier (entrée rapide ≠ meilleur prix si trend négatif)")

    print(f"\n📋 Note : ce crash-test simule uniquement l'impact du délai d'entrée.")
    print(f"   L'open candle fix améliore la QUALITÉ des signaux (moins de faux signaux)")
    print(f"   Les workers 12→20 accélèrent le cycle IA (~80s→~50s, impact supplémentaire)")
    print(f"   Ces 2 effets ne sont pas simulables sans relancer l'IA complète.\n")

    # Sauvegarder JSON
    output = {
        "generated_at": datetime.now().isoformat(),
        "archive": ARCHIVE,
        "config": {"sl_pct": SL_PCT, "tp_pct": TP_PCT, "max_hold_min": MAX_HOLD_MIN,
                   "delay_old_s": DELAY_OLD_S, "delay_new_s": DELAY_NEW_S},
        "results_real": results_real,
        "results_old": results_old,
        "results_new": results_new,
        "summary": {
            "pnl_total_real": sum(pnls_real),
            "pnl_total_old":  sum(pnls_old),
            "pnl_total_new":  sum(pnls_new),
            "wr_old": wr_old if pnls_old else 0,
            "wr_new": wr_new if pnls_new else 0,
        }
    }
    out_file = "crash_test_results.json"
    with open(out_file, "w", encoding="utf-8") as f:
        json.dump(output, f, ensure_ascii=False, indent=2)
    print(f"💾 Résultats complets sauvegardés → {out_file}\n")


if __name__ == "__main__":
    main()
