#!/usr/bin/env python3
"""
SPY Optimizer — GPU Training Script
Entraîne le modèle LSTM+Attention sur GPU (RTX 3060+).

Usage (sur le PC avec GPU):
  python train_gpu.py                          # Train complet
  python train_gpu.py --epochs 100             # Plus d'époques
  python train_gpu.py --seq-len 90             # Séquences de 90 minutes
  python train_gpu.py --export-only            # Re-exporter un modèle existant
  python train_gpu.py --backtest               # Walk-forward backtest

Le modèle entraîné est exporté en TorchScript (.pt) pour inference CPU sur le serveur.
"""
import argparse
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)

from deep_model import (
    SurgePredictor, SurgePredictorExportWrapper,
    build_sequence_dataset, N_SEQUENCE_FEATURES, SURGE_TYPE_MAP,
)

# ─── Paths ───
PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
KLINES_DIR = DATA_DIR / "klines_1m"
MODELS_DIR = PROJECT_DIR / "models"
# Check both locations: same dir (PC layout) or parent dir (server layout)
_hist_local = PROJECT_DIR / "espion_history.json"
_hist_parent = PROJECT_DIR.parent / "espion_history.json"
HISTORY_FILE = _hist_local if _hist_local.exists() else _hist_parent


# ─── Dataset ───

class SurgeDataset(Dataset):
    """PyTorch Dataset pour les séquences de surge."""

    def __init__(self, X: np.ndarray, surge_types: np.ndarray, y: np.ndarray):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.surge_types = torch.tensor(surge_types, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.surge_types[idx], self.y[idx]


# ─── Training ───

class Trainer:
    """Entraîneur avec early stopping, learning rate scheduling, et mixed precision."""

    def __init__(
        self,
        model: SurgePredictor,
        device: torch.device,
        lr: float = 1e-3,
        weight_decay: float = 1e-4,
    ):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.AdamW(
            model.parameters(), lr=lr, weight_decay=weight_decay
        )
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer, T_0=10, T_mult=2
        )
        self.scaler = torch.amp.GradScaler("cuda") if device.type == "cuda" else None
        self.best_val_auc = 0
        self.best_state = None
        self.patience_counter = 0

    def train_epoch(self, loader: DataLoader, pos_weight: float) -> dict:
        """Entraîne une époque."""
        self.model.train()
        total_loss = 0
        all_preds = []
        all_targets = []

        criterion = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor([pos_weight], device=self.device)
        )

        for X, surge, y in loader:
            X = X.to(self.device)
            surge = surge.to(self.device)
            y = y.to(self.device)

            self.optimizer.zero_grad()

            if self.scaler:
                with torch.amp.autocast("cuda"):
                    output = self.model(X, surge)
                    loss = criterion(output["logits"], y)
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                output = self.model(X, surge)
                loss = criterion(output["logits"], y)
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()

            total_loss += loss.item() * len(y)
            all_preds.extend(output["probability"].detach().cpu().numpy())
            all_targets.extend(y.cpu().numpy())

        self.scheduler.step()

        preds = np.array(all_preds)
        targets = np.array(all_targets)
        auc = roc_auc_score(targets, preds) if len(np.unique(targets)) > 1 else 0.5

        return {
            "loss": total_loss / len(loader.dataset),
            "auc": auc,
        }

    @torch.no_grad()
    def evaluate(self, loader: DataLoader) -> dict:
        """Évalue sur le validation set."""
        self.model.eval()
        all_preds = []
        all_targets = []

        for X, surge, y in loader:
            X = X.to(self.device)
            surge = surge.to(self.device)

            if self.scaler:
                with torch.amp.autocast("cuda"):
                    output = self.model(X, surge)
            else:
                output = self.model(X, surge)

            all_preds.extend(output["probability"].cpu().numpy())
            all_targets.extend(y.numpy())

        preds = np.array(all_preds)
        targets = np.array(all_targets)

        auc = roc_auc_score(targets, preds) if len(np.unique(targets)) > 1 else 0.5

        # Optimize threshold for max profit
        best_threshold = 0.5
        best_score = -999
        for thresh in np.arange(0.30, 0.70, 0.02):
            pred_buy = preds >= thresh
            if pred_buy.sum() < max(3, len(preds) * 0.15):
                continue
            # Score = precision-weighted recall
            tp = ((pred_buy) & (targets == 1)).sum()
            fp = ((pred_buy) & (targets == 0)).sum()
            fn = ((~pred_buy) & (targets == 1)).sum()
            precision = tp / max(tp + fp, 1)
            recall = tp / max(tp + fn, 1)
            score = precision * 0.7 + recall * 0.3
            if score > best_score:
                best_score = score
                best_threshold = thresh

        pred_binary = (preds >= best_threshold).astype(int)

        return {
            "auc": auc,
            "threshold": float(best_threshold),
            "accuracy": accuracy_score(targets, pred_binary),
            "precision": precision_score(targets, pred_binary, zero_division=0),
            "recall": recall_score(targets, pred_binary, zero_division=0),
            "f1": f1_score(targets, pred_binary, zero_division=0),
            "pass_rate": pred_binary.mean(),
            "predictions": preds,
            "targets": targets,
        }

    def check_early_stopping(self, val_auc: float, patience: int = 15) -> bool:
        """Retourne True si on doit arrêter."""
        if val_auc > self.best_val_auc:
            self.best_val_auc = val_auc
            self.best_state = {k: v.cpu().clone() for k, v in self.model.state_dict().items()}
            self.patience_counter = 0
            return False
        else:
            self.patience_counter += 1
            return self.patience_counter >= patience

    def restore_best(self):
        """Restaure les meilleurs poids."""
        if self.best_state:
            self.model.load_state_dict(self.best_state)


# ─── Data Loading ───

def load_trades() -> list[dict]:
    """Charge les trades avec surge_type."""
    with open(HISTORY_FILE) as f:
        all_trades = json.load(f)
    trades = [t for t in all_trades if t.get("surge_type") and t.get("entry_time")]
    print(f"  📂 {len(trades)} trades avec surge_type")
    return trades


def prepare_data(
    trades: list[dict],
    seq_len: int = 60,
    val_ratio: float = 0.25,
    batch_size: int = 64,
) -> tuple[DataLoader, DataLoader, float]:
    """Prépare les DataLoaders train/val avec walk-forward split."""

    print(f"\n  📐 Building sequence dataset (seq_len={seq_len})...")
    X, surge_types, y = build_sequence_dataset(trades, str(KLINES_DIR), seq_len)

    if len(X) < 30:
        raise ValueError(f"Pas assez de données: {len(X)} séquences (min: 30)")

    # Walk-forward split: train sur les plus anciens, val sur les plus récents
    split_idx = int(len(X) * (1 - val_ratio))
    X_train, X_val = X[:split_idx], X[split_idx:]
    s_train, s_val = surge_types[:split_idx], surge_types[split_idx:]
    y_train, y_val = y[:split_idx], y[split_idx:]

    pos_weight = (len(y_train) - y_train.sum()) / max(y_train.sum(), 1)

    print(f"  Train: {len(y_train)} ({y_train.sum():.0f} positifs)")
    print(f"  Val:   {len(y_val)} ({y_val.sum():.0f} positifs)")
    print(f"  Pos weight: {pos_weight:.2f}")

    # Weighted sampler pour équilibrer les classes
    sample_weights = np.where(y_train == 1, pos_weight, 1.0)
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)

    train_ds = SurgeDataset(X_train, s_train, y_train)
    val_ds = SurgeDataset(X_val, s_val, y_val)

    train_loader = DataLoader(
        train_ds, batch_size=batch_size, sampler=sampler,
        num_workers=2, pin_memory=True, drop_last=True,
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size * 2, shuffle=False,
        num_workers=2, pin_memory=True,
    )

    return train_loader, val_loader, pos_weight


# ─── Export ───

def export_model(model: SurgePredictor, threshold: float, seq_len: int, metrics: dict):
    """Exporte le modèle en TorchScript + metadata pour le serveur."""
    MODELS_DIR.mkdir(parents=True, exist_ok=True)

    model.eval()
    model_cpu = model.cpu()

    # TorchScript export
    wrapper = SurgePredictorExportWrapper(model_cpu)
    wrapper.eval()

    example_x = torch.randn(1, seq_len, N_SEQUENCE_FEATURES)
    example_surge = torch.tensor([0], dtype=torch.long)

    try:
        traced = torch.jit.trace(wrapper, (example_x, example_surge))
        ts_path = MODELS_DIR / "surge_predictor_gpu.pt"
        traced.save(str(ts_path))
        print(f"  💾 TorchScript model: {ts_path} ({os.path.getsize(ts_path)/1e6:.1f} MB)")
    except Exception as e:
        print(f"  ⚠️  TorchScript trace failed: {e}")
        print(f"  💾 Saving as state_dict instead...")
        torch.save(model_cpu.state_dict(), MODELS_DIR / "surge_predictor_gpu_state.pth")

    # Also save state dict (more flexible)
    torch.save({
        "model_state_dict": model_cpu.state_dict(),
        "threshold": threshold,
        "seq_len": seq_len,
        "n_features": N_SEQUENCE_FEATURES,
        "hidden_size": model_cpu.hidden_size,
        "num_layers": model_cpu.num_layers,
        "metrics": metrics,
        "trained_at": datetime.now(timezone.utc).isoformat(),
        "surge_type_map": SURGE_TYPE_MAP,
    }, MODELS_DIR / "surge_predictor_gpu_checkpoint.pth")

    # Metadata JSON pour le serveur
    meta = {
        "model_type": "lstm_attention",
        "trained_at": datetime.now(timezone.utc).isoformat(),
        "seq_len": seq_len,
        "n_features": N_SEQUENCE_FEATURES,
        "hidden_size": model_cpu.hidden_size,
        "num_layers": model_cpu.num_layers,
        "threshold": threshold,
        "metrics": {k: float(v) if isinstance(v, (np.floating, float)) else v
                    for k, v in metrics.items()
                    if k not in ("predictions", "targets")},
        "surge_type_map": SURGE_TYPE_MAP,
        "device_trained": "cuda",
        "device_inference": "cpu",
    }
    meta_path = MODELS_DIR / "surge_predictor_gpu_meta.json"
    meta_path.write_text(json.dumps(meta, indent=2))
    print(f"  📋 Metadata: {meta_path}")


# ─── Walk-Forward Backtest ───

def walk_forward_backtest(
    trades: list[dict],
    seq_len: int = 60,
    train_days: int = 7,
    val_days: int = 1,
    step_days: int = 1,
    epochs: int = 50,
    device: torch.device = None,
) -> dict:
    """Walk-forward backtest avec re-training par fenêtre."""
    import pandas as pd

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Sort trades by date
    trades_sorted = sorted(trades, key=lambda t: t.get("entry_time", ""))
    dates = sorted(set(t["entry_time"][:10] for t in trades_sorted))

    if len(dates) < train_days + val_days:
        print(f"  ⚠️  Pas assez de jours ({len(dates)}) pour walk-forward")
        return {}

    print(f"\n{'═'*60}")
    print(f"  📈 Walk-Forward Backtest (LSTM+Attention)")
    print(f"{'═'*60}")
    print(f"  Train: {train_days}d | Val: {val_days}d | Step: {step_days}d")
    print(f"  Dates: {dates[0]} → {dates[-1]}")
    print(f"  Epochs/window: {epochs}")

    results = []
    start_idx = train_days

    while start_idx + val_days <= len(dates):
        train_dates = set(dates[max(0, start_idx - train_days):start_idx])
        val_dates = set(dates[start_idx:start_idx + val_days])

        train_trades = [t for t in trades_sorted if t["entry_time"][:10] in train_dates]
        val_trades = [t for t in trades_sorted if t["entry_time"][:10] in val_dates]

        if len(val_trades) == 0:
            start_idx += step_days
            continue

        val_date = sorted(val_dates)[0]
        print(f"\n  Window: train {min(train_dates)}→{max(train_dates)} | val {val_date}")

        # Build sequences
        X_tr, s_tr, y_tr = build_sequence_dataset(train_trades, str(KLINES_DIR), seq_len)
        X_va, s_va, y_va = build_sequence_dataset(val_trades, str(KLINES_DIR), seq_len)

        if len(X_tr) < 20 or len(X_va) == 0:
            print(f"    ⚠️  Skipping (train={len(X_tr)}, val={len(X_va)})")
            start_idx += step_days
            continue

        # Quick train
        model = SurgePredictor(
            n_features=N_SEQUENCE_FEATURES,
            hidden_size=64,  # Smaller for speed
            num_layers=1,
            dropout=0.2,
        )
        trainer = Trainer(model, device, lr=2e-3)

        pos_weight = (len(y_tr) - y_tr.sum()) / max(y_tr.sum(), 1)
        train_ds = SurgeDataset(X_tr, s_tr, y_tr)
        val_ds = SurgeDataset(X_va, s_va, y_va)
        train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_ds, batch_size=64, shuffle=False)

        for epoch in range(epochs):
            trainer.train_epoch(train_loader, pos_weight)

        val_metrics = trainer.evaluate(val_loader)
        preds = val_metrics["predictions"]
        threshold = val_metrics["threshold"]

        # Compute PnL comparison
        pnl_all = sum(float(t.get("pnl_pct", 0)) for t in val_trades if t["entry_time"][:10] in val_dates)
        passed_indices = [i for i, p in enumerate(preds) if p >= threshold]
        # Re-match trades to predictions (same order)
        matched_val = []
        skip_count = 0
        for t in val_trades:
            seq = build_sequence_features(
                pd.read_parquet(Path(KLINES_DIR) / f"{t['symbol']}.parquet")
                if (Path(KLINES_DIR) / f"{t['symbol']}.parquet").exists()
                else pd.read_parquet(Path(KLINES_DIR) / f"{t['symbol'].replace('USDC','USDT')}.parquet"),
                int(pd.Timestamp(t["entry_time"]).timestamp() * 1000),
                seq_len,
            )
            if seq is not None:
                matched_val.append(t)

        pnl_filtered = sum(
            float(matched_val[i].get("pnl_pct", 0))
            for i in passed_indices
            if i < len(matched_val)
        )

        delta = pnl_filtered - pnl_all
        status = "✅" if delta >= 0 else "❌"
        n_passed = len(passed_indices)

        print(f"    {val_date} | {len(val_trades)} trades → {n_passed} passés | "
              f"AUC: {val_metrics['auc']:.3f} | PnL: {pnl_filtered:+.1f}% vs {pnl_all:+.1f}% | "
              f"Δ={delta:+.1f}% {status}")

        results.append({
            "val_date": val_date,
            "n_trades": len(val_trades),
            "n_passed": n_passed,
            "auc": val_metrics["auc"],
            "pnl_filtered": pnl_filtered,
            "pnl_all": pnl_all,
            "delta": delta,
            "threshold": threshold,
        })

        start_idx += step_days

    # Summary
    if results:
        total_delta = sum(r["delta"] for r in results)
        positive = sum(1 for r in results if r["delta"] >= 0)
        avg_auc = np.mean([r["auc"] for r in results])
        print(f"\n{'═'*60}")
        print(f"  📊 Walk-Forward Results (LSTM)")
        print(f"{'═'*60}")
        print(f"  Windows: {len(results)} | Positive: {positive}/{len(results)}")
        print(f"  Avg AUC: {avg_auc:.3f}")
        print(f"  Total Δ PnL: {total_delta:+.1f}%")

    return {"windows": results}


# ─── Main ───

def main():
    parser = argparse.ArgumentParser(description="GPU Training for Surge Predictor")
    parser.add_argument("--epochs", type=int, default=80, help="Training epochs")
    parser.add_argument("--batch-size", type=int, default=64, help="Batch size")
    parser.add_argument("--seq-len", type=int, default=60, help="Sequence length (minutes)")
    parser.add_argument("--hidden-size", type=int, default=128, help="LSTM hidden size")
    parser.add_argument("--num-layers", type=int, default=2, help="LSTM layers")
    parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--patience", type=int, default=15, help="Early stopping patience")
    parser.add_argument("--export-only", action="store_true", help="Just re-export existing model")
    parser.add_argument("--backtest", action="store_true", help="Run walk-forward backtest")
    parser.add_argument("--cpu", action="store_true", help="Force CPU training")
    args = parser.parse_args()

    # Device
    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")
        if not args.cpu:
            print("  ⚠️  CUDA not available, falling back to CPU")
    else:
        device = torch.device("cuda")
        gpu_name = torch.cuda.get_device_name(0)
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"  🎮 GPU: {gpu_name} ({gpu_mem:.1f} GB)")

    print(f"  📍 Device: {device}")

    if args.backtest:
        trades = load_trades()
        results = walk_forward_backtest(
            trades, seq_len=args.seq_len,
            epochs=args.epochs // 2,  # Fewer epochs per window
            device=device,
        )
        # Save results
        results_path = DATA_DIR / "backtest_results_gpu.json"
        results_path.write_text(json.dumps(results, indent=2, default=str))
        print(f"\n  💾 Results: {results_path}")
        return

    # Load data
    trades = load_trades()
    train_loader, val_loader, pos_weight = prepare_data(
        trades, seq_len=args.seq_len, batch_size=args.batch_size,
    )

    # Create model
    model = SurgePredictor(
        n_features=N_SEQUENCE_FEATURES,
        hidden_size=args.hidden_size,
        num_layers=args.num_layers,
        dropout=args.dropout,
    )

    n_params = sum(p.numel() for p in model.parameters())
    print(f"\n  🧠 Model: {n_params:,} parameters")
    print(f"  Architecture: LSTM({args.hidden_size}) × {args.num_layers} layers + Attention")

    if args.export_only:
        ckpt_path = MODELS_DIR / "surge_predictor_gpu_checkpoint.pth"
        if ckpt_path.exists():
            ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
            model.load_state_dict(ckpt["model_state_dict"])
            export_model(model, ckpt["threshold"], ckpt["seq_len"], ckpt["metrics"])
        else:
            print("  ❌ No checkpoint found to export")
        return

    # Train
    trainer = Trainer(model, device, lr=args.lr)

    print(f"\n{'═'*60}")
    print(f"  🏋️ Training — {args.epochs} epochs")
    print(f"{'═'*60}")

    start_time = time.time()

    for epoch in range(1, args.epochs + 1):
        t0 = time.time()
        train_metrics = trainer.train_epoch(train_loader, pos_weight)
        val_metrics = trainer.evaluate(val_loader)
        elapsed = time.time() - t0

        if epoch % 5 == 0 or epoch <= 3:
            lr = trainer.optimizer.param_groups[0]["lr"]
            print(
                f"  Epoch {epoch:3d}/{args.epochs} | "
                f"Loss: {train_metrics['loss']:.4f} | "
                f"Train AUC: {train_metrics['auc']:.3f} | "
                f"Val AUC: {val_metrics['auc']:.3f} | "
                f"Val Prec: {val_metrics['precision']:.3f} | "
                f"LR: {lr:.1e} | "
                f"{elapsed:.1f}s"
            )

        should_stop = trainer.check_early_stopping(val_metrics["auc"], args.patience)
        if should_stop:
            print(f"\n  ⏹️  Early stopping at epoch {epoch} (best AUC: {trainer.best_val_auc:.3f})")
            break

    # Restore best model
    trainer.restore_best()
    total_time = time.time() - start_time

    # Final evaluation
    final_metrics = trainer.evaluate(val_loader)

    print(f"\n{'═'*60}")
    print(f"  📊 Final Results")
    print(f"{'═'*60}")
    print(f"  AUC:       {final_metrics['auc']:.4f}")
    print(f"  Accuracy:  {final_metrics['accuracy']:.4f}")
    print(f"  Precision: {final_metrics['precision']:.4f}")
    print(f"  Recall:    {final_metrics['recall']:.4f}")
    print(f"  F1:        {final_metrics['f1']:.4f}")
    print(f"  Threshold: {final_metrics['threshold']:.2f}")
    print(f"  Pass rate: {final_metrics['pass_rate']:.1%}")
    print(f"  Time:      {total_time:.0f}s")

    # Export
    print(f"\n  📦 Exporting model for CPU inference...")
    export_model(model, final_metrics["threshold"], args.seq_len, final_metrics)

    print(f"\n  ✅ Done! Transfer models/ folder back to server:")
    print(f"     scp {MODELS_DIR}/surge_predictor_gpu* user@server:~/crypto_trading_bot/spy_optimizer/models/")


if __name__ == "__main__":
    main()
