"""
SPY Optimizer — Deep Learning Model (GPU)
Architecture LSTM + Temporal Attention pour prédire la rentabilité des signaux de surge.

Conçu pour être entraîné sur GPU (RTX 3060+) et exporté pour inference CPU sur le serveur.

Architecture:
  Input (seq_len × n_features) → LSTM bidirectionnel → Temporal Attention → FC → sigmoid

Le modèle prend en entrée une SÉQUENCE de features (ex: 60 minutes de données 1m),
pas un vecteur plat comme LightGBM. Cela lui permet de capturer les patterns temporels.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional


class TemporalAttention(nn.Module):
    """Attention mechanism pour pondérer les timesteps les plus importants."""

    def __init__(self, hidden_size: int):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.Tanh(),
            nn.Linear(hidden_size // 2, 1),
        )

    def forward(self, lstm_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            lstm_output: (batch, seq_len, hidden_size)
        Returns:
            context: (batch, hidden_size) — weighted sum
            weights: (batch, seq_len) — attention weights
        """
        scores = self.attention(lstm_output).squeeze(-1)  # (batch, seq_len)
        weights = F.softmax(scores, dim=1)  # (batch, seq_len)
        context = torch.bmm(weights.unsqueeze(1), lstm_output).squeeze(1)  # (batch, hidden_size)
        return context, weights


class SurgePredictor(nn.Module):
    """
    LSTM bidirectionnel + Attention Temporelle pour prédire si un surge sera rentable.

    Le modèle reçoit une séquence temporelle de features (klines pré-processées)
    et produit une probabilité de trade profitable.
    """

    def __init__(
        self,
        n_features: int,
        hidden_size: int = 128,
        num_layers: int = 2,
        dropout: float = 0.3,
        n_surge_types: int = 3,
        surge_embed_dim: int = 8,
    ):
        super().__init__()
        self.n_features = n_features
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Input projection — normalise et projette les features brutes
        self.input_proj = nn.Sequential(
            nn.Linear(n_features, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
        )

        # LSTM bidirectionnel
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0,
        )

        # Temporal attention sur la sortie LSTM
        self.attention = TemporalAttention(hidden_size * 2)  # *2 car bidirectionnel

        # Surge type embedding (catégorielle)
        self.surge_embedding = nn.Embedding(n_surge_types + 1, surge_embed_dim)  # +1 pour UNKNOWN

        # Classification head
        classifier_input = hidden_size * 2 + surge_embed_dim  # attention context + surge
        self.classifier = nn.Sequential(
            nn.Linear(classifier_input, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(hidden_size // 2, 1),
        )

        self._init_weights()

    def _init_weights(self):
        """Xavier initialization pour une convergence plus rapide."""
        for name, param in self.named_parameters():
            if "weight" in name and param.dim() >= 2:
                nn.init.xavier_uniform_(param)
            elif "bias" in name:
                nn.init.zeros_(param)

    def forward(
        self,
        x: torch.Tensor,
        surge_type: torch.Tensor,
        return_attention: bool = False,
    ) -> dict:
        """
        Args:
            x: (batch, seq_len, n_features) — séquence temporelle
            surge_type: (batch,) — index du type de surge (0-3)
            return_attention: si True, retourne aussi les poids d'attention

        Returns:
            dict avec 'logits', 'probability', et optionnellement 'attention_weights'
        """
        # Projection des features
        projected = self.input_proj(x)  # (batch, seq_len, hidden_size)

        # LSTM
        lstm_out, _ = self.lstm(projected)  # (batch, seq_len, hidden_size*2)

        # Temporal attention
        context, attn_weights = self.attention(lstm_out)  # (batch, hidden_size*2)

        # Surge embedding
        surge_emb = self.surge_embedding(surge_type)  # (batch, surge_embed_dim)

        # Concatenation et classification
        combined = torch.cat([context, surge_emb], dim=1)
        logits = self.classifier(combined).squeeze(-1)  # (batch,)
        prob = torch.sigmoid(logits)

        result = {"logits": logits, "probability": prob}
        if return_attention:
            result["attention_weights"] = attn_weights
        return result


class SurgePredictorExportWrapper(nn.Module):
    """
    Wrapper pour exporter le modèle avec TorchScript.
    Simplifie l'interface pour l'inference CPU sur le serveur.
    """

    def __init__(self, model: SurgePredictor):
        super().__init__()
        self.model = model

    def forward(self, x: torch.Tensor, surge_type: torch.Tensor) -> torch.Tensor:
        """Retourne uniquement la probabilité (pour inference simple)."""
        result = self.model(x, surge_type, return_attention=False)
        return result["probability"]


# ─── Séquence Feature Builder ───

# Features continues extraites de chaque candle 1m
SEQUENCE_FEATURES = [
    "return",           # pct change du close
    "high_low_range",   # (high - low) / close
    "close_position",   # (close - low) / (high - low)
    "volume_norm",      # volume / moyenne glissante
    "quote_volume_norm",
    "buy_pressure",     # taker_buy_quote_vol / quote_volume
    "num_trades_norm",  # num_trades / moyenne glissante
    "body_ratio",       # (close - open) / (high - low), signé
    "upper_wick",       # (high - max(close,open)) / (high-low)
    "lower_wick",       # (min(close,open) - low) / (high-low)
    "ema7_dist",        # (close - ema7) / close
    "ema21_dist",       # (close - ema21) / close
    "rsi_norm",         # RSI / 100 (normalisé 0-1)
    "bb_position",      # position dans les bandes de Bollinger
]

N_SEQUENCE_FEATURES = len(SEQUENCE_FEATURES)

# Mapping surge_type → index
SURGE_TYPE_MAP = {
    "FLASH_SURGE": 0,
    "BREAKOUT_SURGE": 1,
    "MOMENTUM_SURGE": 2,
    "UNKNOWN": 3,
}


def build_sequence_features(
    df: 'pd.DataFrame',
    timestamp_ms: int,
    seq_len: int = 60,
) -> Optional[np.ndarray]:
    """
    Construit une séquence de features normalisées pour le modèle LSTM.

    Contrairement au feature engineering tabular (qui résume 120min en 50 features),
    ici on garde la dimension temporelle: chaque minute = un vecteur de N_SEQUENCE_FEATURES.

    Args:
        df: DataFrame klines avec colonnes standard
        timestamp_ms: timestamp en ms du moment d'analyse
        seq_len: longueur de la séquence (nombre de minutes)

    Returns:
        np.ndarray (seq_len, N_SEQUENCE_FEATURES) ou None si données insuffisantes
    """
    import pandas as pd

    # Extra lookback pour les indicateurs (EMA, RSI, BB)
    warmup = 50
    start_ms = timestamp_ms - ((seq_len + warmup) * 60 * 1000)
    mask = (df["open_time"] >= start_ms) & (df["open_time"] <= timestamp_ms)
    window = df.loc[mask].copy()

    if len(window) < seq_len + 20:
        return None

    close = window["close"].values
    high = window["high"].values
    low = window["low"].values
    opn = window["open"].values
    volume = window["volume"].values
    quote_vol = window["quote_volume"].values
    num_trades = window["num_trades"].values
    taker_buy_qvol = window["taker_buy_quote_vol"].values

    n = len(window)

    # Pré-calcul des indicateurs sur toute la fenêtre
    # EMA7
    ema7 = np.zeros(n)
    ema7[0] = close[0]
    alpha7 = 2 / 8
    for i in range(1, n):
        ema7[i] = alpha7 * close[i] + (1 - alpha7) * ema7[i - 1]

    # EMA21
    ema21 = np.zeros(n)
    ema21[0] = close[0]
    alpha21 = 2 / 22
    for i in range(1, n):
        ema21[i] = alpha21 * close[i] + (1 - alpha21) * ema21[i - 1]

    # RSI 14
    delta = np.diff(close, prepend=close[0])
    gain = np.where(delta > 0, delta, 0)
    loss = np.where(delta < 0, -delta, 0)
    avg_gain = np.zeros(n)
    avg_loss = np.zeros(n)
    avg_gain[0] = gain[0]
    avg_loss[0] = loss[0]
    alpha_rsi = 1 / 14
    for i in range(1, n):
        avg_gain[i] = alpha_rsi * gain[i] + (1 - alpha_rsi) * avg_gain[i - 1]
        avg_loss[i] = alpha_rsi * loss[i] + (1 - alpha_rsi) * avg_loss[i - 1]
    rs = np.divide(avg_gain, avg_loss, out=np.full(n, 100.0), where=avg_loss > 0)
    rsi_vals = 100 - 100 / (1 + rs)

    # Bollinger Bands (20 périodes) — "full" mode padding handled manually
    bb_mid = np.zeros(n)
    bb_std_arr = np.zeros(n)
    for i in range(n):
        start = max(0, i - 19)
        window_bb = close[start:i + 1]
        bb_mid[i] = window_bb.mean()
        bb_std_arr[i] = window_bb.std() if len(window_bb) > 1 else 0
    bb_upper = bb_mid + 2 * bb_std_arr
    bb_lower = bb_mid - 2 * bb_std_arr

    # Volume rolling mean (20 périodes)
    vol_ma = np.convolve(quote_vol, np.ones(20) / 20, mode="same")
    trades_ma = np.convolve(num_trades.astype(float), np.ones(20) / 20, mode="same")

    # Extraire les seq_len dernières lignes
    seq = np.zeros((seq_len, N_SEQUENCE_FEATURES))
    offset = n - seq_len

    for i in range(seq_len):
        idx = offset + i
        c = close[idx]
        h = high[idx]
        l = low[idx]
        o = opn[idx]
        hl_range = h - l if h > l else 1e-8

        # return
        if idx > 0 and close[idx - 1] > 0:
            seq[i, 0] = (c / close[idx - 1] - 1) * 100
        # high_low_range
        seq[i, 1] = hl_range / c * 100 if c > 0 else 0
        # close_position
        seq[i, 2] = (c - l) / hl_range if hl_range > 1e-8 else 0.5
        # volume_norm
        seq[i, 3] = volume[idx] / max(vol_ma[idx], 1e-8)
        # quote_volume_norm
        seq[i, 4] = quote_vol[idx] / max(vol_ma[idx], 1e-8)
        # buy_pressure
        seq[i, 5] = taker_buy_qvol[idx] / max(quote_vol[idx], 1e-8)
        # num_trades_norm
        seq[i, 6] = num_trades[idx] / max(trades_ma[idx], 1e-8)
        # body_ratio (signé: positif = vert, négatif = rouge)
        seq[i, 7] = (c - o) / hl_range if hl_range > 1e-8 else 0
        # upper_wick
        seq[i, 8] = (h - max(c, o)) / hl_range if hl_range > 1e-8 else 0
        # lower_wick
        seq[i, 9] = (min(c, o) - l) / hl_range if hl_range > 1e-8 else 0
        # ema7_dist
        seq[i, 10] = (c - ema7[idx]) / c * 100 if c > 0 else 0
        # ema21_dist
        seq[i, 11] = (c - ema21[idx]) / c * 100 if c > 0 else 0
        # rsi_norm
        seq[i, 12] = rsi_vals[idx] / 100
        # bb_position
        bb_range = bb_upper[idx] - bb_lower[idx]
        seq[i, 13] = (c - bb_lower[idx]) / bb_range if bb_range > 1e-8 else 0.5

    return seq


def build_sequence_dataset(
    trades: list[dict],
    klines_dir: str,
    seq_len: int = 60,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Construit le dataset séquentiel pour le modèle LSTM.

    Returns:
        X: (n_samples, seq_len, N_SEQUENCE_FEATURES)
        surge_types: (n_samples,) — indices des surge types
        y: (n_samples,) — target profitable (0/1)
    """
    import pandas as pd
    from pathlib import Path

    X_list = []
    surge_list = []
    y_list = []
    skipped = 0
    loaded_klines = {}

    for i, trade in enumerate(trades):
        symbol = trade.get("symbol", "")
        entry_time = trade.get("entry_time", "")
        if not symbol or not entry_time:
            skipped += 1
            continue

        # Load klines (with cache + USDT/USDC fallback)
        if symbol not in loaded_klines:
            parquet_path = Path(klines_dir) / f"{symbol}.parquet"
            if not parquet_path.exists():
                alt_symbol = (symbol.replace("USDC", "USDT")
                              if symbol.endswith("USDC")
                              else symbol.replace("USDT", "USDC"))
                alt_path = Path(klines_dir) / f"{alt_symbol}.parquet"
                if alt_path.exists():
                    parquet_path = alt_path
                else:
                    skipped += 1
                    continue
            loaded_klines[symbol] = pd.read_parquet(parquet_path)

        df = loaded_klines[symbol]

        try:
            ts_ms = int(pd.Timestamp(entry_time).timestamp() * 1000)
        except Exception:
            skipped += 1
            continue

        seq = build_sequence_features(df, ts_ms, seq_len)
        if seq is None:
            skipped += 1
            continue

        X_list.append(seq)
        surge_type = trade.get("surge_type", "UNKNOWN")
        surge_list.append(SURGE_TYPE_MAP.get(surge_type, 3))
        y_list.append(1 if float(trade.get("pnl_usdt", 0)) > 0 else 0)

        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(trades)} trades ({skipped} skipped)")

    print(f"  Done: {len(X_list)} sequences, {skipped} skipped")

    return (
        np.array(X_list, dtype=np.float32),
        np.array(surge_list, dtype=np.int64),
        np.array(y_list, dtype=np.float32),
    )
