#!/usr/bin/env python3
"""
AI Advanced Scorer - Analyse GPU avancée pour scoring des cryptomonnaies
Utilise PyTorch avec CUDA (RTX) pour une analyse multi-dimensionnelle approfondie
"""

import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional
import json
import os
import logging
from dataclasses import dataclass, field
from collections import deque

# Configuration du logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AIAdvancedScorer")

# Vérifier PyTorch avec GPU
TORCH_AVAILABLE = False
DEVICE = "cpu"
GPU_NAME = "CPU"

# Placeholder pour torch si non disponible
torch = None
nn = None
F = None

try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    # 🔵 FIX 25/03: Limiter threads CPU pour éviter saturation
    torch.set_num_threads(2)
    try:
        torch.set_num_interop_threads(1)
    except RuntimeError:
        pass  # Peut être déjà fixé par un autre module importé avant

    if torch.cuda.is_available():
        DEVICE = "cuda"
        GPU_NAME = torch.cuda.get_device_name(0)
        gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1024**3
        logger.info(f"✅ GPU RTX activé: {GPU_NAME}")
        logger.info(f"   - Mémoire: {gpu_mem:.1f} GB")
        logger.info(f"   - CUDA: {torch.version.cuda}")
        TORCH_AVAILABLE = True
    else:
        TORCH_AVAILABLE = True  # PyTorch dispo mais sans GPU
except ImportError:
    logger.info("ℹ️  PyTorch non installé - Mode CPU utilisé")
    TORCH_AVAILABLE = False


@dataclass
class CryptoProfile:
    """Profil complet d'une crypto pour le scoring avancé"""
    symbol: str
    timestamp: datetime = field(default_factory=datetime.now)
    
    # === SCORES COMPOSANTS (0-100) ===
    technical_score: float = 0.0       # Score technique (EMA, BB, RSI)
    momentum_score: float = 0.0        # Score momentum (vitesse de mouvement)
    volatility_score: float = 0.0      # Score volatilité (opportunité)
    volume_score: float = 0.0          # Score volume (confirmation)
    pattern_score: float = 0.0         # Score patterns (formations)
    trend_score: float = 0.0           # Score tendance (direction)
    reversal_score: float = 0.0        # Score retournement (timing)
    risk_score: float = 0.0            # Score risque (0=risqué, 100=sûr)
    
    # === SCORE FINAL ===
    profit_potential: float = 0.0      # Potentiel de profit estimé
    entry_quality: float = 0.0         # Qualité du point d'entrée
    final_score: float = 0.0           # Score final pondéré
    
    # === SIGNAUX ===
    signal: str = "HOLD"               # ACHAT, HOLD, VENTE, NO_BUY
    confidence: float = 0.0            # Confiance dans le signal (0-100)
    priority: int = 0                  # Priorité (plus haut = plus urgent)
    
    # === CONTEXTE ===
    reasons: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)
    
    def to_dict(self) -> Dict:
        return {
            'symbol': self.symbol,
            'timestamp': self.timestamp.isoformat(),
            'scores': {
                'technical': round(self.technical_score, 1),
                'momentum': round(self.momentum_score, 1),
                'volatility': round(self.volatility_score, 1),
                'volume': round(self.volume_score, 1),
                'pattern': round(self.pattern_score, 1),
                'trend': round(self.trend_score, 1),
                'reversal': round(self.reversal_score, 1),
                'risk': round(self.risk_score, 1),
            },
            'profit_potential': round(self.profit_potential, 2),
            'entry_quality': round(self.entry_quality, 1),
            'final_score': round(self.final_score, 1),
            'signal': self.signal,
            'confidence': round(self.confidence, 1),
            'priority': self.priority,
            'reasons': self.reasons,
            'warnings': self.warnings
        }


class TechnicalAnalyzer:
    """Analyse technique avancée avec GPU"""
    
    @staticmethod
    def calculate_ema(prices: np.ndarray, period: int) -> np.ndarray:
        """Calcule EMA avec numpy vectorisé"""
        alpha = 2 / (period + 1)
        ema = np.zeros_like(prices)
        ema[0] = prices[0]
        for i in range(1, len(prices)):
            ema[i] = alpha * prices[i] + (1 - alpha) * ema[i-1]
        return ema
    
    @staticmethod
    def calculate_rsi(prices: np.ndarray, period: int = 14) -> float:
        """Calcule RSI"""
        if len(prices) < period + 1:
            return 50.0
        
        deltas = np.diff(prices)
        gains = np.where(deltas > 0, deltas, 0)
        losses = np.where(deltas < 0, -deltas, 0)
        
        avg_gain = np.mean(gains[-period:])
        avg_loss = np.mean(losses[-period:])
        
        if avg_loss == 0:
            return 100.0
        
        rs = avg_gain / avg_loss
        return 100 - (100 / (1 + rs))
    
    @staticmethod
    def calculate_bollinger(prices: np.ndarray, period: int = 20, std_dev: float = 2.0) -> Tuple[float, float, float]:
        """Calcule Bollinger Bands"""
        if len(prices) < period:
            return prices[-1], prices[-1], prices[-1]
        
        sma = np.mean(prices[-period:])
        std = np.std(prices[-period:])
        
        upper = sma + std_dev * std
        lower = sma - std_dev * std
        
        return upper, sma, lower
    
    @staticmethod
    def calculate_macd(prices: np.ndarray) -> Tuple[float, float, float]:
        """Calcule MACD"""
        if len(prices) < 26:
            return 0.0, 0.0, 0.0
        
        ema12 = TechnicalAnalyzer.calculate_ema(prices, 12)
        ema26 = TechnicalAnalyzer.calculate_ema(prices, 26)
        
        macd_line = ema12[-1] - ema26[-1]
        
        # Signal line (EMA 9 du MACD)
        macd_history = ema12[-9:] - ema26[-9:]
        signal_line = np.mean(macd_history)
        
        histogram = macd_line - signal_line
        
        return macd_line, signal_line, histogram
    
    @staticmethod
    def calculate_atr(highs: np.ndarray, lows: np.ndarray, closes: np.ndarray, period: int = 14) -> float:
        """Calcule Average True Range"""
        if len(highs) < period + 1:
            return 0.0
        
        tr = np.maximum(
            highs[1:] - lows[1:],
            np.abs(highs[1:] - closes[:-1]),
            np.abs(lows[1:] - closes[:-1])
        )
        
        return np.mean(tr[-period:])


class GPUFeatureExtractor:
    """Extraction de features avec accélération GPU"""
    
    def __init__(self):
        self.device = torch.device(DEVICE) if TORCH_AVAILABLE else None
        
    def extract_all_features(self, prices: List[float], volumes: List[float] = None) -> Dict:
        """Extrait toutes les features avec GPU si disponible"""
        
        if len(prices) < 50:
            return self._get_default_features()
        
        prices_np = np.array(prices, dtype=np.float64)
        
        if TORCH_AVAILABLE and self.device and self.device.type == 'cuda':
            return self._extract_gpu(prices_np, volumes)
        else:
            return self._extract_cpu(prices_np, volumes)
    
    def _extract_gpu(self, prices: np.ndarray, volumes: List[float] = None) -> Dict:
        """Extraction GPU optimisée"""
        try:
            # Convertir en tenseur GPU
            prices_tensor = torch.tensor(prices, dtype=torch.float32, device=self.device)
            
            features = {}
            
            # === MOYENNES MOBILES (GPU) ===
            # Kernel de convolution pour SMA
            for period in [5, 9, 12, 21, 26, 50]:
                if len(prices) >= period:
                    kernel = torch.ones(period, device=self.device) / period
                    # Reshape pour avoir un tenseur 3D (batch, channel, length) pour F.pad
                    prices_3d = prices_tensor.unsqueeze(0).unsqueeze(0)  # (1, 1, N)
                    padded = F.pad(prices_3d, (period-1, 0), mode='replicate')  # Padding gauche
                    sma = F.conv1d(padded, kernel.unsqueeze(0).unsqueeze(0)).squeeze()
                    features[f'sma_{period}'] = sma[-1].item()
            
            # === EMA (GPU) ===
            for period in [9, 12, 21, 26]:
                if len(prices) >= period:
                    ema = self._gpu_ema(prices_tensor, period)
                    features[f'ema_{period}'] = ema[-1].item()
                    
                    # Pente de l'EMA
                    if len(ema) >= 3:
                        slope = ((ema[-1] - ema[-3]) / ema[-3] * 100).item()
                        features[f'ema_{period}_slope'] = slope
            
            # === MOMENTUM (GPU) ===
            for period in [3, 5, 10, 20]:
                if len(prices) >= period:
                    mom = ((prices_tensor[-1] - prices_tensor[-period]) / prices_tensor[-period] * 100).item()
                    features[f'momentum_{period}'] = mom
            
            # === VOLATILITÉ (GPU) ===
            returns = (prices_tensor[1:] - prices_tensor[:-1]) / prices_tensor[:-1] * 100
            features['volatility_5'] = torch.std(returns[-5:]).item() if len(returns) >= 5 else 0
            features['volatility_20'] = torch.std(returns[-20:]).item() if len(returns) >= 20 else 0
            
            # === BOLLINGER BANDS ===
            bb_period = 20
            if len(prices) >= bb_period:
                bb_prices = prices_tensor[-bb_period:]
                bb_mid = torch.mean(bb_prices).item()
                bb_std = torch.std(bb_prices).item()
                bb_upper = bb_mid + 2 * bb_std
                bb_lower = bb_mid - 2 * bb_std
                
                features['bb_upper'] = bb_upper
                features['bb_mid'] = bb_mid
                features['bb_lower'] = bb_lower
                features['bb_bandwidth'] = (bb_upper - bb_lower) / bb_mid * 100 if bb_mid > 0 else 0
                features['bb_position'] = (prices[-1] - bb_lower) / (bb_upper - bb_lower) if (bb_upper - bb_lower) > 0 else 0.5
            
            # === RSI ===
            features['rsi'] = TechnicalAnalyzer.calculate_rsi(prices)
            
            # === MACD ===
            macd, signal, hist = TechnicalAnalyzer.calculate_macd(prices)
            features['macd'] = macd
            features['macd_signal'] = signal
            features['macd_histogram'] = hist
            
            # === PATTERNS ===
            features.update(self._detect_patterns_gpu(prices_tensor))
            
            # === VOLUME ===
            if volumes is not None and len(volumes) >= 20:
                vol_tensor = torch.tensor(volumes[-20:], dtype=torch.float32, device=self.device)
                features['volume_sma'] = torch.mean(vol_tensor).item()
                features['volume_ratio'] = volumes[-1] / features['volume_sma'] if features['volume_sma'] > 0 else 1
                features['volume_trend'] = ((vol_tensor[-1] - vol_tensor[-5]) / vol_tensor[-5] * 100).item() if vol_tensor[-5] > 0 else 0
            else:
                features['volume_sma'] = 0
                features['volume_ratio'] = 1
                features['volume_trend'] = 0
            
            return features
            
        except Exception as e:
            logger.warning(f"Erreur GPU, fallback CPU: {e}")
            return self._extract_cpu(prices, volumes)
    
    def _gpu_ema(self, prices, period: int):
        """Calcule EMA sur GPU"""
        alpha = 2 / (period + 1)
        ema = torch.zeros_like(prices)
        ema[0] = prices[0]
        for i in range(1, len(prices)):
            ema[i] = alpha * prices[i] + (1 - alpha) * ema[i-1]
        return ema
    
    def _detect_patterns_gpu(self, prices) -> Dict:
        """Détection de patterns sur GPU"""
        patterns = {}
        
        if len(prices) < 20:
            return patterns
        
        # === SUPPORT/RÉSISTANCE ===
        recent = prices[-20:]
        patterns['local_high'] = torch.max(recent).item()
        patterns['local_low'] = torch.min(recent).item()
        patterns['price_range'] = (patterns['local_high'] - patterns['local_low']) / patterns['local_low'] * 100
        
        # === TENDANCE ===
        # Régression linéaire simple sur GPU
        x = torch.arange(len(recent), dtype=torch.float32, device=self.device)
        x_mean = torch.mean(x)
        y_mean = torch.mean(recent)
        
        slope = torch.sum((x - x_mean) * (recent - y_mean)) / torch.sum((x - x_mean) ** 2)
        patterns['trend_slope'] = slope.item()
        patterns['trend_strength'] = abs(slope.item()) * 100
        
        # === DOUBLE BOTTOM / DOUBLE TOP ===
        mid = len(recent) // 2
        first_half_low = torch.min(recent[:mid]).item()
        second_half_low = torch.min(recent[mid:]).item()
        
        # Double bottom si les deux creux sont proches (< 1%)
        patterns['double_bottom'] = abs(first_half_low - second_half_low) / first_half_low < 0.01
        
        first_half_high = torch.max(recent[:mid]).item()
        second_half_high = torch.max(recent[mid:]).item()
        patterns['double_top'] = abs(first_half_high - second_half_high) / first_half_high < 0.01
        
        # === SQUEEZE ===
        bb_bandwidth = patterns.get('bb_bandwidth', 5)
        patterns['squeeze_active'] = bb_bandwidth < 3
        
        # === BREAKOUT ===
        current_price = prices[-1].item()
        patterns['near_resistance'] = (patterns['local_high'] - current_price) / current_price < 0.01
        patterns['near_support'] = (current_price - patterns['local_low']) / current_price < 0.01
        
        return patterns
    
    def _extract_cpu(self, prices: np.ndarray, volumes: List[float] = None) -> Dict:
        """Extraction CPU (fallback)"""
        features = {}
        
        # === EMAs ===
        for period in [9, 12, 21, 26]:
            if len(prices) >= period:
                ema = TechnicalAnalyzer.calculate_ema(prices, period)
                features[f'ema_{period}'] = ema[-1]
                if len(ema) >= 3:
                    features[f'ema_{period}_slope'] = (ema[-1] - ema[-3]) / ema[-3] * 100
        
        # === MOMENTUM ===
        for period in [3, 5, 10, 20]:
            if len(prices) >= period:
                features[f'momentum_{period}'] = (prices[-1] - prices[-period]) / prices[-period] * 100
        
        # === BOLLINGER ===
        bb_upper, bb_mid, bb_lower = TechnicalAnalyzer.calculate_bollinger(prices)
        features['bb_upper'] = bb_upper
        features['bb_mid'] = bb_mid
        features['bb_lower'] = bb_lower
        features['bb_bandwidth'] = (bb_upper - bb_lower) / bb_mid * 100 if bb_mid > 0 else 0
        features['bb_position'] = (prices[-1] - bb_lower) / (bb_upper - bb_lower) if (bb_upper - bb_lower) > 0 else 0.5
        
        # === RSI ===
        features['rsi'] = TechnicalAnalyzer.calculate_rsi(prices)
        
        # === MACD ===
        macd, signal, hist = TechnicalAnalyzer.calculate_macd(prices)
        features['macd'] = macd
        features['macd_signal'] = signal
        features['macd_histogram'] = hist
        
        # === VOLATILITÉ ===
        returns = np.diff(prices) / prices[:-1] * 100
        features['volatility_5'] = np.std(returns[-5:]) if len(returns) >= 5 else 0
        features['volatility_20'] = np.std(returns[-20:]) if len(returns) >= 20 else 0
        
        return features
    
    def _get_default_features(self) -> Dict:
        """Features par défaut si pas assez de données"""
        return {
            'ema_9': 0, 'ema_21': 0,
            'momentum_3': 0, 'momentum_5': 0,
            'bb_position': 0.5, 'bb_bandwidth': 5,
            'rsi': 50, 'volatility_5': 0
        }


class AIAdvancedScorer:
    """Scoring IA avancé avec analyse multi-dimensionnelle"""
    
    def __init__(self):
        self.device = DEVICE  # Exposer le device utilisé
        self.feature_extractor = GPUFeatureExtractor()
        self.profiles_cache: Dict[str, CryptoProfile] = {}
        self.historical_performance: Dict[str, List[float]] = {}
        
        # === PONDÉRATIONS OPTIMISÉES POUR PROFIT ===
        self.weights = {
            'technical': 0.20,    # Analyse technique classique
            'momentum': 0.15,     # Vitesse de mouvement
            'volatility': 0.10,   # Opportunité de mouvement
            'volume': 0.10,       # Confirmation
            'pattern': 0.15,      # Formations graphiques
            'trend': 0.10,        # Direction générale
            'reversal': 0.15,     # Timing de retournement (CRITIQUE)
            'risk': 0.05          # Gestion du risque
        }
        
        logger.info(f"🧠 AIAdvancedScorer initialisé - Device: {DEVICE}")
        if DEVICE == 'cuda':
            logger.info(f"   GPU: {GPU_NAME}")
    
    def analyze_crypto(self, symbol: str, prices: List[float], 
                       volumes: List[float] = None,
                       current_price: float = None) -> CryptoProfile:
        """Analyse complète d'une crypto et génère son profil"""
        
        profile = CryptoProfile(symbol=symbol)
        
        if len(prices) < 50:
            profile.warnings.append("Données insuffisantes (<50 bougies)")
            return profile
        
        # Extraire les features
        features = self.feature_extractor.extract_all_features(prices, volumes)
        
        if current_price is None:
            current_price = prices[-1]
        
        # === 1. SCORE TECHNIQUE (EMA, BB, RSI) ===
        profile.technical_score = self._calculate_technical_score(features, current_price)
        
        # === 2. SCORE MOMENTUM ===
        profile.momentum_score = self._calculate_momentum_score(features)
        
        # === 3. SCORE VOLATILITÉ ===
        profile.volatility_score = self._calculate_volatility_score(features)
        
        # === 4. SCORE VOLUME ===
        profile.volume_score = self._calculate_volume_score(features)
        
        # === 5. SCORE PATTERN ===
        profile.pattern_score = self._calculate_pattern_score(features)
        
        # === 6. SCORE TENDANCE ===
        profile.trend_score = self._calculate_trend_score(features)
        
        # === 7. SCORE RETOURNEMENT (CRITIQUE) ===
        profile.reversal_score = self._calculate_reversal_score(features, prices)
        
        # === 8. SCORE RISQUE ===
        profile.risk_score = self._calculate_risk_score(features)
        
        # === CALCUL DU SCORE FINAL ===
        profile.final_score = self._calculate_final_score(profile)
        
        # === POTENTIEL DE PROFIT ===
        profile.profit_potential = self._estimate_profit_potential(features, profile)
        
        # === QUALITÉ DU POINT D'ENTRÉE ===
        profile.entry_quality = self._calculate_entry_quality(features, profile)
        
        # === GÉNÉRATION DU SIGNAL ===
        self._generate_signal(profile, features)
        
        # === CACHE ===
        self.profiles_cache[symbol] = profile
        
        return profile
    
    def _calculate_technical_score(self, features: Dict, price: float) -> float:
        """Score basé sur EMA, BB, RSI — FIX 01/04: pénaliser le bearish, pas le récompenser"""
        score = 50  # Base neutre
        
        # EMA Configuration
        ema9 = features.get('ema_9', price)
        ema21 = features.get('ema_21', price)
        ema_diff = (ema9 - ema21) / ema21 * 100 if ema21 > 0 else 0
        ema9_slope = features.get('ema_9_slope', features.get('ema_slope', 0))
        mom3 = features.get('momentum_3', 0)
        
        # 🔴 FIX 01/04: EMA bearish = PÉNALITÉ, pas bonus
        # Avant: ema_diff < -0.1 → +20 (traitait bearish comme "opportunité")
        # Maintenant: bearish = pénalité, SAUF si reversal confirmé (EMA slope + momentum)
        if ema_diff < -0.1:
            if ema9_slope > 0 and mom3 > 0:
                score += 10  # Bearish MAIS reversal en cours → petit bonus
            else:
                score -= 15  # Bearish sans reversal → pénalité
        elif ema_diff < 0:
            score += 5   # Légèrement bearish, neutre
        elif ema_diff > 0.5:
            score -= 15  # Tendance déjà en cours, risque de correction
        
        # ═══════════════════════════════════════════════════════════════════════
        # BONUS CROISEMENT IMMINENT: EMA9 sous EMA21 mais en train de remonter
        # C'est la configuration IDÉALE pour acheter!
        # ═══════════════════════════════════════════════════════════════════════
        if ema_diff < 0 and ema_diff > -0.5 and ema9_slope > 0 and mom3 > 0:
            score += 20  # Croisement haussier imminent!
        
        # Position BB — FIX 01/04: bonus conditionné au momentum
        # BB basse + prix qui continue de chuter = falling knife, PAS zone de rebond
        bb_pos = features.get('bb_position', 0.5)
        if bb_pos < 0.2:
            if mom3 > 0:
                score += 20  # BB basse + momentum positif = vrai rebond
            else:
                score += 5   # BB basse mais chute continue → prudence
        elif bb_pos < 0.3:
            if mom3 > 0:
                score += 10
            else:
                score += 3
        elif bb_pos > 0.8:
            score -= 20  # Surachat
        
        # RSI — FIX 01/04: oversold sans reversal = bearish, pas bullish
        rsi = features.get('rsi', 50)
        if rsi < 30:
            if mom3 > 0 and ema9_slope > 0:
                score += 15  # Survendu + reversal confirmé = vrai signal
            elif mom3 > 0:
                score += 5   # Survendu + momentum positif = début de reversal
            else:
                score -= 10  # Survendu + chute continue = falling knife
        elif rsi < 40:
            if mom3 > 0:
                score += 10
            else:
                score -= 5   # RSI bas sans momentum = bearish
        elif rsi > 80:
            score -= 30  # Surachat extrême - pénalité forte
        elif rsi > 70:
            score -= 20  # Surachat - pénalité augmentée (était -15)
        
        return max(0, min(100, score))
    
    def _calculate_momentum_score(self, features: Dict) -> float:
        """Score basé sur la vitesse de mouvement"""
        score = 50
        
        mom3 = features.get('momentum_3', 0)
        mom5 = features.get('momentum_5', 0)
        mom10 = features.get('momentum_10', 0)
        
        # Retournement positif (momentum court terme remonte)
        if mom3 > 0 and mom5 < 0:
            score += 25  # Début de retournement haussier
        elif mom3 > mom5 > 0:
            score += 15  # Accélération haussière
        elif mom3 < -0.5:
            score -= 20  # Chute en cours
        
        # Momentum court terme positif
        if mom3 > 0.2:
            score += 10
        elif mom3 > 0:
            score += 5
        
        # Stabilisation après baisse (bon signe)
        if -0.1 < mom3 < 0.1 and mom5 < -0.3:
            score += 20  # Prix qui se stabilise après une baisse
        
        return max(0, min(100, score))
    
    def _calculate_volatility_score(self, features: Dict) -> float:
        """Score basé sur la volatilité (opportunité)"""
        score = 50
        
        vol5 = features.get('volatility_5', 0)
        vol20 = features.get('volatility_20', 0)
        bb_bandwidth = features.get('bb_bandwidth', 5)
        
        # Squeeze (faible volatilité) = explosion potentielle
        if bb_bandwidth < 2:
            score += 30  # Squeeze très serré
        elif bb_bandwidth < 3:
            score += 20
        elif bb_bandwidth < 4:
            score += 10
        elif bb_bandwidth > 8:
            score -= 15  # Trop volatile
        
        # Volatilité qui augmente = mouvement en cours
        if vol5 > vol20 * 1.5:
            score += 10  # Volatilité en expansion
        
        return max(0, min(100, score))
    
    def _calculate_volume_score(self, features: Dict) -> float:
        """Score basé sur le volume"""
        score = 50
        
        vol_ratio = features.get('volume_ratio', 1)
        vol_trend = features.get('volume_trend', 0)
        bb_pos = features.get('bb_position', 0.5)
        mom3 = features.get('momentum_3', 0)
        
        # Volume élevé = confirmation
        if vol_ratio > 2:
            score += 25  # Volume exceptionnel
        elif vol_ratio > 1.5:
            score += 15
        elif vol_ratio > 1.2:
            score += 10
        elif vol_ratio < 0.5:
            score -= 15  # Faible volume = méfiance
        
        # Volume en augmentation
        if vol_trend > 50:
            score += 15
        elif vol_trend > 20:
            score += 10
        
        # ═══════════════════════════════════════════════════════════════════════
        # VOLUME_REVERSAL: Rebond sur BB basse avec volume = signal TRÈS fort!
        # C'est exactement ce qu'on cherche: prix bas + volume = accumulation
        # ═══════════════════════════════════════════════════════════════════════
        is_volume_reversal = (
            bb_pos < 0.35 and          # Prix proche BB basse
            vol_ratio > 1.3 and        # Volume supérieur à la moyenne
            mom3 > 0                   # Prix en hausse
        )
        if is_volume_reversal:
            score += 25  # GROS BONUS pour volume reversal!
        
        return max(0, min(100, score))
    
    def _calculate_pattern_score(self, features: Dict) -> float:
        """Score basé sur les patterns détectés"""
        score = 50
        
        # Double bottom (figure de retournement haussier)
        if features.get('double_bottom', False):
            score += 25
        
        # Double top (figure de retournement baissier)
        if features.get('double_top', False):
            score -= 20
        
        # Proche support
        if features.get('near_support', False):
            score += 20
        
        # Proche résistance
        if features.get('near_resistance', False):
            score -= 10
        
        # Squeeze actif
        if features.get('squeeze_active', False):
            score += 15
        
        return max(0, min(100, score))
    
    def _calculate_trend_score(self, features: Dict) -> float:
        """Score basé sur la tendance"""
        score = 50
        
        trend_slope = features.get('trend_slope', 0)
        ema9_slope = features.get('ema_9_slope', 0)
        ema21_slope = features.get('ema_21_slope', 0)
        
        # Tendance globale
        if trend_slope > 0.01:
            score += 15  # Tendance haussière
        elif trend_slope < -0.01:
            score -= 10  # Tendance baissière
        
        # EMA9 en hausse
        if ema9_slope > 0.1:
            score += 15
        elif ema9_slope > 0:
            score += 10
        elif ema9_slope < -0.2:
            score -= 15
        
        # EMA21 stable ou haussière
        if ema21_slope and ema21_slope >= 0:
            score += 10
        
        return max(0, min(100, score))
    
    def _calculate_reversal_score(self, features: Dict, prices: List[float]) -> float:
        """Score de retournement — FIX 01/04: exiger des PREUVES de reversal"""
        score = 50
        
        bb_pos = features.get('bb_position', 0.5)
        mom3 = features.get('momentum_3', 0)
        mom5 = features.get('momentum_5', 0)
        ema9 = features.get('ema_9', 0)
        ema21 = features.get('ema_21', 0)
        ema_diff = ema9 - ema21
        ema_diff_pct = ema_diff / ema21 * 100 if ema21 > 0 else 0
        ema9_slope = features.get('ema_9_slope', features.get('ema_slope', 0))
        rsi = features.get('rsi', 50)
        
        # === CONDITIONS DE RETOURNEMENT — PREUVES REQUISES ===
        # 🔴 FIX 01/04: être bas (BB/EMA/RSI) N'EST PAS un reversal.
        # Un reversal = être bas ET remonter (momentum positif, EMA slope positive).
        
        has_positive_momentum = mom3 > 0
        has_ema_turning = ema9_slope > 0
        has_reversal_evidence = has_positive_momentum or has_ema_turning
        
        # 1. Prix proche de la BB basse + preuve de rebond
        if bb_pos < 0.15:
            if has_reversal_evidence:
                score += 25  # Vrai rebond depuis BB basse
            else:
                score += 5   # Juste bas, pas encore de reversal
        elif bb_pos < 0.25:
            if has_reversal_evidence:
                score += 15
            else:
                score += 3
        
        # 2. EMA9 sous EMA21 + preuve de rapprochement
        if ema_diff_pct < -0.1:
            if has_ema_turning and has_positive_momentum:
                score += 20  # EMA bearish mais en train de se retourner
            elif has_ema_turning:
                score += 10
            # Sinon: pas de bonus — juste bearish
        elif ema_diff_pct < 0:
            if has_ema_turning:
                score += 10
        
        # ═══════════════════════════════════════════════════════════════════════
        # 2b. CROISEMENT EMA IMMINENT - Le moment IDÉAL d'achat!
        # ═══════════════════════════════════════════════════════════════════════
        is_crossover_imminent = (
            ema_diff_pct < 0 and
            ema_diff_pct > -0.5 and
            ema9_slope > 0 and
            mom3 > 0
        )
        
        if is_crossover_imminent:
            score += 30  # GROS BONUS pour croisement imminent!
            
        # Croisement très proche (EMA_diff entre -0.2% et 0)
        if -0.2 < ema_diff_pct < 0 and ema9_slope > 0:
            score += 15  # Bonus supplémentaire
        
        # 3. Momentum qui se stabilise ou remonte
        if mom3 > 0 and mom5 < 0:
            score += 25  # Retournement en cours !
        elif mom3 > mom5 and mom3 > 0:
            score += 15  # Momentum qui s'améliore ET positif
        elif -0.1 < mom3 < 0.1 and mom5 < -0.3:
            score += 10  # Stabilisation après baisse (réduit de 20→10)
        
        # 4. RSI survendu + preuve de remontée
        if rsi < 35:
            if has_positive_momentum:
                score += 15
                if len(prices) >= 3:
                    prev_rsi = TechnicalAnalyzer.calculate_rsi(np.array(prices[:-1]))
                    if rsi > prev_rsi:
                        score += 10  # RSI en divergence haussière
            # Sinon: RSI bas sans momentum = pas un signal de reversal
        
        # 5. MACD histogram qui remonte
        macd_hist = features.get('macd_histogram', 0)
        if macd_hist > 0:
            score += 10
        elif macd_hist > -0.001:  # Proche de 0, croisement imminent
            score += 15
        
        return max(0, min(100, score))
    
    def _calculate_risk_score(self, features: Dict) -> float:
        """Score de risque (100 = peu risqué, 0 = très risqué)"""
        score = 100  # Commence à 100 et on soustrait les risques
        
        # Volatilité excessive
        vol = features.get('volatility_5', 0)
        if vol > 3:
            score -= 30
        elif vol > 2:
            score -= 20
        elif vol > 1.5:
            score -= 10
        
        # RSI extrême
        rsi = features.get('rsi', 50)
        if rsi > 80 or rsi < 20:
            score -= 20
        elif rsi > 70 or rsi < 30:
            score -= 10
        
        # Bandwidth BB très large (volatilité)
        bb_bw = features.get('bb_bandwidth', 5)
        if bb_bw > 10:
            score -= 25
        elif bb_bw > 7:
            score -= 15
        
        # Momentum très négatif
        mom5 = features.get('momentum_5', 0)
        if mom5 < -2:
            score -= 30
        elif mom5 < -1:
            score -= 20
        elif mom5 < -0.5:
            score -= 10
        
        return max(0, min(100, score))
    
    def _calculate_final_score(self, profile: CryptoProfile) -> float:
        """Calcule le score final pondéré"""
        score = (
            profile.technical_score * self.weights['technical'] +
            profile.momentum_score * self.weights['momentum'] +
            profile.volatility_score * self.weights['volatility'] +
            profile.volume_score * self.weights['volume'] +
            profile.pattern_score * self.weights['pattern'] +
            profile.trend_score * self.weights['trend'] +
            profile.reversal_score * self.weights['reversal'] +
            profile.risk_score * self.weights['risk']
        )
        return score
    
    def _estimate_profit_potential(self, features: Dict, profile: CryptoProfile) -> float:
        """Estime le potentiel de profit en %"""
        
        # Base: position dans les BB
        bb_pos = features.get('bb_position', 0.5)
        
        # Plus on est bas dans les BB, plus le potentiel est élevé
        if bb_pos < 0.2:
            potential = 3.0  # 3% potentiel si très bas
        elif bb_pos < 0.3:
            potential = 2.5
        elif bb_pos < 0.4:
            potential = 2.0
        elif bb_pos < 0.5:
            potential = 1.5
        else:
            potential = 1.0
        
        # Bonus si squeeze (explosion attendue)
        if features.get('squeeze_active', False) or features.get('bb_bandwidth', 5) < 3:
            potential *= 1.3
        
        # Bonus si retournement détecté
        if profile.reversal_score > 70:
            potential *= 1.2
        
        # Malus si risque élevé
        if profile.risk_score < 50:
            potential *= 0.7
        
        return round(potential, 2)
    
    def _calculate_entry_quality(self, features: Dict, profile: CryptoProfile) -> float:
        """Évalue la qualité du point d'entrée actuel"""
        quality = 50
        
        bb_pos = features.get('bb_position', 0.5)
        ema_diff = features.get('ema_9', 0) - features.get('ema_21', 0)
        ema_diff_pct = ema_diff / features.get('ema_21', 1) * 100 if features.get('ema_21', 0) > 0 else 0
        
        # Proche BB basse = excellent point d'entrée
        if bb_pos < 0.2:
            quality += 25
        elif bb_pos < 0.3:
            quality += 15
        elif bb_pos > 0.7:
            quality -= 20
        
        # EMA9 < EMA21 = bon point d'entrée
        if ema_diff_pct < -0.1:
            quality += 20
        elif ema_diff_pct < 0:
            quality += 10
        elif ema_diff_pct > 0.3:
            quality -= 15
        
        # Score de retournement élevé
        if profile.reversal_score > 70:
            quality += 15
        elif profile.reversal_score > 60:
            quality += 10
        
        # Momentum stabilisé
        mom3 = features.get('momentum_3', 0)
        if -0.2 < mom3 < 0.5:
            quality += 10
        elif mom3 < -0.5:
            quality -= 15
        
        return max(0, min(100, quality))
    
    def _generate_signal(self, profile: CryptoProfile, features: Dict):
        """Génère le signal final et les raisons"""
        
        ema9 = features.get('ema_9', 0)
        ema21 = features.get('ema_21', 0)
        ema_diff_pct = (ema9 - ema21) / ema21 * 100 if ema21 > 0 else 0
        bb_pos = features.get('bb_position', 0.5)
        mom3 = features.get('momentum_3', 0)
        ema_slope = features.get('ema_9_slope', features.get('ema_slope', 0))
        rsi = features.get('rsi', 50)
        
        # ═══════════════════════════════════════════════════════════════════════
        # DÉTECTION CROISEMENT EMA IMMINENT - Signal d'achat PRIORITAIRE!
        # Quand EMA9 < EMA21 mais EMA9 remonte avec momentum positif
        # = Les EMAs vont se croiser = MEILLEUR moment pour acheter
        # ASSOUPLI 29/12 18h45: Accepter Mom3 > -0.10% au lieu de > 0
        # ═══════════════════════════════════════════════════════════════════════
        is_crossover_imminent = (
            ema_diff_pct < 0 and                    # EMA9 encore sous EMA21
            ema_diff_pct > -0.5 and                 # Proche du croisement (< 0.5%)
            ema_slope > 0 and                       # EMA9 en hausse
            mom3 > -0.001 and                       # Momentum > -0.10% (accepter léger négatif)
            rsi < 70                                # Pas en surachat
        )
        
        # === CONDITIONS D'ACHAT AU CREUX ===
        buy_conditions = [
            profile.final_score >= 65,
            profile.entry_quality >= 60,
            ema_diff_pct < -0.05,  # EMA9 sous EMA21
            profile.risk_score >= 40
        ]
        
        buy_optimal = [
            profile.final_score >= 75,
            profile.reversal_score >= 70,
            bb_pos < 0.3,
            profile.profit_potential >= 2.0
        ]
        
        # === DÉCISION ===
        
        # ═══════════════════════════════════════════════════════════════════════
        # PRIORITÉ 0: CROISEMENT EMA IMMINENT = ACHAT (même avec score plus bas)
        # C'est le signal le plus fiable car on achète JUSTE AVANT la confirmation
        # ═══════════════════════════════════════════════════════════════════════
        if is_crossover_imminent and profile.final_score >= 55 and profile.risk_score >= 40:
            profile.signal = "ACHAT"
            profile.confidence = min(90, profile.final_score + 15)
            profile.priority = 0  # PRIORITÉ MAXIMALE
            profile.reasons.append(f"🔥 CROISEMENT EMA IMMINENT (EMA_diff={ema_diff_pct:.2f}%, Mom3={mom3:.2f}%)")
        
        # BLOQUER tout achat si EMA9 > EMA21 (correction après pic)
        elif ema_diff_pct >= 0:
            # EMA9 >= EMA21 = on est dans une zone de correction potentielle
            if mom3 < 0:
                profile.signal = "NO_BUY"
                profile.confidence = 90
                profile.priority = 10
                profile.reasons.append("🚫 EMA9 > EMA21 + Momentum négatif = Correction en cours")
            else:
                profile.signal = "POSSIBLE"  # Changé de NO_BUY à POSSIBLE si momentum positif
                profile.confidence = profile.final_score * 0.7
                profile.priority = 4
                profile.reasons.append("⏳ EMA9 > EMA21 mais momentum positif = Surveiller")
        
        # Configuration optimale au CREUX uniquement
        elif all(buy_optimal) and ema_diff_pct < 0:
            profile.signal = "ACHAT"
            profile.confidence = min(95, profile.final_score + 10)
            profile.priority = 1  # Haute priorité
            profile.reasons.append("🔥 Configuration optimale au creux détectée")
            
        # Conditions d'achat au CREUX
        elif all(buy_conditions) and ema_diff_pct < 0:
            profile.signal = "ACHAT"
            profile.confidence = profile.final_score
            profile.priority = 2
            profile.reasons.append("✅ Conditions d'achat au creux remplies")
            
        # Opportunité potentielle au creux
        elif sum(buy_conditions) >= 3 and ema_diff_pct < 0:
            profile.signal = "POSSIBLE"
            profile.confidence = profile.final_score * 0.8
            profile.priority = 3
            profile.reasons.append("⚡ Opportunité potentielle au creux")
            
        # Sinon HOLD
        else:
            profile.signal = "HOLD"
            profile.confidence = 50
            profile.priority = 5
            profile.reasons.append("⏳ Attendre conditions favorables")
        
        # === RAISONS DÉTAILLÉES ===
        if profile.reversal_score >= 70:
            profile.reasons.append(f"🔄 Retournement détecté (score: {profile.reversal_score:.0f})")
        if bb_pos < 0.25:
            profile.reasons.append(f"🛡️ Zone de rebond BB ({bb_pos:.1%})")
        if profile.profit_potential >= 2.5:
            profile.reasons.append(f"💰 Fort potentiel (+{profile.profit_potential:.1f}%)")
        if profile.momentum_score >= 70:
            profile.reasons.append("📈 Momentum favorable")
        
        # === WARNINGS ===
        if profile.risk_score < 50:
            profile.warnings.append(f"⚠️ Risque élevé (score: {profile.risk_score:.0f})")
        if features.get('volatility_5', 0) > 2:
            profile.warnings.append("⚠️ Volatilité élevée")
        if features.get('momentum_5', 0) < -1:
            profile.warnings.append("⚠️ Momentum très négatif")
    
    def get_top_opportunities(self, profiles: List[CryptoProfile], limit: int = 10) -> List[CryptoProfile]:
        """Retourne les meilleures opportunités triées par priorité"""
        
        # Filtrer les signaux ACHAT et POSSIBLE
        buyable = [p for p in profiles if p.signal in ['ACHAT', 'POSSIBLE']]
        
        # Trier par: priorité (asc), puis score (desc), puis profit_potential (desc)
        sorted_profiles = sorted(
            buyable,
            key=lambda p: (p.priority, -p.final_score, -p.profit_potential)
        )
        
        return sorted_profiles[:limit]
    
    def batch_analyze(self, cryptos_data: Dict[str, Dict]) -> Dict[str, CryptoProfile]:
        """Analyse un batch de cryptos en parallèle (GPU optimisé)"""
        results = {}
        
        for symbol, data in cryptos_data.items():
            prices = data.get('prices', [])
            volumes = data.get('volumes', [])
            
            if len(prices) >= 50:
                profile = self.analyze_crypto(symbol, prices, volumes)
                results[symbol] = profile
        
        return results


# === INSTANCE GLOBALE ===
_advanced_scorer = None

def get_advanced_scorer() -> AIAdvancedScorer:
    """Retourne l'instance globale du scorer avancé"""
    global _advanced_scorer
    if _advanced_scorer is None:
        _advanced_scorer = AIAdvancedScorer()
    return _advanced_scorer


# === TEST ===
if __name__ == "__main__":
    import random
    
    print("=" * 60)
    print("  TEST AI ADVANCED SCORER")
    print("=" * 60)
    
    scorer = get_advanced_scorer()
    
    # Générer des données de test
    np.random.seed(42)
    base_price = 100
    prices = [base_price]
    for _ in range(99):
        change = np.random.normal(0, 0.5)
        prices.append(prices[-1] * (1 + change/100))
    
    # Simuler un creux
    for i in range(10):
        prices.append(prices[-1] * 0.995)  # Baisse
    for i in range(5):
        prices.append(prices[-1] * 1.002)  # Stabilisation
    
    volumes = [random.uniform(1000000, 5000000) for _ in range(len(prices))]
    
    # Analyser
    profile = scorer.analyze_crypto("TESTUSDT", prices, volumes)
    
    print(f"\n📊 Analyse de {profile.symbol}:")
    print(f"   Score Final: {profile.final_score:.1f}/100")
    print(f"   Signal: {profile.signal} (confiance: {profile.confidence:.0f}%)")
    print(f"   Potentiel: +{profile.profit_potential:.1f}%")
    print(f"   Qualité entrée: {profile.entry_quality:.0f}/100")
    print(f"\n   Scores détaillés:")
    print(f"   - Technique: {profile.technical_score:.0f}")
    print(f"   - Momentum: {profile.momentum_score:.0f}")
    print(f"   - Retournement: {profile.reversal_score:.0f}")
    print(f"   - Risque: {profile.risk_score:.0f}")
    print(f"\n   Raisons: {profile.reasons}")
    print(f"   Warnings: {profile.warnings}")
