#!/usr/bin/env python3
"""
SPY Optimizer — Main Runner
Point d'entrée pour entraîner, évaluer et backtester le Signal Classifier.

Usage:
  python run_optimizer.py train          # Entraîne sur toutes les données
  python run_optimizer.py backtest       # Walk-forward backtest
  python run_optimizer.py evaluate       # Évalue le modèle courant
  python run_optimizer.py full           # train + backtest complet
"""
import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path

import pandas as pd

from feature_engineering import build_dataset_from_trades, FEATURE_COLUMNS
from signal_classifier import SignalClassifier, walk_forward_backtest

PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
KLINES_DIR = DATA_DIR / "klines_1m"
MODELS_DIR = PROJECT_DIR / "models"
HISTORY_FILE = PROJECT_DIR.parent / "espion_history.json"
RESULTS_FILE = DATA_DIR / "optimizer_results.json"


def load_trades() -> list[dict]:
    """Charge tous les trades exploitables."""
    with open(HISTORY_FILE) as f:
        all_trades = json.load(f)
    # Trades avec surge_type = post-refactor, données fiables
    trades = [t for t in all_trades if t.get("surge_type") and t.get("entry_time")]
    print(f"  📂 {len(trades)} trades avec surge_type chargés")
    print(f"     (sur {len(all_trades)} total, {len(all_trades)-len(trades)} RECONSTRUCTED ignorés)")

    # Stats rapides
    profitable = sum(1 for t in trades if float(t.get("pnl_usdt", 0)) > 0)
    total_pnl = sum(float(t.get("pnl_usdt", 0)) for t in trades)
    print(f"     WR brut: {profitable/len(trades)*100:.1f}% | PnL: ${total_pnl:+.2f}")

    dates = sorted(t.get("entry_time", "")[:10] for t in trades)
    print(f"     Période: {dates[0]} → {dates[-1]}")

    return trades


def cmd_train(args):
    """Entraîne le modèle sur toutes les données disponibles."""
    print(f"\n{'='*60}")
    print(f"  🏋️ Mode: TRAIN")
    print(f"{'='*60}")

    trades = load_trades()

    print(f"\n  📐 Construction du dataset de features...")
    dataset = build_dataset_from_trades(trades, str(KLINES_DIR))
    print(f"  ✅ Dataset: {len(dataset)} samples × {len(FEATURE_COLUMNS)} features")

    # Sauvegarder le dataset pour réutilisation
    dataset_path = DATA_DIR / "training_dataset.parquet"
    dataset.to_parquet(dataset_path, index=False)
    print(f"  💾 Dataset sauvegardé: {dataset_path}")

    # Entraîner
    clf = SignalClassifier()
    metrics = clf.train(
        dataset,
        optimize_hyperparams=not args.no_optuna,
        n_optuna_trials=args.optuna_trials,
        verbose=True,
    )

    # Sauvegarder
    MODELS_DIR.mkdir(parents=True, exist_ok=True)
    clf.save()

    # Sauvegarder les résultats
    results = {
        "mode": "train",
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "dataset_size": len(dataset),
        "metrics": metrics,
    }
    RESULTS_FILE.write_text(json.dumps(results, indent=2))
    print(f"\n  📊 Résultats sauvegardés: {RESULTS_FILE}")

    return metrics


def cmd_backtest(args):
    """Walk-forward backtest complet."""
    print(f"\n{'='*60}")
    print(f"  📈 Mode: WALK-FORWARD BACKTEST")
    print(f"{'='*60}")

    # Charger le dataset existant ou le reconstruire
    dataset_path = DATA_DIR / "training_dataset.parquet"
    if dataset_path.exists() and not args.rebuild:
        print(f"  📂 Dataset existant chargé: {dataset_path}")
        dataset = pd.read_parquet(dataset_path)
    else:
        trades = load_trades()
        print(f"\n  📐 Construction du dataset de features...")
        dataset = build_dataset_from_trades(trades, str(KLINES_DIR))
        dataset.to_parquet(dataset_path, index=False)

    print(f"  ✅ Dataset: {len(dataset)} samples")

    results = walk_forward_backtest(
        dataset,
        train_days=args.train_days,
        val_days=args.val_days,
        step_days=args.step_days,
        n_optuna_trials=args.optuna_trials,
        verbose=True,
    )

    # Sauvegarder
    backtest_results = {
        "mode": "backtest",
        "timestamp": datetime.now(timezone.utc).isoformat(),
        "config": {
            "train_days": args.train_days,
            "val_days": args.val_days,
            "step_days": args.step_days,
        },
        "results": results,
    }
    backtest_file = DATA_DIR / "backtest_results.json"
    backtest_file.write_text(json.dumps(backtest_results, indent=2, default=str))
    print(f"\n  📊 Résultats sauvegardés: {backtest_file}")

    return results


def cmd_evaluate(args):
    """Évalue le modèle actuel sur les dernières données."""
    print(f"\n{'='*60}")
    print(f"  🔍 Mode: EVALUATE")
    print(f"{'='*60}")

    try:
        clf = SignalClassifier.load()
    except FileNotFoundError:
        print("  ❌ Aucun modèle entraîné trouvé. Lance 'train' d'abord.")
        return

    print(f"  📂 Modèle chargé (entraîné: {clf.training_stats.get('trained_at', '?')})")
    print(f"  Seuil: {clf.optimal_threshold:.2f}")
    print(f"  Samples train: {clf.training_stats.get('train_samples', '?')}")

    # Charger les trades récents
    trades = load_trades()

    # Construire les features et prédire
    dataset = build_dataset_from_trades(trades, str(KLINES_DIR))

    correct = 0
    total = 0
    pnl_passed = 0
    pnl_all = 0

    for _, row in dataset.iterrows():
        features = row.to_dict()
        prediction = clf.predict(features)
        actual = row["target_profitable"]
        pnl = row["target_pnl_pct"]

        pnl_all += pnl
        if prediction["signal"] == "BUY":
            pnl_passed += pnl
            if (prediction["probability"] >= clf.optimal_threshold) == actual:
                correct += 1
        else:
            if actual == 0:
                correct += 1
        total += 1

    if total > 0:
        print(f"\n  Accuracy: {correct/total*100:.1f}%")
        print(f"  PnL sans filtre: {pnl_all:+.2f}%")
        print(f"  PnL avec filtre: {pnl_passed:+.2f}%")
        print(f"  Δ PnL: {pnl_passed - pnl_all:+.2f}%")


def cmd_full(args):
    """Train + Backtest complet."""
    print(f"\n{'='*60}")
    print(f"  🚀 Mode: FULL (Train + Backtest)")
    print(f"{'='*60}")

    cmd_train(args)
    print("\n" + "="*60)
    cmd_backtest(args)


def main():
    parser = argparse.ArgumentParser(description="SPY Optimizer — Signal Classifier")
    subparsers = parser.add_subparsers(dest="command", help="Commande à exécuter")

    # Common args
    for name in ["train", "backtest", "evaluate", "full"]:
        if name == "train":
            sp = subparsers.add_parser(name, help="Entraîne le modèle")
            sp.add_argument("--no-optuna", action="store_true", help="Skip l'optimisation Optuna")
        elif name == "backtest":
            sp = subparsers.add_parser(name, help="Walk-forward backtest")
            sp.add_argument("--train-days", type=int, default=7, help="Jours d'entraînement par window")
            sp.add_argument("--val-days", type=int, default=1, help="Jours de validation par window")
            sp.add_argument("--step-days", type=int, default=1, help="Pas entre windows")
            sp.add_argument("--rebuild", action="store_true", help="Reconstruire le dataset")
        elif name == "evaluate":
            sp = subparsers.add_parser(name, help="Évalue le modèle courant")
        elif name == "full":
            sp = subparsers.add_parser(name, help="Train + Backtest")
            sp.add_argument("--train-days", type=int, default=7)
            sp.add_argument("--val-days", type=int, default=1)
            sp.add_argument("--step-days", type=int, default=1)
            sp.add_argument("--rebuild", action="store_true")
            sp.add_argument("--no-optuna", action="store_true")

        sp.add_argument("--optuna-trials", type=int, default=80, help="Nombre d'essais Optuna")

    args = parser.parse_args()

    if not args.command:
        parser.print_help()
        return

    commands = {
        "train": cmd_train,
        "backtest": cmd_backtest,
        "evaluate": cmd_evaluate,
        "full": cmd_full,
    }
    commands[args.command](args)


if __name__ == "__main__":
    main()
