#!/usr/bin/env python3
"""
═══════════════════════════════════════════════════════════════════════════════
 LSTM REVERSAL PREDICTOR - Détection avancée des retournements de marché
 Optimisé pour GPU RTX 5060 Ti / 5070 (16 GB VRAM)
═══════════════════════════════════════════════════════════════════════════════

Architecture:
  - BiLSTM (bidirectionnel) 3 couches × 256 hidden → capture contexte passé ET futur
  - Multi-Head Self-Attention (4 têtes) → focus sur les bougies clés du retournement
  - Feature Engineering spécialisé retournement (30 features)
  - Double sortie: classification (4 classes) + régression (probabilité retournement)

Classes de prédiction:
  0 = NEUTRAL       — Pas de retournement imminent
  1 = REVERSAL_UP   — Retournement haussier détecté (achat)
  2 = REVERSAL_DOWN — Retournement baissier détecté (vente)
  3 = CONTINUATION  — Tendance continue (pas de retournement)

Entraînement:
  - Online learning: se ré-entraîne toutes les 30 min sur données récentes
  - Transfer learning: pré-entraîné sur historique, fine-tuné en temps réel
  - Data augmentation: bruit gaussien + time-warping pour robustesse

Auteur: IA Trading Bot
Date: 01/03/2026
"""

import os
import json
import time
import logging
import threading
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from collections import deque

logger = logging.getLogger("LSTMReversal")

# ═══════════════════════════════════════════════════════════════════════════════
# PYTORCH / GPU SETUP
# ═══════════════════════════════════════════════════════════════════════════════
TORCH_AVAILABLE = False
DEVICE = "cpu"
torch = None
nn = None
F = None

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, TensorDataset
    # 🔵 FIX 25/03: Limiter threads CPU (BiLSTM 3 couches lancé dans 20 workers → saturation)
    torch.set_num_threads(2)
    try:
        torch.set_num_interop_threads(1)
    except RuntimeError:
        pass  # Peut être déjà fixé par un autre module importé avant

    if torch.cuda.is_available() and torch.cuda.device_count() > 0:
        try:
            DEVICE = "cuda"
            gpu_name = torch.cuda.get_device_name(0)
            vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
            logger.info(f"✅ LSTM Reversal: GPU {gpu_name} ({vram_gb:.0f} GB VRAM)")
        except (AssertionError, RuntimeError):
            DEVICE = "cpu"
            logger.info("ℹ️  LSTM Reversal: GPU non accessible, mode CPU")
        TORCH_AVAILABLE = True
    else:
        TORCH_AVAILABLE = True
        logger.info("ℹ️  LSTM Reversal: CPU mode (pas de GPU CUDA)")
except ImportError:
    logger.info("ℹ️  LSTM Reversal: PyTorch non disponible")

# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════
MODEL_DIR = Path(__file__).parent / "models"
MODEL_PATH = MODEL_DIR / "lstm_reversal.pt"
STATS_PATH = Path(__file__).parent / "lstm_reversal_stats.json"

# Architecture (optimisé pour RTX 5060 Ti / 5070 — 16 GB VRAM)
SEQUENCE_LENGTH = 60           # 60 bougies de contexte (5h sur timeframe 5min)
NUM_FEATURES = 30              # 30 features spécialisées retournement
HIDDEN_SIZE = 256              # 256 neurones par couche LSTM
NUM_LAYERS = 3                 # 3 couches BiLSTM
NUM_HEADS = 4                  # 4 têtes d'attention
DROPOUT = 0.25                 # Dropout 25%
NUM_CLASSES = 4                # NEUTRAL, REVERSAL_UP, REVERSAL_DOWN, CONTINUATION

# Entraînement
LEARNING_RATE = 1e-3
BATCH_SIZE = 64
ONLINE_EPOCHS = 3              # Epochs pour le fine-tuning online (rapide)
PRETRAIN_EPOCHS = 30           # Epochs pour le pré-entraînement
RETRAIN_INTERVAL_MIN = 30      # Ré-entraîner toutes les 30 min
MIN_SAMPLES_TRAIN = 200        # Minimum d'échantillons pour entraîner
REVERSAL_LOOKFORWARD = 8       # 8 bougies dans le futur pour détecter retournement (40 min)
REVERSAL_THRESHOLD_PCT = 0.8   # ±0.8% de mouvement = retournement significatif

# Seuils de décision
# 🔧 FIX 01/03 v7: Confiance 0.55→0.65 — filtrer faux signaux (modèle jeune)
REVERSAL_CONFIDENCE_MIN = 0.65   # Confiance minimum pour signaler retournement (était 0.55)
HIGH_CONFIDENCE_THRESHOLD = 0.75 # Confiance élevée


# ═══════════════════════════════════════════════════════════════════════════════
# MODÈLE: BiLSTM + Multi-Head Attention
# ═══════════════════════════════════════════════════════════════════════════════
if TORCH_AVAILABLE:

    class MultiHeadAttention(nn.Module):
        """Attention multi-têtes pour focus sur les bougies clés"""
        def __init__(self, hidden_dim: int, num_heads: int = 4):
            super().__init__()
            self.num_heads = num_heads
            self.head_dim = hidden_dim // num_heads
            assert hidden_dim % num_heads == 0, "hidden_dim doit être divisible par num_heads"

            self.query = nn.Linear(hidden_dim, hidden_dim)
            self.key = nn.Linear(hidden_dim, hidden_dim)
            self.value = nn.Linear(hidden_dim, hidden_dim)
            self.out_proj = nn.Linear(hidden_dim, hidden_dim)
            self.scale = self.head_dim ** 0.5

        def forward(self, x: 'torch.Tensor') -> 'torch.Tensor':
            # x: (batch, seq_len, hidden_dim)
            B, T, C = x.shape
            q = self.query(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.key(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
            v = self.value(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

            attn = (q @ k.transpose(-2, -1)) / self.scale
            attn = F.softmax(attn, dim=-1)
            out = (attn @ v).transpose(1, 2).contiguous().view(B, T, C)
            return self.out_proj(out)


    class ReversalLSTM(nn.Module):
        """
        BiLSTM + Attention pour la détection de retournements.

        Entrée: (batch, 60, 30) — 60 timesteps × 30 features
        Sortie:
          - class_logits: (batch, 4)   — logits des 4 classes
          - reversal_prob: (batch, 1)  — probabilité brute de retournement [0-1]
        """
        def __init__(self,
                     input_size: int = NUM_FEATURES,
                     hidden_size: int = HIDDEN_SIZE,
                     num_layers: int = NUM_LAYERS,
                     num_heads: int = NUM_HEADS,
                     num_classes: int = NUM_CLASSES,
                     dropout: float = DROPOUT):
            super().__init__()

            # Couche de projection d'entrée (30 → 64 pour réduire bruit)
            self.input_proj = nn.Sequential(
                nn.Linear(input_size, 64),
                nn.LayerNorm(64),
                nn.GELU(),
                nn.Dropout(dropout * 0.5)
            )

            # BiLSTM (bidirectionnel → hidden_size * 2 en sortie)
            self.lstm = nn.LSTM(
                input_size=64,
                hidden_size=hidden_size,
                num_layers=num_layers,
                batch_first=True,
                bidirectional=True,
                dropout=dropout
            )

            # Layer norm après LSTM
            self.ln_lstm = nn.LayerNorm(hidden_size * 2)

            # Multi-Head Attention sur les sorties LSTM
            self.attention = MultiHeadAttention(hidden_size * 2, num_heads)
            self.ln_attn = nn.LayerNorm(hidden_size * 2)

            # Tête de classification (4 classes)
            self.classifier = nn.Sequential(
                nn.Linear(hidden_size * 2, 128),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(128, 64),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(64, num_classes)
            )

            # Tête de régression (probabilité de retournement 0-1)
            self.reversal_head = nn.Sequential(
                nn.Linear(hidden_size * 2, 64),
                nn.GELU(),
                nn.Dropout(dropout * 0.5),
                nn.Linear(64, 1),
                nn.Sigmoid()
            )

        def forward(self, x: 'torch.Tensor') -> Tuple:
            # x: (batch, seq_len, num_features)
            x = self.input_proj(x)                       # (B, T, 64)
            lstm_out, _ = self.lstm(x)                   # (B, T, hidden*2)
            lstm_out = self.ln_lstm(lstm_out)

            # Attention (résiduelle)
            attn_out = self.attention(lstm_out)
            attn_out = self.ln_attn(lstm_out + attn_out) # (B, T, hidden*2)

            # Pooling: dernier timestep + mean pooling (plus robuste)
            last_step = attn_out[:, -1, :]               # (B, hidden*2)
            mean_pool = attn_out.mean(dim=1)             # (B, hidden*2)
            pooled = last_step + mean_pool               # Combinaison

            class_logits = self.classifier(pooled)       # (B, 4)
            reversal_prob = self.reversal_head(pooled)    # (B, 1)

            return class_logits, reversal_prob

    # Enregistrer pour chargement sécurisé PyTorch ≥ 2.6
    try:
        torch.serialization.add_safe_globals([ReversalLSTM, MultiHeadAttention])
    except AttributeError:
        pass


# ═══════════════════════════════════════════════════════════════════════════════
# FEATURE ENGINEERING SPÉCIALISÉ RETOURNEMENT
# ═══════════════════════════════════════════════════════════════════════════════

def _ema_array(data: np.ndarray, period: int) -> np.ndarray:
    """EMA sur tout l'array"""
    out = np.zeros_like(data, dtype=np.float64)
    if len(data) < period:
        return out
    out[period - 1] = np.mean(data[:period])
    k = 2.0 / (period + 1)
    for i in range(period, len(data)):
        out[i] = data[i] * k + out[i - 1] * (1 - k)
    out[:period - 1] = out[period - 1]
    return out


def extract_reversal_features(prices: np.ndarray,
                               volumes: np.ndarray = None,
                               highs: np.ndarray = None,
                               lows: np.ndarray = None,
                               seq_len: int = SEQUENCE_LENGTH) -> Optional[np.ndarray]:
    """
    Extrait 30 features orientées retournement pour une séquence de bougies.

    Features 0-19: reprend les 20 features standard (compatibilité)
    Features 20-29: features SPÉCIALISÉES retournement

    Retourne: ndarray (seq_len, 30) ou None si pas assez de données
    """
    if len(prices) < seq_len + 20:  # Besoin de marge pour les indicateurs
        return None

    p = prices[-(seq_len + 20):]     # Marge de 20 en amont
    v = volumes[-(seq_len + 20):] if volumes is not None and len(volumes) >= seq_len + 20 else np.ones(seq_len + 20)
    h = highs[-(seq_len + 20):] if highs is not None and len(highs) >= seq_len + 20 else p.copy()
    lo = lows[-(seq_len + 20):] if lows is not None and len(lows) >= seq_len + 20 else p.copy()

    # Statistiques de normalisation sur la fenêtre complète
    mean_p = np.mean(p)
    std_p = np.std(p) + 1e-10

    # EMAs
    ema9 = _ema_array(p, 9)
    ema21 = _ema_array(p, 21)
    ema12 = _ema_array(p, 12)
    ema26 = _ema_array(p, 26)

    # Bollinger Bands
    bb_mid = np.zeros_like(p)
    bb_upper = np.zeros_like(p)
    bb_lower = np.zeros_like(p)
    for i in range(20, len(p)):
        w = p[i - 20:i]
        bb_mid[i] = np.mean(w)
        s = np.std(w)
        bb_upper[i] = bb_mid[i] + 2 * s
        bb_lower[i] = bb_mid[i] - 2 * s

    # ATR
    atr = np.zeros_like(p)
    for i in range(1, len(p)):
        tr = max(h[i] - lo[i], abs(h[i] - p[i - 1]), abs(lo[i] - p[i - 1]))
        atr[i] = tr
    atr_smooth = _ema_array(atr, 14)

    # Couper à la fenêtre cible (les 20 derniers servent de préchauffage)
    offset = 20
    features = np.zeros((seq_len, 30), dtype=np.float32)

    for t in range(seq_len):
        i = t + offset  # index dans le tableau complet

        # ─── Features 0-19: Standard (comme train_ai_model.py) ───
        # F0: Prix normalisé
        features[t, 0] = (p[i] - mean_p) / std_p

        # F1: Momentum 5
        features[t, 1] = (p[i] - p[i - 5]) / (p[i - 5] + 1e-10) * 100 if i >= 5 else 0
        features[t, 1] = np.clip(features[t, 1], -5, 5) / 5

        # F2: Momentum 10
        features[t, 2] = (p[i] - p[i - 10]) / (p[i - 10] + 1e-10) * 100 if i >= 10 else 0
        features[t, 2] = np.clip(features[t, 2], -5, 5) / 5

        # F3: Volatilité 10 périodes
        features[t, 3] = np.std(p[max(0, i - 10):i + 1]) / std_p if i >= 10 else 0

        # F4: RSI normalisé
        if i >= 14:
            changes = np.diff(p[i - 14:i + 1])
            gains = np.mean(np.where(changes > 0, changes, 0))
            losses = np.mean(np.where(changes < 0, -changes, 0))
            if losses > 0:
                rs = gains / losses
                features[t, 4] = (100 - 100 / (1 + rs)) / 100 - 0.5
            else:
                features[t, 4] = 0.5  # Tout en hausse

        # F5: Volume normalisé
        vol_mean = np.mean(v[max(0, i - 20):i + 1])
        vol_std = np.std(v[max(0, i - 20):i + 1]) + 1e-10
        features[t, 5] = np.clip((v[i] - vol_mean) / vol_std, -3, 3) / 3

        # F6-F7: EMA9 / EMA21 normalisées
        features[t, 6] = (ema9[i] - mean_p) / std_p
        features[t, 7] = (ema21[i] - mean_p) / std_p

        # F8: EMA diff (%)
        if ema21[i] > 0:
            features[t, 8] = np.clip((ema9[i] - ema21[i]) / ema21[i] * 100, -5, 5) / 5

        # F9: Pente EMA9 (5 périodes)
        if i >= 5 and ema9[i - 5] > 0:
            features[t, 9] = np.clip((ema9[i] - ema9[i - 5]) / ema9[i - 5] * 100, -3, 3) / 3

        # F10-F11: BB upper / lower normalisées
        features[t, 10] = (bb_upper[i] - mean_p) / std_p
        features[t, 11] = (bb_lower[i] - mean_p) / std_p

        # F12: BB bandwidth normalisée
        if bb_mid[i] > 0:
            bw = (bb_upper[i] - bb_lower[i]) / bb_mid[i] * 100
            features[t, 12] = np.clip(bw, 0, 20) / 20

        # F13: BB position (centré)
        bb_range = bb_upper[i] - bb_lower[i]
        if bb_range > 0:
            features[t, 13] = (p[i] - bb_lower[i]) / bb_range - 0.5

        # F14: MACD normalisé
        if mean_p > 0:
            features[t, 14] = np.clip((ema12[i] - ema26[i]) / mean_p * 100, -3, 3) / 3

        # F15: ATR normalisé
        if mean_p > 0:
            features[t, 15] = np.clip(atr_smooth[i] / mean_p * 100, 0, 5) / 5

        # F16: Momentum 3 (court terme)
        if i >= 3:
            features[t, 16] = (p[i] - p[i - 3]) / (p[i - 3] + 1e-10) * 100
            features[t, 16] = np.clip(features[t, 16], -3, 3) / 3

        # F17: Volume ratio
        vol_ma10 = np.mean(v[max(0, i - 10):i]) if i >= 10 else (np.mean(v[:i + 1]) + 1e-10)
        if vol_ma10 > 0:
            features[t, 17] = np.clip(v[i] / vol_ma10 - 1, -2, 5) / 5

        # F18: Range high-low normalisé
        if p[i] > 0:
            features[t, 18] = np.clip((h[i] - lo[i]) / p[i] * 100, 0, 10) / 10

        # F19: Distance prix-EMA21
        if ema21[i] > 0:
            features[t, 19] = np.clip((p[i] - ema21[i]) / ema21[i] * 100, -5, 5) / 5

        # ─── Features 20-29: SPÉCIALISÉES RETOURNEMENT ───

        # F20: Momentum acceleration (dérivée du momentum 3)
        if i >= 6:
            mom3_now = (p[i] - p[i - 3]) / (p[i - 3] + 1e-10) * 100
            mom3_prev = (p[i - 3] - p[i - 6]) / (p[i - 6] + 1e-10) * 100
            features[t, 20] = np.clip(mom3_now - mom3_prev, -3, 3) / 3

        # F21: Momentum jerk (2nde dérivée — accélération de l'accélération)
        if i >= 9:
            mom3_pp = (p[i - 6] - p[i - 9]) / (p[i - 9] + 1e-10) * 100
            accel_now = mom3_now - mom3_prev if i >= 6 else 0
            accel_prev = mom3_prev - mom3_pp if i >= 9 else 0
            features[t, 21] = np.clip(accel_now - accel_prev, -3, 3) / 3

        # F22: RSI divergence (prix fait un creux plus bas mais RSI fait un creux plus haut)
        if i >= 20:
            price_min_recent = np.min(p[i - 10:i + 1])
            price_min_prev = np.min(p[i - 20:i - 10])
            rsi_at_recent = features[t, 4]
            # Approximation: si prix plus bas mais RSI plus haut → divergence haussière
            if price_min_recent < price_min_prev and rsi_at_recent > -0.2:
                features[t, 22] = 0.5  # Signal de divergence haussière
            elif price_min_recent > price_min_prev and rsi_at_recent < -0.2:
                features[t, 22] = -0.5  # Divergence baissière
            # On affine avec les 5 dernières bougies
            if i >= 5:
                rsi_slope = features[t, 4] - features[max(0, t - 5), 4] if t >= 5 else 0
                mom_slope = features[t, 1] - features[max(0, t - 5), 1] if t >= 5 else 0
                if rsi_slope > 0 and mom_slope < 0:
                    features[t, 22] = max(features[t, 22], 0.3)

        # F23: Volume anomaly (volume spike relatif au max récent)
        if i >= 20:
            vol_max_20 = np.max(v[i - 20:i]) + 1e-10
            features[t, 23] = np.clip(v[i] / vol_max_20, 0, 3) / 3

        # F24: Distance au plus bas sur 20 bougies (normalised)
        if i >= 20:
            low_20 = np.min(p[i - 20:i + 1])
            high_20 = np.max(p[i - 20:i + 1])
            range_20 = high_20 - low_20 + 1e-10
            features[t, 24] = (p[i] - low_20) / range_20 - 0.5  # -0.5 = au plus bas, +0.5 = au plus haut

        # F25: Taux de changement pente EMA21 (retournement tendance de fond)
        if i >= 15 and ema21[i - 5] > 0 and ema21[i - 10] > 0:
            slope_now = (ema21[i] - ema21[i - 5]) / ema21[i - 5] * 100
            slope_prev = (ema21[i - 5] - ema21[i - 10]) / ema21[i - 10] * 100
            features[t, 25] = np.clip(slope_now - slope_prev, -2, 2) / 2

        # F26: Ratio corps/mèche des bougies (doji detection)
        body = abs(p[i] - p[max(0, i - 1)])  # Approximation open ≈ close précédent
        wick = h[i] - lo[i] + 1e-10
        features[t, 26] = np.clip(body / wick, 0, 1)

        # F27: EMA convergence speed (vitesse de rapprochement EMA9/EMA21)
        if i >= 5 and ema21[i] > 0 and ema21[i - 5] > 0:
            gap_now = abs(ema9[i] - ema21[i]) / ema21[i] * 100
            gap_prev = abs(ema9[i - 5] - ema21[i - 5]) / ema21[i - 5] * 100
            # Positif = convergence (gap diminue), négatif = divergence
            features[t, 27] = np.clip(gap_prev - gap_now, -2, 2) / 2

        # F28: Price rate-of-change smoothed (momentum lissé sur 7 bougies)
        if i >= 7:
            roc7 = (p[i] - p[i - 7]) / (p[i - 7] + 1e-10) * 100
            features[t, 28] = np.clip(roc7, -5, 5) / 5

        # F29: Stochastique %K (14 périodes) — complémentaire au RSI
        if i >= 14:
            low_14 = np.min(lo[i - 14:i + 1])
            high_14 = np.max(h[i - 14:i + 1])
            range_14 = high_14 - low_14 + 1e-10
            features[t, 29] = (p[i] - low_14) / range_14 - 0.5

    return features


def create_reversal_labels(prices: np.ndarray,
                            lookforward: int = REVERSAL_LOOKFORWARD,
                            threshold_pct: float = REVERSAL_THRESHOLD_PCT) -> np.ndarray:
    """
    Crée les labels de retournement pour l'entraînement.

    Pour chaque bougie t, on regarde les `lookforward` bougies suivantes:
      - Si le mouvement min→max depuis t dépasse +threshold → REVERSAL_UP (1)
      - Si le mouvement max→min depuis t dépasse -threshold → REVERSAL_DOWN (2)
      - Si le prix continue dans la même direction que les 5 bougies précédentes → CONTINUATION (3)
      - Sinon → NEUTRAL (0)

    Retourne: array (n_samples,) d'int64
    """
    n = len(prices)
    labels = np.zeros(n, dtype=np.int64)

    for t in range(5, n - lookforward):
        future = prices[t + 1:t + lookforward + 1]
        current = prices[t]

        max_future = np.max(future)
        min_future = np.min(future)

        up_move = (max_future - current) / current * 100
        down_move = (min_future - current) / current * 100

        # Direction passée (5 bougies)
        past_move = (prices[t] - prices[t - 5]) / prices[t - 5] * 100

        if up_move >= threshold_pct and abs(up_move) > abs(down_move) * 1.3:
            # Retournement haussier: hausse significative ET domination sur la baisse
            if past_move < 0:
                labels[t] = 1  # REVERSAL_UP (prix était en baisse → monte)
            else:
                labels[t] = 3  # CONTINUATION haussière
        elif down_move <= -threshold_pct and abs(down_move) > abs(up_move) * 1.3:
            # Retournement baissier
            if past_move > 0:
                labels[t] = 2  # REVERSAL_DOWN (prix était en hausse → descend)
            else:
                labels[t] = 3  # CONTINUATION baissière
        else:
            labels[t] = 0  # NEUTRAL

    return labels


# ═══════════════════════════════════════════════════════════════════════════════
# PRÉDICTEUR PRINCIPAL
# ═══════════════════════════════════════════════════════════════════════════════

class LSTMReversalPredictor:
    """
    Prédicteur de retournements basé sur BiLSTM + Attention.
    S'intègre dans le pipeline ai_predictor.py pour améliorer la détection creux/rebond.
    """

    def __init__(self):
        self.model: Optional['ReversalLSTM'] = None
        self.optimizer = None
        self.scheduler = None
        self.is_trained = False
        self.last_train_time: Optional[datetime] = None
        self.prediction_count = 0
        self.correct_predictions = 0

        # Buffer de données pour l'entraînement online
        self._data_buffer: deque = deque(maxlen=5000)  # ~5000 échantillons max
        self._label_buffer: deque = deque(maxlen=5000)
        self._lock = threading.Lock()

        # Stats
        self.stats = {
            'model_version': '1.0',
            'total_predictions': 0,
            'accuracy': 0.0,
            'reversal_up_detected': 0,
            'reversal_down_detected': 0,
            'false_reversals': 0,
            'last_train': None,
            'train_loss': 0.0,
            'train_accuracy': 0.0,
            'gpu_used': DEVICE == 'cuda'
        }

        self._init_model()
        self._load_stats()
        logger.info(f"✅ LSTMReversalPredictor initialisé (device={DEVICE})")

    def _init_model(self):
        """Initialise ou charge le modèle"""
        if not TORCH_AVAILABLE:
            logger.warning("⚠️ PyTorch non disponible — mode dégradé")
            return

        # Essayer de charger un modèle existant
        MODEL_DIR.mkdir(parents=True, exist_ok=True)

        if MODEL_PATH.exists():
            try:
                try:
                    torch.serialization.add_safe_globals([ReversalLSTM, MultiHeadAttention])
                except AttributeError:
                    pass

                loaded = torch.load(str(MODEL_PATH), map_location=DEVICE, weights_only=False)
                if isinstance(loaded, dict) and 'model_state_dict' in loaded:
                    self.model = ReversalLSTM().to(DEVICE)
                    self.model.load_state_dict(loaded['model_state_dict'])
                    self.is_trained = loaded.get('is_trained', True)
                    logger.info(f"✅ Modèle reversal chargé: {MODEL_PATH}")
                else:
                    self.model = loaded
                    self.is_trained = True
            except Exception as e:
                logger.warning(f"⚠️ Impossible de charger {MODEL_PATH}: {e}")
                self.model = ReversalLSTM().to(DEVICE)
                self.is_trained = False
        else:
            self.model = ReversalLSTM().to(DEVICE)
            self.is_trained = False
            logger.info("🆕 Nouveau modèle reversal créé (non entraîné)")

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4
        )
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )

        # Compter les paramètres
        n_params = sum(p.numel() for p in self.model.parameters())
        logger.info(f"📐 Modèle reversal: {n_params:,} paramètres ({n_params * 4 / 1024**2:.1f} MB)")

    def _save_model(self):
        """Sauvegarde le modèle sur disque"""
        if not self.model:
            return
        try:
            MODEL_DIR.mkdir(parents=True, exist_ok=True)
            torch.save({
                'model_state_dict': self.model.state_dict(),
                'is_trained': self.is_trained,
                'stats': self.stats,
                'timestamp': datetime.now().isoformat()
            }, str(MODEL_PATH))
            logger.info(f"💾 Modèle reversal sauvegardé: {MODEL_PATH}")
        except Exception as e:
            logger.warning(f"⚠️ Erreur sauvegarde modèle: {e}")

    def _load_stats(self):
        """Charge les stats depuis le disque"""
        if STATS_PATH.exists():
            try:
                with open(STATS_PATH, 'r') as f:
                    saved = json.load(f)
                self.stats.update(saved)
            except Exception:
                pass

    def _save_stats(self):
        """Sauvegarde les stats"""
        try:
            with open(STATS_PATH, 'w') as f:
                json.dump(self.stats, f, indent=2, default=str)
        except Exception:
            pass

    # ─────────────────────────────────────────────────────────────────────
    # PRÉDICTION
    # ─────────────────────────────────────────────────────────────────────

    def predict(self,
                prices: np.ndarray,
                volumes: np.ndarray = None,
                highs: np.ndarray = None,
                lows: np.ndarray = None) -> Dict:
        """
        Prédit s'il y a un retournement imminent.

        Retourne un dict:
        {
            'reversal_class': int,          # 0=neutral, 1=reversal_up, 2=reversal_down, 3=continuation
            'reversal_label': str,          # 'NEUTRAL', 'REVERSAL_UP', 'REVERSAL_DOWN', 'CONTINUATION'
            'reversal_probability': float,  # 0-1 probabilité de retournement
            'confidence': float,            # 0-100 confiance dans la prédiction
            'class_probabilities': list,    # [p_neutral, p_rev_up, p_rev_down, p_continuation]
            'is_reversal_signal': bool,     # True si signal fort de retournement haussier
            'is_danger_signal': bool,       # True si signal fort de retournement baissier
        }
        """
        default_result = {
            'reversal_class': 0,
            'reversal_label': 'NEUTRAL',
            'reversal_probability': 0.0,
            'confidence': 0.0,
            'class_probabilities': [1.0, 0.0, 0.0, 0.0],
            'is_reversal_signal': False,
            'is_danger_signal': False,
        }

        if not TORCH_AVAILABLE or not self.model:
            return default_result

        prices_arr = np.array(prices, dtype=np.float64)
        features = extract_reversal_features(prices_arr, volumes, highs, lows)
        if features is None:
            return default_result

        # Inférence GPU
        try:
            self.model.eval()
            with torch.no_grad():
                X = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(DEVICE)
                class_logits, rev_prob = self.model(X)

                probs = F.softmax(class_logits, dim=-1).cpu().numpy()[0]
                rev_probability = rev_prob.cpu().item()
                predicted_class = int(np.argmax(probs))

            labels_map = {0: 'NEUTRAL', 1: 'REVERSAL_UP', 2: 'REVERSAL_DOWN', 3: 'CONTINUATION'}
            confidence = float(probs[predicted_class]) * 100

            # Signaux exploitables
            is_reversal_signal = (
                predicted_class == 1 and
                probs[1] >= REVERSAL_CONFIDENCE_MIN and
                rev_probability >= 0.4
            )
            is_danger_signal = (
                predicted_class == 2 and
                probs[2] >= REVERSAL_CONFIDENCE_MIN and
                rev_probability >= 0.4
            )

            self.stats['total_predictions'] += 1
            if is_reversal_signal:
                self.stats['reversal_up_detected'] += 1
            if is_danger_signal:
                self.stats['reversal_down_detected'] += 1

            return {
                'reversal_class': int(predicted_class),
                'reversal_label': labels_map[predicted_class],
                'reversal_probability': round(float(rev_probability), 4),
                'confidence': round(float(confidence), 1),
                'class_probabilities': [round(float(x), 4) for x in probs],
                'is_reversal_signal': bool(is_reversal_signal),
                'is_danger_signal': bool(is_danger_signal),
            }

        except Exception as e:
            logger.warning(f"⚠️ Erreur prédiction reversal: {e}")
            return default_result

    # ─────────────────────────────────────────────────────────────────────
    # ENTRAÎNEMENT ONLINE
    # ─────────────────────────────────────────────────────────────────────

    def feed_data(self, prices: np.ndarray, volumes: np.ndarray = None,
                  highs: np.ndarray = None, lows: np.ndarray = None):
        """
        Alimente le buffer d'entraînement avec de nouvelles données.
        Appelé périodiquement par le bot avec les données les plus récentes.
        """
        if not TORCH_AVAILABLE or len(prices) < SEQUENCE_LENGTH + REVERSAL_LOOKFORWARD + 20:
            return

        prices_arr = np.array(prices, dtype=np.float64)
        labels = create_reversal_labels(prices_arr)

        # Générer des échantillons glissants
        for end_idx in range(SEQUENCE_LENGTH + 20, len(prices_arr) - REVERSAL_LOOKFORWARD, 3):
            feat = extract_reversal_features(
                prices_arr[:end_idx + 1],
                volumes[:end_idx + 1] if volumes is not None else None,
                highs[:end_idx + 1] if highs is not None else None,
                lows[:end_idx + 1] if lows is not None else None
            )
            if feat is not None:
                label = labels[end_idx]
                with self._lock:
                    self._data_buffer.append(feat)
                    self._label_buffer.append(label)

    def should_retrain(self) -> bool:
        """Vérifie si un ré-entraînement est nécessaire"""
        if not TORCH_AVAILABLE:
            return False
        if len(self._data_buffer) < MIN_SAMPLES_TRAIN:
            return False
        if self.last_train_time is None:
            return True
        elapsed = (datetime.now() - self.last_train_time).total_seconds() / 60
        return elapsed >= RETRAIN_INTERVAL_MIN

    def train_online(self) -> Dict:
        """
        Ré-entraîne le modèle avec les données du buffer.
        Retourne les métriques d'entraînement.
        """
        if not TORCH_AVAILABLE or not self.model:
            return {'status': 'unavailable'}

        with self._lock:
            if len(self._data_buffer) < MIN_SAMPLES_TRAIN:
                return {'status': 'insufficient_data', 'samples': len(self._data_buffer)}

            X = np.array(list(self._data_buffer), dtype=np.float32)
            y = np.array(list(self._label_buffer), dtype=np.int64)

        logger.info(f"🎓 Entraînement reversal: {len(X)} échantillons sur {DEVICE}")

        # Distribution des classes
        unique, counts = np.unique(y, return_counts=True)
        dist = {int(u): int(c) for u, c in zip(unique, counts)}
        logger.info(f"   Classes: {dist}")

        # Pondération des classes (les reversals sont rares → surpondérer)
        class_weights = np.ones(NUM_CLASSES, dtype=np.float32)
        total = len(y)
        for cls_id in range(NUM_CLASSES):
            count = np.sum(y == cls_id)
            if count > 0:
                class_weights[cls_id] = total / (NUM_CLASSES * count)
            else:
                class_weights[cls_id] = 1.0
        weights_tensor = torch.tensor(class_weights).to(DEVICE)

        # Split train/val (85/15)
        n_val = max(1, int(len(X) * 0.15))
        indices = np.random.permutation(len(X))
        X_train, X_val = X[indices[n_val:]], X[indices[:n_val]]
        y_train, y_val = y[indices[n_val:]], y[indices[:n_val]]

        # Data augmentation: ajouter bruit gaussien sur les features d'entraînement
        noise = np.random.normal(0, 0.01, X_train.shape).astype(np.float32)
        X_train_aug = np.concatenate([X_train, X_train + noise], axis=0)
        y_train_aug = np.concatenate([y_train, y_train], axis=0)

        # Créer les DataLoaders
        train_ds = TensorDataset(
            torch.tensor(X_train_aug), torch.tensor(y_train_aug)
        )
        train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

        X_val_t = torch.tensor(X_val).to(DEVICE)
        y_val_t = torch.tensor(y_val).to(DEVICE)

        # Critères de perte
        class_criterion = nn.CrossEntropyLoss(weight=weights_tensor)
        rev_criterion = nn.BCELoss()

        # Entraînement
        epochs = ONLINE_EPOCHS if self.is_trained else PRETRAIN_EPOCHS
        self.model.train()
        best_val_loss = float('inf')
        best_state = None

        for epoch in range(epochs):
            epoch_loss = 0.0
            n_batches = 0

            for batch_X, batch_y in train_loader:
                batch_X = batch_X.to(DEVICE)
                batch_y = batch_y.to(DEVICE)

                self.optimizer.zero_grad()
                class_logits, rev_prob = self.model(batch_X)

                # Perte classification
                loss_cls = class_criterion(class_logits, batch_y)

                # Perte régression reversal (1 si class 1 ou 2, 0 sinon)
                rev_target = ((batch_y == 1) | (batch_y == 2)).float().unsqueeze(1)
                loss_rev = rev_criterion(rev_prob, rev_target)

                loss = loss_cls + 0.3 * loss_rev
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                self.optimizer.step()

                epoch_loss += loss.item()
                n_batches += 1

            avg_loss = epoch_loss / max(n_batches, 1)

            # Validation
            self.model.eval()
            with torch.no_grad():
                val_logits, _ = self.model(X_val_t)
                val_loss = class_criterion(val_logits, y_val_t).item()
                val_preds = val_logits.argmax(dim=1)
                val_acc = (val_preds == y_val_t).float().mean().item()

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_state = {k: v.clone() for k, v in self.model.state_dict().items()}
            self.model.train()

            if (epoch + 1) % max(1, epochs // 5) == 0 or epoch == epochs - 1:
                logger.info(f"   Epoch {epoch + 1}/{epochs}: train_loss={avg_loss:.4f} val_loss={val_loss:.4f} val_acc={val_acc:.1%}")

            self.scheduler.step(val_loss)

        # Restaurer le meilleur modèle
        if best_state:
            self.model.load_state_dict(best_state)

        self.model.eval()
        self.is_trained = True
        self.last_train_time = datetime.now()

        # Métriques finales
        with torch.no_grad():
            val_logits, _ = self.model(X_val_t)
            val_preds = val_logits.argmax(dim=1)
            final_acc = (val_preds == y_val_t).float().mean().item()

        self.stats['train_loss'] = round(best_val_loss, 4)
        self.stats['train_accuracy'] = round(final_acc, 4)
        self.stats['last_train'] = datetime.now().isoformat()

        # Sauvegarder
        self._save_model()
        self._save_stats()

        result = {
            'status': 'success',
            'epochs': epochs,
            'samples': len(X),
            'val_accuracy': round(final_acc, 4),
            'val_loss': round(best_val_loss, 4),
            'class_distribution': dist,
            'device': DEVICE
        }
        logger.info(f"✅ Entraînement reversal terminé: acc={final_acc:.1%} loss={best_val_loss:.4f}")
        return result

    def get_stats(self) -> Dict:
        """Retourne les statistiques du prédicteur"""
        return {
            **self.stats,
            'is_trained': self.is_trained,
            'buffer_size': len(self._data_buffer),
            'model_loaded': self.model is not None,
            'device': DEVICE,
        }


# ═══════════════════════════════════════════════════════════════════════════════
# SINGLETON
# ═══════════════════════════════════════════════════════════════════════════════
_reversal_predictor: Optional[LSTMReversalPredictor] = None

def get_reversal_predictor() -> LSTMReversalPredictor:
    """Retourne le singleton du prédicteur de retournements"""
    global _reversal_predictor
    if _reversal_predictor is None:
        _reversal_predictor = LSTMReversalPredictor()
    return _reversal_predictor


# ═══════════════════════════════════════════════════════════════════════════════
# MAIN (test standalone)
# ═══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s')

    pred = get_reversal_predictor()
    print(f"\n{'='*60}")
    print(f"LSTM Reversal Predictor")
    print(f"{'='*60}")
    print(f"Device: {DEVICE}")
    print(f"Model loaded: {pred.model is not None}")
    print(f"Trained: {pred.is_trained}")
    print(f"Stats: {json.dumps(pred.get_stats(), indent=2, default=str)}")

    # Test avec données synthétiques
    np.random.seed(42)
    n = 200
    prices = 100 + np.cumsum(np.random.randn(n) * 0.5)
    volumes = np.abs(np.random.randn(n) * 1000 + 5000)

    print(f"\nTest prédiction ({n} bougies)...")
    result = pred.predict(prices, volumes)
    print(f"Résultat: {json.dumps(result, indent=2)}")

    print(f"\nAlimentation buffer pour entraînement...")
    for _ in range(5):
        prices_chunk = 100 + np.cumsum(np.random.randn(300) * 0.5)
        volumes_chunk = np.abs(np.random.randn(300) * 1000 + 5000)
        pred.feed_data(prices_chunk, volumes_chunk)
    print(f"Buffer: {len(pred._data_buffer)} échantillons")

    if pred.should_retrain():
        print(f"\nEntraînement online...")
        metrics = pred.train_online()
        print(f"Résultat: {json.dumps(metrics, indent=2)}")

        print(f"\nRe-prédiction après entraînement...")
        result2 = pred.predict(prices, volumes)
        print(f"Résultat: {json.dumps(result2, indent=2)}")
