"""
SPY Optimizer — Deep Learning Inference (CPU)
Charge le modèle LSTM+Attention entraîné sur GPU pour l'inference en production.

Ce module fonctionne SUR LE SERVEUR (CPU-only). Le modèle est entraîné sur le PC (GPU)
et transféré via deploy_model.sh.

Usage:
    from deep_inference import DeepPredictor
    predictor = DeepPredictor()
    result = predictor.predict(klines_df, timestamp_ms, surge_type="FLASH_SURGE")
    # → {"probability": 0.72, "signal": "BUY", "confidence": 42.0, "threshold": 0.5}
"""
import json
import os
from pathlib import Path
from typing import Optional

import numpy as np
import torch

from deep_model import (
    SurgePredictor,
    build_sequence_features,
    N_SEQUENCE_FEATURES,
    SURGE_TYPE_MAP,
)

MODELS_DIR = Path(__file__).parent / "models"


class DeepPredictor:
    """
    Inference CPU pour le modèle LSTM+Attention entraîné sur GPU.
    Thread-safe pour utilisation dans market_spy.py.
    """

    def __init__(self, models_dir: Optional[str] = None):
        self.models_dir = Path(models_dir) if models_dir else MODELS_DIR
        self.model = None
        self.threshold = 0.5
        self.seq_len = 60
        self.metadata = {}
        self.is_loaded = False
        self._load()

    def _load(self):
        """Charge le modèle TorchScript ou checkpoint."""
        # Priorité 1: TorchScript (plus rapide)
        ts_path = self.models_dir / "surge_predictor_gpu.pt"
        ckpt_path = self.models_dir / "surge_predictor_gpu_checkpoint.pth"
        meta_path = self.models_dir / "surge_predictor_gpu_meta.json"

        # Load metadata
        if meta_path.exists():
            with open(meta_path) as f:
                self.metadata = json.load(f)
            self.threshold = self.metadata.get("threshold", 0.5)
            self.seq_len = self.metadata.get("seq_len", 60)

        # Try TorchScript first
        if ts_path.exists():
            try:
                self.model = torch.jit.load(str(ts_path), map_location="cpu")
                self.model.eval()
                self.is_loaded = True
                print(f"  ✅ Deep model loaded (TorchScript): {ts_path.name}")
                return
            except Exception as e:
                print(f"  ⚠️  TorchScript load failed: {e}")

        # Fallback: checkpoint
        if ckpt_path.exists():
            try:
                ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
                model = SurgePredictor(
                    n_features=ckpt.get("n_features", N_SEQUENCE_FEATURES),
                    hidden_size=ckpt.get("hidden_size", 128),
                    num_layers=ckpt.get("num_layers", 2),
                    dropout=0,  # No dropout during inference
                )
                model.load_state_dict(ckpt["model_state_dict"])
                model.eval()
                self.model = model
                self.threshold = ckpt.get("threshold", 0.5)
                self.seq_len = ckpt.get("seq_len", 60)
                self.is_loaded = True
                print(f"  ✅ Deep model loaded (checkpoint): {ckpt_path.name}")
                return
            except Exception as e:
                print(f"  ⚠️  Checkpoint load failed: {e}")

        print(f"  ℹ️  No GPU model found in {self.models_dir}. Deep inference disabled.")

    @torch.no_grad()
    def predict(
        self,
        klines_df: 'pd.DataFrame',
        timestamp_ms: int,
        surge_type: str = "UNKNOWN",
    ) -> dict:
        """
        Prédit si un signal de surge sera rentable.

        Args:
            klines_df: DataFrame des klines 1m pour le symbole
            timestamp_ms: timestamp en ms du moment d'analyse
            surge_type: type de surge détecté ("FLASH_SURGE", etc.)

        Returns:
            dict avec probability, signal, confidence, threshold, model_type
        """
        if not self.is_loaded:
            return {
                "probability": 0.5,
                "signal": "BUY",
                "confidence": 0,
                "threshold": self.threshold,
                "model_type": "none",
            }

        # Build sequence
        seq = build_sequence_features(klines_df, timestamp_ms, self.seq_len)
        if seq is None:
            return {
                "probability": 0.5,
                "signal": "BUY",
                "confidence": 0,
                "threshold": self.threshold,
                "model_type": "deep_no_data",
            }

        # Prepare tensors
        x = torch.tensor(seq, dtype=torch.float32).unsqueeze(0)  # (1, seq_len, n_features)
        surge_idx = SURGE_TYPE_MAP.get(surge_type, 3)
        surge_t = torch.tensor([surge_idx], dtype=torch.long)

        # Replace NaN/Inf
        x = torch.nan_to_num(x, nan=0.0, posinf=10.0, neginf=-10.0)

        # Inference
        if isinstance(self.model, torch.jit.ScriptModule):
            prob = self.model(x, surge_t).item()
        else:
            output = self.model(x, surge_t)
            prob = output["probability"].item()

        signal = "BUY" if prob >= self.threshold else "SKIP"
        confidence = abs(prob - self.threshold) / max(self.threshold, 0.01) * 100
        confidence = min(confidence, 100)

        return {
            "probability": float(prob),
            "signal": signal,
            "confidence": float(confidence),
            "threshold": float(self.threshold),
            "model_type": "lstm_attention",
        }

    def get_info(self) -> dict:
        """Retourne les informations sur le modèle chargé."""
        return {
            "is_loaded": self.is_loaded,
            "seq_len": self.seq_len,
            "threshold": self.threshold,
            "metadata": self.metadata,
        }
