"""
🧠 AI SELF-OPTIMIZER - Analyse automatique et optimisation du bot

Fonctionnalités:
1. Analyse qualité des signaux IA sur historique
2. Détection des fins de cycle manquées
3. Analyse des ventes tardives/manquées
4. Calcul des seuils optimaux (EMA, BB, KC, RSI, momentum)
5. Comparaison avec standards financiers
6. Recommandations d'optimisation automatiques

Auteur: Trading Bot AI System
Date: 2026-01-15
"""

import json
import logging
from datetime import datetime, timedelta
from collections import defaultdict
from typing import Dict, List, Tuple, Optional
import numpy as np
from pathlib import Path

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class AISelfOptimizer:
    """Système d'auto-analyse et auto-optimisation du bot"""
    
    def __init__(self, lookback_hours: int = 24):
        self.lookback_hours = lookback_hours
        self.trades_log_path = Path("trade_logs/trades_log.jsonl")
        self.signals_log_path = Path("trade_logs/signals_log.jsonl")
        
        # Standards financiers (paramètres recommandés par experts)
        self.financial_standards = {
            'bb_position': {
                'buy_zone': (0.0, 0.4),  # Zone basse recommandée
                'caution_zone': (0.4, 0.6),  # Zone médiane
                'avoid_zone': (0.6, 1.0),  # Zone haute à éviter
            },
            'rsi': {
                'oversold': 30,  # Survente
                'buy_zone': (30, 45),  # Zone d'achat
                'neutral': (45, 55),
                'sell_zone': (55, 70),
                'overbought': 70,  # Surachat
            },
            'momentum': {
                'strong_bullish': 0.005,  # > 0.5%
                'moderate_bullish': 0.002,  # > 0.2%
                'weak': 0.001,  # > 0.1%
            },
            'win_rate_target': 0.50,  # 50% minimum
            'profit_factor_target': 1.5,  # Ratio gains/pertes
        }
    
    def analyze_performance(self) -> Dict:
        """Analyse complète des performances récentes"""
        logger.info(f"🔍 Analyse des performances sur {self.lookback_hours}h...")
        
        trades = self._load_recent_trades()
        signals = self._load_recent_signals()
        
        results = {
            'timestamp': datetime.now().isoformat(),
            'period_hours': self.lookback_hours,
            'global_metrics': self._calculate_global_metrics(trades),
            'signal_quality': self._analyze_signal_quality(trades, signals),
            'cycle_detection': self._analyze_cycle_detection(trades),
            'exit_timing': self._analyze_exit_timing(trades),
            'optimal_thresholds': self._calculate_optimal_thresholds(trades),
            'recommendations': [],
        }
        
        # Générer recommandations basées sur les analyses
        results['recommendations'] = self._generate_recommendations(results)
        
        return results
    
    def _load_recent_trades(self) -> List[Dict]:
        """Charge les trades récents"""
        if not self.trades_log_path.exists():
            return []
        
        cutoff_time = datetime.now() - timedelta(hours=self.lookback_hours)
        trades = []
        
        with open(self.trades_log_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    trade = json.loads(line.strip())
                    trade_time = datetime.fromisoformat(trade['timestamp'])
                    if trade_time >= cutoff_time:
                        trades.append(trade)
                except:
                    continue
        
        return trades
    
    def _load_recent_signals(self) -> List[Dict]:
        """Charge les signaux récents"""
        if not self.signals_log_path.exists():
            return []
        
        cutoff_time = datetime.now() - timedelta(hours=self.lookback_hours)
        signals = []
        
        with open(self.signals_log_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    signal = json.loads(line.strip())
                    signal_time = datetime.fromisoformat(signal['timestamp'])
                    if signal_time >= cutoff_time:
                        signals.append(signal)
                except:
                    continue
        
        return signals
    
    def _calculate_global_metrics(self, trades: List[Dict]) -> Dict:
        """Calcule les métriques globales"""
        closes = [t for t in trades if t.get('type') == 'TRADE_CLOSE']
        
        if not closes:
            return {
                'total_trades': 0,
                'win_rate': 0,
                'avg_pnl': 0,
                'profit_factor': 0,
                'max_drawdown': 0,
            }
        
        wins = [t for t in closes if t.get('pnl_pct', 0) > 0]
        losses = [t for t in closes if t.get('pnl_pct', 0) <= 0]
        
        total_wins = sum(t.get('pnl_pct', 0) for t in wins)
        total_losses = abs(sum(t.get('pnl_pct', 0) for t in losses))
        
        return {
            'total_trades': len(closes),
            'wins': len(wins),
            'losses': len(losses),
            'win_rate': len(wins) / len(closes) if closes else 0,
            'avg_pnl': np.mean([t.get('pnl_pct', 0) for t in closes]),
            'avg_win': np.mean([t.get('pnl_pct', 0) for t in wins]) if wins else 0,
            'avg_loss': np.mean([t.get('pnl_pct', 0) for t in losses]) if losses else 0,
            'profit_factor': total_wins / total_losses if total_losses > 0 else 0,
            'max_drawdown': min([t.get('pnl_pct', 0) for t in closes]) if closes else 0,
            'best_trade': max([t.get('pnl_pct', 0) for t in closes]) if closes else 0,
        }
    
    def _analyze_signal_quality(self, trades: List[Dict], signals: List[Dict]) -> Dict:
        """Analyse la qualité des signaux IA"""
        buy_signals = [s for s in signals if s.get('action') == 'BUY' or 'IA SIGNAL' in s.get('reason', '')]
        
        # Grouper par raison de signal
        signal_performance = defaultdict(lambda: {'count': 0, 'wins': 0, 'total_pnl': 0})
        
        closes = [t for t in trades if t.get('type') == 'TRADE_CLOSE']
        
        for trade in closes:
            # Trouver le signal correspondant
            symbol = trade.get('symbol')
            matching_signals = [s for s in buy_signals if s.get('symbol') == symbol]
            
            if matching_signals:
                reason = matching_signals[-1].get('reason', 'UNKNOWN')
                # Extraire le pattern (POSSIBLE, HIGH_SCORE_OVERRIDE, etc.)
                if 'POSSIBLE' in reason:
                    pattern = 'POSSIBLE'
                elif 'HIGH_SCORE_OVERRIDE' in reason:
                    pattern = 'HIGH_SCORE_OVERRIDE'
                elif 'CROSSOVER_IMMINENT' in reason:
                    pattern = 'CROSSOVER_IMMINENT'
                elif 'SQUEEZE_WAITING' in reason:
                    pattern = 'SQUEEZE_WAITING'
                elif 'CREUX_WAITING' in reason:
                    pattern = 'CREUX_WAITING'
                else:
                    pattern = 'OTHER'
                
                pnl = trade.get('pnl_pct', 0)
                signal_performance[pattern]['count'] += 1
                signal_performance[pattern]['total_pnl'] += pnl
                if pnl > 0:
                    signal_performance[pattern]['wins'] += 1
        
        # Calculer win rate par pattern
        pattern_stats = {}
        for pattern, stats in signal_performance.items():
            pattern_stats[pattern] = {
                'count': stats['count'],
                'win_rate': stats['wins'] / stats['count'] if stats['count'] > 0 else 0,
                'avg_pnl': stats['total_pnl'] / stats['count'] if stats['count'] > 0 else 0,
            }
        
        return {
            'total_signals': len(buy_signals),
            'pattern_performance': pattern_stats,
            'signal_to_trade_ratio': len(closes) / len(buy_signals) if buy_signals else 0,
        }
    
    def _analyze_cycle_detection(self, trades: List[Dict]) -> Dict:
        """Analyse les fins de cycle manquées"""
        closes = [t for t in trades if t.get('type') == 'TRADE_CLOSE']
        
        # Détecter les trades qui ont atteint un pic puis chuté
        missed_peaks = []
        
        for trade in closes:
            max_profit = trade.get('max_profit_pct', 0)
            final_pnl = trade.get('pnl_pct', 0)
            
            # Si max_profit significatif mais final_pnl bien plus bas = cycle manqué
            if max_profit > 1.0 and (max_profit - final_pnl) > 0.5:
                missed_peaks.append({
                    'symbol': trade.get('symbol'),
                    'max_profit': max_profit,
                    'final_pnl': final_pnl,
                    'loss_from_peak': max_profit - final_pnl,
                    'exit_reason': trade.get('reason'),
                })
        
        return {
            'total_missed_peaks': len(missed_peaks),
            'avg_loss_from_peak': np.mean([m['loss_from_peak'] for m in missed_peaks]) if missed_peaks else 0,
            'details': missed_peaks[:10],  # Top 10
        }
    
    def _analyze_exit_timing(self, trades: List[Dict]) -> Dict:
        """Analyse le timing des sorties"""
        closes = [t for t in trades if t.get('type') == 'TRADE_CLOSE']
        
        exit_stats = defaultdict(lambda: {'count': 0, 'wins': 0, 'total_pnl': 0})
        
        for trade in closes:
            reason = trade.get('reason', 'unknown')
            pnl = trade.get('pnl_pct', 0)
            
            exit_stats[reason]['count'] += 1
            exit_stats[reason]['total_pnl'] += pnl
            if pnl > 0:
                exit_stats[reason]['wins'] += 1
        
        exit_analysis = {}
        for reason, stats in exit_stats.items():
            exit_analysis[reason] = {
                'count': stats['count'],
                'percentage': stats['count'] / len(closes) * 100 if closes else 0,
                'win_rate': stats['wins'] / stats['count'] if stats['count'] > 0 else 0,
                'avg_pnl': stats['total_pnl'] / stats['count'] if stats['count'] > 0 else 0,
            }
        
        return exit_analysis
    
    def _calculate_optimal_thresholds(self, trades: List[Dict]) -> Dict:
        """Calcule les seuils optimaux basés sur l'historique"""
        opens = [t for t in trades if t.get('type') == 'TRADE_OPEN']
        closes = [t for t in trades if t.get('type') == 'TRADE_CLOSE']
        
        # Créer mapping open -> close
        trade_pairs = []
        for close_trade in closes:
            trade_id = close_trade.get('trade_id')
            open_trade = next((t for t in opens if t.get('trade_id') == trade_id), None)
            if open_trade:
                trade_pairs.append({
                    'open': open_trade,
                    'close': close_trade,
                })
        
        if not trade_pairs:
            return {'error': 'Pas assez de données'}
        
        # Analyser bb_position optimal
        bb_analysis = self._analyze_bb_position_optimal(trade_pairs)
        
        # Analyser RSI optimal
        rsi_analysis = self._analyze_rsi_optimal(trade_pairs)
        
        # Analyser momentum optimal
        momentum_analysis = self._analyze_momentum_optimal(trade_pairs)
        
        return {
            'bb_position': bb_analysis,
            'rsi': rsi_analysis,
            'momentum': momentum_analysis,
        }
    
    def _analyze_bb_position_optimal(self, trade_pairs: List[Dict]) -> Dict:
        """Analyse bb_position optimal"""
        wins = []
        losses = []
        
        for pair in trade_pairs:
            bb_pos = pair['open'].get('bb_position')
            pnl = pair['close'].get('pnl_pct', 0)
            
            if bb_pos is not None:
                if pnl > 0:
                    wins.append(bb_pos)
                else:
                    losses.append(bb_pos)
        
        current_std = self.financial_standards['bb_position']
        
        return {
            'current_usage': {
                'avg_wins': np.mean(wins) if wins else 0,
                'avg_losses': np.mean(losses) if losses else 0,
                'median_wins': np.median(wins) if wins else 0,
                'median_losses': np.median(losses) if losses else 0,
            },
            'recommended': {
                'buy_zone': current_std['buy_zone'],
                'max_threshold': np.percentile(wins, 75) if len(wins) > 10 else 0.65,
            },
            'expert_standards': current_std,
        }
    
    def _analyze_rsi_optimal(self, trade_pairs: List[Dict]) -> Dict:
        """Analyse RSI optimal"""
        wins = []
        losses = []
        
        for pair in trade_pairs:
            rsi = pair['open'].get('rsi')
            pnl = pair['close'].get('pnl_pct', 0)
            
            if rsi is not None:
                if pnl > 0:
                    wins.append(rsi)
                else:
                    losses.append(rsi)
        
        current_std = self.financial_standards['rsi']
        
        return {
            'current_usage': {
                'avg_wins': np.mean(wins) if wins else 0,
                'avg_losses': np.mean(losses) if losses else 0,
                'median_wins': np.median(wins) if wins else 0,
                'median_losses': np.median(losses) if losses else 0,
            },
            'recommended': {
                'buy_zone': current_std['buy_zone'],
                'max_threshold': np.percentile(wins, 75) if len(wins) > 10 else 60,
            },
            'expert_standards': current_std,
        }
    
    def _analyze_momentum_optimal(self, trade_pairs: List[Dict]) -> Dict:
        """Analyse momentum optimal"""
        wins = []
        losses = []
        
        for pair in trade_pairs:
            mom = pair['open'].get('momentum_3')
            pnl = pair['close'].get('pnl_pct', 0)
            
            if mom is not None:
                if pnl > 0:
                    wins.append(mom)
                else:
                    losses.append(mom)
        
        current_std = self.financial_standards['momentum']
        
        return {
            'current_usage': {
                'avg_wins': np.mean(wins) if wins else 0,
                'avg_losses': np.mean(losses) if losses else 0,
                'median_wins': np.median(wins) if wins else 0,
                'median_losses': np.median(losses) if losses else 0,
            },
            'recommended': {
                'strong_bullish': current_std['strong_bullish'],
                'minimum_threshold': np.percentile(wins, 25) if len(wins) > 10 else 0.0025,
            },
            'expert_standards': current_std,
        }
    
    def _generate_recommendations(self, results: Dict) -> List[Dict]:
        """Génère recommandations d'optimisation automatiques"""
        recommendations = []
        metrics = results['global_metrics']
        
        # Recommandation 1: Win rate faible
        if metrics['win_rate'] < self.financial_standards['win_rate_target']:
            gap = (self.financial_standards['win_rate_target'] - metrics['win_rate']) * 100
            recommendations.append({
                'priority': 'HIGH',
                'category': 'WIN_RATE',
                'issue': f"Win rate {metrics['win_rate']:.1%} < objectif {self.financial_standards['win_rate_target']:.0%}",
                'impact': f"Gap de {gap:.1f} points",
                'actions': [
                    "Augmenter score minimum IA de 65 à 70",
                    "Renforcer filtres bb_position (0.65 → 0.60)",
                    "Exiger momentum > 0.25% pour tous les achats",
                ],
            })
        
        # Recommandation 2: Trop de quick-exits
        exit_timing = results.get('exit_timing', {})
        quick_exit_rate = exit_timing.get('quick-exit', {}).get('percentage', 0)
        if quick_exit_rate > 50:
            recommendations.append({
                'priority': 'HIGH',
                'category': 'QUICK_EXITS',
                'issue': f"Quick-exits: {quick_exit_rate:.0f}% des trades (objectif < 40%)",
                'impact': "Achats mal timés, sorties prématurées",
                'actions': [
                    "Implémenter score minimum progressif selon bb_position",
                    "Augmenter seuils quick-exit de 8 à 10 points",
                    "Ajouter délai minimum 15min avant quick-exit",
                ],
            })
        
        # Recommandation 3: Fins de cycle manquées
        cycle_data = results.get('cycle_detection', {})
        if cycle_data.get('total_missed_peaks', 0) > 3:
            avg_loss = cycle_data.get('avg_loss_from_peak', 0)
            recommendations.append({
                'priority': 'MEDIUM',
                'category': 'CYCLE_DETECTION',
                'issue': f"{cycle_data['total_missed_peaks']} fins de cycle manquées",
                'impact': f"Perte moyenne de {avg_loss:.2f}% depuis le pic",
                'actions': [
                    "Activer trailing stop dynamique",
                    "Ajouter détection fin de cycle (RSI > 70 + momentum < 0)",
                    "Implémenter take-profit partiel à +2%",
                ],
            })
        
        # Recommandation 4: Profit factor faible
        if metrics['profit_factor'] < self.financial_standards['profit_factor_target']:
            recommendations.append({
                'priority': 'MEDIUM',
                'category': 'PROFIT_FACTOR',
                'issue': f"Profit factor {metrics['profit_factor']:.2f} < objectif {self.financial_standards['profit_factor_target']:.1f}",
                'impact': "Ratio gains/pertes insuffisant",
                'actions': [
                    "Augmenter take-profit de 3% à 3.5%",
                    "Réduire stop-loss de 2% à 1.8%",
                    "Filtrer cryptos avec win rate < 40%",
                ],
            })
        
        # Recommandation 5: Seuils bb_position non optimaux
        bb_data = results.get('optimal_thresholds', {}).get('bb_position', {})
        if bb_data and bb_data != {'error': 'Pas assez de données'}:
            current_usage = bb_data.get('current_usage', {})
            avg_losses_bb = current_usage.get('avg_losses', 0)
            if avg_losses_bb > 0.55:
                recommendations.append({
                    'priority': 'HIGH',
                    'category': 'BB_POSITION',
                    'issue': f"Achats en perte: bb_position moyen {avg_losses_bb:.2f} (zone médiane-haute)",
                    'impact': "Achats trop hauts dans les bandes de Bollinger",
                    'actions': [
                        f"Réduire seuil bb_position max de 0.65 à {bb_data['recommended']['max_threshold']:.2f}",
                        "Appliquer score minimum 75 si bb > 0.55",
                        "Bloquer achats si bb > 0.60 sans momentum > 0.30%",
                    ],
                })
        
        # Trier par priorité
        priority_order = {'HIGH': 1, 'MEDIUM': 2, 'LOW': 3}
        recommendations.sort(key=lambda x: priority_order.get(x['priority'], 3))
        
        return recommendations
    
    def generate_report(self) -> str:
        """Génère un rapport complet en texte"""
        results = self.analyze_performance()
        
        report = []
        report.append("=" * 80)
        report.append("🧠 AI SELF-OPTIMIZER - RAPPORT D'ANALYSE")
        report.append("=" * 80)
        report.append(f"\n📅 Période: {self.lookback_hours}h")
        report.append(f"🕐 Généré: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        
        # Métriques globales
        metrics = results['global_metrics']
        report.append("\n📊 MÉTRIQUES GLOBALES:")
        report.append(f"   Trades fermés: {metrics['total_trades']}")
        report.append(f"   Win Rate: {metrics['win_rate']:.1%} ({metrics['wins']}W / {metrics['losses']}L)")
        report.append(f"   P&L Moyen: {metrics['avg_pnl']:.2f}%")
        report.append(f"   Profit Factor: {metrics['profit_factor']:.2f}")
        report.append(f"   Meilleur trade: {metrics['best_trade']:.2f}%")
        report.append(f"   Pire trade: {metrics['max_drawdown']:.2f}%")
        
        # Qualité des signaux
        signal_quality = results['signal_quality']
        report.append(f"\n🎯 QUALITÉ DES SIGNAUX:")
        report.append(f"   Signaux générés: {signal_quality['total_signals']}")
        report.append(f"   Ratio signal→trade: {signal_quality['signal_to_trade_ratio']:.1%}")
        report.append("\n   Performance par pattern:")
        for pattern, stats in signal_quality['pattern_performance'].items():
            report.append(f"      • {pattern}: {stats['count']} trades, WR={stats['win_rate']:.0%}, Avg={stats['avg_pnl']:.2f}%")
        
        # Fins de cycle
        cycle = results['cycle_detection']
        if cycle['total_missed_peaks'] > 0:
            report.append(f"\n⚠️ FINS DE CYCLE MANQUÉES:")
            report.append(f"   Total: {cycle['total_missed_peaks']}")
            report.append(f"   Perte moyenne depuis pic: {cycle['avg_loss_from_peak']:.2f}%")
        
        # Recommandations
        report.append(f"\n💡 RECOMMANDATIONS ({len(results['recommendations'])}):")
        for i, rec in enumerate(results['recommendations'], 1):
            report.append(f"\n   {i}. [{rec['priority']}] {rec['category']}")
            report.append(f"      Issue: {rec['issue']}")
            report.append(f"      Impact: {rec['impact']}")
            report.append("      Actions:")
            for action in rec['actions']:
                report.append(f"         - {action}")
        
        report.append("\n" + "=" * 80)
        
        return "\n".join(report)


if __name__ == "__main__":
    # Test du module
    optimizer = AISelfOptimizer(lookback_hours=24)
    
    print("\n🧠 Test AI Self-Optimizer...\n")
    
    # Générer rapport
    report = optimizer.generate_report()
    print(report)
    
    # Sauvegarder résultats JSON
    results = optimizer.analyze_performance()
    output_file = "ai_self_optimizer_results.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    print(f"\n✅ Résultats sauvegardés: {output_file}")
