"""
Système de Backtesting et Walk-Forward Testing
Inspiré de: Freqtrade et Intelligent Trading Bot
"""
import numpy as np
import logging
from typing import List, Dict, Optional, Tuple
from datetime import datetime
from dataclasses import dataclass
import json


@dataclass
class Trade:
    """Représente un trade"""
    entry_time: datetime
    entry_price: float
    exit_time: Optional[datetime] = None
    exit_price: Optional[float] = None
    quantity: float = 0.0
    side: str = "BUY"  # BUY ou SELL
    pnl: float = 0.0
    pnl_percent: float = 0.0
    exit_reason: str = ""


class BacktestEngine:
    """
    Moteur de backtesting pour valider les stratégies
    """

    def __init__(self, initial_balance: float = 10000.0,
                 commission_rate: float = 0.001,  # 0.1% par trade
                 max_positions: int = 5):
        self.initial_balance = initial_balance
        self.balance = initial_balance
        self.commission_rate = commission_rate
        self.max_positions = max_positions

        self.open_trades: List[Trade] = []
        self.closed_trades: List[Trade] = []

        self.equity_curve = []
        self.drawdown_curve = []
        self.peak_balance = initial_balance

    def execute_buy(self, price: float, timestamp: datetime, quantity: Optional[float] = None):
        """Exécute un ordre d'achat"""
        if len(self.open_trades) >= self.max_positions:
            logging.debug(f"⚠️ Max positions atteint ({self.max_positions})")
            return None

        # Calculer la taille de position
        if quantity is None:
            position_size = self.balance / self.max_positions
        else:
            position_size = quantity * price

        if position_size > self.balance:
            logging.debug(f"⚠️ Solde insuffisant")
            return None

        # Commissions
        commission = position_size * self.commission_rate
        actual_quantity = (position_size - commission) / price

        # Mettre à jour le solde
        self.balance -= (position_size + commission)

        # Créer le trade
        trade = Trade(
            entry_time=timestamp,
            entry_price=price,
            quantity=actual_quantity,
            side="BUY"
        )

        self.open_trades.append(trade)
        logging.debug(f"✅ BUY: {actual_quantity:.6f} @ {price:.6f}, Solde: {self.balance:.2f}")

        return trade

    def execute_sell(self, trade: Trade, price: float, timestamp: datetime, reason: str = ""):
        """Exécute un ordre de vente"""
        if trade not in self.open_trades:
            return None

        # Calculer le PnL
        gross_proceeds = trade.quantity * price
        commission = gross_proceeds * self.commission_rate
        net_proceeds = gross_proceeds - commission

        # Mettre à jour le solde
        self.balance += net_proceeds

        # Calculer le profit
        initial_value = trade.quantity * trade.entry_price
        trade.pnl = net_proceeds - initial_value
        trade.pnl_percent = (trade.pnl / initial_value) * 100

        # Fermer le trade
        trade.exit_time = timestamp
        trade.exit_price = price
        trade.exit_reason = reason

        self.open_trades.remove(trade)
        self.closed_trades.append(trade)

        logging.debug(f"✅ SELL: {trade.quantity:.6f} @ {price:.6f}, "
                     f"PnL: {trade.pnl:+.2f} ({trade.pnl_percent:+.2f}%)")

        return trade

    def update_equity(self, current_prices: Dict[str, float], timestamp: datetime):
        """Met à jour la courbe d'équité"""
        # Valeur des positions ouvertes
        open_positions_value = sum(
            trade.quantity * current_prices.get(trade.side, trade.entry_price)
            for trade in self.open_trades
        )

        total_equity = self.balance + open_positions_value
        self.equity_curve.append({
            'timestamp': timestamp,
            'equity': total_equity
        })

        # Calculer le drawdown
        if total_equity > self.peak_balance:
            self.peak_balance = total_equity

        drawdown = ((self.peak_balance - total_equity) / self.peak_balance) * 100
        self.drawdown_curve.append({
            'timestamp': timestamp,
            'drawdown': drawdown
        })

    def check_stop_loss_take_profit(self, current_price: float, timestamp: datetime,
                                    stop_loss_pct: float = 2.0, take_profit_pct: float = 4.0):
        """Vérifie les stop-loss et take-profit"""
        for trade in self.open_trades[:]:  # Copie pour éviter modification pendant itération
            if trade.side == "BUY":
                # Stop Loss
                loss_pct = ((trade.entry_price - current_price) / trade.entry_price) * 100
                if loss_pct >= stop_loss_pct:
                    self.execute_sell(trade, current_price, timestamp, "Stop Loss")

                # Take Profit
                profit_pct = ((current_price - trade.entry_price) / trade.entry_price) * 100
                if profit_pct >= take_profit_pct:
                    self.execute_sell(trade, current_price, timestamp, "Take Profit")

    def get_statistics(self) -> Dict:
        """Calcule les statistiques de performance"""
        if not self.closed_trades:
            return {
                'total_trades': 0,
                'win_rate': 0,
                'total_pnl': 0,
                'avg_pnl': 0
            }

        total_trades = len(self.closed_trades)
        winning_trades = [t for t in self.closed_trades if t.pnl > 0]
        losing_trades = [t for t in self.closed_trades if t.pnl < 0]

        win_rate = (len(winning_trades) / total_trades) * 100 if total_trades > 0 else 0

        total_pnl = sum(t.pnl for t in self.closed_trades)
        avg_pnl = total_pnl / total_trades if total_trades > 0 else 0

        avg_win = np.mean([t.pnl for t in winning_trades]) if winning_trades else 0
        avg_loss = np.mean([t.pnl for t in losing_trades]) if losing_trades else 0

        profit_factor = abs(sum(t.pnl for t in winning_trades) / sum(t.pnl for t in losing_trades)) if losing_trades else 0

        # Drawdown maximum
        max_drawdown = max(d['drawdown'] for d in self.drawdown_curve) if self.drawdown_curve else 0

        # ROI
        final_equity = self.balance + sum(t.quantity * t.entry_price for t in self.open_trades)
        roi = ((final_equity - self.initial_balance) / self.initial_balance) * 100

        return {
            'total_trades': total_trades,
            'winning_trades': len(winning_trades),
            'losing_trades': len(losing_trades),
            'win_rate': win_rate,
            'total_pnl': total_pnl,
            'avg_pnl': avg_pnl,
            'avg_win': avg_win,
            'avg_loss': avg_loss,
            'profit_factor': profit_factor,
            'max_drawdown': max_drawdown,
            'roi': roi,
            'final_balance': final_equity
        }

    def print_summary(self):
        """Affiche un résumé des résultats"""
        stats = self.get_statistics()

        print("\n" + "="*60)
        print("📊 RÉSULTATS DU BACKTEST")
        print("="*60)
        print(f"Solde initial:     {self.initial_balance:>12,.2f} USDT")
        print(f"Solde final:       {stats['final_balance']:>12,.2f} USDT")
        print(f"ROI:               {stats['roi']:>12,.2f}%")
        print(f"PnL Total:         {stats['total_pnl']:>12,.2f} USDT")
        print("-"*60)
        print(f"Trades totaux:     {stats['total_trades']:>12}")
        print(f"Trades gagnants:   {stats['winning_trades']:>12}")
        print(f"Trades perdants:   {stats['losing_trades']:>12}")
        print(f"Win Rate:          {stats['win_rate']:>12,.1f}%")
        print("-"*60)
        print(f"Gain moyen:        {stats['avg_win']:>12,.2f} USDT")
        print(f"Perte moyenne:     {stats['avg_loss']:>12,.2f} USDT")
        print(f"Profit Factor:     {stats['profit_factor']:>12,.2f}")
        print(f"Max Drawdown:      {stats['max_drawdown']:>12,.2f}%")
        print("="*60 + "\n")

    def export_results(self, filepath: str):
        """Exporte les résultats en JSON"""
        data = {
            'statistics': self.get_statistics(),
            'trades': [
                {
                    'entry_time': t.entry_time.isoformat() if t.entry_time else None,
                    'entry_price': t.entry_price,
                    'exit_time': t.exit_time.isoformat() if t.exit_time else None,
                    'exit_price': t.exit_price,
                    'quantity': t.quantity,
                    'pnl': t.pnl,
                    'pnl_percent': t.pnl_percent,
                    'exit_reason': t.exit_reason
                }
                for t in self.closed_trades
            ],
            'equity_curve': [
                {
                    'timestamp': e['timestamp'].isoformat(),
                    'equity': e['equity']
                }
                for e in self.equity_curve
            ]
        }

        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)

        logging.info(f"💾 Résultats exportés: {filepath}")


class WalkForwardTester:
    """
    Walk-Forward Testing pour validation réaliste
    """

    def __init__(self, train_window: int = 100, test_window: int = 20):
        self.train_window = train_window
        self.test_window = test_window
        self.results = []

    def split_data(self, data: List, n_splits: int = 5) -> List[Tuple]:
        """
        Découpe les données en fenêtres train/test successives
        """
        total_length = len(data)
        splits = []

        for i in range(n_splits):
            train_start = i * self.test_window
            train_end = train_start + self.train_window

            test_start = train_end
            test_end = test_start + self.test_window

            if test_end > total_length:
                break

            train_data = data[train_start:train_end]
            test_data = data[test_start:test_end]

            splits.append((train_data, test_data))

        return splits

    def run_walk_forward(self, data: List, strategy_func, **kwargs):
        """
        Exécute le walk-forward testing
        """
        splits = self.split_data(data)

        for i, (train_data, test_data) in enumerate(splits):
            logging.info(f"🔄 Walk-Forward {i+1}/{len(splits)}")

            # Entraîner sur train_data
            # (à implémenter selon la stratégie)

            # Tester sur test_data
            backtest = BacktestEngine(**kwargs)

            # Exécuter la stratégie sur test_data
            # (à implémenter)

            stats = backtest.get_statistics()
            self.results.append({
                'split': i,
                'stats': stats
            })

        # Résumé global
        self.print_walk_forward_summary()

    def print_walk_forward_summary(self):
        """Affiche le résumé du walk-forward testing"""
        if not self.results:
            return

        print("\n" + "="*60)
        print("🔄 RÉSULTATS WALK-FORWARD TESTING")
        print("="*60)

        for result in self.results:
            stats = result['stats']
            print(f"\nSplit {result['split']+1}:")
            print(f"  ROI: {stats['roi']:+.2f}% | Win Rate: {stats['win_rate']:.1f}% | "
                  f"Trades: {stats['total_trades']}")

        # Moyenne
        avg_roi = np.mean([r['stats']['roi'] for r in self.results])
        avg_win_rate = np.mean([r['stats']['win_rate'] for r in self.results])

        print("\n" + "-"*60)
        print(f"Moyenne ROI:       {avg_roi:+.2f}%")
        print(f"Moyenne Win Rate:  {avg_win_rate:.1f}%")
        print("="*60 + "\n")
