#!/usr/bin/env python3
"""
Construit le dataset séquentiel LSTM à partir des klines 1m et du combined dataset.
Sauvegarde en .npz compact pour transfert vers le PC GPU.

Usage:
    python build_lstm_sequences.py
    # → data/lstm_sequences.npz (~150 MB)
"""
import json
import os
import sys
import time
from pathlib import Path

import numpy as np
import pandas as pd

# Import deep_model's sequence builder
from deep_model import build_sequence_features, SURGE_TYPE_MAP, N_SEQUENCE_FEATURES

PROJECT_DIR = Path(__file__).parent
DATA_DIR = PROJECT_DIR / "data"
KLINES_DIR = DATA_DIR / "klines_1m"
SEQ_LEN = 60  # 60 minutes de contexte


def main():
    print(f"\n{'═'*60}")
    print(f"  🧠 Build LSTM Sequence Dataset")
    print(f"{'═'*60}")

    # Load combined dataset
    combined_path = DATA_DIR / "combined_training_dataset.parquet"
    if not combined_path.exists():
        print("  ❌ combined_training_dataset.parquet not found")
        sys.exit(1)

    dataset = pd.read_parquet(combined_path)
    print(f"  📂 Dataset: {len(dataset)} samples")
    print(f"  📂 Klines dir: {KLINES_DIR} ({len(list(KLINES_DIR.glob('*.parquet')))} files)")

    # Sort by entry_time to maintain temporal order
    dataset = dataset.sort_values("entry_time").reset_index(drop=True)

    # Pre-load klines (with cache)
    klines_cache = {}
    X_list = []
    surge_list = []
    y_list = []
    pnl_list = []
    indices = []  # To track which dataset rows have sequences
    skipped = 0

    t0 = time.time()

    for i in range(len(dataset)):
        row = dataset.iloc[i]
        symbol = row["symbol"]
        entry_time = row["entry_time"]

        # Load klines for this symbol
        if symbol not in klines_cache:
            parquet_path = KLINES_DIR / f"{symbol}.parquet"
            if not parquet_path.exists():
                # Try USDT/USDC fallback
                alt_symbol = (symbol.replace("USDC", "USDT")
                              if symbol.endswith("USDC")
                              else symbol.replace("USDT", "USDC"))
                alt_path = KLINES_DIR / f"{alt_symbol}.parquet"
                if alt_path.exists():
                    parquet_path = alt_path
                else:
                    klines_cache[symbol] = None

            if symbol not in klines_cache:
                try:
                    klines_cache[symbol] = pd.read_parquet(parquet_path)
                except Exception:
                    klines_cache[symbol] = None

        df_klines = klines_cache[symbol]
        if df_klines is None:
            skipped += 1
            continue

        # Convert entry_time to ms timestamp
        try:
            ts_ms = int(pd.Timestamp(entry_time).timestamp() * 1000)
        except Exception:
            skipped += 1
            continue

        # Build sequence
        seq = build_sequence_features(df_klines, ts_ms, SEQ_LEN)
        if seq is None:
            skipped += 1
            continue

        X_list.append(seq)
        surge_type = row.get("surge_type", "UNKNOWN")
        surge_list.append(SURGE_TYPE_MAP.get(surge_type, 3))
        y_list.append(int(row["target_profitable"]))
        pnl_list.append(float(row.get("target_pnl_pct", 0)))
        indices.append(i)

        if (i + 1) % 2000 == 0:
            elapsed = time.time() - t0
            rate = (i + 1) / elapsed
            eta = (len(dataset) - i - 1) / rate
            print(f"  Processed {i + 1}/{len(dataset)} "
                  f"({len(X_list)} sequences, {skipped} skipped) "
                  f"[{elapsed:.0f}s, ETA {eta:.0f}s]")

    elapsed = time.time() - t0
    print(f"\n  ✅ Built {len(X_list)} sequences in {elapsed:.0f}s ({skipped} skipped)")
    print(f"  Shape: ({len(X_list)}, {SEQ_LEN}, {N_SEQUENCE_FEATURES})")

    # Save as compressed npz
    X = np.array(X_list, dtype=np.float32)
    surge_types = np.array(surge_list, dtype=np.int64)
    y = np.array(y_list, dtype=np.float32)
    pnl = np.array(pnl_list, dtype=np.float32)
    idx = np.array(indices, dtype=np.int64)

    # Replace NaN/inf
    X = np.nan_to_num(X, nan=0.0, posinf=10.0, neginf=-10.0)

    output_path = DATA_DIR / "lstm_sequences.npz"
    np.savez_compressed(
        output_path,
        X=X,
        surge_types=surge_types,
        y=y,
        pnl=pnl,
        indices=idx,
        seq_len=np.array([SEQ_LEN]),
        n_features=np.array([N_SEQUENCE_FEATURES]),
    )

    file_size = output_path.stat().st_size / 1e6
    print(f"\n  💾 Saved: {output_path} ({file_size:.1f} MB)")
    print(f"  Positives: {y.sum():.0f}/{len(y)} ({y.mean()*100:.1f}%)")


if __name__ == "__main__":
    main()
