#!/usr/bin/env python3
"""
=============================================================================
TRADE COMPRESSION AUDIT — Analyse rétrospective du filtre de compression
=============================================================================

OBJECTIF
--------
Pour chaque trade historique du bot, calculer la "compression de prix"
qui existait au moment exact de l'entrée. Comparer ensuite la performance
des trades selon le niveau de compression — pour valider (ou invalider)
l'hypothèse : "les trades en compression > 1.0% surperforment".

CE QUE FAIT LE SCRIPT
---------------------
1. Charge les trades historiques (auto-détection du format)
2. Pour chaque trade, télécharge les klines 5m de la période [entrée -30min, entrée]
3. Calcule la compression price_compression_pct (range moyen sur 6 klines)
4. Stocke ce score à côté du PnL réel observé
5. Produit un rapport stratifié par bucket de compression :
   - Bucket [0, 0.3] : trades à compression faible
   - Bucket [0.3, 0.7] : compression modérée
   - Bucket [0.7, 1.0] : compression élevée
   - Bucket [1.0, 1.5] : compression forte
   - Bucket [1.5+] : compression extrême

RÉSULTAT ATTENDU
----------------
Si l'hypothèse est correcte, on devrait observer :
- Win rate qui MONTE avec la compression
- PnL moyen qui MONTE avec la compression
- Bucket [1.0+] significativement plus rentable que [0, 0.5]

USAGE
-----
    python3 audit_compression.py --trades /path/to/trades.csv
    python3 audit_compression.py --trades /path/to/trades.json --output ./report
    python3 audit_compression.py --trades trades.csv --max-trades 100  # test rapide

DÉPENDANCES
-----------
    pip install requests pandas numpy pyarrow tabulate scipy
=============================================================================
"""

import argparse
import json
import logging
import sys
import time
from datetime import datetime, timedelta, timezone
from pathlib import Path

try:
    import numpy as np
    import pandas as pd
    import requests
    from scipy.stats import mannwhitneyu
except ImportError as e:
    print(f"ERREUR : dépendance manquante ({e})")
    print("Installe avec : pip install requests pandas numpy scipy pyarrow tabulate")
    sys.exit(1)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
log = logging.getLogger("audit")


# ---------------------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------------------

BINANCE_BASE = "https://api.binance.com"   # ou "https://data-api.binance.vision"

PRE_WINDOW_KLINES = 6      # 6 × 5min = 30 min avant l'entrée
KLINE_INTERVAL = "5m"
HTTP_TIMEOUT = 15

COMPRESSION_BUCKETS = [
    (0.0, 0.3, "très faible"),
    (0.3, 0.7, "faible"),
    (0.7, 1.0, "modérée"),
    (1.0, 1.5, "élevée"),
    (1.5, 99.0, "extrême"),
]


# ---------------------------------------------------------------------------
# 1. CHARGEMENT DES TRADES — auto-détection du format
# ---------------------------------------------------------------------------

REQUIRED_FIELDS = {
    "symbol": ["symbol", "coin", "pair", "asset"],
    "entry_time": ["entry_time", "entry_timestamp", "open_time", "buy_time", "timestamp", "time"],
    "pnl_pct": ["pnl_pct", "pnl_percent", "profit_pct", "return_pct", "pnl"],
}


def load_trades(path: Path) -> pd.DataFrame:
    log.info(f"Chargement des trades depuis : {path}")

    if path.suffix.lower() == ".csv":
        df = pd.read_csv(path)
    elif path.suffix.lower() in (".json", ".jsonl"):
        try:
            df = pd.read_json(path)
        except ValueError:
            df = pd.read_json(path, lines=True)
    elif path.suffix.lower() in (".parquet", ".pq"):
        df = pd.read_parquet(path)
    else:
        raise ValueError(f"Format non reconnu : {path.suffix}")

    log.info(f"Trades bruts chargés : {len(df)}")
    log.info(f"Colonnes disponibles : {list(df.columns)}")

    mapping = {}
    for canonical, aliases in REQUIRED_FIELDS.items():
        found = None
        for alias in aliases:
            for col in df.columns:
                if col.lower() == alias.lower():
                    found = col
                    break
            if found:
                break
        if not found:
            raise ValueError(
                f"Colonne '{canonical}' introuvable. "
                f"Aliases tentés : {aliases}. "
                f"Colonnes du fichier : {list(df.columns)}."
            )
        mapping[found] = canonical

    df = df.rename(columns=mapping)

    df["entry_time"] = pd.to_datetime(df["entry_time"], utc=True, errors="coerce")
    invalid = df["entry_time"].isna().sum()
    if invalid > 0:
        log.warning(f"{invalid} trades avec entry_time invalide → exclus")
        df = df.dropna(subset=["entry_time"])

    df["pnl_pct"] = pd.to_numeric(df["pnl_pct"], errors="coerce")
    df = df.dropna(subset=["pnl_pct"])

    df["symbol"] = df["symbol"].astype(str).str.upper()
    df["symbol"] = df["symbol"].apply(
        lambda s: s if s.endswith("USDT") else (s[:-4] + "USDT" if s.endswith("USDC") else f"{s}USDT")
    )

    log.info(f"Trades exploitables après nettoyage : {len(df)}")
    log.info(f"Période couverte : {df['entry_time'].min()} → {df['entry_time'].max()}")
    log.info(f"Win rate brut : {(df['pnl_pct'] > 0).mean()*100:.1f}%")
    log.info(f"PnL moyen brut : {df['pnl_pct'].mean():.3f}%")

    return df.reset_index(drop=True)


# ---------------------------------------------------------------------------
# 2. RÉCUPÉRATION DES KLINES
# ---------------------------------------------------------------------------

def fetch_klines_around(symbol: str, entry_time: datetime,
                         pre_minutes: int = 35) -> pd.DataFrame:
    end_ts = int(entry_time.timestamp() * 1000)
    start_ts = int((entry_time - timedelta(minutes=pre_minutes)).timestamp() * 1000)

    try:
        r = requests.get(
            f"{BINANCE_BASE}/api/v3/klines",
            params={
                "symbol": symbol,
                "interval": KLINE_INTERVAL,
                "startTime": start_ts,
                "endTime": end_ts,
                "limit": 50,
            },
            timeout=HTTP_TIMEOUT,
        )
        r.raise_for_status()
        data = r.json()
    except Exception as e:
        log.warning(f"  Klines KO {symbol} @ {entry_time}: {e}")
        return pd.DataFrame()

    if not data:
        return pd.DataFrame()

    cols = ["open_time", "open", "high", "low", "close", "volume", "close_time",
            "quote_volume", "trades", "taker_buy_base", "taker_buy_quote", "ignore"]
    df = pd.DataFrame(data, columns=cols)
    for c in ["open", "high", "low", "close", "volume", "quote_volume",
              "taker_buy_base", "taker_buy_quote"]:
        df[c] = df[c].astype(float)
    df["open_time"] = pd.to_datetime(df["open_time"], unit="ms", utc=True)
    return df


def compute_compression(klines: pd.DataFrame) -> dict:
    if len(klines) < PRE_WINDOW_KLINES:
        return {}

    pre = klines.iloc[-PRE_WINDOW_KLINES:]
    range_pct = ((pre["high"] - pre["low"]) / pre["close"]).mean() * 100
    spread_max = ((pre["high"] - pre["low"]) / pre["close"]).max() * 100

    return {
        "compression_pct": float(range_pct),
        "spread_max_pct": float(spread_max),
        "n_klines_used": len(pre),
    }


# ---------------------------------------------------------------------------
# 3. ENRICHISSEMENT DES TRADES
# ---------------------------------------------------------------------------

def enrich_trades(trades: pd.DataFrame, max_trades: int = None,
                  cache_dir: Path = None) -> pd.DataFrame:
    if max_trades:
        trades = trades.head(max_trades).copy()
        log.info(f"Mode test : limité à {max_trades} trades")

    cache = {}
    cache_path = None
    if cache_dir:
        cache_dir.mkdir(parents=True, exist_ok=True)
        cache_path = cache_dir / "klines_cache.json"
        if cache_path.exists():
            with open(cache_path) as f:
                cache = json.load(f)
            log.info(f"Cache klines : {len(cache)} entrées chargées")

    enriched_rows = []
    n_total = len(trades)
    n_ok = 0
    n_fail = 0

    for i, row in trades.iterrows():
        cache_key = f"{row['symbol']}_{int(row['entry_time'].timestamp())}"

        if cache_key in cache:
            features = cache[cache_key]
        else:
            klines = fetch_klines_around(row["symbol"], row["entry_time"])
            features = compute_compression(klines) if not klines.empty else {}
            cache[cache_key] = features
            time.sleep(0.05)

            if cache_path and (i + 1) % 50 == 0:
                with open(cache_path, "w") as f:
                    json.dump(cache, f)

        if features:
            n_ok += 1
        else:
            n_fail += 1

        merged = row.to_dict()
        merged.update(features)
        enriched_rows.append(merged)

        if (i + 1) % 50 == 0 or (i + 1) == n_total:
            log.info(f"  Progression : {i+1}/{n_total} ({n_ok} OK, {n_fail} fails)")

    if cache_path:
        with open(cache_path, "w") as f:
            json.dump(cache, f)

    df = pd.DataFrame(enriched_rows)
    log.info(f"Trades enrichis : {n_ok}/{n_total} avec compression calculable")

    if n_ok == 0:
        log.error("AUCUN TRADE N'A PU ÊTRE ENRICHI")
        log.error(f"→ Tester : curl {BINANCE_BASE}/api/v3/ping")
        log.error("  Si bloqué, change BINANCE_BASE = 'https://data-api.binance.vision'")

    return df


# ---------------------------------------------------------------------------
# 4. ANALYSE STRATIFIÉE
# ---------------------------------------------------------------------------

def stratify_by_compression(df: pd.DataFrame) -> pd.DataFrame:
    valid = df.dropna(subset=["compression_pct"]).copy()
    log.info(f"Trades avec compression calculable : {len(valid)} / {len(df)}")

    rows = []
    for lo, hi, label in COMPRESSION_BUCKETS:
        mask = (valid["compression_pct"] >= lo) & (valid["compression_pct"] < hi)
        bucket = valid[mask]
        if len(bucket) == 0:
            rows.append({
                "bucket": f"[{lo:.1f} – {hi:.1f}[",
                "label": label,
                "n_trades": 0,
                "win_rate": 0.0,
                "pnl_mean": 0.0,
                "pnl_median": 0.0,
                "pnl_total": 0.0,
            })
            continue
        wr = (bucket["pnl_pct"] > 0).mean() * 100
        rows.append({
            "bucket": f"[{lo:.1f} – {hi:.1f}[",
            "label": label,
            "n_trades": len(bucket),
            "win_rate": wr,
            "pnl_mean": bucket["pnl_pct"].mean(),
            "pnl_median": bucket["pnl_pct"].median(),
            "pnl_total": bucket["pnl_pct"].sum(),
        })

    return pd.DataFrame(rows)


def split_test_compression_threshold(df: pd.DataFrame, threshold: float = 1.0) -> dict:
    valid = df.dropna(subset=["compression_pct"])
    high = valid[valid["compression_pct"] >= threshold]
    low = valid[valid["compression_pct"] < threshold]

    if len(high) < 10 or len(low) < 10:
        return {"error": "échantillons trop petits"}

    try:
        stat, pvalue = mannwhitneyu(
            high["pnl_pct"].values,
            low["pnl_pct"].values,
            alternative="greater",
        )
    except ValueError:
        pvalue = 1.0

    h_arr = high["pnl_pct"].values
    l_arr = low["pnl_pct"].values
    gt = (h_arr[:, None] > l_arr[None, :]).sum()
    lt = (h_arr[:, None] < l_arr[None, :]).sum()
    cliff = (gt - lt) / (len(h_arr) * len(l_arr))

    return {
        "threshold": threshold,
        "n_high": len(high),
        "n_low": len(low),
        "winrate_high": (high["pnl_pct"] > 0).mean() * 100,
        "winrate_low": (low["pnl_pct"] > 0).mean() * 100,
        "pnl_mean_high": high["pnl_pct"].mean(),
        "pnl_mean_low": low["pnl_pct"].mean(),
        "pnl_total_high": high["pnl_pct"].sum(),
        "pnl_total_low": low["pnl_pct"].sum(),
        "p_value_one_sided": pvalue,
        "cliff_delta": cliff,
    }


# ---------------------------------------------------------------------------
# 5. RAPPORT
# ---------------------------------------------------------------------------

def generate_report(enriched: pd.DataFrame, output_dir: Path) -> str:
    output_dir.mkdir(parents=True, exist_ok=True)

    valid = enriched.dropna(subset=["compression_pct"]) if "compression_pct" in enriched.columns else pd.DataFrame()
    n_total = len(enriched)
    n_valid = len(valid)

    if n_valid == 0:
        msg = (
            "AUCUN TRADE EXPLOITABLE\n"
            f"{n_total} trades chargés mais aucune compression calculable.\n"
            "Vérifie les logs ci-dessus pour la cause (API bloquée, symboles invalides...)."
        )
        (output_dir / "report.txt").write_text(msg)
        return msg

    report = []
    report.append("=" * 78)
    report.append("RAPPORT D'AUDIT — IMPACT DE LA COMPRESSION SUR LES TRADES")
    report.append("=" * 78)
    report.append(f"Date du rapport : {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}")
    report.append(f"Trades analysés : {n_total} ({n_valid} avec compression calculable)")
    report.append(f"Win rate global : {(valid['pnl_pct'] > 0).mean()*100:.1f}%")
    report.append(f"PnL moyen      : {valid['pnl_pct'].mean():.3f}%")
    report.append(f"PnL total      : {valid['pnl_pct'].sum():.1f}%")
    report.append("")

    report.append("-" * 78)
    report.append("STRATIFICATION PAR BUCKET DE COMPRESSION")
    report.append("-" * 78)
    strat = stratify_by_compression(enriched)
    report.append(f"{'Bucket':<14}{'Label':<14}{'N':>8}{'Win rate':>12}{'PnL moy':>12}{'PnL méd':>12}{'PnL tot':>12}")
    report.append("-" * 78)
    for _, r in strat.iterrows():
        report.append(
            f"{r['bucket']:<14}{r['label']:<14}{r['n_trades']:>8}"
            f"{r['win_rate']:>11.1f}%{r['pnl_mean']:>11.3f}%"
            f"{r['pnl_median']:>11.3f}%{r['pnl_total']:>11.1f}%"
        )

    report.append("")
    report.append("-" * 78)
    report.append("TEST STATISTIQUE — TRADES HAUTE COMPRESSION vs BASSE COMPRESSION")
    report.append("-" * 78)

    for thr in [0.5, 0.7, 1.0, 1.5]:
        r = split_test_compression_threshold(enriched, thr)
        if "error" in r:
            report.append(f"\nSeuil {thr}: {r['error']}")
            continue
        report.append(f"\nSeuil compression >= {thr}%")
        report.append(f"  Trades haute compression : {r['n_high']:>5}  | WR {r['winrate_high']:>5.1f}%  | PnL moy {r['pnl_mean_high']:+.3f}%")
        report.append(f"  Trades basse compression : {r['n_low']:>5}  | WR {r['winrate_low']:>5.1f}%  | PnL moy {r['pnl_mean_low']:+.3f}%")
        report.append(f"  Différence PnL/trade     : {r['pnl_mean_high'] - r['pnl_mean_low']:+.3f}%")
        report.append(f"  p-value (high > low)     : {r['p_value_one_sided']:.4f}")
        report.append(f"  Cliff's delta            : {r['cliff_delta']:+.3f}")

        if r["p_value_one_sided"] < 0.05 and r["cliff_delta"] > 0.1:
            report.append(f"  ✓ VERDICT : la compression >= {thr}% est associée à de meilleurs trades")
        elif r["p_value_one_sided"] < 0.05:
            report.append(f"  ~ VERDICT : différence significative mais effet faible")
        else:
            report.append(f"  ✗ VERDICT : pas de différence significative")

    report.append("")
    report.append("-" * 78)
    report.append("SIMULATION : QUE SE SERAIT-IL PASSÉ AVEC FILTRE COMPRESSION >= 1.0% ?")
    report.append("-" * 78)
    actual_total = valid["pnl_pct"].sum()
    with_filter = valid[valid["compression_pct"] >= 1.0]
    filtered_total = with_filter["pnl_pct"].sum()
    excluded = valid[valid["compression_pct"] < 1.0]
    excluded_total = excluded["pnl_pct"].sum()

    report.append(f"Réel observé          : {len(valid)} trades, PnL total {actual_total:+.1f}%")
    report.append(f"Avec filtre 1.0%      : {len(with_filter)} trades, PnL total {filtered_total:+.1f}%")
    report.append(f"Trades exclus (<1.0%) : {len(excluded)} trades, PnL total {excluded_total:+.1f}%")
    report.append(f"")
    report.append(f"Si on avait appliqué le filtre :")
    report.append(f"  - {len(excluded)} trades évités")
    report.append(f"  - Gain net (PnL exclus inversé) : {-excluded_total:+.1f}% de PnL")

    text_report = "\n".join(report)
    (output_dir / "report.txt").write_text(text_report)
    enriched.to_parquet(output_dir / "trades_enriched.parquet")
    strat.to_csv(output_dir / "stratification.csv", index=False)

    log.info(f"\nRapport sauvegardé dans : {output_dir}")
    return text_report


# ---------------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Audit compression sur trades historiques")
    parser.add_argument("--trades", required=True, help="Fichier CSV/JSON/Parquet des trades")
    parser.add_argument("--output", default="./audit_output", help="Dossier de sortie")
    parser.add_argument("--max-trades", type=int, default=None, help="Limiter pour test rapide")
    parser.add_argument("--no-cache", action="store_true", help="Désactive le cache klines")
    args = parser.parse_args()

    trades_path = Path(args.trades)
    if not trades_path.exists():
        log.error(f"Fichier introuvable : {trades_path}")
        sys.exit(1)

    output_dir = Path(args.output)
    cache_dir = output_dir if not args.no_cache else None

    log.info("=" * 60)
    log.info("DÉMARRAGE DE L'AUDIT")
    log.info("=" * 60)

    trades = load_trades(trades_path)
    enriched = enrich_trades(trades, max_trades=args.max_trades, cache_dir=cache_dir)
    text = generate_report(enriched, output_dir)
    print("\n" + text)


if __name__ == "__main__":
    main()
