#!/usr/bin/env python3
"""
Advanced Feature Engineering - Extraction de 50+ features techniques avancées
Améliore drastiquement la précision du modèle IA
"""

import numpy as np
from typing import Dict, List, Optional
import logging
try:
    from config import EMA_SHORT, EMA_LONG
except ImportError:
    EMA_SHORT, EMA_LONG = 7, 25  # fallback

logger = logging.getLogger("AdvancedFeatureEngineering")

class AdvancedFeatureExtractor:
    """Extracteur de features techniques avancées"""
    
    def __init__(self):
        logger.info("✅ Advanced Feature Extractor initialisé")
    
    def extract_all_features(self, prices: List[float], volumes: List[float] = None) -> Dict:
        """
        Extrait toutes les features avancées
        
        Returns:
            Dict avec 50+ features
        """
        if not prices or len(prices) < 50:
            return {}
        
        prices_array = np.array(prices)
        volumes_array = np.array(volumes) if volumes else np.ones(len(prices))
        
        features = {}
        
        # === PRICE ACTION FEATURES ===
        features.update(self._extract_price_action_features(prices_array))
        
        # === MOMENTUM FEATURES ===
        features.update(self._extract_momentum_features(prices_array))
        
        # === VOLATILITY FEATURES ===
        features.update(self._extract_volatility_features(prices_array))
        
        # === VOLUME FEATURES ===
        features.update(self._extract_volume_features(prices_array, volumes_array))
        
        # === TREND FEATURES ===
        features.update(self._extract_trend_features(prices_array))
        
        # === CYCLE FEATURES ===
        features.update(self._extract_cycle_features(prices_array, volumes_array))
        
        return features
    
    def _extract_price_action_features(self, prices: np.ndarray) -> Dict:
        """Features de price action (candlesticks, swings, fibonacci)"""
        features = {}
        
        # Candlestick patterns (simplifié)
        if len(prices) >= 3:
            candle1 = prices[-3]
            candle2 = prices[-2]
            candle3 = prices[-1]
            
            # Doji (indécision)
            body_size = abs(candle3 - candle2) / candle2
            features['is_doji'] = 1 if body_size < 0.001 else 0
            
            # Bullish engulfing (approximation)
            features['is_bullish_engulfing'] = 1 if (candle3 > candle2 and candle2 < candle1) else 0
            
            # Bearish engulfing
            features['is_bearish_engulfing'] = 1 if (candle3 < candle2 and candle2 > candle1) else 0
        
        # Swing points (highs/lows locaux)
        if len(prices) >= 20:
            recent_high = np.max(prices[-20:])
            recent_low = np.min(prices[-20:])
            current_price = prices[-1]
            
            features['distance_to_high'] = (recent_high - current_price) / current_price
            features['distance_to_low'] = (current_price - recent_low) / current_price
            
            # Position relative dans le range
            if recent_high != recent_low:
                features['position_in_range'] = (current_price - recent_low) / (recent_high - recent_low)
            else:
                features['position_in_range'] = 0.5
        
        # Fibonacci retracements (simplifié)
        if len(prices) >= 50:
            high_50 = np.max(prices[-50:])
            low_50 = np.min(prices[-50:])
            current = prices[-1]
            
            if high_50 != low_50:
                fib_level = (current - low_50) / (high_50 - low_50)
                features['fibonacci_level'] = fib_level
                
                # Proche de niveaux clés (0.382, 0.5, 0.618)
                features['near_fib_382'] = 1 if abs(fib_level - 0.382) < 0.05 else 0
                features['near_fib_50'] = 1 if abs(fib_level - 0.5) < 0.05 else 0
                features['near_fib_618'] = 1 if abs(fib_level - 0.618) < 0.05 else 0
            else:
                features['fibonacci_level'] = 0.5
                features['near_fib_382'] = 0
                features['near_fib_50'] = 0
                features['near_fib_618'] = 0
        
        return features
    
    def _extract_momentum_features(self, prices: np.ndarray) -> Dict:
        """Features de momentum avancées"""
        features = {}
        
        # MACD
        if len(prices) >= 26:
            ema12 = self._ema(prices, 12)
            ema26 = self._ema(prices, 26)
            macd_line = ema12 - ema26
            
            # Signal line (EMA 9 du MACD)
            if len(prices) >= 35:
                macd_values = []
                for i in range(len(prices) - 26):
                    ema12_i = self._ema(prices[:26+i], 12)
                    ema26_i = self._ema(prices[:26+i], 26)
                    macd_values.append(ema12_i - ema26_i)
                
                if len(macd_values) >= 9:
                    signal_line = self._ema(np.array(macd_values), 9)
                    features['macd_histogram'] = macd_line - signal_line
                    features['macd_cross'] = 1 if macd_line > signal_line else 0
            
            features['macd'] = macd_line
            features['macd_normalized'] = macd_line / prices[-1] if prices[-1] > 0 else 0
        
        # Stochastic RSI
        rsi_values = self._calculate_rsi_array(prices, 14)
        if len(rsi_values) >= 14:
            stoch_rsi = self._stochastic(rsi_values[-14:])
            features['stochastic_rsi'] = stoch_rsi
            features['stoch_rsi_oversold'] = 1 if stoch_rsi < 20 else 0
            features['stoch_rsi_overbought'] = 1 if stoch_rsi > 80 else 0
        
        # Williams %R
        if len(prices) >= 14:
            williams_r = self._williams_r(prices, 14)
            features['williams_r'] = williams_r
            features['williams_oversold'] = 1 if williams_r < -80 else 0
        
        # ROC (Rate of Change)
        for period in [3, 5, 10, 20]:
            if len(prices) > period:
                roc = ((prices[-1] - prices[-period]) / prices[-period]) * 100
                features[f'roc_{period}'] = roc
        
        return features
    
    def _extract_volatility_features(self, prices: np.ndarray) -> Dict:
        """Features de volatilité avancées"""
        features = {}
        
        # ATR (Average True Range) en pourcentage
        if len(prices) >= 14:
            atr = self._atr(prices, 14)
            features['atr_percent'] = (atr / prices[-1]) * 100 if prices[-1] > 0 else 0
        
        # Volatilité historique (différentes périodes)
        for period in [5, 10, 20]:
            if len(prices) > period:
                returns = np.diff(prices[-period-1:]) / prices[-period-1:-1]
                volatility = np.std(returns) * 100
                features[f'volatility_{period}d'] = volatility
        
        # Volatility ratio (court terme vs long terme)
        if len(prices) >= 20:
            vol_5 = np.std(np.diff(prices[-6:]) / prices[-6:-1]) * 100
            vol_20 = np.std(np.diff(prices[-21:]) / prices[-21:-1]) * 100
            features['volatility_ratio'] = vol_5 / vol_20 if vol_20 > 0 else 1
        
        # Volatility regime
        if len(prices) >= 50:
            current_vol = np.std(np.diff(prices[-21:]) / prices[-21:-1])
            historical_vol = np.std(np.diff(prices[-51:]) / prices[-51:-1])
            
            if current_vol > historical_vol * 1.5:
                features['volatility_regime'] = 2  # High
            elif current_vol < historical_vol * 0.7:
                features['volatility_regime'] = 0  # Low
            else:
                features['volatility_regime'] = 1  # Normal
        
        return features
    
    def _extract_volume_features(self, prices: np.ndarray, volumes: np.ndarray) -> Dict:
        """Features de volume avancées"""
        features = {}
        
        if len(volumes) < 20:
            return features
        
        # OBV (On-Balance Volume)
        obv = self._calculate_obv(prices, volumes)
        features['obv_trend'] = 1 if obv[-1] > obv[-10] else 0
        
        # OBV slope
        if len(obv) >= 10:
            obv_change = (obv[-1] - obv[-10]) / (obv[-10] + 1)
            features['obv_slope'] = obv_change
        
        # VWAP (Volume Weighted Average Price)
        if len(prices) >= 20:
            vwap = self._calculate_vwap(prices[-20:], volumes[-20:])
            features['price_vs_vwap'] = (prices[-1] - vwap) / vwap if vwap > 0 else 0
        
        # Volume profile (distribution)
        avg_volume = np.mean(volumes[-20:])
        features['volume_ratio_current'] = volumes[-1] / avg_volume if avg_volume > 0 else 1
        
        # Volume trend
        recent_vol = np.mean(volumes[-5:])
        older_vol = np.mean(volumes[-20:-5])
        features['volume_trend'] = (recent_vol - older_vol) / older_vol if older_vol > 0 else 0
        
        return features
    
    def _extract_trend_features(self, prices: np.ndarray) -> Dict:
        """Features de tendance avancées"""
        features = {}
        
        # ADX (Average Directional Index)
        if len(prices) >= 14:
            adx = self._calculate_adx(prices, 14)
            features['adx'] = adx
            features['strong_trend'] = 1 if adx > 25 else 0
            features['very_strong_trend'] = 1 if adx > 40 else 0
        
        # Supertrend (simplifié)
        if len(prices) >= 10:
            atr = self._atr(prices, 10)
            multiplier = 3
            basic_ub = (prices[-1] + prices[-2]) / 2 + multiplier * atr
            basic_lb = (prices[-1] + prices[-2]) / 2 - multiplier * atr
            features['above_supertrend'] = 1 if prices[-1] > basic_lb else 0
        
        # EMA slopes (différentes périodes)
        for period in [EMA_SHORT, EMA_LONG, 99]:
            if len(prices) >= period + 5:
                ema_current = self._ema(prices, period)
                ema_5_ago = self._ema(prices[:-5], period)
                slope = (ema_current - ema_5_ago) / ema_5_ago if ema_5_ago > 0 else 0
                features[f'ema{period}_slope'] = slope
        
        return features
    
    def _extract_cycle_features(self, prices: np.ndarray, volumes: np.ndarray) -> Dict:
        """Features de cycle de marché"""
        features = {}
        
        if len(prices) < 50:
            return features
        
        # Détection phase de cycle (simplifié)
        # Accumulation = bas range + volume élevé
        # Distribution = haut range + volume élevé
        # Markup = tendance haussière forte
        # Markdown = tendance baissière forte
        
        price_range_50 = np.max(prices[-50:]) - np.min(prices[-50:])
        current_position = (prices[-1] - np.min(prices[-50:])) / price_range_50 if price_range_50 > 0 else 0.5
        
        avg_volume = np.mean(volumes[-50:]) if len(volumes) >= 50 else np.mean(volumes)
        current_volume = volumes[-1]
        volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1
        
        # Momentum récent
        momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 else 0
        
        # Classification
        if current_position < 0.3 and volume_ratio > 1.2:
            features['market_cycle_phase'] = 0  # Accumulation
        elif current_position > 0.7 and volume_ratio > 1.2:
            features['market_cycle_phase'] = 1  # Distribution
        elif momentum > 0.02:
            features['market_cycle_phase'] = 2  # Markup (hausse)
        elif momentum < -0.02:
            features['market_cycle_phase'] = 3  # Markdown (baisse)
        else:
            features['market_cycle_phase'] = 4  # Neutral
        
        return features
    
    # === HELPER FUNCTIONS ===
    
    def _ema(self, prices: np.ndarray, period: int) -> float:
        """Calcule EMA"""
        if len(prices) < period:
            return np.mean(prices)
        multiplier = 2 / (period + 1)
        ema = np.mean(prices[:period])
        for price in prices[period:]:
            ema = (price * multiplier) + (ema * (1 - multiplier))
        return ema
    
    def _calculate_rsi_array(self, prices: np.ndarray, period: int = 14) -> np.ndarray:
        """Calcule RSI pour un array"""
        if len(prices) < period + 1:
            return np.array([50])
        
        deltas = np.diff(prices)
        gains = np.where(deltas > 0, deltas, 0)
        losses = np.where(deltas < 0, -deltas, 0)
        
        avg_gain = np.mean(gains[:period])
        avg_loss = np.mean(losses[:period])
        
        rsi_values = []
        for i in range(period, len(deltas)):
            avg_gain = (avg_gain * (period - 1) + gains[i]) / period
            avg_loss = (avg_loss * (period - 1) + losses[i]) / period
            
            if avg_loss == 0:
                rsi = 100
            else:
                rs = avg_gain / avg_loss
                rsi = 100 - (100 / (1 + rs))
            
            rsi_values.append(rsi)
        
        return np.array(rsi_values)
    
    def _stochastic(self, values: np.ndarray) -> float:
        """Calcule Stochastic"""
        if len(values) < 2:
            return 50
        high = np.max(values)
        low = np.min(values)
        current = values[-1]
        if high == low:
            return 50
        return ((current - low) / (high - low)) * 100
    
    def _williams_r(self, prices: np.ndarray, period: int) -> float:
        """Calcule Williams %R"""
        if len(prices) < period:
            return -50
        high = np.max(prices[-period:])
        low = np.min(prices[-period:])
        close = prices[-1]
        if high == low:
            return -50
        return ((high - close) / (high - low)) * -100
    
    def _atr(self, prices: np.ndarray, period: int) -> float:
        """Calcule ATR (simplifié sans high/low)"""
        if len(prices) < period + 1:
            return 0
        true_ranges = np.abs(np.diff(prices[-period-1:]))
        return np.mean(true_ranges)
    
    def _calculate_obv(self, prices: np.ndarray, volumes: np.ndarray) -> np.ndarray:
        """Calcule OBV"""
        obv = np.zeros(len(prices))
        obv[0] = volumes[0]
        for i in range(1, len(prices)):
            if prices[i] > prices[i-1]:
                obv[i] = obv[i-1] + volumes[i]
            elif prices[i] < prices[i-1]:
                obv[i] = obv[i-1] - volumes[i]
            else:
                obv[i] = obv[i-1]
        return obv
    
    def _calculate_vwap(self, prices: np.ndarray, volumes: np.ndarray) -> float:
        """Calcule VWAP"""
        if len(prices) != len(volumes) or len(prices) == 0:
            return prices[-1] if len(prices) > 0 else 0
        return np.sum(prices * volumes) / np.sum(volumes) if np.sum(volumes) > 0 else prices[-1]
    
    def _calculate_adx(self, prices: np.ndarray, period: int) -> float:
        """Calcule ADX (simplifié)"""
        if len(prices) < period + 1:
            return 20
        
        # Calculer DM+ et DM-
        ups = []
        downs = []
        for i in range(1, len(prices)):
            up = prices[i] - prices[i-1]
            down = prices[i-1] - prices[i]
            ups.append(up if up > down and up > 0 else 0)
            downs.append(down if down > up and down > 0 else 0)
        
        # Moyennes
        avg_up = np.mean(ups[-period:])
        avg_down = np.mean(downs[-period:])
        
        # ADX simplifié
        if avg_up + avg_down == 0:
            return 20
        
        adx = abs(avg_up - avg_down) / (avg_up + avg_down) * 100
        return min(100, adx)


# Instance globale
_feature_extractor = None

def get_feature_extractor() -> AdvancedFeatureExtractor:
    """Retourne l'instance globale de l'extracteur de features"""
    global _feature_extractor
    if _feature_extractor is None:
        _feature_extractor = AdvancedFeatureExtractor()
    return _feature_extractor
