"""
Feature Engineering Avancé pour Trading
Inspiré de: https://github.com/asavinov/intelligent-trading-bot
Utilise TSFresh, statistiques avancées et TA-lib
"""
import numpy as np
import logging
from typing import List, Dict, Optional
from scipy import stats

# Import conditionnel
try:
    import talib
    TALIB_AVAILABLE = True
except ImportError:
    TALIB_AVAILABLE = False
    logging.warning("⚠️ TA-Lib non installé. Fonctionnalités limitées.")


class FeatureEngineer:
    """
    Générateur de features avancées pour le machine learning
    """

    @staticmethod
    def calculate_statistical_features(prices: List[float], window: int = 20) -> Dict[str, float]:
        """
        Calcule des features statistiques avancées
        """
        if len(prices) < window:
            return {}

        recent_prices = np.array(prices[-window:])
        returns = np.diff(recent_prices) / recent_prices[:-1]

        features = {}

        # Statistiques de base
        features['mean'] = np.mean(recent_prices)
        features['std'] = np.std(recent_prices)
        features['variance'] = np.var(recent_prices)

        # Statistiques sur les returns
        if len(returns) > 0:
            features['return_mean'] = np.mean(returns)
            features['return_std'] = np.std(returns)

            # Skewness (asymétrie) - mesure l'asymétrie de la distribution
            features['skewness'] = stats.skew(returns)

            # Kurtosis (aplatissement) - mesure les queues de distribution
            features['kurtosis'] = stats.kurtosis(returns)

            # Percentiles
            features['percentile_25'] = np.percentile(recent_prices, 25)
            features['percentile_50'] = np.percentile(recent_prices, 50)  # Médiane
            features['percentile_75'] = np.percentile(recent_prices, 75)

        # Tendance linéaire (pente)
        if len(recent_prices) > 1:
            x = np.arange(len(recent_prices))
            slope, intercept, r_value, p_value, std_err = stats.linregress(x, recent_prices)
            features['trend_slope'] = slope
            features['trend_r_squared'] = r_value ** 2

        # Autocorrélation (mémoire du signal)
        if len(returns) > 1:
            features['autocorr_1'] = np.corrcoef(returns[:-1], returns[1:])[0, 1] if len(returns) > 2 else 0

        return features

    @staticmethod
    def calculate_momentum_features(prices: List[float]) -> Dict[str, float]:
        """
        Features de momentum sur plusieurs horizons temporels
        """
        features = {}

        current_price = prices[-1]

        # Momentum sur différentes périodes
        for period in [3, 5, 10, 20, 50]:
            if len(prices) > period:
                past_price = prices[-period-1]
                momentum = (current_price - past_price) / past_price
                features[f'momentum_{period}'] = momentum
            else:
                features[f'momentum_{period}'] = 0

        # Rate of Change (ROC)
        for period in [5, 10, 20]:
            if len(prices) > period:
                roc = ((current_price - prices[-period-1]) / prices[-period-1]) * 100
                features[f'roc_{period}'] = roc
            else:
                features[f'roc_{period}'] = 0

        return features

    @staticmethod
    def calculate_volatility_features(prices: List[float]) -> Dict[str, float]:
        """
        Features de volatilité
        """
        features = {}

        # Volatilité sur différentes fenêtres
        for window in [5, 10, 20]:
            if len(prices) >= window:
                returns = np.diff(prices[-window:]) / prices[-window:-1]
                features[f'volatility_{window}'] = np.std(returns) * np.sqrt(252)  # Annualisée
            else:
                features[f'volatility_{window}'] = 0

        # ATR-like (Average True Range approximation)
        if len(prices) >= 14:
            ranges = [abs(prices[i] - prices[i-1]) for i in range(-14, 0)]
            features['atr_14'] = np.mean(ranges)
        else:
            features['atr_14'] = 0

        return features

    @staticmethod
    def calculate_pattern_features(prices: List[float]) -> Dict[str, float]:
        """
        Features basées sur les patterns de prix
        """
        features = {}

        current_price = prices[-1]

        # Position dans le range sur différentes périodes
        for period in [10, 20, 50]:
            if len(prices) >= period:
                high = max(prices[-period:])
                low = min(prices[-period:])
                range_size = high - low

                if range_size > 0:
                    position_in_range = (current_price - low) / range_size
                    features[f'position_in_range_{period}'] = position_in_range
                else:
                    features[f'position_in_range_{period}'] = 0.5
            else:
                features[f'position_in_range_{period}'] = 0.5

        # Distance aux plus hauts/bas
        if len(prices) >= 20:
            high_20 = max(prices[-20:])
            low_20 = min(prices[-20:])

            features['distance_from_high_20'] = (high_20 - current_price) / current_price
            features['distance_from_low_20'] = (current_price - low_20) / current_price

        # Nombre de fois où le prix a touché certains niveaux
        if len(prices) >= 50:
            # Compteur de touches du support (prix bas)
            low_50 = min(prices[-50:])
            threshold = low_50 * 1.01  # 1% au-dessus du plus bas
            touches_support = sum(1 for p in prices[-50:] if p <= threshold)
            features['support_touches'] = touches_support

        return features

    @staticmethod
    def calculate_talib_features(prices: np.ndarray, highs: Optional[np.ndarray] = None,
                                lows: Optional[np.ndarray] = None,
                                volumes: Optional[np.ndarray] = None) -> Dict[str, float]:
        """
        Features TA-Lib (si disponible)
        """
        features = {}

        if not TALIB_AVAILABLE or len(prices) < 30:
            return features

        try:
            # Indicateurs de momentum
            features['rsi_14'] = talib.RSI(prices, timeperiod=14)[-1] if len(prices) >= 14 else 50

            # MACD
            macd, macdsignal, macdhist = talib.MACD(prices, fastperiod=12, slowperiod=26, signalperiod=9)
            if len(macd) > 0:
                features['macd'] = macd[-1]
                features['macd_signal'] = macdsignal[-1]
                features['macd_hist'] = macdhist[-1]

            # Moyennes mobiles
            features['sma_20'] = talib.SMA(prices, timeperiod=20)[-1] if len(prices) >= 20 else prices[-1]
            features['ema_12'] = talib.EMA(prices, timeperiod=12)[-1] if len(prices) >= 12 else prices[-1]
            features['ema_26'] = talib.EMA(prices, timeperiod=26)[-1] if len(prices) >= 26 else prices[-1]

            # Bollinger Bands
            if len(prices) >= 20:
                upper, middle, lower = talib.BBANDS(prices, timeperiod=20, nbdevup=2, nbdevdn=2, matype=0)
                features['bb_upper'] = upper[-1]
                features['bb_middle'] = middle[-1]
                features['bb_lower'] = lower[-1]
                features['bb_width'] = (upper[-1] - lower[-1]) / middle[-1]

            # Stochastic (si highs/lows disponibles)
            if highs is not None and lows is not None and len(highs) >= 14:
                slowk, slowd = talib.STOCH(highs, lows, prices, fastk_period=14,
                                          slowk_period=3, slowd_period=3)
                if len(slowk) > 0:
                    features['stoch_k'] = slowk[-1]
                    features['stoch_d'] = slowd[-1]

            # ADX (Average Directional Index) - force de tendance
            if highs is not None and lows is not None and len(prices) >= 14:
                adx = talib.ADX(highs, lows, prices, timeperiod=14)
                if len(adx) > 0:
                    features['adx'] = adx[-1]

            # CCI (Commodity Channel Index)
            if highs is not None and lows is not None and len(prices) >= 14:
                cci = talib.CCI(highs, lows, prices, timeperiod=14)
                if len(cci) > 0:
                    features['cci'] = cci[-1]

            # OBV (On Balance Volume) - si volumes disponibles
            if volumes is not None and len(volumes) >= 20:
                obv = talib.OBV(prices, volumes)
                if len(obv) > 0:
                    features['obv'] = obv[-1]

        except Exception as e:
            logging.error(f"❌ Erreur calcul TA-Lib features: {e}")

        return features

    @staticmethod
    def calculate_time_features() -> Dict[str, float]:
        """
        Features temporelles (heure, jour de la semaine, etc.)
        """
        from datetime import datetime

        features = {}
        now = datetime.now()

        # Heure (cyclique)
        hour = now.hour
        features['hour_sin'] = np.sin(2 * np.pi * hour / 24)
        features['hour_cos'] = np.cos(2 * np.pi * hour / 24)

        # Jour de la semaine (cyclique)
        day = now.weekday()
        features['day_sin'] = np.sin(2 * np.pi * day / 7)
        features['day_cos'] = np.cos(2 * np.pi * day / 7)

        # Week-end flag
        features['is_weekend'] = 1 if day >= 5 else 0

        return features

    @staticmethod
    def generate_all_features(prices: List[float],
                             highs: Optional[List[float]] = None,
                             lows: Optional[List[float]] = None,
                             volumes: Optional[List[float]] = None) -> Dict[str, float]:
        """
        Génère toutes les features disponibles
        """
        all_features = {}

        # Statistiques
        all_features.update(FeatureEngineer.calculate_statistical_features(prices))

        # Momentum
        all_features.update(FeatureEngineer.calculate_momentum_features(prices))

        # Volatilité
        all_features.update(FeatureEngineer.calculate_volatility_features(prices))

        # Patterns
        all_features.update(FeatureEngineer.calculate_pattern_features(prices))

        # TA-Lib (si disponible)
        if TALIB_AVAILABLE and len(prices) >= 30:
            prices_array = np.array(prices)
            highs_array = np.array(highs) if highs else prices_array
            lows_array = np.array(lows) if lows else prices_array
            volumes_array = np.array(volumes) if volumes else None

            all_features.update(FeatureEngineer.calculate_talib_features(
                prices_array, highs_array, lows_array, volumes_array
            ))

        # Features temporelles
        all_features.update(FeatureEngineer.calculate_time_features())

        return all_features

    @staticmethod
    def features_to_array(features: Dict[str, float], feature_names: Optional[List[str]] = None) -> np.ndarray:
        """
        Convertit un dict de features en array numpy pour ML
        """
        if feature_names is None:
            # Utiliser toutes les features dans un ordre trié
            feature_names = sorted(features.keys())

        feature_values = [features.get(name, 0.0) for name in feature_names]
        return np.array(feature_values).reshape(1, -1)

    @staticmethod
    def normalize_features(features: Dict[str, float]) -> Dict[str, float]:
        """
        Normalise les features pour ML
        """
        normalized = {}

        for key, value in features.items():
            # Éviter les valeurs infinies ou NaN
            if np.isnan(value) or np.isinf(value):
                normalized[key] = 0.0
            else:
                # Clipper les valeurs extrêmes
                normalized[key] = np.clip(value, -100, 100)

        return normalized
