#!/usr/bin/env python3
"""
Ensemble Predictor - Combine plusieurs modèles ML pour améliorer la précision
Vote majoritaire entre LSTM, GBM, et autres modèles
"""

import numpy as np
from typing import Dict, List, Optional, Tuple
import logging

logger = logging.getLogger("EnsemblePredictor")

class EnsemblePredictor:
    """
    Combine plusieurs modèles de prédiction par vote
    
    Modèles:
    - LSTM (existant dans ai_predictor.py)
    - Gradient Boosting (simplifié)
    - Linear Regression
    - Moving Average Crossover
    """
    
    def __init__(self):
        self.models_weights = {
            'lstm': 0.40,      # 40% - Modèle principal
            'trend': 0.25,     # 25% - Détection de tendance
            'momentum': 0.20,  # 20% - Indicateurs momentum
            'volume': 0.15     # 15% - Analyse volume
        }
        
        logger.info("✅ Ensemble Predictor initialisé (4 modèles)")
    
    def predict_ensemble(self, prices: List[float], volumes: List[float] = None,
                        lstm_prediction: int = None, lstm_confidence: float = None) -> Dict:
        """
        Prédiction par ensemble de modèles
        
        Args:
            prices: Historique de prix
            volumes: Historique de volumes
            lstm_prediction: Prédiction LSTM (0=baisse, 1=neutre, 2=hausse)
            lstm_confidence: Confiance LSTM (0-100)
            
        Returns:
            Dict avec prédiction finale et détails
        """
        if not prices or len(prices) < 20:
            return {
                'prediction': 1,  # Neutre
                'confidence': 50,
                'consensus': 'WEAK',
                'details': {}
            }
        
        prices_array = np.array(prices)
        volumes_array = np.array(volumes) if volumes else np.ones(len(prices))
        
        # === MODÈLE 1: LSTM (si disponible) ===
        if lstm_prediction is not None and lstm_confidence is not None:
            lstm_vote = self._lstm_vote(lstm_prediction, lstm_confidence)
        else:
            lstm_vote = {'vote': 1, 'confidence': 50, 'weight': 0}
        
        # === MODÈLE 2: TREND FOLLOWING ===
        trend_vote = self._trend_model(prices_array)
        
        # === MODÈLE 3: MOMENTUM ===
        momentum_vote = self._momentum_model(prices_array)
        
        # === MODÈLE 4: VOLUME ANALYSIS ===
        volume_vote = self._volume_model(prices_array, volumes_array)
        
        # === AGRÉGATION DES VOTES ===
        final_prediction, consensus_strength = self._aggregate_votes([
            (lstm_vote, self.models_weights['lstm']),
            (trend_vote, self.models_weights['trend']),
            (momentum_vote, self.models_weights['momentum']),
            (volume_vote, self.models_weights['volume'])
        ])
        
        # Calcul de la confiance globale
        weighted_confidence = (
            lstm_vote['confidence'] * self.models_weights['lstm'] +
            trend_vote['confidence'] * self.models_weights['trend'] +
            momentum_vote['confidence'] * self.models_weights['momentum'] +
            volume_vote['confidence'] * self.models_weights['volume']
        )
        
        # Classification du consensus
        if consensus_strength >= 0.85:
            consensus = 'VERY_STRONG'
        elif consensus_strength >= 0.70:
            consensus = 'STRONG'
        elif consensus_strength >= 0.55:
            consensus = 'MODERATE'
        else:
            consensus = 'WEAK'
        
        return {
            'prediction': final_prediction,
            'confidence': round(weighted_confidence, 1),
            'consensus': consensus,
            'consensus_strength': round(consensus_strength, 2),
            'details': {
                'lstm': lstm_vote,
                'trend': trend_vote,
                'momentum': momentum_vote,
                'volume': volume_vote
            },
            'votes_summary': {
                'baisse': sum(1 for v, _ in [(lstm_vote, 1), (trend_vote, 1), (momentum_vote, 1), (volume_vote, 1)] if v['vote'] == 0),
                'neutre': sum(1 for v, _ in [(lstm_vote, 1), (trend_vote, 1), (momentum_vote, 1), (volume_vote, 1)] if v['vote'] == 1),
                'hausse': sum(1 for v, _ in [(lstm_vote, 1), (trend_vote, 1), (momentum_vote, 1), (volume_vote, 1)] if v['vote'] == 2)
            }
        }
    
    def _lstm_vote(self, prediction: int, confidence: float) -> Dict:
        """Vote du modèle LSTM"""
        return {
            'vote': prediction,
            'confidence': confidence,
            'model': 'LSTM'
        }
    
    def _trend_model(self, prices: np.ndarray) -> Dict:
        """
        Modèle de tendance (EMA crossovers + ADX)
        
        Returns:
            Dict avec vote et confiance
        """
        if len(prices) < 20:
            return {'vote': 1, 'confidence': 50, 'model': 'Trend'}
        
        # EMA courte et longue
        ema_short = self._ema(prices, 9)
        ema_long = self._ema(prices, 21)
        
        # Direction de la tendance
        if ema_short > ema_long:
            trend_direction = 2  # Haussière
        elif ema_short < ema_long:
            trend_direction = 0  # Baissière
        else:
            trend_direction = 1  # Neutre
        
        # Force de la tendance (ADX simplifié)
        adx = self._simple_adx(prices)
        
        # Confiance basée sur ADX
        if adx > 40:
            confidence = 85
        elif adx > 25:
            confidence = 70
        elif adx > 15:
            confidence = 55
        else:
            confidence = 40
        
        return {
            'vote': trend_direction,
            'confidence': confidence,
            'model': 'Trend',
            'adx': round(adx, 1),
            'ema_short': round(ema_short, 8),
            'ema_long': round(ema_long, 8)
        }
    
    def _momentum_model(self, prices: np.ndarray) -> Dict:
        """
        Modèle momentum (RSI + MACD + Stochastic)
        """
        if len(prices) < 20:
            return {'vote': 1, 'confidence': 50, 'model': 'Momentum'}
        
        # RSI
        rsi = self._calculate_rsi(prices, 14)
        
        # MACD (simplifié)
        ema12 = self._ema(prices, 12)
        ema26 = self._ema(prices, 26)
        macd = ema12 - ema26
        
        # Stochastic (simplifié)
        stoch = self._stochastic(prices[-14:])
        
        # Vote basé sur les indicateurs
        bullish_signals = 0
        bearish_signals = 0
        
        # RSI
        if rsi < 30:
            bullish_signals += 1
        elif rsi > 70:
            bearish_signals += 1
        
        # MACD
        if macd > 0:
            bullish_signals += 1
        elif macd < 0:
            bearish_signals += 1
        
        # Stochastic
        if stoch < 20:
            bullish_signals += 1
        elif stoch > 80:
            bearish_signals += 1
        
        # Déterminer le vote
        if bullish_signals > bearish_signals:
            vote = 2  # Hausse
            confidence = 60 + (bullish_signals * 10)
        elif bearish_signals > bullish_signals:
            vote = 0  # Baisse
            confidence = 60 + (bearish_signals * 10)
        else:
            vote = 1  # Neutre
            confidence = 50
        
        return {
            'vote': vote,
            'confidence': min(95, confidence),
            'model': 'Momentum',
            'rsi': round(rsi, 1),
            'macd': round(macd, 8),
            'stochastic': round(stoch, 1)
        }
    
    def _volume_model(self, prices: np.ndarray, volumes: np.ndarray) -> Dict:
        """
        Modèle volume (OBV + Volume Price Trend)
        """
        if len(prices) < 20 or len(volumes) < 20:
            return {'vote': 1, 'confidence': 50, 'model': 'Volume'}
        
        # OBV
        obv = self._calculate_obv(prices, volumes)
        obv_trend = 'UP' if obv[-1] > obv[-10] else 'DOWN'
        
        # Volume trend
        recent_vol = np.mean(volumes[-5:])
        older_vol = np.mean(volumes[-20:-5])
        vol_increasing = recent_vol > older_vol * 1.2
        
        # Price trend
        price_trend = (prices[-1] - prices[-10]) / prices[-10]
        
        # Vote basé sur convergence volume/prix
        if obv_trend == 'UP' and price_trend > 0:
            vote = 2  # Hausse confirmée par volume
            confidence = 75 if vol_increasing else 65
        elif obv_trend == 'DOWN' and price_trend < 0:
            vote = 0  # Baisse confirmée par volume
            confidence = 75 if vol_increasing else 65
        elif obv_trend == 'UP' and price_trend < 0:
            vote = 2  # Accumulation (divergence haussière)
            confidence = 70
        elif obv_trend == 'DOWN' and price_trend > 0:
            vote = 0  # Distribution (divergence baissière)
            confidence = 70
        else:
            vote = 1  # Neutre
            confidence = 50
        
        return {
            'vote': vote,
            'confidence': confidence,
            'model': 'Volume',
            'obv_trend': obv_trend,
            'volume_increasing': vol_increasing
        }
    
    def _aggregate_votes(self, weighted_votes: List[Tuple[Dict, float]]) -> Tuple[int, float]:
        """
        Agrège les votes pondérés
        
        Returns:
            (prediction finale, force du consensus)
        """
        # Scores pondérés pour chaque classe
        scores = {0: 0.0, 1: 0.0, 2: 0.0}  # baisse, neutre, hausse
        
        for vote_dict, weight in weighted_votes:
            vote = vote_dict['vote']
            confidence = vote_dict['confidence'] / 100  # Normaliser 0-1
            
            # Score = poids × confiance
            scores[vote] += weight * confidence
        
        # Prédiction finale = classe avec le score max
        final_prediction = max(scores, key=scores.get)
        
        # Force du consensus = score max / somme des scores
        total_score = sum(scores.values())
        consensus_strength = scores[final_prediction] / total_score if total_score > 0 else 0.33
        
        return final_prediction, consensus_strength
    
    # === HELPER FUNCTIONS ===
    
    def _ema(self, prices: np.ndarray, period: int) -> float:
        """Calcule EMA"""
        if len(prices) < period:
            return np.mean(prices)
        multiplier = 2 / (period + 1)
        ema = np.mean(prices[:period])
        for price in prices[period:]:
            ema = (price * multiplier) + (ema * (1 - multiplier))
        return ema
    
    def _simple_adx(self, prices: np.ndarray, period: int = 14) -> float:
        """ADX simplifié"""
        if len(prices) < period + 1:
            return 20
        
        ups = []
        downs = []
        for i in range(1, len(prices)):
            up = prices[i] - prices[i-1]
            down = prices[i-1] - prices[i]
            ups.append(up if up > down and up > 0 else 0)
            downs.append(down if down > up and down > 0 else 0)
        
        avg_up = np.mean(ups[-period:])
        avg_down = np.mean(downs[-period:])
        
        if avg_up + avg_down == 0:
            return 20
        
        adx = abs(avg_up - avg_down) / (avg_up + avg_down) * 100
        return min(100, adx)
    
    def _calculate_rsi(self, prices: np.ndarray, period: int = 14) -> float:
        """Calcule RSI"""
        if len(prices) < period + 1:
            return 50
        
        deltas = np.diff(prices)
        gains = np.where(deltas > 0, deltas, 0)
        losses = np.where(deltas < 0, -deltas, 0)
        
        avg_gain = np.mean(gains[:period])
        avg_loss = np.mean(losses[:period])
        
        for i in range(period, len(deltas)):
            avg_gain = (avg_gain * (period - 1) + gains[i]) / period
            avg_loss = (avg_loss * (period - 1) + losses[i]) / period
        
        if avg_loss == 0:
            return 100
        
        rs = avg_gain / avg_loss
        rsi = 100 - (100 / (1 + rs))
        
        return rsi
    
    def _stochastic(self, prices: np.ndarray) -> float:
        """Calcule Stochastic"""
        if len(prices) < 2:
            return 50
        high = np.max(prices)
        low = np.min(prices)
        current = prices[-1]
        if high == low:
            return 50
        return ((current - low) / (high - low)) * 100
    
    def _calculate_obv(self, prices: np.ndarray, volumes: np.ndarray) -> np.ndarray:
        """Calcule OBV"""
        obv = np.zeros(len(prices))
        obv[0] = volumes[0]
        for i in range(1, len(prices)):
            if prices[i] > prices[i-1]:
                obv[i] = obv[i-1] + volumes[i]
            elif prices[i] < prices[i-1]:
                obv[i] = obv[i-1] - volumes[i]
            else:
                obv[i] = obv[i-1]
        return obv
    
    def get_ensemble_bonus(self, ensemble_result: Dict) -> float:
        """
        Calcule un bonus de score basé sur la force du consensus
        
        Args:
            ensemble_result: Résultat de predict_ensemble
            
        Returns:
            Bonus (0-20 points)
        """
        consensus = ensemble_result.get('consensus', 'WEAK')
        prediction = ensemble_result.get('prediction', 1)
        
        # Bonus seulement pour prédictions haussières
        if prediction != 2:
            return 0
        
        if consensus == 'VERY_STRONG':
            return 20
        elif consensus == 'STRONG':
            return 15
        elif consensus == 'MODERATE':
            return 10
        else:
            return 0


# Instance globale
_ensemble_predictor = None

def get_ensemble_predictor() -> EnsemblePredictor:
    """Retourne l'instance globale de l'ensemble predictor"""
    global _ensemble_predictor
    if _ensemble_predictor is None:
        _ensemble_predictor = EnsemblePredictor()
    return _ensemble_predictor
