#!/usr/bin/env python3
"""
SPY Optimizer — Full GPU Training Pipeline
============================================
Script conçu pour le PC avec GPU (RTX 5060 Ti 16GB).

Exécute 3 tâches en exploitant pleinement le GPU :
  1. Optuna XGBoost GPU (500 trials) — tree_method='gpu_hist'
  2. Optuna LightGBM CPU (500 trials, 4 jobs parallèles)
  3. LSTM+Attention GPU training (PyTorch CUDA)

Usage:
    python train_gpu_full.py                    # Tout (500 trials + LSTM 100 epochs)
    python train_gpu_full.py --trials 1000      # Plus de trials Optuna
    python train_gpu_full.py --lstm-only        # LSTM seulement
    python train_gpu_full.py --tabular-only     # XGB+LGB seulement
    python train_gpu_full.py --epochs 200       # Plus d'epochs LSTM

Inputs requis dans le même dossier :
    - combined_training_dataset.parquet  (15 MB — features tabulaires)
    - lstm_sequences.npz                (81 MB — séquences temporelles)

Outputs :
    - models/optimized_params.json         → params pour le serveur
    - models/surge_predictor_gpu.pt        → modèle LSTM TorchScript
    - models/surge_predictor_gpu_checkpoint.pth → checkpoint complet
    - models/surge_predictor_gpu_meta.json → métadonnées LSTM
    - gpu_optimization_results.json        → résultats complets
"""
import argparse
import json
import os
import sys
import time
import warnings
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import lightgbm as lgb
import xgboost as xgb
import optuna
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
from sklearn.model_selection import TimeSeriesSplit

optuna.logging.set_verbosity(optuna.logging.WARNING)
warnings.filterwarnings("ignore")

# ─── Paths ───
SCRIPT_DIR = Path(__file__).parent
MODELS_DIR = SCRIPT_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

# Cherche les fichiers data dans le dossier courant ou sous-dossier data/
def find_data_file(name):
    for p in [SCRIPT_DIR / name, SCRIPT_DIR / "data" / name]:
        if p.exists():
            return p
    return None

# ─── Feature columns (doit matcher feature_engineering.py sur le serveur) ───
FEATURE_COLUMNS = [
    "return_1m", "return_3m", "return_5m", "return_10m", "return_15m", "return_30m", "return_60m",
    "volatility_5m", "volatility_15m", "volatility_30m",
    "price_vs_ema7", "price_vs_ema21", "ema7_vs_ema21",
    "ema7_slope_3m", "ema7_slope_5m", "ema21_slope_5m",
    "ema_gap_current", "ema_converging",
    "rsi_14", "rsi_delta_3m", "rsi_oversold", "rsi_overbought",
    "bb_width", "bb_position", "bb_squeeze",
    "volume_ratio_1m", "volume_ratio_3m", "volume_spike", "volume_trend_5m",
    "buy_pressure_5m", "buy_pressure_1m",
    "green_candles_5m", "green_candles_3m", "avg_body_ratio_5m", "upper_wick_ratio",
    "momentum_3m", "momentum_5m", "momentum_10m", "monotonic_up_5m",
    "price_position_30m", "price_position_60m", "breakout_pct",
    "avg_trades_per_min_5m", "trade_intensity_ratio",
    "hour_utc", "is_asia_session", "is_europe_session", "is_us_session",
    "surge_strength",
]
SURGE_TYPES = ["FLASH_SURGE", "BREAKOUT_SURGE", "MOMENTUM_SURGE"]


# ═══════════════════════════════════════════════════════════════
#  GPU Detection
# ═══════════════════════════════════════════════════════════════

def detect_gpu():
    """Détecte et affiche le GPU disponible."""
    import torch
    if not torch.cuda.is_available():
        print("  ⚠️  CUDA non disponible — utilisation CPU")
        return "cpu", False

    name = torch.cuda.get_device_name(0)
    mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"  🎮 GPU: {name} ({mem:.1f} GB VRAM)")
    print(f"  🔧 CUDA: {torch.version.cuda}")
    print(f"  🔧 PyTorch: {torch.__version__}")
    return "cuda", True


# ═══════════════════════════════════════════════════════════════
#  PART 1: Tabular Data Preparation
# ═══════════════════════════════════════════════════════════════

def load_tabular_dataset():
    """Charge le dataset tabulaire (features plates)."""
    path = find_data_file("combined_training_dataset.parquet")
    if path is None:
        print("  ❌ combined_training_dataset.parquet introuvable")
        sys.exit(1)
    dataset = pd.read_parquet(path)
    print(f"  📂 Dataset tabulaire: {len(dataset)} samples ({path})")
    return dataset


def prepare_features(dataset):
    """Prépare X, y pour les modèles tabulaires."""
    dataset = dataset.sort_values("entry_time").reset_index(drop=True)
    y = dataset["target_profitable"].values

    feature_cols = [c for c in FEATURE_COLUMNS if c in dataset.columns]
    X = dataset[feature_cols].copy()

    # Encoder surge_type
    if "surge_type" in dataset.columns:
        surge_map = {s: i for i, s in enumerate(SURGE_TYPES)}
        X["surge_type_encoded"] = dataset["surge_type"].map(surge_map).fillna(len(SURGE_TYPES))

    X = X.fillna(0).values
    return X, y, feature_cols + (["surge_type_encoded"] if "surge_type" in dataset.columns else [])


# ═══════════════════════════════════════════════════════════════
#  PART 2: XGBoost GPU Optimization (Optuna)
# ═══════════════════════════════════════════════════════════════

def optimize_xgb_gpu(X_train, y_train, X_val, y_val, n_trials, device):
    """Optuna XGBoost avec GPU — ultra rapide sur RTX 5060 Ti."""
    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    def objective(trial):
        params = {
            "objective": "binary:logistic",
            "eval_metric": "auc",
            "device": device,
            "tree_method": "hist",
            "scale_pos_weight": pos_weight,
            "verbosity": 0,
            "learning_rate": trial.suggest_float("learning_rate", 0.005, 0.3, log=True),
            "max_depth": trial.suggest_int("max_depth", 2, 12),
            "min_child_weight": trial.suggest_int("min_child_weight", 1, 50),
            "subsample": trial.suggest_float("subsample", 0.4, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
            "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.3, 1.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 100.0, log=True),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 100.0, log=True),
            "gamma": trial.suggest_float("gamma", 0.0, 10.0),
            "max_delta_step": trial.suggest_float("max_delta_step", 0.0, 10.0),
            "grow_policy": trial.suggest_categorical("grow_policy", ["depthwise", "lossguide"]),
        }
        if params["grow_policy"] == "lossguide":
            params["max_leaves"] = trial.suggest_int("max_leaves", 15, 255)

        num_boost_round = trial.suggest_int("num_boost_round", 50, 1000)

        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)

        model = xgb.train(
            params, dtrain,
            num_boost_round=num_boost_round,
            evals=[(dval, "val")],
            early_stopping_rounds=30,
            verbose_eval=False,
        )
        pred = model.predict(dval)
        auc = roc_auc_score(y_val, pred)
        pnl_score = compute_pnl_score(pred, y_val)
        return auc * 0.6 + pnl_score * 0.4

    study = optuna.create_study(direction="maximize", study_name="xgb_gpu")
    study.optimize(objective, n_trials=n_trials, show_progress_bar=True, n_jobs=1)

    print(f"\n  🏆 XGBoost Best: {study.best_value:.4f} (trial #{study.best_trial.number})")
    return study.best_params, study.best_value


def optimize_lgb(X_train, y_train, X_val, y_val, n_trials):
    """Optuna LightGBM CPU (4 jobs parallèles)."""
    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    def objective(trial):
        params = {
            "objective": "binary",
            "metric": "auc",
            "verbosity": -1,
            "scale_pos_weight": pos_weight,
            "learning_rate": trial.suggest_float("learning_rate", 0.005, 0.3, log=True),
            "num_leaves": trial.suggest_int("num_leaves", 7, 255),
            "max_depth": trial.suggest_int("max_depth", 2, 15),
            "min_child_samples": trial.suggest_int("min_child_samples", 3, 60),
            "subsample": trial.suggest_float("subsample", 0.4, 1.0),
            "subsample_freq": trial.suggest_int("subsample_freq", 1, 10),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
            "reg_alpha": trial.suggest_float("reg_alpha", 1e-8, 100.0, log=True),
            "reg_lambda": trial.suggest_float("reg_lambda", 1e-8, 100.0, log=True),
            "min_split_gain": trial.suggest_float("min_split_gain", 0.0, 2.0),
            "path_smooth": trial.suggest_float("path_smooth", 0.0, 10.0),
        }

        train_set = lgb.Dataset(X_train, y_train)
        val_set = lgb.Dataset(X_val, y_val, reference=train_set)

        model = lgb.train(
            params, train_set,
            num_boost_round=1000,
            valid_sets=[val_set],
            callbacks=[lgb.early_stopping(30), lgb.log_evaluation(period=0)],
        )
        pred = model.predict(X_val)
        auc = roc_auc_score(y_val, pred)
        pnl_score = compute_pnl_score(pred, y_val)
        return auc * 0.6 + pnl_score * 0.4

    study = optuna.create_study(direction="maximize", study_name="lgb_opt")
    study.optimize(objective, n_trials=n_trials, show_progress_bar=True, n_jobs=4)

    print(f"\n  🏆 LightGBM Best: {study.best_value:.4f} (trial #{study.best_trial.number})")
    return study.best_params, study.best_value


def compute_pnl_score(pred_proba, y_true):
    """Score basé sur la capacité à filtrer les mauvais trades."""
    best_score = 0
    for thresh in np.arange(0.30, 0.75, 0.02):
        passed = pred_proba >= thresh
        if passed.sum() < max(3, len(pred_proba) * 0.10):
            continue
        tp = ((passed) & (y_true == 1)).sum()
        fp = ((passed) & (y_true == 0)).sum()
        fn = ((~passed) & (y_true == 1)).sum()
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        score = precision * 0.65 + recall * 0.35
        best_score = max(best_score, score)
    return best_score


def find_optimal_threshold(pred_proba, y_true, pnl_values=None):
    """Trouve le seuil optimal pour le profit."""
    best_threshold = 0.5
    best_score = -999

    for thresh in np.arange(0.25, 0.80, 0.01):
        passed = pred_proba >= thresh
        n_passed = passed.sum()
        if n_passed < max(3, len(pred_proba) * 0.10):
            continue

        tp = ((passed) & (y_true == 1)).sum()
        fp = ((passed) & (y_true == 0)).sum()
        precision = tp / max(tp + fp, 1)

        if pnl_values is not None:
            pnl_passed = pnl_values[passed].sum()
            pnl_all = pnl_values.sum()
            score = pnl_passed - pnl_all + precision * 10
        else:
            recall = tp / max(((~passed) & (y_true == 1)).sum() + tp, 1)
            score = precision * 0.65 + recall * 0.35

        if score > best_score:
            best_score = score
            best_threshold = thresh

    return best_threshold


# ═══════════════════════════════════════════════════════════════
#  PART 3: LSTM+Attention GPU Training (PyTorch)
# ═══════════════════════════════════════════════════════════════

def load_lstm_data():
    """Charge les séquences LSTM pré-calculées."""
    path = find_data_file("lstm_sequences.npz")
    if path is None:
        print("  ❌ lstm_sequences.npz introuvable")
        return None
    data = np.load(path)
    X = data["X"]          # (N, 60, 14)
    surge = data["surge_types"]  # (N,)
    y = data["y"]          # (N,)
    pnl = data["pnl"]      # (N,)
    print(f"  📂 Séquences LSTM: {X.shape[0]} × {X.shape[1]} × {X.shape[2]} ({path})")
    return X, surge, y, pnl


class TemporalAttention(torch.nn.Module):
    """Attention temporelle pour pondérer les timesteps."""
    def __init__(self, hidden_size):
        super().__init__()
        self.attention = torch.nn.Sequential(
            torch.nn.Linear(hidden_size, hidden_size // 2),
            torch.nn.Tanh(),
            torch.nn.Linear(hidden_size // 2, 1),
        )

    def forward(self, lstm_output):
        scores = self.attention(lstm_output).squeeze(-1)
        weights = torch.nn.functional.softmax(scores, dim=1)
        context = torch.bmm(weights.unsqueeze(1), lstm_output).squeeze(1)
        return context, weights


class SurgePredictor(torch.nn.Module):
    """LSTM bidirectionnel + Attention pour prédire la rentabilité d'un surge."""

    def __init__(self, n_features, hidden_size=128, num_layers=2, dropout=0.3,
                 n_surge_types=3, surge_embed_dim=8):
        super().__init__()
        self.n_features = n_features
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.input_proj = torch.nn.Sequential(
            torch.nn.Linear(n_features, hidden_size),
            torch.nn.LayerNorm(hidden_size),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout * 0.5),
        )

        self.lstm = torch.nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
        )

        self.attention = TemporalAttention(hidden_size * 2)
        self.surge_embedding = torch.nn.Embedding(n_surge_types + 1, surge_embed_dim)

        classifier_input = hidden_size * 2 + surge_embed_dim
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(classifier_input, hidden_size),
            torch.nn.LayerNorm(hidden_size),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_size, hidden_size // 2),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout * 0.5),
            torch.nn.Linear(hidden_size // 2, 1),
        )

        self._init_weights()

    def _init_weights(self):
        for name, param in self.named_parameters():
            if "weight" in name and param.dim() >= 2:
                torch.nn.init.xavier_uniform_(param)
            elif "bias" in name:
                torch.nn.init.zeros_(param)

    def forward(self, x, surge_type, return_attention=False):
        projected = self.input_proj(x)
        lstm_out, _ = self.lstm(projected)
        context, attn_weights = self.attention(lstm_out)
        surge_emb = self.surge_embedding(surge_type)
        combined = torch.cat([context, surge_emb], dim=1)
        logits = self.classifier(combined).squeeze(-1)
        prob = torch.sigmoid(logits)
        result = {"logits": logits, "probability": prob}
        if return_attention:
            result["attention_weights"] = attn_weights
        return result


class SurgePredictorExportWrapper(torch.nn.Module):
    """Wrapper pour export TorchScript (inference CPU serveur)."""
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x, surge_type):
        result = self.model(x, surge_type, return_attention=False)
        return result["probability"]


def train_lstm(X, surge_types, y, pnl, device, epochs=100, batch_size=512,
               hidden_size=128, num_layers=2, lr=1e-3, patience=15):
    """
    Entraîne le modèle LSTM+Attention sur GPU.
    Walk-forward split : 75% train, 25% val.
    """
    import torch
    import torch.nn as nn
    from torch.utils.data import TensorDataset, DataLoader

    n_features = X.shape[2]
    split_idx = int(len(X) * 0.75)

    X_train, X_val = X[:split_idx], X[split_idx:]
    s_train, s_val = surge_types[:split_idx], surge_types[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]
    pnl_val = pnl[split_idx:]

    print(f"  Train: {len(X_train)} | Val: {len(X_val)}")
    print(f"  Train positives: {y_train.sum():.0f}/{len(y_train)} ({y_train.mean()*100:.1f}%)")
    print(f"  Val positives: {y_val.sum():.0f}/{len(y_val)} ({y_val.mean()*100:.1f}%)")

    # Tensors
    X_train_t = torch.FloatTensor(X_train).to(device)
    s_train_t = torch.LongTensor(s_train).to(device)
    y_train_t = torch.FloatTensor(y_train).to(device)

    X_val_t = torch.FloatTensor(X_val).to(device)
    s_val_t = torch.LongTensor(s_val).to(device)
    y_val_t = torch.FloatTensor(y_val).to(device)

    train_ds = TensorDataset(X_train_t, s_train_t, y_train_t)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

    # Model
    model = SurgePredictor(
        n_features=n_features,
        hidden_size=hidden_size,
        num_layers=num_layers,
        dropout=0.3,
    ).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"  Model: {total_params:,} params ({trainable:,} trainable)")

    # Class-weighted loss
    pos_weight = torch.tensor([(len(y_train) - y_train.sum()) / max(y_train.sum(), 1)]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2)

    # Training loop
    best_auc = 0
    best_epoch = 0
    best_state = None
    history = []

    print(f"\n  {'Epoch':>5} | {'Train Loss':>10} | {'Val Loss':>8} | {'Val AUC':>7} | {'Val Prec':>8} | {'LR':>10} | {'Status'}")
    print(f"  {'─'*80}")

    for epoch in range(1, epochs + 1):
        # ── Train ──
        model.train()
        train_loss = 0
        n_batches = 0
        for X_b, s_b, y_b in train_loader:
            optimizer.zero_grad()
            out = model(X_b, s_b)
            loss = criterion(out["logits"], y_b)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
            n_batches += 1

        scheduler.step()
        train_loss /= max(n_batches, 1)

        # ── Validate ──
        model.eval()
        with torch.no_grad():
            val_out = model(X_val_t, s_val_t)
            val_loss = criterion(val_out["logits"], y_val_t).item()
            val_probs = val_out["probability"].cpu().numpy()

        auc = roc_auc_score(y_val, val_probs) if len(np.unique(y_val)) > 1 else 0.5

        # Precision at optimal threshold
        threshold = find_optimal_threshold(val_probs, y_val, pnl_val)
        passed = val_probs >= threshold
        prec = precision_score(y_val, passed.astype(int), zero_division=0)

        lr_current = optimizer.param_groups[0]["lr"]

        status = ""
        if auc > best_auc:
            best_auc = auc
            best_epoch = epoch
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            status = f"★ best (AUC {auc:.4f})"

        history.append({
            "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss,
            "val_auc": auc, "val_prec": float(prec), "threshold": float(threshold),
        })

        if epoch % 5 == 0 or epoch == 1 or status:
            print(f"  {epoch:5d} | {train_loss:10.4f} | {val_loss:8.4f} | {auc:7.4f} | {prec:8.4f} | {lr_current:10.6f} | {status}")

        # Early stopping
        if epoch - best_epoch >= patience:
            print(f"\n  ⏹️  Early stopping at epoch {epoch} (best was {best_epoch})")
            break

    # Load best model
    if best_state is not None:
        model.load_state_dict(best_state)
    model.eval()

    # Final evaluation
    with torch.no_grad():
        val_out = model(X_val_t, s_val_t)
        val_probs = val_out["probability"].cpu().numpy()

    final_auc = roc_auc_score(y_val, val_probs)
    threshold = find_optimal_threshold(val_probs, y_val, pnl_val)
    passed = val_probs >= threshold
    prec = precision_score(y_val, passed.astype(int), zero_division=0)
    recall = recall_score(y_val, passed.astype(int), zero_division=0)

    # PnL analysis
    if pnl_val is not None:
        pnl_all = pnl_val.sum()
        pnl_filtered = pnl_val[passed].sum()
        delta_pnl = pnl_filtered - pnl_all
    else:
        pnl_all = pnl_filtered = delta_pnl = 0

    print(f"\n{'═'*60}")
    print(f"  📊 LSTM Final Results (Best epoch {best_epoch})")
    print(f"{'═'*60}")
    print(f"  AUC:            {final_auc:.4f}")
    print(f"  Precision:      {prec:.4f}")
    print(f"  Recall:         {recall:.4f}")
    print(f"  Threshold:      {threshold:.2f}")
    print(f"  Pass rate:      {passed.sum()}/{len(y_val)} ({passed.mean()*100:.0f}%)")
    print(f"  PnL sans filtre:  {pnl_all:+.2f}%")
    print(f"  PnL avec filtre:  {pnl_filtered:+.2f}%")
    print(f"  Δ PnL:            {delta_pnl:+.2f}%")

    # ── Export models ──
    # 1. TorchScript pour inference CPU
    model_cpu = SurgePredictor(
        n_features=n_features, hidden_size=hidden_size,
        num_layers=num_layers, dropout=0,
    )
    model_cpu.load_state_dict(model.cpu().state_dict())
    model_cpu.eval()

    wrapper = SurgePredictorExportWrapper(model_cpu)
    dummy_x = torch.randn(1, X.shape[1], X.shape[2])
    dummy_s = torch.zeros(1, dtype=torch.long)
    scripted = torch.jit.trace(wrapper, (dummy_x, dummy_s))

    ts_path = MODELS_DIR / "surge_predictor_gpu.pt"
    scripted.save(str(ts_path))
    print(f"\n  💾 TorchScript: {ts_path} ({ts_path.stat().st_size / 1e6:.1f} MB)")

    # 2. Checkpoint complet
    ckpt_path = MODELS_DIR / "surge_predictor_gpu_checkpoint.pth"
    torch.save({
        "model_state_dict": model_cpu.state_dict(),
        "n_features": n_features,
        "hidden_size": hidden_size,
        "num_layers": num_layers,
        "threshold": float(threshold),
        "seq_len": X.shape[1],
        "auc": float(final_auc),
        "best_epoch": best_epoch,
    }, str(ckpt_path))
    print(f"  💾 Checkpoint: {ckpt_path} ({ckpt_path.stat().st_size / 1e6:.1f} MB)")

    # 3. Métadonnées
    meta = {
        "trained_at": datetime.now(timezone.utc).isoformat(),
        "device": str(device),
        "n_features": n_features,
        "hidden_size": hidden_size,
        "num_layers": num_layers,
        "seq_len": int(X.shape[1]),
        "threshold": float(threshold),
        "auc": float(final_auc),
        "precision": float(prec),
        "recall": float(recall),
        "best_epoch": best_epoch,
        "total_epochs": len(history),
        "train_samples": len(X_train),
        "val_samples": len(X_val),
        "pnl_improvement": float(delta_pnl),
    }
    meta_path = MODELS_DIR / "surge_predictor_gpu_meta.json"
    meta_path.write_text(json.dumps(meta, indent=2))
    print(f"  💾 Metadata: {meta_path}")

    return {
        "auc": final_auc, "precision": float(prec), "recall": float(recall),
        "threshold": float(threshold), "best_epoch": best_epoch,
        "pnl_improvement": float(delta_pnl),
        "history": history,
    }


# ═══════════════════════════════════════════════════════════════
#  Walk-Forward Backtest (XGB + LGB)
# ═══════════════════════════════════════════════════════════════

def walk_forward_backtest(dataset, best_lgb_params, best_xgb_params, device, n_splits=5):
    """Walk-forward avec les meilleurs paramètres tabulaires."""
    dataset = dataset.sort_values("entry_time").reset_index(drop=True)
    X, y, feature_names = prepare_features(dataset)
    pnl_values = dataset["target_pnl_pct"].values if "target_pnl_pct" in dataset.columns else None

    tscv = TimeSeriesSplit(n_splits=n_splits)
    results = []

    print(f"\n{'═'*60}")
    print(f"  📈 Walk-Forward Backtest ({n_splits} folds)")
    print(f"{'═'*60}")

    for fold, (train_idx, val_idx) in enumerate(tscv.split(X)):
        X_tr, X_va = X[train_idx], X[val_idx]
        y_tr, y_va = y[train_idx], y[val_idx]
        pnl_va = pnl_values[val_idx] if pnl_values is not None else None

        pos_weight = (len(y_tr) - y_tr.sum()) / max(y_tr.sum(), 1)

        # LightGBM
        lgb_params = {
            **best_lgb_params,
            "objective": "binary", "metric": "auc",
            "verbosity": -1, "scale_pos_weight": pos_weight,
        }
        train_set = lgb.Dataset(X_tr, y_tr)
        val_set = lgb.Dataset(X_va, y_va, reference=train_set)
        lgb_model = lgb.train(
            lgb_params, train_set, num_boost_round=1000,
            valid_sets=[val_set],
            callbacks=[lgb.early_stopping(30), lgb.log_evaluation(0)],
        )
        lgb_pred = lgb_model.predict(X_va)

        # XGBoost GPU
        xgb_params = {
            **best_xgb_params,
            "objective": "binary:logistic", "eval_metric": "auc",
            "device": device, "tree_method": "hist",
            "verbosity": 0, "scale_pos_weight": pos_weight,
        }
        for k in ["num_boost_round"]:
            xgb_params.pop(k, None)
        dtrain = xgb.DMatrix(X_tr, label=y_tr)
        dval = xgb.DMatrix(X_va, label=y_va)
        xgb_model = xgb.train(
            xgb_params, dtrain,
            num_boost_round=best_xgb_params.get("num_boost_round", 500),
            evals=[(dval, "val")],
            early_stopping_rounds=30, verbose_eval=False,
        )
        xgb_pred = xgb_model.predict(dval)

        pred = 0.5 * lgb_pred + 0.5 * xgb_pred
        auc = roc_auc_score(y_va, pred) if len(np.unique(y_va)) > 1 else 0.5
        threshold = find_optimal_threshold(pred, y_va, pnl_va)
        passed = pred >= threshold
        n_passed = passed.sum()
        prec = precision_score(y_va, passed.astype(int), zero_division=0)

        if pnl_va is not None:
            pnl_all = pnl_va.sum()
            pnl_filtered = pnl_va[passed].sum()
            delta = pnl_filtered - pnl_all
        else:
            pnl_all = pnl_filtered = delta = 0

        status = "✅" if delta >= 0 else "❌"
        print(f"  Fold {fold+1} | {len(y_va)} trades → {n_passed} passés ({n_passed/len(y_va)*100:.0f}%) | "
              f"AUC: {auc:.3f} | Prec: {prec:.3f} | Thresh: {threshold:.2f} | "
              f"PnL: {pnl_filtered:+.1f}% vs {pnl_all:+.1f}% Δ={delta:+.1f}% {status}")

        results.append({
            "fold": fold + 1, "n_val": len(y_va), "n_passed": int(n_passed),
            "auc": float(auc), "precision": float(prec), "threshold": float(threshold),
            "pnl_filtered": float(pnl_filtered), "pnl_all": float(pnl_all), "delta": float(delta),
        })

    avg_auc = np.mean([r["auc"] for r in results])
    total_delta = sum(r["delta"] for r in results)
    positive_folds = sum(1 for r in results if r["delta"] >= 0)

    print(f"\n  📊 Summary: Avg AUC {avg_auc:.3f} | {positive_folds}/{len(results)} positive folds | Δ PnL: {total_delta:+.1f}%")
    return results


# ═══════════════════════════════════════════════════════════════
#  Main
# ═══════════════════════════════════════════════════════════════

def main():
    import torch

    parser = argparse.ArgumentParser(description="Full GPU Training Pipeline")
    parser.add_argument("--trials", type=int, default=500, help="Optuna trials per model")
    parser.add_argument("--epochs", type=int, default=100, help="LSTM epochs")
    parser.add_argument("--batch-size", type=int, default=512, help="LSTM batch size")
    parser.add_argument("--hidden-size", type=int, default=128, help="LSTM hidden size")
    parser.add_argument("--num-layers", type=int, default=2, help="LSTM layers")
    parser.add_argument("--lr", type=float, default=1e-3, help="LSTM learning rate")
    parser.add_argument("--patience", type=int, default=15, help="Early stopping patience")
    parser.add_argument("--lstm-only", action="store_true", help="LSTM seulement")
    parser.add_argument("--tabular-only", action="store_true", help="XGB+LGB seulement")
    parser.add_argument("--folds", type=int, default=5, help="Walk-forward folds")
    args = parser.parse_args()

    print(f"\n{'═'*60}")
    print(f"  🚀 SPY Optimizer — Full GPU Training Pipeline")
    print(f"  📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'═'*60}")

    device_str, has_gpu = detect_gpu()
    device = torch.device(device_str)
    xgb_device = "cuda" if has_gpu else "cpu"

    total_start = time.time()
    all_results = {"device": device_str, "has_gpu": has_gpu, "started_at": datetime.now(timezone.utc).isoformat()}

    # ═══════════════════════════════════════════════
    #  TABULAR OPTIMIZATION (XGBoost GPU + LightGBM)
    # ═══════════════════════════════════════════════
    if not args.lstm_only:
        print(f"\n{'═'*60}")
        print(f"  PART 1: Tabular Hyperparameter Optimization")
        print(f"{'═'*60}")

        dataset = load_tabular_dataset()
        X, y, feature_names = prepare_features(dataset)

        split_idx = int(len(X) * 0.75)
        X_train, X_val = X[:split_idx], X[split_idx:]
        y_train, y_val = y[:split_idx], y[split_idx:]

        pos = y.sum()
        print(f"  📐 Dataset: {len(y)} samples ({pos:.0f} profitable, {len(y)-pos:.0f} perdants)")
        print(f"  Features: {len(feature_names)}")
        print(f"  Train: {len(y_train)} | Val: {len(y_val)}")

        # ── XGBoost GPU ──
        print(f"\n{'─'*60}")
        print(f"  🔍 XGBoost {'GPU' if has_gpu else 'CPU'} Optimization ({args.trials} trials)")
        print(f"{'─'*60}")

        t0 = time.time()
        best_xgb_params, xgb_score = optimize_xgb_gpu(X_train, y_train, X_val, y_val, args.trials, xgb_device)
        xgb_time = time.time() - t0
        print(f"  ⏱️  XGBoost: {xgb_time:.0f}s ({args.trials/xgb_time:.1f} trials/s)")

        # ── LightGBM CPU ──
        print(f"\n{'─'*60}")
        print(f"  🔍 LightGBM CPU Optimization ({args.trials} trials, 4 jobs)")
        print(f"{'─'*60}")

        t0 = time.time()
        best_lgb_params, lgb_score = optimize_lgb(X_train, y_train, X_val, y_val, args.trials)
        lgb_time = time.time() - t0
        print(f"  ⏱️  LightGBM: {lgb_time:.0f}s ({args.trials/lgb_time:.1f} trials/s)")

        # ── Train final ensemble ──
        print(f"\n{'─'*60}")
        print(f"  🏋️ Training final ensemble")
        print(f"{'─'*60}")

        pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

        lgb_params = {**best_lgb_params, "objective": "binary", "metric": "auc",
                      "verbosity": -1, "scale_pos_weight": pos_weight}
        train_set = lgb.Dataset(X_train, y_train)
        val_set = lgb.Dataset(X_val, y_val, reference=train_set)
        lgb_model = lgb.train(lgb_params, train_set, num_boost_round=1000,
                              valid_sets=[val_set],
                              callbacks=[lgb.early_stopping(50), lgb.log_evaluation(0)])
        lgb_pred = lgb_model.predict(X_val)

        xgb_params_final = {**best_xgb_params, "objective": "binary:logistic", "eval_metric": "auc",
                            "device": xgb_device, "tree_method": "hist",
                            "verbosity": 0, "scale_pos_weight": pos_weight}
        n_boost = xgb_params_final.pop("num_boost_round", 500)
        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)
        xgb_model = xgb.train(xgb_params_final, dtrain, num_boost_round=n_boost,
                              evals=[(dval, "val")], early_stopping_rounds=50, verbose_eval=False)
        xgb_pred = xgb_model.predict(dval)

        ensemble_pred = 0.5 * lgb_pred + 0.5 * xgb_pred
        ensemble_auc = roc_auc_score(y_val, ensemble_pred) if len(np.unique(y_val)) > 1 else 0.5

        pnl_values = dataset.sort_values("entry_time")["target_pnl_pct"].values[split_idx:] \
            if "target_pnl_pct" in dataset.columns else None
        threshold = find_optimal_threshold(ensemble_pred, y_val, pnl_values)

        passed = ensemble_pred >= threshold
        prec = precision_score(y_val, passed.astype(int), zero_division=0)
        rec = recall_score(y_val, passed.astype(int), zero_division=0)

        if pnl_values is not None:
            pnl_all = pnl_values.sum()
            pnl_filtered = pnl_values[passed].sum()
            delta = pnl_filtered - pnl_all
        else:
            pnl_all = pnl_filtered = delta = 0

        print(f"\n  📊 Tabular Ensemble Results:")
        print(f"  AUC: {ensemble_auc:.4f} | Precision: {prec:.4f} | Recall: {rec:.4f}")
        print(f"  Threshold: {threshold:.2f} | Pass: {passed.sum()}/{len(y_val)} ({passed.mean()*100:.0f}%)")
        if pnl_values is not None:
            print(f"  PnL: {pnl_all:+.2f}% → {pnl_filtered:+.2f}% (Δ {delta:+.2f}%)")

        # Feature importance
        lgb_imp = lgb_model.feature_importance(importance_type="gain")
        print(f"\n  🔑 Top 10 Features:")
        sorted_idx = np.argsort(lgb_imp)[::-1]
        for rank, idx in enumerate(sorted_idx[:10]):
            name = feature_names[idx] if idx < len(feature_names) else f"feat_{idx}"
            bar = "█" * int(lgb_imp[idx] / lgb_imp[sorted_idx[0]] * 20)
            print(f"    {rank+1:2d}. {name:30s} {bar}")

        # Walk-forward
        wf_results = walk_forward_backtest(dataset, best_lgb_params, best_xgb_params, xgb_device, args.folds)

        # Save params for server
        deploy_params = {
            "lgb_params": best_lgb_params,
            "xgb_params": {k: v for k, v in best_xgb_params.items() if k != "num_boost_round"},
            "xgb_num_boost_round": best_xgb_params.get("num_boost_round", 500),
            "threshold": float(threshold),
            "optimized_at": datetime.now(timezone.utc).isoformat(),
            "ensemble_auc": float(ensemble_auc),
            "device_used": xgb_device,
            "n_trials": args.trials,
        }
        deploy_path = MODELS_DIR / "optimized_params.json"
        deploy_path.write_text(json.dumps(deploy_params, indent=2))
        print(f"\n  💾 Params: {deploy_path}")

        all_results["tabular"] = {
            "xgb_params": best_xgb_params, "lgb_params": best_lgb_params,
            "ensemble_auc": float(ensemble_auc), "threshold": float(threshold),
            "precision": float(prec), "recall": float(rec),
            "pnl_improvement": float(delta),
            "xgb_time": xgb_time, "lgb_time": lgb_time,
            "walk_forward": wf_results,
        }

    # ═══════════════════════════════════════════════
    #  LSTM+ATTENTION GPU TRAINING
    # ═══════════════════════════════════════════════
    if not args.tabular_only:
        print(f"\n{'═'*60}")
        print(f"  PART 2: LSTM+Attention GPU Training")
        print(f"{'═'*60}")

        lstm_data = load_lstm_data()
        if lstm_data is not None:
            X_seq, surge_types, y_seq, pnl_seq = lstm_data

            lstm_results = train_lstm(
                X_seq, surge_types, y_seq, pnl_seq,
                device=device,
                epochs=args.epochs,
                batch_size=args.batch_size,
                hidden_size=args.hidden_size,
                num_layers=args.num_layers,
                lr=args.lr,
                patience=args.patience,
            )
            all_results["lstm"] = lstm_results
        else:
            print("  ⚠️  Pas de données LSTM — skipping")

    # ═══════════════════════════════════════════════
    #  SAVE FULL RESULTS
    # ═══════════════════════════════════════════════
    total_time = time.time() - total_start
    all_results["total_time_seconds"] = total_time
    all_results["finished_at"] = datetime.now(timezone.utc).isoformat()

    results_path = SCRIPT_DIR / "gpu_optimization_results.json"
    # Remove non-serializable history details
    results_save = {k: v for k, v in all_results.items()}
    if "lstm" in results_save and "history" in results_save["lstm"]:
        results_save["lstm"] = {k: v for k, v in results_save["lstm"].items() if k != "history"}
    results_path.write_text(json.dumps(results_save, indent=2, default=str))

    print(f"\n{'═'*60}")
    print(f"  ✅ DONE — Total time: {total_time/60:.1f} min")
    print(f"{'═'*60}")
    print(f"  📦 Results:   {results_path}")
    print(f"  📦 Params:    {MODELS_DIR / 'optimized_params.json'}")
    if not args.tabular_only:
        print(f"  📦 LSTM:      {MODELS_DIR / 'surge_predictor_gpu.pt'}")
        print(f"  📦 LSTM meta: {MODELS_DIR / 'surge_predictor_gpu_meta.json'}")
    print(f"\n  🚀 Pour déployer sur le serveur:")
    print(f"     scp models/optimized_params.json user@server:~/crypto_trading_bot/spy_optimizer/models/")
    if not args.tabular_only:
        print(f"     scp models/surge_predictor_gpu.pt user@server:~/crypto_trading_bot/spy_optimizer/models/")
        print(f"     scp models/surge_predictor_gpu_meta.json user@server:~/crypto_trading_bot/spy_optimizer/models/")
        print(f"     scp models/surge_predictor_gpu_checkpoint.pth user@server:~/crypto_trading_bot/spy_optimizer/models/")


if __name__ == "__main__":
    main()
