"""
Entry Filter Trainer — Filtre d'entrée IA pour market_spy.py
=============================================================
Analyse les patterns qui causent les pertes, entraîne un filtre XGBoost/LightGBM
sur les données de spy_optimizer, et génère des règles concrètes.

Usage (CPU, sur le serveur Ubuntu):
    python3 entry_filter_trainer.py

Outputs:
    - spy_optimizer/models/entry_filter.pkl    → modèle sérialisé
    - spy_optimizer/models/entry_filter_rules.json → règles + seuils
"""

import json
import pickle
import sys
import warnings
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ── Chemins ────────────────────────────────────────────────────────────────────
ROOT_DIR    = Path(__file__).parent
SPY_DIR     = ROOT_DIR / "spy_optimizer"
DATA_DIR    = SPY_DIR / "data"
MODELS_DIR  = SPY_DIR / "models"
MODELS_DIR.mkdir(exist_ok=True)

PARQUET     = DATA_DIR / "combined_training_dataset.parquet"

# ── Config ─────────────────────────────────────────────────────────────────────
REAL_WEIGHT = 50    # Poids des vrais trades vs historiques simulés
TEST_RATIO  = 0.20  # 20% dernier en validation (time-split strict)

# Features disponibles AU MOMENT de l'entrée (pas de look-ahead)
ENTRY_FEATURES = [
    "return_1m", "return_3m", "return_5m", "return_10m", "return_15m",
    "price_vs_ema7", "price_vs_ema21", "ema7_vs_ema21",
    "ema7_slope_3m", "ema7_slope_5m", "ema21_slope_5m",
    "ema_gap_current", "ema_converging",
    "rsi_14", "rsi_delta_3m", "rsi_oversold", "rsi_overbought",
    "bb_width", "bb_position", "bb_squeeze",
    "volume_ratio_1m", "volume_ratio_3m", "volume_spike", "volume_trend_5m",
    "buy_pressure_5m", "buy_pressure_1m",
    "green_candles_5m", "green_candles_3m", "avg_body_ratio_5m", "upper_wick_ratio",
    "momentum_3m", "momentum_5m", "momentum_10m", "monotonic_up_5m",
    "price_position_30m", "price_position_60m", "breakout_pct",
    "avg_trades_per_min_5m", "trade_intensity_ratio",
    "hour_utc", "is_asia_session", "is_europe_session", "is_us_session",
    "surge_strength",
]

SURGE_TYPE_ENCODING = {"FLASH_SURGE": 0, "BREAKOUT_SURGE": 1, "MOMENTUM_SURGE": 2}


# ── Helpers ────────────────────────────────────────────────────────────────────
def banner(title):
    print(f"\n{'═'*60}\n  {title}\n{'═'*60}")


def load_data():
    df = pd.read_parquet(PARQUET)
    df["entry_time"] = pd.to_datetime(df["entry_time"], errors="coerce")
    df = df.sort_values("entry_time").reset_index(drop=True)
    df["surge_type_enc"] = df["surge_type"].map(SURGE_TYPE_ENCODING).fillna(0).astype(int)
    df["is_real"]        = (df["source"] == "real").astype(int)
    return df


# ── 1. Analyse des règles sur vrais trades ─────────────────────────────────────
def analyze_rules(df):
    banner("ANALYSE RÈGLES — VRAIS TRADES")
    real = df[df["source"] == "real"].copy()
    n = len(real)
    wr_global = real["target_profitable"].mean()
    print(f"  Vrais trades : {n} | WR global : {wr_global*100:.1f}%")
    print()

    rules = {}

    def quartile_rule(feat, direction="min"):
        q = real[feat].quantile(0.25 if direction == "min" else 0.75)
        wr_lo = real[real[feat] <= q]["target_profitable"].mean()
        wr_hi = real[real[feat] > q]["target_profitable"].mean()
        sym = "≤" if direction == "min" else "≥"
        inv = ">" if direction == "min" else "<"
        threshold_wr = wr_hi if direction == "min" else wr_lo
        reject_wr    = wr_lo if direction == "min" else wr_hi
        print(f"  {feat:30s} {sym}{q:.3f} → WR {reject_wr*100:.0f}%  |  {inv}{q:.3f} → WR {threshold_wr*100:.0f}%")
        return round(q, 3)

    rules["return_1m_min"]        = quartile_rule("return_1m",        "min")
    rules["price_vs_ema7_min"]    = quartile_rule("price_vs_ema7",    "min")
    rules["upper_wick_ratio_max"] = quartile_rule("upper_wick_ratio", "max")
    rules["rsi_14_min"]           = round(real["rsi_14"].quantile(0.50), 1)
    rules["volume_ratio_1m_min"]  = round(real["volume_ratio_1m"].quantile(0.30), 2)

    # Print rsi + volume manuellement
    q_r = rules["rsi_14_min"]
    wr_lo = real[real["rsi_14"] < q_r]["target_profitable"].mean()
    wr_hi = real[real["rsi_14"] >= q_r]["target_profitable"].mean()
    print(f"  {'rsi_14':30s} <{q_r:.1f} → WR {wr_lo*100:.0f}%  |  ≥{q_r:.1f} → WR {wr_hi*100:.0f}%")
    q_v = rules["volume_ratio_1m_min"]
    wr_lo = real[real["volume_ratio_1m"] < q_v]["target_profitable"].mean()
    wr_hi = real[real["volume_ratio_1m"] >= q_v]["target_profitable"].mean()
    print(f"  {'volume_ratio_1m':30s} <{q_v:.2f} → WR {wr_lo*100:.0f}%  |  ≥{q_v:.2f} → WR {wr_hi*100:.0f}%")

    print()
    # Surge types
    for st, row in real.groupby("surge_type")["target_profitable"].agg(["mean","count"]).iterrows():
        print(f"  {st:<22} WR {row['mean']*100:.0f}% (n={row['count']})")

    # Test filtre combiné
    print()
    mask = (
        (real["return_1m"]        > rules["return_1m_min"]) &
        (real["price_vs_ema7"]    > rules["price_vs_ema7_min"]) &
        (real["upper_wick_ratio"] < rules["upper_wick_ratio_max"]) &
        (real["rsi_14"]           >= rules["rsi_14_min"]) &
        (real["volume_ratio_1m"]  >= rules["volume_ratio_1m_min"])
    )
    passed   = real[mask]
    rejected = real[~mask]
    print(f"  Filtre combiné → passe {len(passed)}/{n} ({len(passed)/n*100:.0f}%)")
    print(f"  WR passé   : {passed['target_profitable'].mean()*100:.1f}%  vs {wr_global*100:.1f}% avant")
    print(f"  WR rejeté  : {rejected['target_profitable'].mean()*100:.1f}%")
    print(f"  PnL passé  : {passed['target_pnl_pct'].mean():.3f}%")
    print(f"  PnL rejeté : {rejected['target_pnl_pct'].mean():.3f}%")

    rules["_stats"] = {
        "n_real":    n,
        "n_passed":  int(mask.sum()),
        "wr_before": round(wr_global, 4),
        "wr_after":  round(passed["target_profitable"].mean(), 4) if len(passed) > 0 else 0,
        "pnl_before":round(real["target_pnl_pct"].mean(), 4),
        "pnl_after": round(passed["target_pnl_pct"].mean(), 4) if len(passed) > 0 else 0,
    }
    return rules


# ── 2. Entraînement XGBoost (CPU) ─────────────────────────────────────────────
def train_model(df, rules):
    banner("ENTRAÎNEMENT — XGBoost CPU")

    try:
        import xgboost as xgb
        from sklearn.metrics import roc_auc_score
    except ImportError:
        print("  ❌ xgboost manquant: pip install xgboost")
        return None

    features = [f for f in ENTRY_FEATURES if f in df.columns] + ["surge_type_enc"]
    X = df[features].fillna(0).values
    y = df["target_profitable"].values

    n = len(df)
    split = int(n * (1 - TEST_RATIO))
    X_tr, X_va = X[:split], X[split:]
    y_tr, y_va = y[:split], y[split:]

    is_real_tr = df["is_real"].values[:split]
    w_tr = np.where(is_real_tr == 1, REAL_WEIGHT, 1.0)

    print(f"  Train: {split} | Val: {n-split} | Vrais trains: {is_real_tr.sum()} (x{REAL_WEIGHT})")

    # Charger meilleurs params depuis l'optimisation précédente
    best_xgb = {}
    opt_path = DATA_DIR / "gpu_optimization_results.json"
    if opt_path.exists():
        with open(opt_path) as f:
            best_xgb = json.load(f).get("tabular", {}).get("xgb_params", {})
        print("  Params: chargés depuis gpu_optimization_results.json")

    if not best_xgb:
        best_xgb = {"learning_rate": 0.027, "max_depth": 6, "min_child_weight": 13,
                    "subsample": 0.987, "colsample_bytree": 0.974,
                    "reg_alpha": 0.0, "reg_lambda": 1.2, "gamma": 4.3}

    n_rounds = int(best_xgb.pop("num_boost_round", 400))
    grow_policy = best_xgb.pop("grow_policy", "depthwise")
    max_leaves  = int(best_xgb.pop("max_leaves", 0)) if "max_leaves" in best_xgb else 0

    params = {
        **best_xgb,
        "objective":   "binary:logistic",
        "eval_metric": "auc",
        "tree_method": "hist",   # CPU
        "device":      "cpu",
        "nthread":     4,        # Limiter les threads sur le serveur
        "seed":        42,
        "grow_policy": grow_policy,
    }
    if max_leaves:
        params["max_leaves"] = max_leaves

    dtrain = xgb.DMatrix(X_tr, label=y_tr, weight=w_tr)
    dval   = xgb.DMatrix(X_va, label=y_va)

    print(f"  Entraînement CPU... ({n_rounds} rounds, early_stopping=30)")
    model = xgb.train(
        params, dtrain,
        num_boost_round=n_rounds,
        evals=[(dval, "val")],
        verbose_eval=100,
        early_stopping_rounds=30,
    )

    proba = model.predict(dval)
    auc   = roc_auc_score(y_va, proba)
    print(f"\n  AUC val : {auc:.4f}")

    # Seuil optimal (precision > 55% et coverage > 20%)
    best_thr, best_prec, best_cov = 0.5, 0.0, 0.0
    for thr in np.arange(0.35, 0.72, 0.02):
        pred = proba >= thr
        n_p  = pred.sum()
        if n_p < 10:
            continue
        prec = y_va[pred].mean()
        cov  = n_p / len(y_va)
        if prec > 0.55 and cov > 0.20 and prec > best_prec:
            best_prec, best_thr, best_cov = prec, thr, cov

    print(f"  Seuil optimal: {best_thr:.2f} | Precision: {best_prec*100:.1f}% | Coverage: {best_cov*100:.1f}%")

    # Feature importance
    importance = model.get_score(importance_type="gain")
    fi_sorted = sorted(importance.items(), key=lambda x: x[1], reverse=True)[:15]
    print("\n  Top 15 features (gain):")
    for fname, gain in fi_sorted:
        idx = int(fname[1:]) if fname.startswith("f") else -1
        label = features[idx] if 0 <= idx < len(features) else fname
        print(f"    {label:<32} {gain:.1f}")

    return {
        "model":      model,
        "features":   features,
        "threshold":  best_thr,
        "auc":        auc,
        "precision":  best_prec,
        "coverage":   best_cov,
    }


# ── 3. Évaluation filtre combiné ───────────────────────────────────────────────
def evaluate(df, model_data, rules):
    banner("RÉSULTAT FILTRE COMBINÉ (val set)")

    try:
        import xgboost as xgb
    except ImportError:
        return None

    features = model_data["features"]
    thr      = model_data["threshold"]
    n        = len(df)
    split    = int(n * (1 - TEST_RATIO))
    val      = df.iloc[split:].copy()

    dval       = xgb.DMatrix(val[features].fillna(0).values)
    val["ml"]  = model_data["model"].predict(dval)

    val["rules_ok"] = (
        (val["return_1m"]        > rules.get("return_1m_min", -999)) &
        (val["price_vs_ema7"]    > rules.get("price_vs_ema7_min", -999)) &
        (val["upper_wick_ratio"] < rules.get("upper_wick_ratio_max", 999)) &
        (val["rsi_14"]           >= rules.get("rsi_14_min", 0)) &
        (val["volume_ratio_1m"]  >= rules.get("volume_ratio_1m_min", 0))
    )
    val["ml_ok"] = val["ml"] >= thr

    def stats(mask, label):
        sub = val[mask]
        if len(sub) == 0:
            print(f"  {label:<22} n=0")
            return
        print(f"  {label:<22} n={len(sub):>5}  WR={sub['target_profitable'].mean()*100:.1f}%  PnL={sub['target_pnl_pct'].mean():.3f}%")

    stats(pd.Series([True]*len(val), index=val.index), "Aucun filtre")
    stats(val["rules_ok"],                              "Règles seules")
    stats(val["ml_ok"],                                 "ML seul")
    stats(val["ml_ok"] & val["rules_ok"],               "ML + Règles")

    print()
    combo = val[val["ml_ok"] & val["rules_ok"]]
    for st in sorted(val["surge_type"].unique()):
        sub = combo[combo["surge_type"] == st]
        if len(sub) > 0:
            print(f"    {st:<22} n={len(sub):>5}  WR={sub['target_profitable'].mean()*100:.0f}%")

    return {
        "n_val":     len(val),
        "wr_raw":    round(val["target_profitable"].mean(), 4),
        "wr_combo":  round(combo["target_profitable"].mean(), 4) if len(combo) > 0 else 0,
        "n_combo":   len(combo),
    }


# ── 4. Sauvegarde ──────────────────────────────────────────────────────────────
def save(model_data, rules, eval_res):
    banner("SAUVEGARDE")

    model_path = MODELS_DIR / "entry_filter.pkl"
    with open(model_path, "wb") as f:
        pickle.dump(model_data, f)
    print(f"  ✅ Modèle : {model_path}")

    rules_path = MODELS_DIR / "entry_filter_rules.json"
    out = {
        "generated_at":    datetime.now(timezone.utc).isoformat(),
        "model_threshold": model_data["threshold"],
        "model_auc":       round(model_data["auc"], 4),
        "model_precision": round(model_data["precision"], 4),
        "rules":           {k: v for k, v in rules.items() if not k.startswith("_")},
        "eval":            eval_res or {},
    }
    with open(rules_path, "w") as f:
        json.dump(out, f, indent=2)
    print(f"  ✅ Règles : {rules_path}")

    s = rules["_stats"]
    print(f"""
  ── Constantes à ajouter dans market_spy.py ──────────────────

# ═══ ENTRY FILTER IA — généré {datetime.now().strftime('%d/%m/%Y')} ═══
# {s['n_real']} vrais trades | WR: {s['wr_before']*100:.1f}% → {s['wr_after']*100:.1f}% après filtre
ENTRY_FILTER_RULES = {{
    'return_1m_min':        {rules['return_1m_min']},
    'price_vs_ema7_min':    {rules['price_vs_ema7_min']},
    'upper_wick_ratio_max': {rules['upper_wick_ratio_max']},
    'rsi_14_min':           {rules['rsi_14_min']},
    'volume_ratio_1m_min':  {rules['volume_ratio_1m_min']},
}}
""")


# ── Main ───────────────────────────────────────────────────────────────────────
def main():
    print(f"\n╔{'═'*58}╗")
    print(f"║  ENTRY FILTER TRAINER (CPU) — {datetime.now().strftime('%d/%m/%Y %H:%M'):<27}║")
    print(f"╚{'═'*58}╝")

    if not PARQUET.exists():
        print(f"❌ Dataset introuvable: {PARQUET}")
        return

    df = load_data()
    n_real = (df["source"] == "real").sum()
    print(f"  Dataset: {len(df)} exemples | {n_real} vrais trades")

    rules      = analyze_rules(df)
    model_data = train_model(df, rules)

    if model_data is None:
        return

    eval_res = evaluate(df, model_data, rules)
    save(model_data, rules, eval_res)

    print("\n  ✅ Terminé.")


if __name__ == "__main__":
    main()
