#!/usr/bin/env python3
"""
SPY Optimizer — Binance 1m Klines Downloader
Télécharge toutes les bougies 1 minute des paires USDC éligibles depuis le 01/01/2026.

Features:
  - Reprise automatique (state file persisté par symbole)
  - Rate limiting respectueux de l'API Binance
  - Filtrage volume 24h pour ne garder que les paires pertinentes
  - Stockage parquet compressé par symbole
  - Progression détaillée en temps réel

Usage:
  python downloader.py                  # Télécharge tout
  python downloader.py --symbols BTCUSDC ETHUSDC  # Paires spécifiques
  python downloader.py --resume         # Reprend uniquement les paires incomplètes
  python downloader.py --check          # Vérifie l'état sans télécharger
"""
import argparse
import json
import os
import sys
import time
from datetime import datetime, timezone
from pathlib import Path

import pandas as pd
import requests

from config import (
    BASE_URL, DATA_DIR, DATA_START, DATA_END, DOWNLOAD_STATE_FILE,
    EXCHANGE_INFO_ENDPOINT, KLINE_COLUMNS, KLINE_INTERVAL, KLINE_LIMIT,
    KLINES_DIR, KLINES_ENDPOINT, MAX_PRICE, MAX_REQUESTS_PER_MINUTE,
    MIN_PRICE, MIN_VOLUME_USDT, TRAINING_QUOTES, REQUEST_DELAY, STABLECOINS,
    TICKER_24H_ENDPOINT,
)

# ─── Couleurs terminal ───
GREEN = "\033[92m"
YELLOW = "\033[93m"
RED = "\033[91m"
CYAN = "\033[96m"
RESET = "\033[0m"
BOLD = "\033[1m"


def log(msg: str, color: str = ""):
    ts = datetime.now().strftime("%H:%M:%S")
    prefix = f"{color}" if color else ""
    suffix = RESET if color else ""
    print(f"[{ts}] {prefix}{msg}{suffix}", flush=True)


# ─── State Management ───

def load_state() -> dict:
    if os.path.exists(DOWNLOAD_STATE_FILE):
        with open(DOWNLOAD_STATE_FILE, "r") as f:
            return json.load(f)
    return {}


def save_state(state: dict):
    os.makedirs(DATA_DIR, exist_ok=True)
    tmp = DOWNLOAD_STATE_FILE + ".tmp"
    with open(tmp, "w") as f:
        json.dump(state, f, indent=2)
    os.replace(tmp, DOWNLOAD_STATE_FILE)


# ─── Binance API helpers ───

_session = requests.Session()
_session.headers.update({"User-Agent": "SpyOptimizer/1.0"})
_last_request_time = 0.0
_request_count = 0
_window_start = 0.0


def _rate_limit():
    """Respecte le rate limit Binance (1200 req/min) avec marge."""
    global _last_request_time, _request_count, _window_start
    now = time.time()

    # Reset window every 60s
    if now - _window_start > 60:
        _request_count = 0
        _window_start = now

    # Hard limit
    if _request_count >= MAX_REQUESTS_PER_MINUTE - 50:  # marge de 50
        sleep_time = 60 - (now - _window_start) + 1
        if sleep_time > 0:
            log(f"⏳ Rate limit approché ({_request_count} req) — pause {sleep_time:.0f}s", YELLOW)
            time.sleep(sleep_time)
            _request_count = 0
            _window_start = time.time()

    # Min delay between requests
    elapsed = now - _last_request_time
    if elapsed < REQUEST_DELAY:
        time.sleep(REQUEST_DELAY - elapsed)

    _last_request_time = time.time()
    _request_count += 1


def api_get(endpoint: str, params: dict, retries: int = 3) -> list | dict:
    """GET avec retry et rate limiting."""
    url = BASE_URL + endpoint
    for attempt in range(retries):
        _rate_limit()
        try:
            r = _session.get(url, params=params, timeout=15)
            if r.status_code == 429:
                retry_after = int(r.headers.get("Retry-After", 30))
                log(f"⚠️ HTTP 429 — pause {retry_after}s", RED)
                time.sleep(retry_after)
                continue
            if r.status_code == 418:
                log("🚫 IP ban (418) — pause 120s", RED)
                time.sleep(120)
                continue
            r.raise_for_status()
            return r.json()
        except requests.exceptions.RequestException as e:
            if attempt < retries - 1:
                wait = 2 ** (attempt + 1)
                log(f"⚠️ Erreur réseau ({e}) — retry dans {wait}s", YELLOW)
                time.sleep(wait)
            else:
                raise
    return []


# ─── Pair Discovery ───

def get_traded_symbols() -> set[str]:
    """Récupère tous les symboles historiquement tradés par le SPY."""
    history_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "espion_history.json")
    if not os.path.exists(history_path):
        log(f"⚠️ {history_path} introuvable — pas d'inclusion historique", YELLOW)
        return set()
    with open(history_path, "r") as f:
        history = json.load(f)
    symbols = set(t.get("symbol", "") for t in history if t.get("symbol"))
    return symbols


def get_eligible_pairs(force_symbols: list[str] | None = None, include_traded: bool = True) -> list[dict]:
    """Récupère les paires USDC + USDT éligibles depuis Binance."""
    if force_symbols:
        log(f"📋 Mode forcé: {len(force_symbols)} symboles spécifiés")
        return [{"symbol": s} for s in force_symbols]

    log("📡 Récupération des paires actives (USDC + USDT)...", CYAN)
    info = api_get(EXCHANGE_INFO_ENDPOINT, {})
    candidate_symbols = []
    for s in info.get("symbols", []):
        if (s.get("quoteAsset") in TRAINING_QUOTES
                and s.get("status") == "TRADING"
                and s["symbol"] not in STABLECOINS):
            candidate_symbols.append(s["symbol"])
    log(f"   {len(candidate_symbols)} paires actives (hors stablecoins)")

    # Filtrer par volume 24h
    log("📊 Vérification volumes 24h...", CYAN)
    tickers = api_get(TICKER_24H_ENDPOINT, {})
    ticker_map = {t["symbol"]: t for t in tickers}

    eligible = []
    seen_bases = {}  # base_asset -> best quote (prefer USDC over USDT if both exist)
    
    # First pass: collect all eligible
    all_eligible = []
    for sym in candidate_symbols:
        t = ticker_map.get(sym)
        if not t:
            continue
        vol_24h = float(t.get("quoteVolume", 0))
        price = float(t.get("lastPrice", 0))
        if vol_24h < MIN_VOLUME_USDT:
            continue
        if price > MAX_PRICE or price < MIN_PRICE:
            continue
        all_eligible.append({
            "symbol": sym,
            "volume_24h": vol_24h,
            "price": price,
        })

    # L'IA a besoin des deux car le SPY a tradé USDT Jan-Avr puis USDC Avr+
    eligible_syms = set(p["symbol"] for p in all_eligible)
    eligible = list(all_eligible)

    # Ajouter les paires historiquement tradées par le SPY (même si volume actuel < seuil)
    extra_traded = 0
    if include_traded:
        traded = get_traded_symbols()
        # Vérifier lesquelles sont encore actives sur Binance
        active_symbols = set(s["symbol"] for s in info.get("symbols", []) if s.get("status") == "TRADING")
        for sym in traded:
            if sym in eligible_syms or sym not in active_symbols:
                continue
            t = ticker_map.get(sym)
            vol = float(t.get("quoteVolume", 0)) if t else 0
            price = float(t.get("lastPrice", 0)) if t else 0
            eligible.append({"symbol": sym, "volume_24h": vol, "price": price, "source": "traded_history"})
            extra_traded += 1
        if extra_traded:
            log(f"   📜 +{extra_traded} paires ajoutées depuis l'historique de trades SPY", YELLOW)

    eligible.sort(key=lambda x: x["volume_24h"], reverse=True)
    usdc_count = sum(1 for p in eligible if p["symbol"].endswith("USDC"))
    usdt_count = sum(1 for p in eligible if p["symbol"].endswith("USDT"))
    log(f"   ✅ {len(eligible)} paires éligibles ({usdc_count} USDC + {usdt_count} USDT)")
    return eligible


# ─── Klines Download ───

def download_symbol(symbol: str, state: dict) -> int:
    """
    Télécharge toutes les klines 1m pour un symbole.
    Retourne le nombre de candles téléchargées.
    """
    sym_state = state.get(symbol, {})
    parquet_path = os.path.join(KLINES_DIR, f"{symbol}.parquet")

    # Déterminer le point de départ
    start_ms = int(DATA_START.timestamp() * 1000)
    if sym_state.get("last_close_time"):
        # Reprendre après la dernière bougie téléchargée
        start_ms = sym_state["last_close_time"] + 1

    end_ms = int((DATA_END or datetime.now(timezone.utc)).timestamp() * 1000)

    if start_ms >= end_ms:
        return 0  # Déjà à jour

    # Charger les données existantes si reprise
    existing_df = None
    if os.path.exists(parquet_path) and sym_state.get("last_close_time"):
        try:
            existing_df = pd.read_parquet(parquet_path)
        except Exception:
            existing_df = None

    all_candles = []
    current_start = start_ms
    total_new = 0

    while current_start < end_ms:
        params = {
            "symbol": symbol,
            "interval": KLINE_INTERVAL,
            "startTime": current_start,
            "endTime": end_ms,
            "limit": KLINE_LIMIT,
        }

        data = api_get(KLINES_ENDPOINT, params)
        if not data:
            break

        # Extraire les 11 champs utiles (ignorer le dernier "ignore")
        for k in data:
            all_candles.append([
                int(k[0]),       # open_time
                float(k[1]),     # open
                float(k[2]),     # high
                float(k[3]),     # low
                float(k[4]),     # close
                float(k[5]),     # volume
                int(k[6]),       # close_time
                float(k[7]),     # quote_volume
                int(k[8]),       # num_trades
                float(k[9]),     # taker_buy_base_vol
                float(k[10]),    # taker_buy_quote_vol
            ])

        total_new += len(data)
        last_close = int(data[-1][6])
        current_start = last_close + 1

        # Sauvegarder le state de progression à chaque batch
        state[symbol] = {
            "last_close_time": last_close,
            "candles_downloaded": sym_state.get("candles_downloaded", 0) + len(data),
            "last_update": datetime.now(timezone.utc).isoformat(),
        }

        if len(data) < KLINE_LIMIT:
            break  # Plus de données disponibles

    if not all_candles:
        return 0

    # Construire le DataFrame
    new_df = pd.DataFrame(all_candles, columns=KLINE_COLUMNS)

    # Fusionner avec données existantes si reprise
    if existing_df is not None and len(existing_df) > 0:
        combined = pd.concat([existing_df, new_df], ignore_index=True)
        combined = combined.drop_duplicates(subset=["open_time"]).sort_values("open_time").reset_index(drop=True)
    else:
        combined = new_df

    # Sauvegarder en parquet
    os.makedirs(KLINES_DIR, exist_ok=True)
    combined.to_parquet(parquet_path, index=False, compression="snappy")

    # Mettre à jour le state final
    state[symbol]["candles_downloaded"] = len(combined)
    state[symbol]["status"] = "complete" if current_start >= end_ms else "partial"

    return total_new


def check_status(state: dict, eligible: list[dict]):
    """Affiche l'état du téléchargement."""
    total_pairs = len(eligible)
    complete = 0
    partial = 0
    not_started = 0
    total_candles = 0

    for p in eligible:
        sym = p["symbol"]
        s = state.get(sym, {})
        candles = s.get("candles_downloaded", 0)
        total_candles += candles
        status = s.get("status", "not_started")
        if status == "complete":
            complete += 1
        elif candles > 0:
            partial += 1
        else:
            not_started += 1

    data_size = 0
    for f in Path(KLINES_DIR).glob("*.parquet"):
        data_size += f.stat().st_size

    print(f"\n{BOLD}═══ SPY Optimizer — Download Status ═══{RESET}")
    print(f"  Paires éligibles : {total_pairs}")
    print(f"  {GREEN}✅ Complètes     : {complete}{RESET}")
    print(f"  {YELLOW}🔄 Partielles    : {partial}{RESET}")
    print(f"  {RED}⏳ Non démarrées : {not_started}{RESET}")
    print(f"  Candles totales  : {total_candles:,}")
    print(f"  Taille disque    : {data_size / 1024 / 1024:.1f} MB")
    print()


# ─── Main ───

def main():
    parser = argparse.ArgumentParser(description="Télécharge les klines 1m Binance")
    parser.add_argument("--symbols", nargs="+", help="Paires spécifiques à télécharger")
    parser.add_argument("--resume", action="store_true", help="Reprend uniquement les paires incomplètes")
    parser.add_argument("--check", action="store_true", help="Affiche l'état sans télécharger")
    args = parser.parse_args()

    os.makedirs(KLINES_DIR, exist_ok=True)
    state = load_state()

    eligible = get_eligible_pairs(args.symbols)

    if args.check:
        check_status(state, eligible)
        return

    if args.resume:
        eligible = [p for p in eligible if state.get(p["symbol"], {}).get("status") != "complete"]
        log(f"🔄 Mode reprise: {len(eligible)} paires à compléter")

    total_pairs = len(eligible)
    if total_pairs == 0:
        log("✅ Tout est à jour!", GREEN)
        return

    log(f"\n{BOLD}═══ SPY Optimizer — Téléchargement klines 1m ═══{RESET}")
    log(f"   Période: {DATA_START.strftime('%Y-%m-%d')} → maintenant")
    log(f"   Paires : {total_pairs}")
    log(f"   Intervalle: {KLINE_INTERVAL}")
    log("")

    grand_total = 0
    start_time = time.time()

    for i, pair in enumerate(eligible):
        symbol = pair["symbol"]
        vol_str = f"${pair.get('volume_24h', 0):,.0f}" if "volume_24h" in pair else "N/A"

        existing = state.get(symbol, {}).get("candles_downloaded", 0)
        status_str = f" (reprise: {existing:,} existantes)" if existing > 0 else ""

        log(f"[{i + 1}/{total_pairs}] {BOLD}{symbol}{RESET} vol24h={vol_str}{status_str}", CYAN)

        try:
            new_candles = download_symbol(symbol, state)
            grand_total += new_candles

            total = state.get(symbol, {}).get("candles_downloaded", 0)
            elapsed = time.time() - start_time
            rate = grand_total / elapsed if elapsed > 0 else 0

            if new_candles > 0:
                log(f"   ✅ +{new_candles:,} candles (total: {total:,}) — {rate:,.0f} candles/s", GREEN)
            else:
                log(f"   ⏭️ Déjà à jour ({total:,} candles)")

            # Sauvegarder le state après chaque symbole
            save_state(state)

        except Exception as e:
            log(f"   ❌ Erreur: {e}", RED)
            state[symbol] = state.get(symbol, {})
            state[symbol]["error"] = str(e)
            state[symbol]["status"] = "error"
            save_state(state)

    elapsed = time.time() - start_time
    log("")
    log(f"═══ Téléchargement terminé ═══", GREEN)
    log(f"   {grand_total:,} nouvelles candles téléchargées")
    log(f"   Durée: {elapsed / 60:.1f} min")

    check_status(state, eligible)


if __name__ == "__main__":
    main()
