#!/usr/bin/env python3
"""
AI Adaptive Retrainer - Réentraînement automatique des modèles GPU
Inspiré de FreqAI - Permet l'adaptation automatique aux changements de marché

Features:
- Retraining périodique (toutes les 24-48h)
- Détection changements de régime marché
- Retraining d'urgence si win rate < 30%
- Sauvegarde automatique des modèles
- Crash resilience
"""

import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional
import json
import os
import sys
import subprocess
import logging
from pathlib import Path
import pickle

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AIAdaptiveRetrainer")

# Imports des modèles à réentraîner
try:
    from ai_advanced_scorer import get_advanced_scorer, AdvancedScorer
    ADVANCED_SCORER_AVAILABLE = True
except ImportError:
    ADVANCED_SCORER_AVAILABLE = False
    logger.warning("⚠️ Advanced Scorer non disponible")

try:
    from ai_opportunity_selector import get_opportunity_selector
    OPPORTUNITY_SELECTOR_AVAILABLE = True
except ImportError:
    OPPORTUNITY_SELECTOR_AVAILABLE = False
    logger.warning("⚠️ Opportunity Selector non disponible")

try:
    from market_regime import MarketRegimeDetector
    MARKET_REGIME_AVAILABLE = True
except ImportError:
    MARKET_REGIME_AVAILABLE = False
    logger.warning("⚠️ Market Regime non disponible")


class AdaptiveRetrainer:
    """
    Gestionnaire de réentraînement adaptatif des modèles GPU
    
    Stratégies de retraining:
    1. Périodique - Tous les N jours (défaut: 2 jours)
    2. Performance-based - Si win rate < seuil (défaut: 30%)
    3. Regime-change - Si changement majeur de marché (BULL→BEAR)
    """
    
    def __init__(self, 
                 retrain_interval_hours: int = 48,
                 min_samples: int = 100,
                 emergency_winrate_threshold: float = 0.30,
                 models_dir: str = "trained_models"):
        """
        Args:
            retrain_interval_hours: Intervalle entre retrainings (heures)
            min_samples: Nombre minimum de trades pour retrain
            emergency_winrate_threshold: Win rate minimum avant retraining urgence
            models_dir: Dossier de sauvegarde des modèles
        """
        self.retrain_interval = timedelta(hours=retrain_interval_hours)
        self.min_samples = min_samples
        self.emergency_threshold = emergency_winrate_threshold
        
        # Dossier de sauvegarde
        self.models_dir = Path(models_dir)
        self.models_dir.mkdir(exist_ok=True)
        
        # État du retrainer
        self.last_retrain = None
        self.last_market_regime = None
        self.retrain_history = []
        
        # Charger état précédent
        self._load_state()
        
        logger.info("✅ Adaptive Retrainer initialisé")
        logger.info(f"   • Intervalle: {retrain_interval_hours}h")
        logger.info(f"   • Min samples: {min_samples}")
        logger.info(f"   • Emergency threshold: {emergency_winrate_threshold*100}%")
    
    def _load_state(self):
        """Charger l'état précédent du retrainer"""
        state_file = self.models_dir / "retrainer_state.json"
        if state_file.exists():
            try:
                with open(state_file, 'r') as f:
                    state = json.load(f)
                    last_retrain_str = state.get('last_retrain')
                    self.last_retrain = datetime.fromisoformat(last_retrain_str) if last_retrain_str else None
                    self.last_market_regime = state.get('last_regime')
                    self.retrain_history = state.get('history', [])
                    logger.info(f"📂 État chargé: dernier retrain {self.last_retrain}")
            except Exception as e:
                logger.warning(f"⚠️ Erreur chargement état: {e}")
    
    def _save_state(self):
        """Sauvegarder l'état du retrainer"""
        state_file = self.models_dir / "retrainer_state.json"
        try:
            state = {
                'last_retrain': self.last_retrain.isoformat() if self.last_retrain else None,
                'last_regime': self.last_market_regime,
                'history': self.retrain_history[-50:]  # Garder 50 derniers
            }
            with open(state_file, 'w') as f:
                json.dump(state, f, indent=2)
        except Exception as e:
            logger.error(f"❌ Erreur sauvegarde état: {e}")
    
    def should_retrain(self) -> Tuple[bool, str]:
        """
        Déterminer si un retraining est nécessaire
        
        Returns:
            (should_retrain, reason)
        """
        # Raison 1: Intervalle dépassé
        if self.last_retrain is None:
            return True, "INITIAL_TRAIN"
        
        time_since_last = datetime.now() - self.last_retrain
        if time_since_last >= self.retrain_interval:
            return True, f"PERIODIC ({time_since_last.total_seconds()/3600:.1f}h écoulées)"
        
        # Raison 2: Win rate catastrophique
        recent_winrate = self._get_recent_winrate()
        if recent_winrate is not None and recent_winrate < self.emergency_threshold:
            return True, f"EMERGENCY (win rate {recent_winrate*100:.1f}% < {self.emergency_threshold*100}%)"
        
        # Raison 3: Changement de régime marché
        if MARKET_REGIME_AVAILABLE:
            current_regime = self._get_current_market_regime()
            if self.last_market_regime and current_regime != self.last_market_regime:
                # Changement majeur uniquement (BULL↔BEAR)
                major_changes = [
                    ('BULL_STRONG', 'BEAR'),
                    ('BULL_WEAK', 'BEAR'),
                    ('BEAR', 'BULL_STRONG'),
                    ('BEAR', 'BULL_WEAK')
                ]
                if (self.last_market_regime, current_regime) in major_changes:
                    return True, f"REGIME_CHANGE ({self.last_market_regime}→{current_regime})"
        
        return False, "NO_RETRAIN_NEEDED"
    
    def _get_recent_winrate(self, days: int = 7) -> Optional[float]:
        """Calculer le win rate des N derniers jours"""
        try:
            trade_file = Path("trade_history.json")
            if not trade_file.exists():
                return None
            
            with open(trade_file, 'r') as f:
                trades = json.load(f)
            
            # Filtrer trades récents
            cutoff = datetime.now() - timedelta(days=days)
            recent_trades = []
            for trade in trades:
                try:
                    entry_time = datetime.fromisoformat(trade.get('entry_time', ''))
                    if entry_time >= cutoff:
                        recent_trades.append(trade)
                except:
                    continue
            
            if len(recent_trades) < 10:
                return None
            
            wins = sum(1 for t in recent_trades if t.get('pnl', 0) > 0)
            return wins / len(recent_trades)
            
        except Exception as e:
            logger.warning(f"⚠️ Erreur calcul win rate: {e}")
            return None
    
    def _get_current_market_regime(self) -> Optional[str]:
        """Obtenir le régime de marché actuel"""
        try:
            from market_regime import MarketRegimeDetector
            detector = MarketRegimeDetector()
            regime, _ = detector.get_current_regime()
            return regime
        except Exception as e:
            logger.warning(f"⚠️ Erreur détection régime: {e}")
            return None
    
    def load_training_data(self, days: int = 30) -> Optional[Dict]:
        """
        Charger les données d'entraînement des N derniers jours
        
        Returns:
            {
                'trades': [...],
                'prices': {...},
                'patterns': {...}
            }
        """
        try:
            logger.info(f"📥 Chargement données training ({days} jours)...")
            
            # Charger historique des trades
            trade_file = Path("trade_history.json")
            if not trade_file.exists():
                logger.warning("⚠️ trade_history.json introuvable")
                return None
            
            with open(trade_file, 'r') as f:
                all_trades = json.load(f)
            
            # Filtrer trades récents
            cutoff = datetime.now() - timedelta(days=days)
            recent_trades = []
            for trade in all_trades:
                try:
                    entry_time = datetime.fromisoformat(trade.get('entry_time', ''))
                    if entry_time >= cutoff:
                        recent_trades.append(trade)
                except:
                    continue
            
            if len(recent_trades) < self.min_samples:
                logger.warning(f"⚠️ Pas assez de trades: {len(recent_trades)} < {self.min_samples}")
                return None
            
            logger.info(f"✅ {len(recent_trades)} trades chargés")
            
            # Préparer données par symbole
            trades_by_symbol = {}
            for trade in recent_trades:
                symbol = trade.get('symbol')
                if symbol:
                    if symbol not in trades_by_symbol:
                        trades_by_symbol[symbol] = []
                    trades_by_symbol[symbol].append(trade)
            
            # Statistiques par pattern
            patterns_stats = {}
            for trade in recent_trades:
                pattern = trade.get('pattern', 'UNKNOWN')
                if pattern not in patterns_stats:
                    patterns_stats[pattern] = {'wins': 0, 'losses': 0, 'total': 0}
                
                patterns_stats[pattern]['total'] += 1
                if trade.get('pnl', 0) > 0:
                    patterns_stats[pattern]['wins'] += 1
                else:
                    patterns_stats[pattern]['losses'] += 1
            
            # Calculer win rates
            for pattern, stats in patterns_stats.items():
                if stats['total'] > 0:
                    stats['win_rate'] = stats['wins'] / stats['total']
            
            return {
                'trades': recent_trades,
                'trades_by_symbol': trades_by_symbol,
                'patterns_stats': patterns_stats,
                'total_trades': len(recent_trades),
                'period_days': days
            }
            
        except Exception as e:
            logger.error(f"❌ Erreur chargement données: {e}")
            return None
    
    def retrain_advanced_scorer(self, training_data: Dict) -> bool:
        """Réentraîner le Advanced Scorer avec nouvelles données"""
        if not ADVANCED_SCORER_AVAILABLE:
            logger.warning("⚠️ Advanced Scorer non disponible")
            return False
        
        try:
            logger.info("🔄 Retraining Advanced Scorer...")
            
            scorer = get_advanced_scorer()
            
            # Extraire features et labels des trades
            features = []
            labels = []
            
            for trade in training_data['trades']:
                # Features: prix d'entrée, volumes, indicateurs techniques
                feature_vector = self._extract_trade_features(trade)
                if feature_vector is not None:
                    features.append(feature_vector)
                    # Label: 1 si profit, 0 si perte
                    labels.append(1 if trade.get('pnl', 0) > 0 else 0)
            
            if len(features) < 50:
                logger.warning("⚠️ Pas assez de features pour training")
                return False
            
            features = np.array(features)
            labels = np.array(labels)
            
            logger.info(f"   • Features shape: {features.shape}")
            logger.info(f"   • Labels positifs: {sum(labels)}/{len(labels)} ({sum(labels)/len(labels)*100:.1f}%)")
            
            # Mettre à jour les poids du scorer basé sur la performance des patterns
            patterns_stats = training_data.get('patterns_stats', {})
            if hasattr(scorer, 'pattern_weights'):
                for pattern, stats in patterns_stats.items():
                    wr = stats.get('win_rate', 0.5)
                    total = stats.get('total', 0)
                    if total >= 5:
                        # Ajuster le poids: patterns gagnants reçoivent plus de poids
                        weight = 0.5 + wr  # [0.5, 1.5]
                        scorer.pattern_weights[pattern] = weight
                        logger.info(f"     • {pattern}: WR={wr*100:.0f}% → weight={weight:.2f}")
            
            # Sauvegarder les stats de performance pour référence
            perf_path = self.models_dir / "scorer_performance.json"
            with open(perf_path, 'w') as f:
                json.dump({
                    'patterns': {p: {**s, 'win_rate': float(s.get('win_rate', 0))} 
                                 for p, s in patterns_stats.items()},
                    'total_features': len(features),
                    'positive_rate': float(sum(labels)/len(labels)),
                    'updated_at': datetime.now().isoformat()
                }, f, indent=2)
            
            # Sauvegarder le scorer mis à jour
            model_path = self.models_dir / f"advanced_scorer_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pkl"
            with open(model_path, 'wb') as f:
                pickle.dump(scorer, f)
            
            logger.info(f"✅ Advanced Scorer retrained - sauvegardé: {model_path.name}")
            return True
            
        except Exception as e:
            logger.error(f"❌ Erreur retraining scorer: {e}")
            return False
    
    def _extract_trade_features(self, trade: Dict) -> Optional[np.ndarray]:
        """Extraire features d'un trade pour training"""
        try:
            # Features enrichies: utiliser tous les attributs disponibles du trade
            entry_price = trade.get('entry_price', 0)
            exit_price = trade.get('exit_price', 0)
            pnl_pct = trade.get('pnl_pct', 0)
            if pnl_pct == 0 and entry_price > 0:
                pnl_pct = (exit_price - entry_price) / entry_price * 100
            
            # Encoder le pattern en one-hot pour les 8 principaux patterns
            pattern = trade.get('pattern', 'UNKNOWN')
            patterns_list = ['CREUX_REBOUND', 'SQUEEZE_BREAKOUT', 'EARLY_BREAKOUT', 
                           'PULLBACK', 'CROSSOVER_IMMINENT', 'IMMEDIATE_DIP',
                           'RSI_REVERSAL', 'BB_BREAKOUT']
            pattern_encoded = [1.0 if pattern == p else 0.0 for p in patterns_list]
            
            features = [
                entry_price / 10000 if entry_price > 1000 else entry_price / 100,  # Prix normalisé
                pnl_pct / 10,  # PnL en %/10
                min(trade.get('hold_time', 0), 1440) / 1440,  # Hold time normalisé (max 24h)
                trade.get('ai_score', 50) / 100,  # Score IA normalisé
                trade.get('volatility_score', 50) / 100,  # Score volatilité normalisé
                # 🆕 FIX 10/03: Contexte de tendance structurelle (cas PHA — EMA Bearish + falling knife)
                float(trade.get('ema_trend_bearish', trade.get('features', {}).get('ema_trend_bearish', 0))),  # 1=bearish
                float(trade.get('momentum_20', trade.get('features', {}).get('momentum_20', 0))) / 5,  # Tendance 100min
            ] + pattern_encoded  # +8 features = 15 total (was 13)
            
            return np.array(features, dtype=np.float32)
        except Exception as e:
            logger.debug(f"Erreur extraction features: {e}")
            return None
    
    def retrain_opportunity_selector(self, training_data: Dict) -> bool:
        """Réentraîner l'Opportunity Selector"""
        if not OPPORTUNITY_SELECTOR_AVAILABLE:
            logger.warning("⚠️ Opportunity Selector non disponible")
            return False
        
        try:
            logger.info("🔄 Retraining Opportunity Selector...")
            
            selector = get_opportunity_selector()
            
            # Analyser quels patterns/cryptos performent le mieux
            patterns_stats = training_data['patterns_stats']
            
            # Identifier les patterns gagnants
            winning_patterns = [
                p for p, stats in patterns_stats.items()
                if stats.get('win_rate', 0) > 0.50 and stats['total'] >= 5
            ]
            
            logger.info(f"   • Patterns gagnants: {winning_patterns}")
            
            # Identifier cryptos performantes
            symbol_stats = {}
            for symbol, trades in training_data['trades_by_symbol'].items():
                if len(trades) >= 3:
                    wins = sum(1 for t in trades if t.get('pnl', 0) > 0)
                    win_rate = wins / len(trades)
                    symbol_stats[symbol] = {
                        'trades': len(trades),
                        'win_rate': win_rate
                    }
            
            top_symbols = sorted(
                symbol_stats.items(),
                key=lambda x: x[1]['win_rate'],
                reverse=True
            )[:10]
            
            logger.info(f"   • Top 10 cryptos: {[s[0] for s in top_symbols]}")
            
            # Sauvegarder les stats mises à jour
            stats_path = self.models_dir / "opportunity_stats.json"
            with open(stats_path, 'w') as f:
                json.dump({
                    'patterns': patterns_stats,
                    'symbols': symbol_stats,
                    'updated_at': datetime.now().isoformat()
                }, f, indent=2)
            
            logger.info("✅ Opportunity Selector retrained")
            return True
            
        except Exception as e:
            logger.error(f"❌ Erreur retraining selector: {e}")
            return False
    
    def retrain_all(self) -> Dict[str, bool]:
        """
        Réentraîner tous les modèles
        
        Returns:
            Dict avec status de chaque modèle
        """
        logger.info("\n" + "="*60)
        logger.info("🔄 RETRAINING ADAPTATIF LANCÉ")
        logger.info("="*60)
        
        start_time = datetime.now()
        
        # Charger données d'entraînement
        training_data = self.load_training_data(days=30)
        if training_data is None:
            logger.error("❌ Impossible de charger les données - retraining annulé")
            return {'status': 'FAILED', 'reason': 'NO_DATA'}
        
        # Stats données
        logger.info(f"\n📊 Données training:")
        logger.info(f"   • Période: {training_data['period_days']} jours")
        logger.info(f"   • Trades: {training_data['total_trades']}")
        logger.info(f"   • Symboles: {len(training_data['trades_by_symbol'])}")
        logger.info(f"   • Patterns: {len(training_data['patterns_stats'])}")
        
        # Win rate global
        total_wins = sum(1 for t in training_data['trades'] if t.get('pnl', 0) > 0)
        global_winrate = total_wins / training_data['total_trades'] if training_data['total_trades'] > 0 else 0
        logger.info(f"   • Win rate global: {global_winrate*100:.1f}%")
        
        # Réentraîner chaque modèle
        results = {}
        
        results['advanced_scorer'] = self.retrain_advanced_scorer(training_data)
        results['opportunity_selector'] = self.retrain_opportunity_selector(training_data)
        
        # Retraining LSTM si win rate est mauvais
        if global_winrate < 0.40:
            logger.info(f"\n🚨 Win rate {global_winrate*100:.1f}% < 40% → relance LSTM training")
            results['lstm_model'] = self.retrain_lstm_model()
        else:
            results['lstm_model'] = None  # Pas nécessaire
            logger.info(f"   ℹ️  LSTM: win rate OK ({global_winrate*100:.1f}%), pas de retrain")
        
        # Mettre à jour état
        self.last_retrain = datetime.now()
        if MARKET_REGIME_AVAILABLE:
            self.last_market_regime = self._get_current_market_regime()
        
        # Historique
        self.retrain_history.append({
            'timestamp': self.last_retrain.isoformat(),
            'trades_count': training_data['total_trades'],
            'win_rate': global_winrate,
            'regime': self.last_market_regime,
            'results': results
        })
        
        self._save_state()
        
        # Résumé
        duration = (datetime.now() - start_time).total_seconds()
        success_count = sum(1 for v in results.values() if v)
        
        logger.info("\n" + "="*60)
        logger.info("✅ RETRAINING TERMINÉ")
        logger.info("="*60)
        logger.info(f"Durée: {duration:.1f}s")
        logger.info(f"Modèles: {success_count}/{len(results)} réussis")
        logger.info(f"Prochain retrain prévu: {(self.last_retrain + self.retrain_interval).strftime('%Y-%m-%d %H:%M')}")
        
        return {
            'status': 'SUCCESS' if all(results.values()) else 'PARTIAL',
            'results': results,
            'duration': duration,
            'next_retrain': (self.last_retrain + self.retrain_interval).isoformat()
        }
    
    def retrain_lstm_model(self) -> bool:
        """Relance l'entraînement complet du modèle LSTM via train_ai_model.py"""
        try:
            logger.info("🔄 Retraining LSTM model (train_ai_model.py)...")
            
            script_dir = Path(__file__).parent
            if sys.platform == 'win32':
                venv_python = script_dir / '.venv' / 'Scripts' / 'python.exe'
            else:
                venv_python = script_dir / '.venv' / 'bin' / 'python3'
            train_script = script_dir / 'train_ai_model.py'
            
            if not venv_python.exists():
                logger.error(f"❌ Python venv non trouvé: {venv_python}")
                return False
            
            if not train_script.exists():
                logger.error(f"❌ Script training non trouvé: {train_script}")
                return False
            
            cmd = [str(venv_python), '-u', str(train_script), '--epochs', '50']
            
            env = os.environ.copy()
            env['PYTHONIOENCODING'] = 'utf-8'
            env['PYTHONUTF8'] = '1'
            
            creation_flags = 0
            if sys.platform == 'win32':
                creation_flags = subprocess.CREATE_NO_WINDOW
            
            result = subprocess.run(
                cmd,
                cwd=str(script_dir),
                capture_output=True,
                text=True,
                timeout=7200,  # 2h timeout
                encoding='utf-8',
                errors='replace',
                env=env,
                creationflags=creation_flags
            )
            
            success = result.returncode == 0 or 'terminé' in (result.stdout or '').lower()
            
            if success:
                logger.info("✅ LSTM model retrained successfully")
                
                # Recharger le modèle dans ai_predictor si possible
                try:
                    from ai_predictor import get_ai_predictor
                    predictor = get_ai_predictor()
                    if hasattr(predictor, 'load_model'):
                        predictor.load_model()
                        logger.info("✅ Modèle rechargé dans ai_predictor")
                except Exception as e:
                    logger.warning(f"⚠️ Impossible de recharger le modèle: {e}")
            else:
                stderr = result.stderr or ''
                logger.error(f"❌ LSTM training failed: {stderr[:300]}")
            
            return success
            
        except subprocess.TimeoutExpired:
            logger.error("⏰ LSTM training timeout (2h)")
            return False
        except Exception as e:
            logger.error(f"❌ Erreur LSTM retraining: {e}")
            return False
    
    def check_and_retrain(self) -> Optional[Dict]:
        """
        Vérifier si retraining nécessaire et l'exécuter si oui
        
        Returns:
            Résultats du retraining ou None si pas nécessaire
        """
        should_train, reason = self.should_retrain()
        
        if not should_train:
            logger.info(f"ℹ️  Retraining non nécessaire: {reason}")
            return None
        
        logger.info(f"🚨 Retraining déclenché: {reason}")
        return self.retrain_all()
    
    def get_status(self) -> Dict:
        """Obtenir le status actuel du retrainer"""
        should_train, reason = self.should_retrain()
        
        time_since_last = None
        if self.last_retrain:
            time_since_last = (datetime.now() - self.last_retrain).total_seconds() / 3600
        
        return {
            'last_retrain': self.last_retrain.isoformat() if self.last_retrain else None,
            'time_since_last_hours': time_since_last,
            'should_retrain': should_train,
            'reason': reason,
            'current_regime': self.last_market_regime,
            'recent_winrate': self._get_recent_winrate(),
            'history_count': len(self.retrain_history),
            'next_retrain_expected': (self.last_retrain + self.retrain_interval).isoformat() if self.last_retrain else None
        }


# Singleton global
_retrainer_instance = None

def get_adaptive_retrainer() -> AdaptiveRetrainer:
    """Obtenir l'instance singleton du retrainer"""
    global _retrainer_instance
    if _retrainer_instance is None:
        _retrainer_instance = AdaptiveRetrainer()
    return _retrainer_instance


if __name__ == "__main__":
    """Test du retrainer"""
    print("="*60)
    print("🧪 TEST ADAPTIVE RETRAINER")
    print("="*60)
    
    retrainer = get_adaptive_retrainer()
    
    # Status actuel
    status = retrainer.get_status()
    print("\n📊 Status actuel:")
    for key, value in status.items():
        print(f"   • {key}: {value}")
    
    # Test retraining
    print("\n🔄 Test retraining...")
    should_train, reason = retrainer.should_retrain()
    print(f"Should retrain: {should_train}")
    print(f"Reason: {reason}")
    
    if should_train:
        result = retrainer.retrain_all()
        print(f"\n✅ Résultat: {result.get('status')}")
