#!/usr/bin/env python3
"""
Adaptive Backtesting - Backtest avec retraining périodique simulé
Inspiré de FreqAI - Émule le comportement réel du bot avec adaptation

Features:
- Simulation retraining périodique (tous les N jours)
- Utilise UNIQUEMENT les données disponibles AVANT chaque retraining
- Métriques réalistes (win rate, profit, drawdown, Sharpe)
- Comparaison vs backtest classique (optimiste)
- Visualisation des résultats
"""

import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional
import json
import os
import logging
from pathlib import Path
from dataclasses import dataclass, field
from collections import defaultdict

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("AdaptiveBacktesting")

try:
    from ai_predictor import AIPredictor, PatternItem
    AI_PREDICTOR_AVAILABLE = True
except ImportError:
    AI_PREDICTOR_AVAILABLE = False
    PatternItem = dict  # Fallback to dict if not available
    logger.warning("⚠️ AI Predictor non disponible")

try:
    from binance_api import BinanceClient
    BINANCE_AVAILABLE = True
except ImportError:
    BINANCE_AVAILABLE = False
    logger.warning("⚠️ Binance Client non disponible")


@dataclass
class BacktestTrade:
    """Représente un trade dans le backtest"""
    symbol: str
    entry_time: datetime
    entry_price: float
    exit_time: Optional[datetime] = None
    exit_price: Optional[float] = None
    quantity: float = 0.0
    side: str = "BUY"
    pattern: str = "UNKNOWN"
    score: float = 0.0
    
    stop_loss: float = 0.0
    take_profit: float = 0.0
    
    pnl: float = 0.0
    pnl_pct: float = 0.0
    fees_pct: float = 0.1  # 0.1% Binance fees
    
    status: str = "OPEN"  # OPEN, TP, SL, TIMEOUT
    hold_time_hours: float = 0.0
    
    def close(self, exit_price: float, exit_time: datetime, reason: str):
        """Fermer le trade"""
        self.exit_price = exit_price
        self.exit_time = exit_time
        self.status = reason
        
        # Calculer P&L
        price_change = (exit_price - self.entry_price) / self.entry_price
        self.pnl_pct = (price_change * 100) - self.fees_pct
        self.pnl = self.quantity * (exit_price - self.entry_price) - (self.quantity * self.entry_price * self.fees_pct / 100)
        
        # Hold time
        self.hold_time_hours = (exit_time - self.entry_time).total_seconds() / 3600
    
    def to_dict(self) -> Dict:
        return {
            'symbol': self.symbol,
            'entry_time': self.entry_time.isoformat(),
            'entry_price': self.entry_price,
            'exit_time': self.exit_time.isoformat() if self.exit_time else None,
            'exit_price': self.exit_price,
            'quantity': self.quantity,
            'pattern': self.pattern,
            'score': self.score,
            'pnl': round(self.pnl, 2),
            'pnl_pct': round(self.pnl_pct, 2),
            'status': self.status,
            'hold_time_hours': round(self.hold_time_hours, 1)
        }


@dataclass
class BacktestMetrics:
    """Métriques de performance du backtest"""
    total_trades: int = 0
    winning_trades: int = 0
    losing_trades: int = 0
    
    win_rate: float = 0.0
    
    total_profit: float = 0.0
    total_loss: float = 0.0
    net_profit: float = 0.0
    
    avg_win: float = 0.0
    avg_loss: float = 0.0
    profit_factor: float = 0.0
    
    max_drawdown: float = 0.0
    max_drawdown_pct: float = 0.0
    
    sharpe_ratio: float = 0.0
    sortino_ratio: float = 0.0
    
    avg_hold_time_hours: float = 0.0
    
    retrains_count: int = 0
    
    def calculate(self, trades: List[BacktestTrade], retrains: int = 0):
        """Calculer les métriques depuis une liste de trades"""
        if not trades:
            return
        
        self.total_trades = len(trades)
        self.retrains_count = retrains
        
        wins = [t for t in trades if t.pnl > 0]
        losses = [t for t in trades if t.pnl <= 0]
        
        self.winning_trades = len(wins)
        self.losing_trades = len(losses)
        self.win_rate = self.winning_trades / self.total_trades if self.total_trades > 0 else 0
        
        self.total_profit = sum(t.pnl for t in wins)
        self.total_loss = abs(sum(t.pnl for t in losses))
        self.net_profit = self.total_profit - self.total_loss
        
        self.avg_win = self.total_profit / len(wins) if wins else 0
        self.avg_loss = self.total_loss / len(losses) if losses else 0
        self.profit_factor = self.total_profit / self.total_loss if self.total_loss > 0 else 0
        
        # Drawdown
        cumulative_pnl = 0
        peak = 0
        max_dd = 0
        
        for trade in trades:
            cumulative_pnl += trade.pnl
            if cumulative_pnl > peak:
                peak = cumulative_pnl
            drawdown = peak - cumulative_pnl
            if drawdown > max_dd:
                max_dd = drawdown
        
        self.max_drawdown = max_dd
        self.max_drawdown_pct = (max_dd / peak * 100) if peak > 0 else 0
        
        # Sharpe ratio (simplifié)
        returns = [t.pnl_pct for t in trades]
        if len(returns) > 1:
            avg_return = np.mean(returns)
            std_return = np.std(returns)
            self.sharpe_ratio = (avg_return / std_return) * np.sqrt(252) if std_return > 0 else 0
            
            # Sortino (uniquement downside deviation)
            downside_returns = [r for r in returns if r < 0]
            if downside_returns:
                downside_std = np.std(downside_returns)
                self.sortino_ratio = (avg_return / downside_std) * np.sqrt(252) if downside_std > 0 else 0
        
        # Hold time moyen
        self.avg_hold_time_hours = np.mean([t.hold_time_hours for t in trades if t.hold_time_hours > 0])
    
    def to_dict(self) -> Dict:
        return {
            'total_trades': self.total_trades,
            'winning_trades': self.winning_trades,
            'losing_trades': self.losing_trades,
            'win_rate': round(self.win_rate * 100, 2),
            'net_profit': round(self.net_profit, 2),
            'total_profit': round(self.total_profit, 2),
            'total_loss': round(self.total_loss, 2),
            'avg_win': round(self.avg_win, 2),
            'avg_loss': round(self.avg_loss, 2),
            'profit_factor': round(self.profit_factor, 2),
            'max_drawdown': round(self.max_drawdown, 2),
            'max_drawdown_pct': round(self.max_drawdown_pct, 2),
            'sharpe_ratio': round(self.sharpe_ratio, 2),
            'sortino_ratio': round(self.sortino_ratio, 2),
            'avg_hold_time_hours': round(self.avg_hold_time_hours, 1),
            'retrains_count': self.retrains_count
        }


class AdaptiveBacktester:
    """
    Backtester avec retraining adaptatif simulé
    
    Simule le comportement réel du bot:
    1. Démarre avec modèle initial
    2. Trade pendant N jours
    3. Retrain sur données accumulées
    4. Repeat
    """
    
    def __init__(self,
                 initial_capital: float = 10000.0,
                 max_positions: int = 10,
                 position_size_usdt: float = 100.0,
                 retrain_interval_days: int = 7,
                 max_hold_hours: int = 48):
        """
        Args:
            initial_capital: Capital initial USDT
            max_positions: Nombre max positions simultanées
            position_size_usdt: Taille de chaque position
            retrain_interval_days: Intervalle entre retrainings (jours)
            max_hold_hours: Durée max d'une position (heures)
        """
        self.initial_capital = initial_capital
        self.current_capital = initial_capital
        self.max_positions = max_positions
        self.position_size = position_size_usdt
        self.retrain_interval = timedelta(days=retrain_interval_days)
        self.max_hold_time = timedelta(hours=max_hold_hours)
        
        self.open_positions: List[BacktestTrade] = []
        self.closed_trades: List[BacktestTrade] = []
        
        self.last_retrain = None
        self.retrains_count = 0
        
        logger.info("✅ Adaptive Backtester initialisé")
        logger.info(f"   • Capital initial: {initial_capital} USDT")
        logger.info(f"   • Max positions: {max_positions}")
        logger.info(f"   • Taille position: {position_size_usdt} USDT")
        logger.info(f"   • Retrain interval: {retrain_interval_days} jours")
    
    def should_retrain(self, current_time: datetime) -> bool:
        """Vérifier si retraining nécessaire"""
        if self.last_retrain is None:
            return False  # Pas de retrain au démarrage
        
        return (current_time - self.last_retrain) >= self.retrain_interval
    
    def retrain_models(self, current_time: datetime):
        """
        Simuler le retraining des modèles
        
        Important: Utilise UNIQUEMENT les trades fermés AVANT current_time
        pour éviter le look-ahead bias
        """
        logger.info(f"\n🔄 RETRAINING @ {current_time.strftime('%Y-%m-%d %H:%M')}")
        
        # Filtrer trades avant current_time
        past_trades = [t for t in self.closed_trades if t.entry_time < current_time]
        
        if len(past_trades) < 20:
            logger.warning(f"⚠️ Pas assez de trades pour retrain: {len(past_trades)}")
            return
        
        # Analyser performances par pattern
        pattern_stats = defaultdict(lambda: {'wins': 0, 'losses': 0, 'total': 0})
        
        for trade in past_trades[-100:]:  # 100 derniers trades
            pattern = trade.pattern
            pattern_stats[pattern]['total'] += 1
            if trade.pnl > 0:
                pattern_stats[pattern]['wins'] += 1
            else:
                pattern_stats[pattern]['losses'] += 1
        
        # Calculer win rates
        for pattern, stats in pattern_stats.items():
            if stats['total'] > 0:
                wr = stats['wins'] / stats['total']
                logger.info(f"   • {pattern}: {wr*100:.1f}% ({stats['total']} trades)")
        
        # Note: Ici on pourrait réellement réentraîner les modèles GPU
        # Pour l'instant, on simule juste l'analyse des performances
        
        self.last_retrain = current_time
        self.retrains_count += 1
        
        logger.info(f"✅ Retrain #{self.retrains_count} terminé")
    
    def get_signals(self, current_time: datetime, symbols: List[str]) -> List[PatternItem]:
        """
        Obtenir les signaux de trading pour un timestamp donné
        
        Note: Dans un vrai backtest, on chargerait les données historiques
        Pour simplifier, on retourne des signaux simulés
        """
        # TODO: Intégrer avec AIPredictor pour signaux réels
        # Pour l'instant, retourne signaux vides
        return []
    
    def open_trade(self, signal: PatternItem, current_time: datetime, current_price: float):
        """Ouvrir un nouveau trade"""
        if len(self.open_positions) >= self.max_positions:
            return
        
        if self.current_capital < self.position_size:
            return
        
        quantity = self.position_size / current_price
        
        trade = BacktestTrade(
            symbol=signal.symbol,
            entry_time=current_time,
            entry_price=current_price,
            quantity=quantity,
            pattern=signal.pattern,
            score=signal.score,
            stop_loss=current_price * 0.98,  # -2% SL
            take_profit=current_price * 1.015  # +1.5% TP
        )
        
        self.open_positions.append(trade)
        self.current_capital -= self.position_size
        
        logger.debug(f"   ✅ {signal.symbol} ouvert @ {current_price:.6f}")
    
    def check_exits(self, current_time: datetime, prices: Dict[str, float]):
        """Vérifier les conditions de sortie pour positions ouvertes"""
        for trade in list(self.open_positions):
            current_price = prices.get(trade.symbol)
            if current_price is None:
                continue
            
            # Check TP
            if current_price >= trade.take_profit:
                trade.close(current_price, current_time, "TP")
                self.open_positions.remove(trade)
                self.closed_trades.append(trade)
                self.current_capital += trade.quantity * current_price
                logger.debug(f"   🟢 {trade.symbol} TP @ {current_price:.6f} (+{trade.pnl_pct:.2f}%)")
                continue
            
            # Check SL
            if current_price <= trade.stop_loss:
                trade.close(current_price, current_time, "SL")
                self.open_positions.remove(trade)
                self.closed_trades.append(trade)
                self.current_capital += trade.quantity * current_price
                logger.debug(f"   🔴 {trade.symbol} SL @ {current_price:.6f} ({trade.pnl_pct:.2f}%)")
                continue
            
            # Check timeout
            if (current_time - trade.entry_time) >= self.max_hold_time:
                trade.close(current_price, current_time, "TIMEOUT")
                self.open_positions.remove(trade)
                self.closed_trades.append(trade)
                self.current_capital += trade.quantity * current_price
                logger.debug(f"   ⏰ {trade.symbol} TIMEOUT @ {current_price:.6f} ({trade.pnl_pct:.2f}%)")
    
    def run(self,
            start_date: datetime,
            end_date: datetime,
            symbols: List[str]) -> BacktestMetrics:
        """
        Exécuter le backtest adaptatif
        
        Args:
            start_date: Date de début
            end_date: Date de fin
            symbols: Liste des symboles à trader
        
        Returns:
            Métriques de performance
        """
        logger.info("\n" + "="*60)
        logger.info("🚀 ADAPTIVE BACKTEST DÉMARRÉ")
        logger.info("="*60)
        logger.info(f"Période: {start_date.strftime('%Y-%m-%d')} → {end_date.strftime('%Y-%m-%d')}")
        logger.info(f"Symboles: {len(symbols)}")
        logger.info(f"Capital initial: {self.initial_capital} USDT")
        
        self.current_capital = self.initial_capital
        self.open_positions = []
        self.closed_trades = []
        self.last_retrain = start_date
        self.retrains_count = 0
        
        # Simuler trading jour par jour
        current_date = start_date
        days_count = 0
        
        while current_date <= end_date:
            days_count += 1
            
            # Check si retraining nécessaire
            if self.should_retrain(current_date):
                self.retrain_models(current_date)
            
            # TODO: Charger données réelles pour ce jour
            # Pour l'instant, on simule
            
            # Progress log tous les 7 jours
            if days_count % 7 == 0:
                logger.info(f"📅 {current_date.strftime('%Y-%m-%d')} | Trades: {len(self.closed_trades)} | Capital: {self.current_capital:.0f} USDT")
            
            current_date += timedelta(days=1)
        
        # Fermer positions restantes
        for trade in self.open_positions:
            trade.close(trade.entry_price, end_date, "END")
            self.closed_trades.append(trade)
        
        # Calculer métriques
        metrics = BacktestMetrics()
        metrics.calculate(self.closed_trades, self.retrains_count)
        
        logger.info("\n" + "="*60)
        logger.info("✅ BACKTEST TERMINÉ")
        logger.info("="*60)
        self._print_metrics(metrics)
        
        return metrics
    
    def _print_metrics(self, metrics: BacktestMetrics):
        """Afficher les métriques"""
        logger.info(f"\n📊 RÉSULTATS:")
        logger.info(f"   Trades: {metrics.total_trades} (W:{metrics.winning_trades} L:{metrics.losing_trades})")
        logger.info(f"   Win Rate: {metrics.win_rate:.2f}%")
        logger.info(f"   Profit Net: {metrics.net_profit:.2f} USDT")
        logger.info(f"   Profit Factor: {metrics.profit_factor:.2f}")
        logger.info(f"   Max Drawdown: {metrics.max_drawdown:.2f} USDT ({metrics.max_drawdown_pct:.2f}%)")
        logger.info(f"   Sharpe Ratio: {metrics.sharpe_ratio:.2f}")
        logger.info(f"   Sortino Ratio: {metrics.sortino_ratio:.2f}")
        logger.info(f"   Avg Hold Time: {metrics.avg_hold_time_hours:.1f}h")
        logger.info(f"   Retrains: {metrics.retrains_count}")
    
    def save_results(self, metrics: BacktestMetrics, filename: str = "backtest_results.json"):
        """Sauvegarder les résultats"""
        results = {
            'config': {
                'initial_capital': self.initial_capital,
                'max_positions': self.max_positions,
                'position_size': self.position_size,
                'retrain_interval_days': self.retrain_interval.days,
                'max_hold_hours': self.max_hold_time.total_seconds() / 3600
            },
            'metrics': metrics.to_dict(),
            'trades': [t.to_dict() for t in self.closed_trades[-100:]],  # 100 derniers
            'timestamp': datetime.now().isoformat()
        }
        
        with open(filename, 'w') as f:
            json.dump(results, f, indent=2)
        
        logger.info(f"💾 Résultats sauvegardés: {filename}")


def compare_adaptive_vs_classic(start_date: datetime,
                                  end_date: datetime,
                                  symbols: List[str]) -> Dict:
    """
    Comparer backtest adaptatif vs classique
    
    Returns:
        Dict avec comparaison des deux approches
    """
    logger.info("\n" + "="*60)
    logger.info("🔬 COMPARAISON: ADAPTIVE vs CLASSIC")
    logger.info("="*60)
    
    # Backtest adaptatif (réaliste)
    logger.info("\n1️⃣ ADAPTIVE BACKTEST (avec retraining)")
    adaptive_backtester = AdaptiveBacktester(retrain_interval_days=7)
    adaptive_metrics = adaptive_backtester.run(start_date, end_date, symbols)
    
    # Backtest classique (optimiste - pas de retraining)
    logger.info("\n2️⃣ CLASSIC BACKTEST (sans retraining)")
    classic_backtester = AdaptiveBacktester(retrain_interval_days=999)  # Pas de retrain
    classic_metrics = classic_backtester.run(start_date, end_date, symbols)
    
    # Comparaison
    logger.info("\n" + "="*60)
    logger.info("📊 COMPARAISON RÉSULTATS")
    logger.info("="*60)
    
    comparison = {
        'adaptive': adaptive_metrics.to_dict(),
        'classic': classic_metrics.to_dict(),
        'difference': {
            'win_rate_delta': adaptive_metrics.win_rate - classic_metrics.win_rate,
            'profit_delta': adaptive_metrics.net_profit - classic_metrics.net_profit,
            'sharpe_delta': adaptive_metrics.sharpe_ratio - classic_metrics.sharpe_ratio
        }
    }
    
    logger.info(f"\n   Win Rate: Adaptive {adaptive_metrics.win_rate:.1f}% vs Classic {classic_metrics.win_rate:.1f}% ({comparison['difference']['win_rate_delta']:+.1f}%)")
    logger.info(f"   Profit: Adaptive {adaptive_metrics.net_profit:.0f} vs Classic {classic_metrics.net_profit:.0f} ({comparison['difference']['profit_delta']:+.0f} USDT)")
    logger.info(f"   Sharpe: Adaptive {adaptive_metrics.sharpe_ratio:.2f} vs Classic {classic_metrics.sharpe_ratio:.2f} ({comparison['difference']['sharpe_delta']:+.2f})")
    
    return comparison


if __name__ == "__main__":
    """Test du backtester adaptatif"""
    print("="*60)
    print("🧪 TEST ADAPTIVE BACKTESTER")
    print("="*60)
    
    # Test sur 30 jours
    start = datetime.now() - timedelta(days=30)
    end = datetime.now()
    symbols = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "SOLUSDT"]
    
    backtester = AdaptiveBacktester(
        initial_capital=10000,
        retrain_interval_days=7
    )
    
    metrics = backtester.run(start, end, symbols)
    backtester.save_results(metrics)
