#!/usr/bin/env python3
"""
Remote GPU Training Orchestrator
==================================
Lance le pipeline GPU complet sur le PC Windows distant depuis le serveur.

Étapes automatisées :
  1. Push des données (combined_training_dataset.parquet + lstm_sequences.npz)
  2. Push du script train_gpu_full.py
  3. Exécution distante via SSH
  4. Récupération des résultats (optimized_params.json + modèle LSTM)
  5. Re-entraînement du signal_classifier avec les nouveaux params

Usage:
    python remote_gpu_train.py                          # Tout le pipeline
    python remote_gpu_train.py --push-only              # Push data seulement
    python remote_gpu_train.py --pull-only              # Pull résultats seulement
    python remote_gpu_train.py --retrain-only            # Retrain classifier seulement
    python remote_gpu_train.py --trials 1000 --epochs 200  # Custom params

Prérequis:
    1. OpenSSH Server activé sur le PC Windows
    2. Clé SSH du serveur autorisée sur le PC
    3. gpu_remote_config.json rempli avec les infos du PC
"""
import argparse
import json
import os
import subprocess
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
MODELS_DIR = PROJECT_DIR / "models"
CONFIG_FILE = PROJECT_DIR / "gpu_remote_config.json"


def load_config():
    """Charge la configuration de connexion au PC."""
    if not CONFIG_FILE.exists():
        print("  ❌ gpu_remote_config.json introuvable")
        print("  📝 Créez-le avec : pc_user, pc_host, pc_port, pc_work_dir")
        sys.exit(1)

    config = json.loads(CONFIG_FILE.read_text())

    if config.get("pc_host") == "CHANGE_ME" or config.get("pc_user") == "CHANGE_ME":
        print("  ❌ gpu_remote_config.json non configuré")
        print("  📝 Remplissez pc_user et pc_host avec les infos de votre PC")
        sys.exit(1)

    return config


def ssh_cmd(config, command, capture=False, timeout=None):
    """Exécute une commande SSH sur le PC."""
    ssh_args = [
        "ssh",
        "-p", str(config["pc_port"]),
        "-o", "StrictHostKeyChecking=accept-new",
        "-o", "ConnectTimeout=10",
    ]
    if config.get("pc_ssh_key"):
        ssh_args += ["-i", config["pc_ssh_key"]]

    target = f"{config['pc_user']}@{config['pc_host']}"
    ssh_args += [target, command]

    if capture:
        result = subprocess.run(ssh_args, capture_output=True, timeout=timeout)
        # Decode with error handling for Windows encoding
        result.stdout = result.stdout.decode('utf-8', errors='replace') if isinstance(result.stdout, bytes) else result.stdout
        result.stderr = result.stderr.decode('utf-8', errors='replace') if isinstance(result.stderr, bytes) else result.stderr
        return result
    else:
        return subprocess.run(ssh_args, timeout=timeout)


def scp_to_pc(config, local_path, remote_path):
    """Copie un fichier vers le PC."""
    scp_args = [
        "scp",
        "-P", str(config["pc_port"]),
        "-o", "StrictHostKeyChecking=accept-new",
    ]
    if config.get("pc_ssh_key"):
        scp_args += ["-i", config["pc_ssh_key"]]

    target = f"{config['pc_user']}@{config['pc_host']}:{remote_path}"
    scp_args += [str(local_path), target]

    return subprocess.run(scp_args)


def scp_from_pc(config, remote_path, local_path):
    """Copie un fichier depuis le PC."""
    scp_args = [
        "scp",
        "-P", str(config["pc_port"]),
        "-o", "StrictHostKeyChecking=accept-new",
    ]
    if config.get("pc_ssh_key"):
        scp_args += ["-i", config["pc_ssh_key"]]

    target = f"{config['pc_user']}@{config['pc_host']}:{remote_path}"
    scp_args += [target, str(local_path)]

    return subprocess.run(scp_args)


def test_connection(config):
    """Teste la connexion SSH au PC (via tunnel inverse)."""
    print("  🔌 Test connexion SSH au PC...")

    if config["pc_host"] == "localhost":
        # Check if tunnel is active
        try:
            import socket
            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            s.settimeout(3)
            result = s.connect_ex(("localhost", config["pc_port"]))
            s.close()
            if result != 0:
                print(f"  ❌ Tunnel SSH non actif sur localhost:{config['pc_port']}")
                print(f"     → Sur le PC, lancez: .\\connect_tunnel.ps1")
                return False
        except Exception:
            pass

    try:
        result = ssh_cmd(config, "echo OK", capture=True, timeout=15)
        if result.returncode == 0 and "OK" in result.stdout:
            print(f"  ✅ Connexion OK → {config['pc_user']}@{config['pc_host']}:{config['pc_port']}")
            # Get PC info
            info = ssh_cmd(config, "hostname & echo --- & where python 2>nul", capture=True, timeout=10)
            if info.returncode == 0:
                print(f"     PC: {info.stdout.strip().split(chr(10))[0]}")
            return True
        else:
            print(f"  ❌ Connexion échouée: {result.stderr.strip()}")
            if config["pc_host"] == "localhost":
                print(f"     → Le tunnel est actif mais l'auth SSH échoue.")
                print(f"     → Vérifiez que la clé serveur est dans authorized_keys du PC")
            return False
    except subprocess.TimeoutExpired:
        print("  ❌ Timeout — le PC est-il allumé et le tunnel actif ?")
        if config["pc_host"] == "localhost":
            print(f"     → Sur le PC, lancez: .\\connect_tunnel.ps1")
        return False
    except FileNotFoundError:
        print("  ❌ ssh introuvable")
        return False


def push_data(config, rebuild_lstm=False):
    """Envoie les données et le script au PC."""
    work_dir = config["pc_work_dir"]
    win_work_dir = work_dir.replace("\\", "/")

    # Créer les répertoires sur le PC (Windows cmd)
    ssh_cmd(config, f'cmd /c "if not exist {work_dir}\\models mkdir {work_dir}\\models"', timeout=15)

    files_to_push = [
        (PROJECT_DIR / "train_gpu_full.py", f"{win_work_dir}/train_gpu_full.py"),
        (DATA_DIR / "combined_training_dataset.parquet", f"{win_work_dir}/combined_training_dataset.parquet"),
    ]

    # LSTM sequences
    lstm_path = DATA_DIR / "lstm_sequences.npz"
    if lstm_path.exists():
        files_to_push.append((lstm_path, f"{win_work_dir}/lstm_sequences.npz"))
    else:
        print("  ⚠️  lstm_sequences.npz manquant — lancez build_lstm_sequences.py d'abord")

    print(f"\n  📤 Push {len(files_to_push)} fichiers vers le PC...")
    for local, remote in files_to_push:
        size_mb = local.stat().st_size / 1e6
        print(f"    → {local.name} ({size_mb:.1f} MB)")
        result = scp_to_pc(config, local, remote)
        if result.returncode != 0:
            print(f"    ❌ Échec transfer {local.name}")
            return False

    print("  ✅ Données transférées")
    return True


def run_training(config, trials=500, epochs=100, mode="full"):
    """Lance le training GPU sur le PC distant."""
    work_dir = config["pc_work_dir"]
    python = config.get("pc_python", "python")

    # Construire la commande
    if mode == "lstm-only":
        extra = "--lstm-only"
    elif mode == "tabular-only":
        extra = "--tabular-only"
    else:
        extra = ""

    # Windows cmd : cd /d pour changer de disque, activer venv si trouvé
    # Force UTF-8 output to avoid cp1252 encoding errors
    # Le venv peut être dans work_dir\.venv ou work_dir\windows\.venv
    venv1 = f"{work_dir}\\.venv\\Scripts\\activate.bat"
    venv2 = f"{work_dir}\\windows\\.venv\\Scripts\\activate.bat"
    cmd = f'cmd /c "chcp 65001 >nul & cd /d {work_dir} & ( if exist {venv1} ( call {venv1} ) else if exist {venv2} ( call {venv2} ) ) & set PYTHONIOENCODING=utf-8 & python train_gpu_full.py --trials {trials} --epochs {epochs} {extra}"'

    print(f"\n  🚀 Lancement du training GPU sur le PC...")
    print(f"     Command: {cmd}")
    print(f"     {datetime.now().strftime('%H:%M:%S')} — Début")
    print(f"     (Ctrl+C pour abandonner)\n")

    t0 = time.time()
    # Run without timeout — training can take 30+ minutes
    result = ssh_cmd(config, cmd, capture=False, timeout=None)
    elapsed = time.time() - t0

    if result.returncode == 0:
        print(f"\n  ✅ Training terminé en {elapsed/60:.1f} min")
        return True
    else:
        print(f"\n  ❌ Training échoué (exit code {result.returncode}) après {elapsed/60:.1f} min")
        return False


def pull_results(config):
    """Récupère les résultats depuis le PC."""
    work_dir = config["pc_work_dir"].replace("\\", "/")

    files_to_pull = [
        # Tabular params
        (f"{work_dir}/models/optimized_params.json", MODELS_DIR / "optimized_params.json"),
        # LSTM model
        (f"{work_dir}/models/surge_predictor_gpu.pt", MODELS_DIR / "surge_predictor_gpu.pt"),
        (f"{work_dir}/models/surge_predictor_gpu_checkpoint.pth", MODELS_DIR / "surge_predictor_gpu_checkpoint.pth"),
        (f"{work_dir}/models/surge_predictor_gpu_meta.json", MODELS_DIR / "surge_predictor_gpu_meta.json"),
        # Full results
        (f"{work_dir}/gpu_optimization_results.json", DATA_DIR / "gpu_optimization_results.json"),
    ]

    print(f"\n  📥 Pull résultats depuis le PC...")
    pulled = 0
    for remote, local in files_to_pull:
        result = scp_from_pc(config, remote, local)
        if result.returncode == 0:
            if local.exists():
                size_kb = local.stat().st_size / 1024
                print(f"    ✅ {local.name} ({size_kb:.0f} KB)")
                pulled += 1
        else:
            print(f"    ⚠️  {local.name} — non disponible")

    print(f"\n  📦 {pulled}/{len(files_to_pull)} fichiers récupérés")
    return pulled > 0


def retrain_classifier(verbose=True):
    """Re-entraîne le signal_classifier avec les nouveaux GPU params + LSTM."""
    print(f"\n{'═'*60}")
    print(f"  🔄 Re-training signal_classifier avec GPU params")
    print(f"{'═'*60}")

    # Vérifier que les params GPU existent
    params_path = MODELS_DIR / "optimized_params.json"
    if not params_path.exists():
        print("  ❌ optimized_params.json manquant — lancez le training GPU d'abord")
        return False

    # Vérifier le dataset
    combined_path = DATA_DIR / "combined_training_dataset.parquet"
    if not combined_path.exists():
        print("  ❌ combined_training_dataset.parquet manquant")
        return False

    # Import et train
    try:
        import pandas as pd
        # Add project dir to path for imports
        if str(PROJECT_DIR) not in sys.path:
            sys.path.insert(0, str(PROJECT_DIR))

        from signal_classifier import SignalClassifier

        dataset = pd.read_parquet(combined_path)
        print(f"  📂 Dataset: {len(dataset)} samples")

        params = json.loads(params_path.read_text())
        print(f"  📂 GPU params: AUC {params.get('ensemble_auc', '?')}, threshold {params.get('threshold', '?')}")

        clf = SignalClassifier()
        metrics = clf.train(dataset, optimize_hyperparams=False, verbose=verbose)

        clf.save(str(MODELS_DIR / "signal_classifier.pkl"))
        print(f"\n  ✅ signal_classifier.pkl mis à jour")
        print(f"     AUC: {metrics.get('auc_roc', '?')}")
        print(f"     PnL improvement: {metrics.get('pnl_improvement_pct', '?')}%")

        # Check if LSTM model is available
        lstm_meta = MODELS_DIR / "surge_predictor_gpu_meta.json"
        if lstm_meta.exists():
            meta = json.loads(lstm_meta.read_text())
            print(f"  🧠 Modèle LSTM disponible: AUC {meta.get('auc', '?')}")
            print(f"     → Le classifier utilisera l'ensemble hybride (60% tabular + 40% LSTM)")

        return True

    except Exception as e:
        print(f"  ❌ Erreur retrain: {e}")
        import traceback
        traceback.print_exc()
        return False


def main():
    parser = argparse.ArgumentParser(description="Remote GPU Training Orchestrator")
    parser.add_argument("--trials", type=int, default=500, help="Optuna trials")
    parser.add_argument("--epochs", type=int, default=100, help="LSTM epochs")
    parser.add_argument("--push-only", action="store_true", help="Push data only")
    parser.add_argument("--pull-only", action="store_true", help="Pull results only")
    parser.add_argument("--retrain-only", action="store_true", help="Retrain classifier only")
    parser.add_argument("--lstm-only", action="store_true", help="LSTM training only")
    parser.add_argument("--tabular-only", action="store_true", help="Tabular only (no LSTM)")
    parser.add_argument("--skip-retrain", action="store_true", help="Skip classifier retrain")
    args = parser.parse_args()

    print(f"\n{'═'*60}")
    print(f"  🖥️  Remote GPU Training Orchestrator")
    print(f"  📅 {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'═'*60}")

    # ── Retrain only ──
    if args.retrain_only:
        retrain_classifier()
        return

    # ── Load config ──
    config = load_config()
    print(f"  PC: {config['pc_user']}@{config['pc_host']}:{config['pc_port']}")
    print(f"  Dir: {config['pc_work_dir']}")

    # ── Test connection ──
    if not test_connection(config):
        sys.exit(1)

    # ── Pull only ──
    if args.pull_only:
        pull_results(config)
        if not args.skip_retrain:
            retrain_classifier()
        return

    # ── Push only ──
    if args.push_only:
        push_data(config)
        return

    # ── Full pipeline ──
    t0 = time.time()

    # Step 1: Push data
    if not push_data(config):
        sys.exit(1)

    # Step 2: Run training
    mode = "lstm-only" if args.lstm_only else ("tabular-only" if args.tabular_only else "full")
    if not run_training(config, args.trials, args.epochs, mode):
        print("\n  ⚠️  Training a échoué, tentative de récupération des résultats partiels...")

    # Step 3: Pull results
    if not pull_results(config):
        print("  ❌ Aucun résultat récupéré")
        sys.exit(1)

    # Step 4: Retrain classifier
    if not args.skip_retrain:
        retrain_classifier()

    total = time.time() - t0
    print(f"\n{'═'*60}")
    print(f"  ✅ Pipeline complet en {total/60:.1f} min")
    print(f"{'═'*60}")


if __name__ == "__main__":
    main()
