#!/usr/bin/env python3
"""
SPY Optimizer — GPU Hyperparameter Optimization
Utilise Optuna + XGBoost GPU pour trouver les meilleurs paramètres.

Usage (sur le PC avec GPU):
  python optimize_gpu.py                     # 500 trials Optuna
  python optimize_gpu.py --trials 2000       # Plus de trials
  python optimize_gpu.py --backtest          # Walk-forward avec best params
  python optimize_gpu.py --full              # Optimize + backtest + export

Le résultat est un fichier JSON avec les meilleurs hyperparamètres,
à déployer sur le serveur pour le signal_classifier.
"""
import argparse
import json
import os
import sys
import time
import warnings
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import pandas as pd
import lightgbm as lgb
import xgboost as xgb
import optuna
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
from sklearn.model_selection import TimeSeriesSplit

optuna.logging.set_verbosity(optuna.logging.WARNING)
warnings.filterwarnings("ignore")

# ─── Paths (compatible PC + serveur) ───
PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
KLINES_DIR = DATA_DIR / "klines_1m"
MODELS_DIR = PROJECT_DIR / "models"
_hist_local = PROJECT_DIR / "espion_history.json"
_hist_parent = PROJECT_DIR.parent / "espion_history.json"
HISTORY_FILE = _hist_local if _hist_local.exists() else _hist_parent
RESULTS_FILE = DATA_DIR / "gpu_optimization_results.json"

from feature_engineering import build_dataset_from_trades, FEATURE_COLUMNS, SURGE_TYPES

# ─── GPU Detection ───

def detect_device():
    """Détecte GPU pour XGBoost."""
    try:
        import torch
        if torch.cuda.is_available():
            name = torch.cuda.get_device_name(0)
            mem = torch.cuda.get_device_properties(0).total_memory / 1e9
            print(f"  🎮 GPU: {name} ({mem:.1f} GB)")
            return "cuda"
    except ImportError:
        pass

    # Fallback: test XGBoost GPU directly
    try:
        dtmp = xgb.DMatrix(np.random.randn(10, 3), label=np.random.randint(0, 2, 10))
        xgb.train({"device": "cuda", "verbosity": 0}, dtmp, num_boost_round=1)
        print("  🎮 GPU: XGBoost CUDA available")
        return "cuda"
    except Exception:
        pass

    print("  ⚠️  No GPU detected, using CPU")
    return "cpu"


# ─── Data ───

def load_dataset() -> pd.DataFrame:
    """Charge ou construit le dataset. Préfère le combined (historique + réel)."""
    combined_path = DATA_DIR / "combined_training_dataset.parquet"
    dataset_path = DATA_DIR / "training_dataset.parquet"

    if combined_path.exists():
        dataset = pd.read_parquet(combined_path)
        print(f"  📂 Combined dataset chargé: {len(dataset)} samples (historique + réel)")
    elif dataset_path.exists():
        dataset = pd.read_parquet(dataset_path)
        print(f"  📂 Dataset chargé: {len(dataset)} samples")
    else:
        print("  📂 Building dataset from trades...")
        trades = load_trades()
        dataset = build_dataset_from_trades(trades, str(KLINES_DIR))
        dataset.to_parquet(dataset_path, index=False)
        print(f"  ✅ Dataset: {len(dataset)} samples")

    return dataset


def load_trades() -> list:
    with open(HISTORY_FILE) as f:
        all_trades = json.load(f)
    trades = [t for t in all_trades if t.get("surge_type") and t.get("entry_time")]
    print(f"  📂 {len(trades)} trades avec surge_type")
    return trades


def prepare_features(dataset: pd.DataFrame):
    """Prépare X, y, et le surge_type encodé."""
    dataset = dataset.sort_values("entry_time").reset_index(drop=True)

    y = dataset["target_profitable"].values

    # Features numériques
    feature_cols = [c for c in FEATURE_COLUMNS if c in dataset.columns]
    X = dataset[feature_cols].copy()

    # Encoder surge_type comme feature numérique
    if "surge_type" in dataset.columns:
        surge_map = {s: i for i, s in enumerate(SURGE_TYPES)}
        X["surge_type_encoded"] = dataset["surge_type"].map(surge_map).fillna(len(SURGE_TYPES))

    X = X.fillna(0).values

    return X, y, feature_cols + (["surge_type_encoded"] if "surge_type" in dataset.columns else [])


# ─── Optuna Optimization ───

def optimize_xgb_gpu(X_train, y_train, X_val, y_val, n_trials: int, device: str) -> dict:
    """Optuna optimization pour XGBoost avec GPU."""
    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    def objective(trial):
        params = {
            "objective": "binary:logistic",
            "eval_metric": "auc",
            "device": device,
            "scale_pos_weight": pos_weight,
            "verbosity": 0,
            "learning_rate": trial.suggest_float("learning_rate", 0.005, 0.3, log=True),
            "max_depth": trial.suggest_int("max_depth", 2, 12),
            "min_child_weight": trial.suggest_int("min_child_weight", 1, 50),
            "subsample": trial.suggest_float("subsample", 0.4, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
            "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.3, 1.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 100.0, log=True),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 100.0, log=True),
            "gamma": trial.suggest_float("gamma", 0.0, 10.0),
            "max_delta_step": trial.suggest_float("max_delta_step", 0.0, 10.0),
            "grow_policy": trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]),
        }
        if params["grow_policy"] == "lossguide":
            params["max_leaves"] = trial.suggest_int("max_leaves", 15, 255)

        num_boost_round = trial.suggest_int("num_boost_round", 50, 1000)

        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)

        model = xgb.train(
            params, dtrain,
            num_boost_round=num_boost_round,
            evals=[(dval, "val")],
            early_stopping_rounds=30,
            verbose_eval=False,
        )
        pred = model.predict(dval)

        auc = roc_auc_score(y_val, pred)

        # Bonus pour les modèles qui filtrent utilement
        pnl_score = compute_pnl_score(pred, y_val)
        return auc * 0.6 + pnl_score * 0.4

    study = optuna.create_study(direction="maximize", study_name="xgb_gpu")
    study.optimize(objective, n_trials=n_trials, show_progress_bar=True, n_jobs=1)

    print(f"\n  🏆 XGBoost Best trial: {study.best_value:.4f}")
    return study.best_params, study.best_value


def optimize_lgb(X_train, y_train, X_val, y_val, n_trials: int) -> dict:
    """Optuna optimization pour LightGBM (CPU, car LightGBM GPU build est compliqué sur Windows)."""
    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    def objective(trial):
        params = {
            "objective": "binary",
            "metric": "auc",
            "verbosity": -1,
            "scale_pos_weight": pos_weight,
            "learning_rate": trial.suggest_float("learning_rate", 0.005, 0.3, log=True),
            "num_leaves": trial.suggest_int("num_leaves", 7, 255),
            "max_depth": trial.suggest_int("max_depth", 2, 15),
            "min_child_samples": trial.suggest_int("min_child_samples", 3, 60),
            "subsample": trial.suggest_float("subsample", 0.4, 1.0),
            "subsample_freq": trial.suggest_int("subsample_freq", 1, 10),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 100.0, log=True),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 100.0, log=True),
            "min_split_gain": trial.suggest_float("min_split_gain", 0.0, 2.0),
            "path_smooth": trial.suggest_float("path_smooth", 0.0, 10.0),
        }

        train_set = lgb.Dataset(X_train, y_train)
        val_set = lgb.Dataset(X_val, y_val, reference=train_set)

        model = lgb.train(
            params, train_set,
            num_boost_round=1000,
            valid_sets=[val_set],
            callbacks=[lgb.early_stopping(30), lgb.log_evaluation(period=0)],
        )
        pred = model.predict(X_val)

        auc = roc_auc_score(y_val, pred)
        pnl_score = compute_pnl_score(pred, y_val)
        return auc * 0.6 + pnl_score * 0.4

    study = optuna.create_study(direction="maximize", study_name="lgb_opt")
    study.optimize(objective, n_trials=n_trials, show_progress_bar=True, n_jobs=4)

    print(f"\n  🏆 LightGBM Best trial: {study.best_value:.4f}")
    return study.best_params, study.best_value


def compute_pnl_score(pred_proba, y_true):
    """Score normalisé basé sur la capacité à filtrer les mauvais trades."""
    best_score = 0
    for thresh in np.arange(0.30, 0.75, 0.02):
        passed = pred_proba >= thresh
        if passed.sum() < max(3, len(pred_proba) * 0.10):
            continue
        tp = ((passed) & (y_true == 1)).sum()
        fp = ((passed) & (y_true == 0)).sum()
        fn = ((~passed) & (y_true == 1)).sum()
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        # Favorise la précision (filtrer les mauvais) tout en gardant du rappel
        score = precision * 0.65 + recall * 0.35
        best_score = max(best_score, score)
    return best_score


def find_optimal_threshold(pred_proba, y_true, pnl_values=None):
    """Trouve le seuil optimal pour le profit."""
    best_threshold = 0.5
    best_score = -999

    for thresh in np.arange(0.25, 0.80, 0.01):
        passed = pred_proba >= thresh
        n_passed = passed.sum()
        if n_passed < max(3, len(pred_proba) * 0.10):
            continue

        tp = ((passed) & (y_true == 1)).sum()
        fp = ((passed) & (y_true == 0)).sum()
        precision = tp / max(tp + fp, 1)

        if pnl_values is not None:
            pnl_passed = pnl_values[passed].sum()
            pnl_all = pnl_values.sum()
            score = pnl_passed - pnl_all + precision * 10
        else:
            recall = tp / max(((~passed) & (y_true == 1)).sum() + tp, 1)
            score = precision * 0.65 + recall * 0.35

        if score > best_score:
            best_score = score
            best_threshold = thresh

    return best_threshold


# ─── Walk-Forward Backtest ───

def walk_forward_backtest(dataset, best_lgb_params, best_xgb_params, device, n_splits=5):
    """Walk-forward backtest avec les meilleurs paramètres."""
    dataset = dataset.sort_values("entry_time").reset_index(drop=True)
    X, y, feature_names = prepare_features(dataset)
    pnl_values = dataset["target_pnl_pct"].values if "target_pnl_pct" in dataset.columns else None

    tscv = TimeSeriesSplit(n_splits=n_splits)
    results = []

    print(f"\n{'═'*60}")
    print(f"  📈 Walk-Forward Backtest ({n_splits} folds)")
    print(f"{'═'*60}")

    for fold, (train_idx, val_idx) in enumerate(tscv.split(X)):
        X_tr, X_va = X[train_idx], X[val_idx]
        y_tr, y_va = y[train_idx], y[val_idx]
        pnl_va = pnl_values[val_idx] if pnl_values is not None else None

        pos_weight = (len(y_tr) - y_tr.sum()) / max(y_tr.sum(), 1)

        # LightGBM
        lgb_params = {
            **best_lgb_params,
            "objective": "binary", "metric": "auc",
            "verbosity": -1, "scale_pos_weight": pos_weight,
        }
        train_set = lgb.Dataset(X_tr, y_tr)
        val_set = lgb.Dataset(X_va, y_va, reference=train_set)
        lgb_model = lgb.train(
            lgb_params, train_set, num_boost_round=1000,
            valid_sets=[val_set],
            callbacks=[lgb.early_stopping(30), lgb.log_evaluation(0)],
        )
        lgb_pred = lgb_model.predict(X_va)

        # XGBoost GPU
        xgb_params = {
            **best_xgb_params,
            "objective": "binary:logistic", "eval_metric": "auc",
            "device": device, "verbosity": 0, "scale_pos_weight": pos_weight,
        }
        # Remove non-xgb keys
        for k in ["num_boost_round"]:
            xgb_params.pop(k, None)
        dtrain = xgb.DMatrix(X_tr, label=y_tr)
        dval = xgb.DMatrix(X_va, label=y_va)
        xgb_model = xgb.train(
            xgb_params, dtrain,
            num_boost_round=best_xgb_params.get("num_boost_round", 500),
            evals=[(dval, "val")],
            early_stopping_rounds=30, verbose_eval=False,
        )
        xgb_pred = xgb_model.predict(dval)

        # Ensemble
        pred = 0.5 * lgb_pred + 0.5 * xgb_pred
        auc = roc_auc_score(y_va, pred) if len(np.unique(y_va)) > 1 else 0.5
        threshold = find_optimal_threshold(pred, y_va, pnl_va)
        passed = pred >= threshold
        n_passed = passed.sum()

        prec = precision_score(y_va, passed.astype(int), zero_division=0)

        # PnL
        if pnl_va is not None:
            pnl_all = pnl_va.sum()
            pnl_filtered = pnl_va[passed].sum()
            delta = pnl_filtered - pnl_all
        else:
            pnl_all = pnl_filtered = delta = 0

        status = "✅" if delta >= 0 else "❌"
        print(f"  Fold {fold+1} | {len(y_va)} trades → {n_passed} passés ({n_passed/len(y_va)*100:.0f}%) | "
              f"AUC: {auc:.3f} | Prec: {prec:.3f} | Threshold: {threshold:.2f} | "
              f"PnL: {pnl_filtered:+.1f}% vs {pnl_all:+.1f}% Δ={delta:+.1f}% {status}")

        results.append({
            "fold": fold + 1,
            "n_val": len(y_va),
            "n_passed": int(n_passed),
            "auc": float(auc),
            "precision": float(prec),
            "threshold": float(threshold),
            "pnl_filtered": float(pnl_filtered),
            "pnl_all": float(pnl_all),
            "delta": float(delta),
        })

    # Summary
    avg_auc = np.mean([r["auc"] for r in results])
    avg_prec = np.mean([r["precision"] for r in results])
    total_delta = sum(r["delta"] for r in results)
    positive_folds = sum(1 for r in results if r["delta"] >= 0)

    print(f"\n{'═'*60}")
    print(f"  📊 Walk-Forward Summary")
    print(f"{'═'*60}")
    print(f"  Avg AUC:        {avg_auc:.3f}")
    print(f"  Avg Precision:  {avg_prec:.3f}")
    print(f"  Positive folds: {positive_folds}/{len(results)}")
    print(f"  Total Δ PnL:    {total_delta:+.1f}%")

    return results


# ─── Main ───

def main():
    parser = argparse.ArgumentParser(description="GPU Hyperparameter Optimization")
    parser.add_argument("--trials", type=int, default=500, help="Optuna trials per model")
    parser.add_argument("--backtest", action="store_true", help="Run walk-forward backtest only")
    parser.add_argument("--full", action="store_true", help="Optimize + backtest + export")
    parser.add_argument("--cpu", action="store_true", help="Force CPU mode")
    parser.add_argument("--folds", type=int, default=5, help="Walk-forward folds")
    args = parser.parse_args()

    print(f"\n{'═'*60}")
    print(f"  🚀 SPY Optimizer — GPU Hyperparameter Search")
    print(f"{'═'*60}")

    device = "cpu" if args.cpu else detect_device()

    # Load data
    dataset = load_dataset()
    X, y, feature_names = prepare_features(dataset)

    # Walk-forward split
    split_idx = int(len(X) * 0.75)
    X_train, X_val = X[:split_idx], X[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    pos = y.sum()
    print(f"\n  📐 Dataset: {len(y)} samples ({pos:.0f} profitable, {len(y)-pos:.0f} perdants)")
    print(f"  Features: {len(feature_names)}")
    print(f"  Train: {len(y_train)} | Val: {len(y_val)}")

    if args.backtest:
        # Load existing results
        if RESULTS_FILE.exists():
            results = json.loads(RESULTS_FILE.read_text())
            best_lgb = results["best_lgb_params"]
            best_xgb = results["best_xgb_params"]
            print(f"  📂 Loaded best params from {RESULTS_FILE}")
        else:
            print("  ❌ No optimization results found. Run without --backtest first.")
            return
        walk_forward_backtest(dataset, best_lgb, best_xgb, device, args.folds)
        return

    # ─── Optimize XGBoost (GPU) ───
    print(f"\n{'─'*60}")
    print(f"  🔍 XGBoost GPU Optimization ({args.trials} trials)")
    print(f"{'─'*60}")

    t0 = time.time()
    best_xgb_params, xgb_score = optimize_xgb_gpu(
        X_train, y_train, X_val, y_val, args.trials, device
    )
    xgb_time = time.time() - t0
    print(f"  ⏱️  XGBoost: {xgb_time:.0f}s ({args.trials/xgb_time:.1f} trials/s)")

    # ─── Optimize LightGBM (CPU, parallel) ───
    print(f"\n{'─'*60}")
    print(f"  🔍 LightGBM Optimization ({args.trials} trials)")
    print(f"{'─'*60}")

    t0 = time.time()
    best_lgb_params, lgb_score = optimize_lgb(
        X_train, y_train, X_val, y_val, args.trials
    )
    lgb_time = time.time() - t0
    print(f"  ⏱️  LightGBM: {lgb_time:.0f}s ({args.trials/lgb_time:.1f} trials/s)")

    # ─── Train final ensemble with best params ───
    print(f"\n{'─'*60}")
    print(f"  🏋️ Training final ensemble with best params")
    print(f"{'─'*60}")

    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    # LightGBM final
    lgb_params = {
        **best_lgb_params,
        "objective": "binary", "metric": "auc",
        "verbosity": -1, "scale_pos_weight": pos_weight,
    }
    train_set = lgb.Dataset(X_train, y_train)
    val_set = lgb.Dataset(X_val, y_val, reference=train_set)
    lgb_model = lgb.train(
        lgb_params, train_set, num_boost_round=1000,
        valid_sets=[val_set],
        callbacks=[lgb.early_stopping(50), lgb.log_evaluation(0)],
    )
    lgb_pred = lgb_model.predict(X_val)

    # XGBoost final
    xgb_params_final = {
        **best_xgb_params,
        "objective": "binary:logistic", "eval_metric": "auc",
        "device": device, "verbosity": 0, "scale_pos_weight": pos_weight,
    }
    n_boost = xgb_params_final.pop("num_boost_round", 500)
    dtrain = xgb.DMatrix(X_train, label=y_train)
    dval = xgb.DMatrix(X_val, label=y_val)
    xgb_model = xgb.train(
        xgb_params_final, dtrain,
        num_boost_round=n_boost,
        evals=[(dval, "val")],
        early_stopping_rounds=50, verbose_eval=False,
    )
    xgb_pred = xgb_model.predict(dval)

    # Ensemble
    ensemble_pred = 0.5 * lgb_pred + 0.5 * xgb_pred
    ensemble_auc = roc_auc_score(y_val, ensemble_pred) if len(np.unique(y_val)) > 1 else 0.5

    pnl_values = dataset.sort_values("entry_time")["target_pnl_pct"].values[split_idx:] \
        if "target_pnl_pct" in dataset.columns else None
    threshold = find_optimal_threshold(ensemble_pred, y_val, pnl_values)

    passed = ensemble_pred >= threshold
    prec = precision_score(y_val, passed.astype(int), zero_division=0)
    rec = recall_score(y_val, passed.astype(int), zero_division=0)

    if pnl_values is not None:
        pnl_all = pnl_values.sum()
        pnl_filtered = pnl_values[passed].sum()
        delta = pnl_filtered - pnl_all
    else:
        pnl_all = pnl_filtered = delta = 0

    print(f"\n{'═'*60}")
    print(f"  📊 Final Results (Ensemble)")
    print(f"{'═'*60}")
    print(f"  AUC:            {ensemble_auc:.4f}")
    print(f"  Precision:      {prec:.4f}")
    print(f"  Recall:         {rec:.4f}")
    print(f"  Threshold:      {threshold:.2f}")
    print(f"  Pass rate:      {passed.sum()}/{len(y_val)} ({passed.mean()*100:.0f}%)")
    if pnl_values is not None:
        print(f"  PnL sans filtre:  {pnl_all:+.2f}%")
        print(f"  PnL avec filtre:  {pnl_filtered:+.2f}%")
        print(f"  Amélioration:     {delta:+.2f}%")

    # Feature importance
    lgb_imp = lgb_model.feature_importance(importance_type="gain")
    print(f"\n  🔑 Top 10 Features (LightGBM):")
    sorted_idx = np.argsort(lgb_imp)[::-1]
    for rank, idx in enumerate(sorted_idx[:10]):
        name = feature_names[idx] if idx < len(feature_names) else f"feat_{idx}"
        bar = "█" * int(lgb_imp[idx] / lgb_imp[sorted_idx[0]] * 20)
        print(f"    {rank+1:2d}. {name:30s} {bar}")

    # ─── Save results ───
    output = {
        "optimized_at": datetime.now(timezone.utc).isoformat(),
        "device": device,
        "n_trials": args.trials,
        "dataset_size": len(y),
        "train_size": len(y_train),
        "val_size": len(y_val),
        "best_lgb_params": best_lgb_params,
        "best_xgb_params": best_xgb_params,
        "lgb_score": float(lgb_score),
        "xgb_score": float(xgb_score),
        "ensemble_metrics": {
            "auc": float(ensemble_auc),
            "precision": float(prec),
            "recall": float(rec),
            "threshold": float(threshold),
            "pass_rate": float(passed.mean()),
            "pnl_without_filter": float(pnl_all) if pnl_values is not None else None,
            "pnl_with_filter": float(pnl_filtered) if pnl_values is not None else None,
            "pnl_improvement": float(delta) if pnl_values is not None else None,
        },
        "timing": {
            "xgb_seconds": float(xgb_time),
            "lgb_seconds": float(lgb_time),
        },
    }

    RESULTS_FILE.parent.mkdir(parents=True, exist_ok=True)
    RESULTS_FILE.write_text(json.dumps(output, indent=2))
    print(f"\n  💾 Results: {RESULTS_FILE}")

    # Walk-forward if --full
    if args.full:
        wf_results = walk_forward_backtest(dataset, best_lgb_params, best_xgb_params, device, args.folds)
        output["walk_forward"] = wf_results
        RESULTS_FILE.write_text(json.dumps(output, indent=2, default=str))

    # Export params for server deployment
    deploy_params = {
        "lgb_params": best_lgb_params,
        "xgb_params": {k: v for k, v in best_xgb_params.items() if k != "num_boost_round"},
        "xgb_num_boost_round": best_xgb_params.get("num_boost_round", 500),
        "threshold": float(threshold),
        "optimized_at": datetime.now(timezone.utc).isoformat(),
        "ensemble_auc": float(ensemble_auc),
    }
    deploy_path = MODELS_DIR / "optimized_params.json"
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    deploy_path.write_text(json.dumps(deploy_params, indent=2))
    print(f"  📦 Deploy params: {deploy_path}")
    print(f"\n  ✅ Deploy to server:")
    print(f"     scp {deploy_path} user@server:~/crypto_trading_bot/spy_optimizer/models/")


if __name__ == "__main__":
    main()
