"""
SPY Optimizer — Feature Engineering
Calcule les features techniques à partir des klines 1m pour chaque moment de trade.
Utilisé à la fois pour l'entraînement (trades historiques) et l'inférence (signal en temps réel).
"""
import numpy as np
import pandas as pd
from typing import Optional


def ema(series: pd.Series, period: int) -> pd.Series:
    return series.ewm(span=period, adjust=False).mean()


def rsi(series: pd.Series, period: int = 14) -> pd.Series:
    delta = series.diff()
    gain = delta.clip(lower=0)
    loss = (-delta.clip(upper=0))
    avg_gain = gain.ewm(span=period, adjust=False).mean()
    avg_loss = loss.ewm(span=period, adjust=False).mean()
    rs = avg_gain / avg_loss.replace(0, np.nan)
    return 100 - (100 / (1 + rs))


def bollinger_bands(series: pd.Series, period: int = 20, std: float = 2.0):
    mid = series.rolling(period).mean()
    s = series.rolling(period).std()
    return mid, mid + std * s, mid - std * s


def compute_features_at_timestamp(
    df: pd.DataFrame,
    timestamp_ms: int,
    lookback_minutes: int = 120,
) -> Optional[dict]:
    """
    Calcule ~50 features techniques à partir des klines 1m
    pour un moment précis (timestamp d'entrée du trade).

    Args:
        df: DataFrame klines avec colonnes standard (open_time, open, high, low, close, volume, ...)
        timestamp_ms: timestamp en millisecondes du moment d'analyse
        lookback_minutes: nombre de minutes de données avant le timestamp

    Returns:
        dict de features ou None si données insuffisantes
    """
    # Fenêtre de données: lookback minutes avant le timestamp
    start_ms = timestamp_ms - (lookback_minutes * 60 * 1000)
    mask = (df["open_time"] >= start_ms) & (df["open_time"] <= timestamp_ms)
    window = df.loc[mask].copy()

    if len(window) < 30:  # minimum 30 minutes de données
        return None

    close = window["close"]
    high = window["high"]
    low = window["low"]
    volume = window["volume"]
    quote_vol = window["quote_volume"]
    num_trades = window["num_trades"]
    taker_buy_qvol = window["taker_buy_quote_vol"]

    features = {}

    # ─── Prix & Returns ───
    last_price = close.iloc[-1]
    features["price"] = last_price
    features["return_1m"] = _pct(close, 1)
    features["return_3m"] = _pct(close, 3)
    features["return_5m"] = _pct(close, 5)
    features["return_10m"] = _pct(close, 10)
    features["return_15m"] = _pct(close, 15)
    features["return_30m"] = _pct(close, 30)
    features["return_60m"] = _pct(close, 60)

    # Volatilité des returns
    returns = close.pct_change().dropna()
    features["volatility_5m"] = returns.tail(5).std() * 100 if len(returns) >= 5 else 0
    features["volatility_15m"] = returns.tail(15).std() * 100 if len(returns) >= 15 else 0
    features["volatility_30m"] = returns.tail(30).std() * 100 if len(returns) >= 30 else 0

    # ─── EMAs ───
    ema7 = ema(close, 7)
    ema21 = ema(close, 21)
    ema50 = ema(close, 50)

    features["ema7"] = ema7.iloc[-1]
    features["ema21"] = ema21.iloc[-1]
    features["price_vs_ema7"] = (last_price / ema7.iloc[-1] - 1) * 100 if ema7.iloc[-1] > 0 else 0
    features["price_vs_ema21"] = (last_price / ema21.iloc[-1] - 1) * 100 if ema21.iloc[-1] > 0 else 0
    features["ema7_vs_ema21"] = (ema7.iloc[-1] / ema21.iloc[-1] - 1) * 100 if ema21.iloc[-1] > 0 else 0

    # EMA slopes (dérivées) — momentum des EMAs
    features["ema7_slope_3m"] = _slope(ema7, 3)
    features["ema7_slope_5m"] = _slope(ema7, 5)
    features["ema21_slope_5m"] = _slope(ema21, 5)

    # Distance EMA7/EMA21 — convergence = crossover imminent
    ema_gap = ((ema7 - ema21) / ema21 * 100).dropna()
    features["ema_gap_current"] = ema_gap.iloc[-1] if len(ema_gap) > 0 else 0
    features["ema_gap_3m_ago"] = ema_gap.iloc[-4] if len(ema_gap) >= 4 else features["ema_gap_current"]
    features["ema_converging"] = 1 if abs(features["ema_gap_current"]) < abs(features["ema_gap_3m_ago"]) else 0

    # ─── RSI ───
    rsi_vals = rsi(close, 14)
    features["rsi_14"] = rsi_vals.iloc[-1] if not np.isnan(rsi_vals.iloc[-1]) else 50
    features["rsi_3m_ago"] = rsi_vals.iloc[-4] if len(rsi_vals) >= 4 else features["rsi_14"]
    features["rsi_delta_3m"] = features["rsi_14"] - features["rsi_3m_ago"]
    features["rsi_oversold"] = 1 if features["rsi_14"] < 30 else 0
    features["rsi_overbought"] = 1 if features["rsi_14"] > 70 else 0

    # ─── Bollinger Bands ───
    bb_mid, bb_upper, bb_lower = bollinger_bands(close, 20)
    bb_width = ((bb_upper - bb_lower) / bb_mid * 100).dropna()
    features["bb_width"] = bb_width.iloc[-1] if len(bb_width) > 0 else 0
    features["bb_position"] = ((last_price - bb_lower.iloc[-1]) / (bb_upper.iloc[-1] - bb_lower.iloc[-1])
                               if (bb_upper.iloc[-1] - bb_lower.iloc[-1]) > 0 else 0.5)
    features["bb_squeeze"] = 1 if features["bb_width"] < bb_width.tail(30).quantile(0.2) else 0

    # ─── Volume ───
    avg_vol_30m = quote_vol.tail(30).mean()
    avg_vol_5m = quote_vol.tail(5).mean()
    features["volume_ratio_1m"] = quote_vol.iloc[-1] / avg_vol_30m if avg_vol_30m > 0 else 1
    features["volume_ratio_3m"] = avg_vol_5m / avg_vol_30m if avg_vol_30m > 0 else 1
    features["volume_spike"] = 1 if features["volume_ratio_1m"] > 3.0 else 0
    features["volume_trend_5m"] = _slope(quote_vol, 5)

    # Buy pressure (taker buy / total volume)
    recent_qvol = quote_vol.tail(5).sum()
    recent_buy_qvol = taker_buy_qvol.tail(5).sum()
    features["buy_pressure_5m"] = recent_buy_qvol / recent_qvol * 100 if recent_qvol > 0 else 50
    features["buy_pressure_1m"] = (taker_buy_qvol.iloc[-1] / quote_vol.iloc[-1] * 100
                                   if quote_vol.iloc[-1] > 0 else 50)

    # ─── Candle Patterns ───
    body = close - window["open"]
    candle_range = high - low
    features["green_candles_5m"] = (body.tail(5) > 0).sum() / 5
    features["green_candles_3m"] = (body.tail(3) > 0).sum() / 3
    features["avg_body_ratio_5m"] = (abs(body.tail(5)) / candle_range.tail(5).replace(0, np.nan)).mean()

    # Wick analysis — gros wick haut = rejet
    upper_wick = high - pd.concat([close, window["open"]], axis=1).max(axis=1)
    features["upper_wick_ratio"] = (upper_wick.iloc[-1] / candle_range.iloc[-1]
                                    if candle_range.iloc[-1] > 0 else 0)

    # ─── Momentum ───
    features["momentum_3m"] = _pct(close, 3)
    features["momentum_5m"] = _pct(close, 5)
    features["momentum_10m"] = _pct(close, 10)

    # Monotonic check — prix en hausse continue?
    recent_5 = close.tail(5).values
    features["monotonic_up_5m"] = 1 if all(recent_5[i] <= recent_5[i+1] for i in range(len(recent_5)-1)) else 0

    # ─── Price Position ───
    high_30m = high.tail(30).max()
    low_30m = low.tail(30).min()
    features["price_position_30m"] = ((last_price - low_30m) / (high_30m - low_30m)
                                      if (high_30m - low_30m) > 0 else 0.5)
    high_60m = high.tail(60).max() if len(high) >= 60 else high.max()
    low_60m = low.tail(60).min() if len(low) >= 60 else low.min()
    features["price_position_60m"] = ((last_price - low_60m) / (high_60m - low_60m)
                                      if (high_60m - low_60m) > 0 else 0.5)

    # Breakout — prix vs high récent
    prev_high_15m = high.iloc[-16:-1].max() if len(high) >= 16 else high.max()
    features["breakout_pct"] = (last_price / prev_high_15m - 1) * 100 if prev_high_15m > 0 else 0

    # ─── Trade Activity ───
    features["avg_trades_per_min_5m"] = num_trades.tail(5).mean()
    features["trade_intensity_ratio"] = (num_trades.iloc[-1] / num_trades.tail(30).mean()
                                         if num_trades.tail(30).mean() > 0 else 1)

    # ─── Temporal ───
    # Heure UTC — les patterns de trading changent selon l'heure
    hour = pd.Timestamp(timestamp_ms, unit="ms", tz="UTC").hour
    features["hour_utc"] = hour
    features["is_asia_session"] = 1 if 0 <= hour < 8 else 0
    features["is_europe_session"] = 1 if 7 <= hour < 16 else 0
    features["is_us_session"] = 1 if 13 <= hour < 22 else 0

    return features


def _pct(series: pd.Series, lookback: int) -> float:
    """Pourcentage de variation sur les N dernières bougies."""
    if len(series) <= lookback:
        return 0.0
    old = series.iloc[-(lookback + 1)]
    new = series.iloc[-1]
    return (new / old - 1) * 100 if old > 0 else 0.0


def _slope(series: pd.Series, lookback: int) -> float:
    """Slope normalisée sur les N dernières valeurs."""
    if len(series) <= lookback:
        return 0.0
    vals = series.tail(lookback + 1).values
    mid = vals[len(vals) // 2]
    if mid == 0:
        return 0.0
    return (vals[-1] - vals[0]) / mid * 100


def build_dataset_from_trades(
    trades: list[dict],
    klines_dir: str,
    lookback_minutes: int = 120,
) -> pd.DataFrame:
    """
    Construit le dataset d'entraînement en recalculant les features
    à partir des klines 1m pour chaque trade historique.

    Returns:
        DataFrame avec features + colonnes target (profitable, pnl_pct)
    """
    from pathlib import Path
    from datetime import datetime

    rows = []
    skipped = 0
    loaded_klines = {}  # cache par symbole

    for i, trade in enumerate(trades):
        symbol = trade.get("symbol", "")
        entry_time = trade.get("entry_time", "")
        if not symbol or not entry_time:
            skipped += 1
            continue

        # Charger les klines (avec cache)
        if symbol not in loaded_klines:
            parquet_path = Path(klines_dir) / f"{symbol}.parquet"
            if not parquet_path.exists():
                # Essayer l'équivalent USDT↔USDC
                alt_symbol = symbol.replace("USDC", "USDT") if symbol.endswith("USDC") else symbol.replace("USDT", "USDC")
                alt_path = Path(klines_dir) / f"{alt_symbol}.parquet"
                if alt_path.exists():
                    parquet_path = alt_path
                else:
                    skipped += 1
                    continue
            loaded_klines[symbol] = pd.read_parquet(parquet_path)

        df = loaded_klines[symbol]

        # Convertir entry_time en timestamp ms
        try:
            ts = pd.Timestamp(entry_time).timestamp() * 1000
        except Exception:
            skipped += 1
            continue

        ts_ms = int(ts)

        # Calculer les features
        features = compute_features_at_timestamp(df, ts_ms, lookback_minutes)
        if features is None:
            skipped += 1
            continue

        # Ajouter les métadonnées du trade
        features["symbol"] = symbol
        features["surge_type"] = trade.get("surge_type", "UNKNOWN")
        features["surge_strength"] = float(trade.get("surge_strength", 0))
        features["entry_time"] = entry_time

        # Targets
        pnl_pct = float(trade.get("pnl_pct", 0))
        pnl_usdt = float(trade.get("pnl_usdt", 0))
        max_pnl = float(trade.get("max_pnl", 0))
        features["target_profitable"] = 1 if pnl_usdt > 0 else 0
        features["target_pnl_pct"] = pnl_pct
        features["target_max_pnl"] = max_pnl
        features["target_exit_reason"] = trade.get("exit_reason", "")

        rows.append(features)

        if (i + 1) % 100 == 0:
            print(f"  Processed {i + 1}/{len(trades)} trades ({skipped} skipped)")

    print(f"  Done: {len(rows)} trades avec features, {skipped} skipped")
    return pd.DataFrame(rows)


# ─── Feature columns used by the model (excludes metadata & targets) ───
FEATURE_COLUMNS = [
    # Price returns
    "return_1m", "return_3m", "return_5m", "return_10m", "return_15m", "return_30m", "return_60m",
    # Volatility
    "volatility_5m", "volatility_15m", "volatility_30m",
    # EMAs
    "price_vs_ema7", "price_vs_ema21", "ema7_vs_ema21",
    "ema7_slope_3m", "ema7_slope_5m", "ema21_slope_5m",
    "ema_gap_current", "ema_converging",
    # RSI
    "rsi_14", "rsi_delta_3m", "rsi_oversold", "rsi_overbought",
    # Bollinger
    "bb_width", "bb_position", "bb_squeeze",
    # Volume
    "volume_ratio_1m", "volume_ratio_3m", "volume_spike", "volume_trend_5m",
    "buy_pressure_5m", "buy_pressure_1m",
    # Candle patterns
    "green_candles_5m", "green_candles_3m", "avg_body_ratio_5m", "upper_wick_ratio",
    # Momentum
    "momentum_3m", "momentum_5m", "momentum_10m", "monotonic_up_5m",
    # Price position
    "price_position_30m", "price_position_60m", "breakout_pct",
    # Trade activity
    "avg_trades_per_min_5m", "trade_intensity_ratio",
    # Temporal
    "hour_utc", "is_asia_session", "is_europe_session", "is_us_session",
    # Surge info (from spy detection)
    "surge_strength",
]

# Surge type encoded separately (categorical)
SURGE_TYPES = ["FLASH_SURGE", "BREAKOUT_SURGE", "MOMENTUM_SURGE", "TREND_MOMENTUM_SURGE"]
