#!/usr/bin/env python3
"""
🎓 Script d'entraînement du modèle IA
=====================================
Utilise les données historiques de Binance pour entraîner le modèle LSTM
sur GPU (RTX 5060 Ti) pour prédire les mouvements de prix.

Usage:
    python train_ai_model.py                    # Entraînement standard
    python train_ai_model.py --epochs 100       # Plus d'epochs
    python train_ai_model.py --symbols 20       # Limiter le nombre de symboles
"""

import os
# Limiter les threads BLAS/MKL/OpenBLAS AVANT l'import de numpy (serveur CPU sans GPU)
os.environ.setdefault('OMP_NUM_THREADS', '2')
os.environ.setdefault('MKL_NUM_THREADS', '2')
os.environ.setdefault('OPENBLAS_NUM_THREADS', '2')
import sys
import json
import time
import gc
import argparse
import requests
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Dict, Tuple

# Configuration du logging
import logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s | %(levelname)s | %(message)s',
    datefmt='%H:%M:%S'
)
logger = logging.getLogger("AITrainer")

# Vérifier PyTorch
try:
    import torch
    import torch.nn as nn
    from torch.utils.data import DataLoader, TensorDataset
    
    # Vérifier si CUDA est disponible ET compatible
    DEVICE = "cpu"  # Par défaut CPU
    GPU_NAME = None
    
    if torch.cuda.is_available():
        GPU_NAME = torch.cuda.get_device_name(0)
        # Tester si le GPU est vraiment utilisable
        try:
            # Test simple pour vérifier la compatibilité
            test_tensor = torch.zeros(1).cuda()
            del test_tensor
            DEVICE = "cuda"
            logger.info(f"✅ GPU détecté et compatible: {GPU_NAME}")
        except RuntimeError as e:
            if "no kernel image" in str(e) or "not compatible" in str(e):
                logger.warning(f"⚠️ GPU {GPU_NAME} détecté mais non compatible avec PyTorch")
                logger.warning("   Architecture trop récente - Utilisation du CPU")
                DEVICE = "cpu"
            else:
                raise
    else:
        logger.info("ℹ️ Pas de GPU CUDA, utilisation du CPU")
    
    logger.info(f"🔧 Device sélectionné: {DEVICE}")
    # Limiter les threads PyTorch pour ne pas saturer le serveur CPU
    torch.set_num_threads(2)
    torch.set_num_interop_threads(1)
    logger.info("🔧 Threads PyTorch limités à 2 (économie CPU/RAM, pas de GPU)")

except ImportError:
    logger.error("❌ PyTorch non installé!")
    sys.exit(1)

# Importer le modèle
from ai_predictor import get_ai_predictor, TORCH_AVAILABLE

# ═══════════════════════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════

BINANCE_API = "https://api.binance.com/api/v3"
WATCHLIST_FILE = Path(__file__).parent / "watchlist.json"
STATS_FILE = Path(__file__).parent / "ai_training_stats.json"
MODEL_PATH = Path(__file__).parent / "models" / "predictor.pt"

# Paramètres d'entraînement par défaut (optimisés pour serveur CPU sans GPU)
DEFAULT_EPOCHS = 30       # Réduit 50→30 (l'early stopping gère la qualité)
DEFAULT_BATCH_SIZE = 64   # Augmenté 32→64 (meilleure utilisation BLAS sur CPU)
DEFAULT_LOOKBACK = 50     # Nombre de bougies pour chaque échantillon
DEFAULT_FUTURE = 12       # Nombre de bougies pour prédire le résultat (12 x 5min = 1h)
GAIN_THRESHOLD = 1.0      # % de gain pour considérer comme "hausse"
MAX_TOTAL_SAMPLES = 15000 # Cap mémoire: max 15k échantillons (évite saturation swap)


# ═══════════════════════════════════════════════════════════════════════════════
# FONCTIONS UTILITAIRES
# ═══════════════════════════════════════════════════════════════════════════════

def load_watchlist() -> List[str]:
    """Charge la liste des symboles"""
    if WATCHLIST_FILE.exists():
        with open(WATCHLIST_FILE, 'r') as f:
            data = json.load(f)
            return data.get('symbols', [])
    return ['BTCUSDT', 'ETHUSDT', 'SOLUSDT']


HISTORICAL_DIR = Path(__file__).parent / "historical_data"


def fetch_klines(symbol: str, interval: str = '5m', limit: int = 1000) -> List:
    """Récupère les klines depuis Binance"""
    try:
        url = f"{BINANCE_API}/klines"
        params = {'symbol': symbol, 'interval': interval, 'limit': limit}
        response = requests.get(url, params=params, timeout=10)
        if response.status_code == 200:
            return response.json()
    except Exception as e:
        logger.warning(f"Erreur fetch {symbol}: {e}")
    return []


def load_local_historical(symbol: str, interval: str = '1h') -> List[Dict]:
    """Charge les données historiques locales (historical_data/) pour un symbole.
    Retourne une liste de dicts avec keys: open, high, low, close, volume."""
    hist_file = HISTORICAL_DIR / f"{symbol}_historical.json"
    if not hist_file.exists():
        return []
    try:
        with open(hist_file, 'r') as f:
            data = json.load(f)
        intervals = data.get('intervals', {})
        if interval not in intervals:
            # Fallback: essayer n'importe quel intervalle disponible
            available = list(intervals.keys())
            if not available:
                return []
            interval = available[0]
        klines = intervals[interval].get('klines', [])
        logger.info(f"    📂 Local: {len(klines)} candles {interval} pour {symbol}")
        return klines
    except Exception as e:
        logger.warning(f"Erreur lecture historique local {symbol}: {e}")
        return []


def _ema(arr: np.ndarray, period: int) -> np.ndarray:
    """Calcule l'EMA sur un array (retourne array de même taille)"""
    ema = np.zeros_like(arr, dtype=np.float64)
    if len(arr) < period:
        return ema
    ema[period-1] = np.mean(arr[:period])
    multiplier = 2.0 / (period + 1)
    for i in range(period, len(arr)):
        ema[i] = arr[i] * multiplier + ema[i-1] * (1 - multiplier)
    # Remplir les premières valeurs avec la première EMA calculée
    ema[:period-1] = ema[period-1]
    return ema


def create_training_sample(prices: np.ndarray, volumes: np.ndarray = None,
                           highs: np.ndarray = None, lows: np.ndarray = None,
                           lookback: int = 50) -> np.ndarray:
    """Crée un échantillon d'entraînement avec les 20 features complètes.
    
    Features (alignées avec ai_predictor.py PatternFeatures):
     0: Prix normalisé (z-score)
     1: Momentum 5 périodes
     2: Momentum 10 périodes
     3: Volatilité glissante (10 périodes)
     4: RSI normalisé (centré sur 0)
     5: Volume normalisé (z-score)
     6: EMA9 normalisée
     7: EMA21 normalisée
     8: EMA_diff (EMA9-EMA21 en % du prix, utile pour squeeze/creux)
     9: EMA_slope (pente EMA9 sur 5 périodes)
    10: Bollinger upper band normalisée
    11: Bollinger lower band normalisée
    12: Bollinger bandwidth (indicateur de squeeze)
    13: Position dans les Bandes de Bollinger (0-1)
    14: MACD normalisé (EMA12-EMA26)
    15: ATR normalisé (Average True Range)
    16: Momentum 3 périodes (court terme, crucial pour timing)
    17: Volume momentum (ratio volume courant / moyenne)
    18: High-Low range normalisé (mesure de volatilité intrabar)
    19: Price-EMA21 distance (tendance de fond)
    """
    if len(prices) < lookback:
        return None
    
    prices_sample = prices[-lookback:]
    mean_p = np.mean(prices_sample)
    std_p = np.std(prices_sample)
    if std_p == 0:
        return None
    
    # Normaliser
    prices_norm = (prices_sample - mean_p) / std_p
    
    # Préparer volumes, highs, lows
    if volumes is not None and len(volumes) >= lookback:
        vol_sample = volumes[-lookback:]
        vol_mean = np.mean(vol_sample)
        vol_std = np.std(vol_sample)
        vol_norm = (vol_sample - vol_mean) / vol_std if vol_std > 0 else np.zeros(lookback)
    else:
        vol_sample = None
        vol_norm = np.zeros(lookback)
    
    if highs is not None and len(highs) >= lookback:
        highs_sample = highs[-lookback:]
    else:
        highs_sample = prices_sample  # Fallback
    
    if lows is not None and len(lows) >= lookback:
        lows_sample = lows[-lookback:]
    else:
        lows_sample = prices_sample  # Fallback
    
    # Calculer EMAs sur un historique plus large pour stabilité
    ema9_arr = _ema(prices_sample, 9)
    ema21_arr = _ema(prices_sample, 21)
    ema12_arr = _ema(prices_sample, 12)
    ema26_arr = _ema(prices_sample, 26)
    
    # Bollinger Bands (période 20, 2 std) — vectorisé
    bb_period = 20
    bb_mid = np.zeros(lookback)
    bb_upper = np.zeros(lookback)
    bb_lower = np.zeros(lookback)
    if lookback > bb_period:
        from numpy.lib.stride_tricks import sliding_window_view
        _bb_wins = sliding_window_view(prices_sample, bb_period)[:lookback - bb_period]
        _bb_mid_v = _bb_wins.mean(axis=1)
        _bb_std_v = _bb_wins.std(axis=1)
        bb_mid[bb_period:] = _bb_mid_v
        bb_upper[bb_period:] = _bb_mid_v + 2 * _bb_std_v
        bb_lower[bb_period:] = _bb_mid_v - 2 * _bb_std_v
    
    # ATR (Average True Range) — vectorisé
    atr_arr = np.zeros(lookback)
    atr_arr[1:] = np.maximum(
        highs_sample[1:] - lows_sample[1:],
        np.maximum(
            np.abs(highs_sample[1:] - prices_sample[:-1]),
            np.abs(lows_sample[1:] - prices_sample[:-1])
        )
    )
    # Smoothed ATR (14 périodes)
    atr_smooth = _ema(atr_arr, 14)
    
    # Créer les features (50 timesteps x 20 features)
    features = np.zeros((lookback, 20), dtype=np.float32)
    
    # F0: Prix normalisé
    features[:, 0] = prices_norm
    
    # F1: Momentum 5 — vectorisé
    features[5:, 1] = (prices_norm[5:] - prices_norm[:lookback - 5]) / 5

    # F2: Momentum 10 — vectorisé
    features[10:, 2] = (prices_norm[10:] - prices_norm[:lookback - 10]) / 10

    # F3: Volatilité glissante — vectorisé
    if lookback > 10:
        from numpy.lib.stride_tricks import sliding_window_view
        _vol_wins = sliding_window_view(prices_norm, 10)[:lookback - 10]
        features[10:, 3] = _vol_wins.std(axis=1)
    
    # F4: RSI normalisé
    for i in range(14, lookback):
        changes = np.diff(prices_sample[i-14:i+1])
        gains = np.mean(np.where(changes > 0, changes, 0))
        losses = np.mean(np.where(changes < 0, -changes, 0))
        if losses > 0:
            rs = gains / losses
            features[i, 4] = (100 - (100 / (1 + rs))) / 100 - 0.5
    
    # F5: Volume normalisé
    features[:, 5] = np.clip(vol_norm, -3, 3) / 3  # Clippé et redimensionné [-1, 1]
    
    # F6: EMA9 normalisée
    features[:, 6] = (ema9_arr - mean_p) / std_p
    
    # F7: EMA21 normalisée
    features[:, 7] = (ema21_arr - mean_p) / std_p
    
    # F8: EMA_diff (EMA9-EMA21) en % du prix — crucial pour stratégies creux/squeeze
    for i in range(lookback):
        if ema21_arr[i] > 0:
            features[i, 8] = (ema9_arr[i] - ema21_arr[i]) / ema21_arr[i] * 100
    features[:, 8] = np.clip(features[:, 8], -5, 5) / 5  # Normalisé [-1, 1]
    
    # F9: Pente EMA9 (sur 5 périodes)
    for i in range(5, lookback):
        if ema9_arr[i-5] > 0:
            features[i, 9] = (ema9_arr[i] - ema9_arr[i-5]) / ema9_arr[i-5] * 100
    features[:, 9] = np.clip(features[:, 9], -3, 3) / 3
    
    # F10: Bollinger upper normalisée
    for i in range(bb_period, lookback):
        if std_p > 0:
            features[i, 10] = (bb_upper[i] - mean_p) / std_p
    
    # F11: Bollinger lower normalisée
    for i in range(bb_period, lookback):
        if std_p > 0:
            features[i, 11] = (bb_lower[i] - mean_p) / std_p
    
    # F12: Bollinger bandwidth (indicateur de squeeze)
    for i in range(bb_period, lookback):
        if bb_mid[i] > 0:
            features[i, 12] = (bb_upper[i] - bb_lower[i]) / bb_mid[i] * 100
    bw_max = np.max(np.abs(features[:, 12])) if np.max(np.abs(features[:, 12])) > 0 else 1
    features[:, 12] /= bw_max  # Normalisé
    
    # F13: Position dans les Bandes de Bollinger (0-1)
    for i in range(bb_period, lookback):
        bb_range = bb_upper[i] - bb_lower[i]
        if bb_range > 0:
            features[i, 13] = (prices_sample[i] - bb_lower[i]) / bb_range
    features[:, 13] = np.clip(features[:, 13], 0, 1) - 0.5  # Centré sur 0
    
    # F14: MACD normalisé
    for i in range(lookback):
        if mean_p > 0:
            features[i, 14] = (ema12_arr[i] - ema26_arr[i]) / mean_p * 100
    features[:, 14] = np.clip(features[:, 14], -3, 3) / 3
    
    # F15: ATR normalisé
    for i in range(lookback):
        if mean_p > 0:
            features[i, 15] = atr_smooth[i] / mean_p * 100  # ATR en % du prix
    atr_max = np.max(features[:, 15]) if np.max(features[:, 15]) > 0 else 1
    features[:, 15] /= atr_max
    
    # F16: Momentum 3 (court terme — crucial pour timing d'entrée) — vectorisé
    features[3:, 16] = (prices_norm[3:] - prices_norm[:lookback - 3]) / 3
    
    # F17: Volume momentum (ratio volume courant / moyenne mobile 10)
    if vol_sample is not None:
        vol_ma = _ema(vol_sample.astype(np.float64), 10)
        for i in range(10, lookback):
            if vol_ma[i] > 0:
                features[i, 17] = (vol_sample[i] / vol_ma[i]) - 1  # 0 = normal, >0 = volume élevé
        features[:, 17] = np.clip(features[:, 17], -2, 5) / 5
    
    # F18: High-Low range normalisé
    for i in range(lookback):
        if prices_sample[i] > 0:
            features[i, 18] = (highs_sample[i] - lows_sample[i]) / prices_sample[i] * 100
    hl_max = np.max(features[:, 18]) if np.max(features[:, 18]) > 0 else 1
    features[:, 18] /= hl_max
    
    # F19: Distance prix-EMA21 (tendance de fond)
    for i in range(lookback):
        if ema21_arr[i] > 0:
            features[i, 19] = (prices_sample[i] - ema21_arr[i]) / ema21_arr[i] * 100
    features[:, 19] = np.clip(features[:, 19], -5, 5) / 5
    
    return features


def determine_label(prices: np.ndarray, future_bars: int = 12, threshold: float = 1.0) -> int:
    """
    Détermine le label basé sur le mouvement futur
    0 = baisse, 1 = neutre, 2 = hausse
    """
    if len(prices) < future_bars + 1:
        return 1  # Neutre par défaut
    
    current_price = prices[0]
    future_price = prices[future_bars]
    
    pct_change = (future_price - current_price) / current_price * 100
    
    if pct_change >= threshold:
        return 2  # Hausse
    elif pct_change <= -threshold:
        return 0  # Baisse
    else:
        return 1  # Neutre


def _generate_samples_from_ohlcv(prices: np.ndarray, volumes: np.ndarray,
                                   highs: np.ndarray, lows: np.ndarray,
                                   lookback: int, future: int, threshold: float) -> Tuple[List, List]:
    """Génère des échantillons glissants à partir de données OHLCV."""
    X_out, y_out = [], []
    for i in range(lookback, len(prices) - future):
        sample = create_training_sample(
            prices[:i+1], volumes[:i+1] if volumes is not None else None,
            highs[:i+1] if highs is not None else None,
            lows[:i+1] if lows is not None else None,
            lookback
        )
        if sample is None:
            continue
        label = determine_label(prices[i:], future, threshold)
        X_out.append(sample)
        y_out.append(label)
    return X_out, y_out


def generate_training_data(symbols: List[str], max_symbols: int = 30) -> Tuple[np.ndarray, np.ndarray]:
    """Génère les données d'entraînement depuis l'historique local + API Binance.
    
    Priorité: données locales (historical_data/) pour volume plus important,
    puis API Binance (5m) pour données très récentes.
    """
    logger.info(f"📊 Génération des données d'entraînement (local + API)...")
    
    X_all = []
    y_all = []
    
    symbols_to_use = symbols[:max_symbols]
    total = len(symbols_to_use)
    
    for idx, symbol in enumerate(symbols_to_use):
        logger.info(f"  [{idx+1}/{total}] {symbol}...")
        symbol_samples = 0
        
        # ═══ SOURCE 1: Données historiques locales (1h, ~6 mois) ═══
        local_klines = load_local_historical(symbol, '1h')
        if len(local_klines) >= 100:
            prices_local = np.array([float(k['close']) for k in local_klines])
            volumes_local = np.array([float(k.get('volume', 0)) for k in local_klines])
            highs_local = np.array([float(k.get('high', k['close'])) for k in local_klines])
            lows_local = np.array([float(k.get('low', k['close'])) for k in local_klines])
            
            # Pour 1h: future = 6 bars (6h), threshold adapté
            local_future = 6  # 6h de prédiction
            local_threshold = 1.5  # ±1.5% sur 6h
            
            X_local, y_local = _generate_samples_from_ohlcv(
                prices_local, volumes_local, highs_local, lows_local,
                DEFAULT_LOOKBACK, local_future, local_threshold
            )
            X_all.extend(X_local)
            y_all.extend(y_local)
            symbol_samples += len(X_local)
            logger.info(f"    📂 Local 1h: {len(X_local)} échantillons")
        
        # ═══ SOURCE 2: API Binance (5m, données très récentes) ═══
        klines = fetch_klines(symbol, '5m', 500)  # Réduit 1000→500 (économie RAM)
        if len(klines) >= 100:
            prices_api = np.array([float(k[4]) for k in klines])  # close
            volumes_api = np.array([float(k[5]) for k in klines])  # volume
            highs_api = np.array([float(k[2]) for k in klines])   # high
            lows_api = np.array([float(k[3]) for k in klines])    # low
            
            X_api, y_api = _generate_samples_from_ohlcv(
                prices_api, volumes_api, highs_api, lows_api,
                DEFAULT_LOOKBACK, DEFAULT_FUTURE, GAIN_THRESHOLD
            )
            X_all.extend(X_api)
            y_all.extend(y_api)
            symbol_samples += len(X_api)
            logger.info(f"    🌐 API 5m: {len(X_api)} échantillons")
        
        if symbol_samples == 0:
            logger.warning(f"    ⚠️ Aucune donnée pour {symbol}")
        else:
            logger.info(f"    ✓ Total: {symbol_samples} échantillons")
        
        # Pause pour éviter le rate limiting
        time.sleep(0.15)
    
    if len(X_all) == 0:
        logger.error("❌ Aucune donnée générée!")
        return None, None

    # Limiter le dataset pour éviter la saturation RAM/swap
    if len(X_all) > MAX_TOTAL_SAMPLES:
        import random as _random
        _idx = sorted(_random.sample(range(len(X_all)), MAX_TOTAL_SAMPLES))
        X_all = [X_all[i] for i in _idx]
        y_all = [y_all[i] for i in _idx]
        logger.info(f"   ↓ Dataset limité à {MAX_TOTAL_SAMPLES} échantillons (économie mémoire)")

    X = np.array(X_all, dtype=np.float32)
    y = np.array(y_all, dtype=np.int64)
    del X_all, y_all  # Libérer les listes source
    gc.collect()

    # Stats sur les labels
    unique, counts = np.unique(y, return_counts=True)
    logger.info(f"\n📈 Distribution des labels:")
    for u, c in zip(unique, counts):
        label_name = ['Baisse', 'Neutre', 'Hausse'][u]
        pct = c / len(y) * 100
        logger.info(f"   {label_name}: {c} ({pct:.1f}%)")
    
    return X, y


# ═══════════════════════════════════════════════════════════════════════════════
# MODÈLE LSTM
# ═══════════════════════════════════════════════════════════════════════════════

class PredictorLSTM(nn.Module):
    """Modèle LSTM pour la prédiction de direction de prix"""
    
    def __init__(self, input_size=20, hidden_size=64, num_layers=1, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=0.0  # dropout doit être 0.0 si num_layers=1 (contrainte PyTorch)
        )
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(hidden_size, 32)
        self.fc2 = nn.Linear(32, 3)  # 3 classes: baisse, neutre, hausse

    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        # Prendre le dernier timestep
        last_output = lstm_out[:, -1, :]
        x = self.dropout(last_output)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def train_model(X: np.ndarray, y: np.ndarray, epochs: int = 50, batch_size: int = 32) -> Dict:
    """Entraîne le modèle LSTM sur GPU"""
    logger.info(f"\n🎓 Démarrage de l'entraînement sur {DEVICE.upper()}...")
    logger.info(f"   Epochs: {epochs}")
    logger.info(f"   Batch size: {batch_size}")
    logger.info(f"   Échantillons: {len(X)}")
    
    # Convertir en tensors
    X_tensor = torch.tensor(X, dtype=torch.float32)
    y_tensor = torch.tensor(y, dtype=torch.long)
    
    # Split train/validation (80/20)
    split_idx = int(len(X) * 0.8)
    X_train = X_tensor[:split_idx].clone()
    X_val   = X_tensor[split_idx:].clone()
    y_train = y_tensor[:split_idx].clone()
    y_val   = y_tensor[split_idx:].clone()
    del X_tensor, y_tensor  # Libérer le tenseur complet (évite double allocation)
    gc.collect()

    logger.info(f"   Train: {len(X_train)}, Val: {len(X_val)}")
    
    # Créer les DataLoaders
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset = TensorDataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Créer le modèle
    model = PredictorLSTM().to(DEVICE)
    
    # Afficher le nombre de paramètres
    total_params = sum(p.numel() for p in model.parameters())
    logger.info(f"   Paramètres du modèle: {total_params:,}")
    
    # Optimizer et Loss avec CLASS BALANCING
    # Calculer les poids de classe pour compenser le déséquilibre (trop de "neutre")
    class_counts = torch.bincount(y_train, minlength=3).float()
    # Poids inversement proportionnel à la fréquence + lissage
    total_samples = class_counts.sum()
    class_weights = total_samples / (3.0 * class_counts.clamp(min=1))
    # Normaliser pour que le poids moyen soit 1.0
    class_weights = class_weights / class_weights.mean()
    class_weights = class_weights.to(DEVICE)
    logger.info(f"   Poids de classe (baisse/neutre/hausse): {class_weights.cpu().tolist()}")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Entraînement avec EARLY STOPPING
    best_val_acc = 0
    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0
    early_stop_patience = 8   # Arrêter après 8 epochs sans amélioration (réduit de 12)
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    
    start_time = time.time()
    
    for epoch in range(epochs):
        # Mode entraînement
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_X, batch_y in train_loader:
            batch_X = batch_X.to(DEVICE)
            batch_y = batch_y.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += batch_y.size(0)
            train_correct += (predicted == batch_y).sum().item()
        
        # Gradient clipping pour stabilité
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        train_loss /= len(train_loader)
        train_acc = train_correct / train_total
        
        # Mode évaluation
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X = batch_X.to(DEVICE)
                batch_y = batch_y.to(DEVICE)
                
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += batch_y.size(0)
                val_correct += (predicted == batch_y).sum().item()
        
        val_loss /= len(val_loader)
        val_acc = val_correct / val_total
        
        # Scheduler
        scheduler.step(val_loss)
        
        # Sauvegarder le meilleur modèle (basé sur val_loss pour plus de stabilité)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
        
        # Early stopping
        if patience_counter >= early_stop_patience and epoch > 20:
            logger.info(f"   ⏹️ Early stopping à epoch {epoch+1} (pas d'amélioration depuis {early_stop_patience} epochs)")
            break
        
        # Historique
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        # Affichage
        if epoch % 5 == 0 or epoch == epochs - 1:
            elapsed = time.time() - start_time
            logger.info(
                f"   Epoch {epoch+1:3d}/{epochs} | "
                f"Train: {train_loss:.4f} ({train_acc:.1%}) | "
                f"Val: {val_loss:.4f} ({val_acc:.1%}) | "
                f"Best: {best_val_acc:.1%} | "
                f"Time: {elapsed:.0f}s"
            )
    
    # Charger le meilleur modèle
    model.load_state_dict(best_model_state)
    
    total_time = time.time() - start_time
    logger.info(f"\n✅ Entraînement terminé en {total_time:.1f}s")
    logger.info(f"   Meilleure accuracy validation: {best_val_acc:.1%}")
    
    return {
        'model': model,
        'history': history,
        'best_val_acc': best_val_acc,
        'total_time': total_time,
        'samples_count': len(X),
        'epochs': epochs
    }


def save_model(model: nn.Module, stats: Dict):
    """Sauvegarde le modèle et met à jour les stats"""
    # Créer le dossier models si nécessaire
    MODEL_PATH.parent.mkdir(exist_ok=True)
    
    # Sauvegarder le state_dict (meilleure pratique PyTorch)
    torch.save({
        'model_state_dict': model.state_dict(),
        'model_class': 'PredictorLSTM',
        'input_size': 20,
        'hidden_size1': 64,
        'hidden_size2': 32,
        'output_size': 3
    }, MODEL_PATH)
    logger.info(f"💾 Modèle sauvegardé: {MODEL_PATH}")
    
    # Mettre à jour les stats
    training_stats = {
        'status': 'trained',
        'samples_count': stats['samples_count'],
        'epochs_completed': stats['epochs'],
        'last_loss': stats['history']['train_loss'][-1],
        'last_accuracy': stats['history']['train_acc'][-1],
        'validation_accuracy': stats['best_val_acc'] * 100,
        'last_training': datetime.now().isoformat(),
        'predictions_made': 0,
        'correct_predictions': 0,
        'gpu_name': GPU_NAME if DEVICE == 'cuda' else 'CPU',
        'gpu_available': DEVICE == 'cuda',
        'training_time_seconds': stats['total_time']
    }
    
    with open(STATS_FILE, 'w') as f:
        json.dump(training_stats, f, indent=2)
    
    logger.info(f"📊 Stats sauvegardées: {STATS_FILE}")


# ═══════════════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════════════

def main():
    parser = argparse.ArgumentParser(description="Entraînement du modèle IA")
    parser.add_argument('--epochs', type=int, default=DEFAULT_EPOCHS, help='Nombre d\'epochs')
    parser.add_argument('--batch-size', type=int, default=DEFAULT_BATCH_SIZE, help='Taille des batches')
    parser.add_argument('--symbols', type=int, default=15, help='Nombre max de symboles')
    args = parser.parse_args()
    
    print(f"""
╔══════════════════════════════════════════════════════════════╗
║         🎓 ENTRAÎNEMENT DU MODÈLE IA - GPU ACCELERATED       ║
╠══════════════════════════════════════════════════════════════╣
║  Device: {DEVICE.upper():6s}                                          ║
║  GPU: {GPU_NAME if DEVICE == 'cuda' else 'N/A':50s}   ║
║  Epochs: {args.epochs:4d}                                           ║
║  Batch Size: {args.batch_size:4d}                                       ║
╚══════════════════════════════════════════════════════════════╝
""")
    
    # 1. Charger la watchlist
    symbols = load_watchlist()
    logger.info(f"📋 {len(symbols)} symboles dans la watchlist")
    
    # 2. Générer les données d'entraînement
    X, y = generate_training_data(symbols, max_symbols=args.symbols)
    if X is None:
        logger.error("Impossible de générer les données")
        return
    
    logger.info(f"\n📦 Dataset: {X.shape[0]} échantillons, {X.shape[1]} timesteps, {X.shape[2]} features")
    
    # 3. Entraîner le modèle
    result = train_model(X, y, epochs=args.epochs, batch_size=args.batch_size)
    del X, y  # Libérer le dataset après entraînement
    gc.collect()

    # 4. Sauvegarder
    save_model(result['model'], result)
    
    print(f"""
╔══════════════════════════════════════════════════════════════╗
║                    ✅ ENTRAÎNEMENT TERMINÉ                    ║
╠══════════════════════════════════════════════════════════════╣
║  Échantillons utilisés: {result['samples_count']:,}                           ║
║  Accuracy validation: {result['best_val_acc']:.1%}                            ║
║  Temps total: {result['total_time']:.1f}s                                    ║
║  Modèle sauvegardé: models/predictor.pt                      ║
╚══════════════════════════════════════════════════════════════╝
""")


if __name__ == "__main__":
    main()
