"""
Méthodes Ensemble ML pour Trading
Inspiré de: https://github.com/N00Bception/AI-CryptoTrader
Combine Random Forest, Gradient Boosting et autres modèles
"""
import numpy as np
import logging
from typing import List, Dict, Optional, Tuple
from collections import deque
import json
import os

# Import conditionnel de sklearn
try:
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from sklearn.preprocessing import StandardScaler
    from sklearn.model_selection import train_test_split
    SKLEARN_AVAILABLE = True
except ImportError:
    SKLEARN_AVAILABLE = False
    logging.warning("⚠️ scikit-learn non installé. Installez avec: pip install scikit-learn")


class EnsembleMLPredictor:
    """
    Prédicteur ensemble combinant plusieurs modèles ML
    """

    def __init__(self, max_history=1000):
        self.max_history = max_history
        self.training_data = deque(maxlen=max_history)

        # Modèles
        self.random_forest = None
        self.gradient_boosting = None
        self.scaler = StandardScaler()

        # Performance tracking
        self.model_weights = {
            'random_forest': 1.0,
            'gradient_boosting': 1.0,
            'indicators': 1.0  # Indicateurs techniques traditionnels
        }

        self.predictions_history = []
        self.is_trained = False

        if SKLEARN_AVAILABLE:
            self._initialize_models()
        else:
            logging.warning("⚠️ Ensemble ML désactivé (sklearn manquant)")

    def _initialize_models(self):
        """Initialise les modèles ML"""
        # Random Forest - bon pour capturer les non-linéarités
        self.random_forest = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            min_samples_split=20,
            min_samples_leaf=10,
            random_state=42
        )

        # Gradient Boosting - bon pour améliorer progressivement
        self.gradient_boosting = GradientBoostingClassifier(
            n_estimators=100,
            learning_rate=0.1,
            max_depth=5,
            min_samples_split=20,
            min_samples_leaf=10,
            random_state=42
        )

        logging.info("✅ Modèles ensemble initialisés")

    def extract_features(self, prices: List[float], volumes: Optional[List[float]] = None,
                        rsi: Optional[float] = None, ema_short: Optional[float] = None,
                        ema_long: Optional[float] = None) -> np.ndarray:
        """
        Extraction de features pour ML
        """
        if len(prices) < 50:
            return None

        features = []

        current_price = prices[-1]

        # 1. FEATURES BASÉES SUR LES PRIX
        # Returns sur différentes périodes
        for period in [1, 3, 5, 10, 20]:
            if len(prices) > period:
                ret = (current_price - prices[-period-1]) / prices[-period-1]
                features.append(ret)
            else:
                features.append(0)

        # Volatilité
        if len(prices) >= 20:
            volatility = np.std(prices[-20:]) / np.mean(prices[-20:])
            features.append(volatility)
        else:
            features.append(0)

        # 2. FEATURES STATISTIQUES
        if len(prices) >= 20:
            # Asymétrie (skewness)
            returns = np.diff(prices[-20:]) / prices[-20:-1]
            skew = np.mean(((returns - np.mean(returns)) / np.std(returns)) ** 3) if np.std(returns) > 0 else 0
            features.append(skew)

            # Kurtosis
            kurt = np.mean(((returns - np.mean(returns)) / np.std(returns)) ** 4) - 3 if np.std(returns) > 0 else 0
            features.append(kurt)
        else:
            features.extend([0, 0])

        # 3. FEATURES MOMENTUM
        # Prix relatif aux moyennes
        if len(prices) >= 20:
            sma_20 = np.mean(prices[-20:])
            features.append((current_price - sma_20) / sma_20)
        else:
            features.append(0)

        if len(prices) >= 50:
            sma_50 = np.mean(prices[-50:])
            features.append((current_price - sma_50) / sma_50)
        else:
            features.append(0)

        # 4. INDICATEURS TECHNIQUES (si fournis)
        if rsi is not None:
            features.append(rsi / 100.0)  # Normalisation
        else:
            features.append(0.5)

        if ema_short is not None and ema_long is not None:
            ema_diff = (ema_short - ema_long) / ema_long
            features.append(ema_diff)
        else:
            features.append(0)

        # 5. FEATURES VOLUME (si disponible)
        if volumes and len(volumes) >= 20:
            current_volume = volumes[-1]
            avg_volume = np.mean(volumes[-20:])
            volume_ratio = current_volume / avg_volume if avg_volume > 0 else 1
            features.append(volume_ratio)
        else:
            features.append(1)

        # 6. PATTERNS DE PRIX
        # Plus haut/bas relatif sur 20 périodes
        if len(prices) >= 20:
            high_20 = max(prices[-20:])
            low_20 = min(prices[-20:])
            range_20 = high_20 - low_20

            if range_20 > 0:
                position_in_range = (current_price - low_20) / range_20
                features.append(position_in_range)
            else:
                features.append(0.5)
        else:
            features.append(0.5)

        return np.array(features).reshape(1, -1)

    def add_training_sample(self, features: np.ndarray, label: int):
        """
        Ajoute un échantillon d'entraînement
        label: 1 = prix monte, 0 = prix descend
        """
        if features is None:
            return

        self.training_data.append({
            'features': features.flatten(),
            'label': label
        })

    def train_models(self, min_samples=100):
        """
        Entraîne les modèles ensemble
        """
        if not SKLEARN_AVAILABLE:
            return False

        if len(self.training_data) < min_samples:
            logging.info(f"⏳ Pas assez de données pour entraîner ({len(self.training_data)}/{min_samples})")
            return False

        try:
            # Préparer les données
            X = np.array([sample['features'] for sample in self.training_data])
            y = np.array([sample['label'] for sample in self.training_data])

            # Vérifier qu'on a les deux classes
            if len(np.unique(y)) < 2:
                logging.warning("⚠️ Une seule classe dans les données, impossible d'entraîner")
                return False

            # Split train/test
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=0.2, random_state=42, stratify=y
            )

            # Normalisation
            X_train_scaled = self.scaler.fit_transform(X_train)
            X_test_scaled = self.scaler.transform(X_test)

            # Entraînement Random Forest
            self.random_forest.fit(X_train_scaled, y_train)
            rf_score = self.random_forest.score(X_test_scaled, y_test)

            # Entraînement Gradient Boosting
            self.gradient_boosting.fit(X_train_scaled, y_train)
            gb_score = self.gradient_boosting.score(X_test_scaled, y_test)

            # Ajuster les poids selon les performances
            total_score = rf_score + gb_score
            if total_score > 0:
                self.model_weights['random_forest'] = rf_score / total_score
                self.model_weights['gradient_boosting'] = gb_score / total_score

            self.is_trained = True

            logging.info(f"✅ Modèles entraînés - RF: {rf_score:.2%}, GB: {gb_score:.2%}")
            logging.info(f"📊 Poids: RF={self.model_weights['random_forest']:.2f}, "
                        f"GB={self.model_weights['gradient_boosting']:.2f}")

            return True

        except Exception as e:
            logging.error(f"❌ Erreur entraînement: {e}")
            return False

    def predict(self, features: np.ndarray, indicators_signal: Optional[str] = None,
               indicators_confidence: float = 0.5) -> Tuple[str, float, Dict]:
        """
        Prédiction ensemble combinant tous les modèles
        Retourne: (signal, confiance, détails)
        """
        if not SKLEARN_AVAILABLE or not self.is_trained or features is None:
            # Fallback sur indicateurs seulement
            if indicators_signal:
                return indicators_signal, indicators_confidence, {'method': 'indicators_only'}
            return "HOLD", 0.5, {'method': 'no_prediction'}

        try:
            # Normaliser les features
            features_scaled = self.scaler.transform(features)

            # Prédictions des modèles
            rf_proba = self.random_forest.predict_proba(features_scaled)[0]
            gb_proba = self.gradient_boosting.predict_proba(features_scaled)[0]

            # rf_proba et gb_proba sont de forme [proba_classe_0, proba_classe_1]
            # classe_1 = prix monte (BUY), classe_0 = prix descend (SELL)

            # Score ensemble pondéré
            ensemble_proba_up = (
                rf_proba[1] * self.model_weights['random_forest'] +
                gb_proba[1] * self.model_weights['gradient_boosting']
            ) / (self.model_weights['random_forest'] + self.model_weights['gradient_boosting'])

            # Intégrer les indicateurs techniques
            if indicators_signal == "BUY":
                indicators_score = indicators_confidence / 10  # Convertir en probabilité
                ensemble_proba_up = (ensemble_proba_up * 0.7 + indicators_score * 0.3)
            elif indicators_signal == "SELL":
                indicators_score = indicators_confidence / 10
                ensemble_proba_up = (ensemble_proba_up * 0.7 + (1 - indicators_score) * 0.3)

            # Seuils de décision
            if ensemble_proba_up > 0.65:
                signal = "BUY"
                confidence = ensemble_proba_up
            elif ensemble_proba_up < 0.35:
                signal = "SELL"
                confidence = 1 - ensemble_proba_up
            else:
                signal = "HOLD"
                confidence = 0.5

            details = {
                'method': 'ensemble',
                'rf_proba_up': float(rf_proba[1]),
                'gb_proba_up': float(gb_proba[1]),
                'ensemble_proba_up': float(ensemble_proba_up),
                'rf_weight': self.model_weights['random_forest'],
                'gb_weight': self.model_weights['gradient_boosting']
            }

            return signal, confidence, details

        except Exception as e:
            logging.error(f"❌ Erreur prédiction ensemble: {e}")
            if indicators_signal:
                return indicators_signal, indicators_confidence, {'method': 'indicators_fallback'}
            return "HOLD", 0.5, {'method': 'error'}

    def save_models(self, filepath: str):
        """Sauvegarde les modèles entraînés"""
        if not self.is_trained:
            return False

        try:
            import pickle
            data = {
                'random_forest': self.random_forest,
                'gradient_boosting': self.gradient_boosting,
                'scaler': self.scaler,
                'model_weights': self.model_weights,
                'is_trained': self.is_trained
            }

            with open(filepath, 'wb') as f:
                pickle.dump(data, f)

            logging.info(f"💾 Modèles sauvegardés: {filepath}")
            return True

        except Exception as e:
            logging.error(f"❌ Erreur sauvegarde: {e}")
            return False

    def load_models(self, filepath: str):
        """Charge les modèles depuis un fichier"""
        if not os.path.exists(filepath):
            return False

        try:
            import pickle

            with open(filepath, 'rb') as f:
                data = pickle.load(f)

            self.random_forest = data['random_forest']
            self.gradient_boosting = data['gradient_boosting']
            self.scaler = data['scaler']
            self.model_weights = data['model_weights']
            self.is_trained = data['is_trained']

            logging.info(f"📂 Modèles chargés: {filepath}")
            return True

        except Exception as e:
            logging.error(f"❌ Erreur chargement: {e}")
            return False

    def update_model_performance(self, prediction: str, actual_result: str):
        """
        Met à jour les poids des modèles selon leurs performances
        """
        # Système d'adaptation dynamique
        # Si un modèle prédit mieux, son poids augmente

        # Pour l'instant, simple logging
        # TODO: implémenter l'ajustement dynamique des poids
        correct = (prediction == actual_result)

        self.predictions_history.append({
            'prediction': prediction,
            'actual': actual_result,
            'correct': correct
        })

        # Garder seulement les 100 dernières prédictions
        if len(self.predictions_history) > 100:
            self.predictions_history.pop(0)

        # Calculer l'accuracy récente
        if len(self.predictions_history) >= 10:
            recent_accuracy = sum(1 for p in self.predictions_history[-10:] if p['correct']) / 10
            logging.debug(f"📊 Accuracy récente (10): {recent_accuracy:.1%}")


# Instance globale
ensemble_predictor = None

def get_ensemble_predictor(max_history=1000):
    """Récupère ou crée l'instance du prédicteur ensemble"""
    global ensemble_predictor
    if ensemble_predictor is None:
        ensemble_predictor = EnsembleMLPredictor(max_history)
    return ensemble_predictor
