#!/usr/bin/env python3 """ label_backfill.py — V5.3 信号标签回填脚本 功能: - 遍历 signal_feature_events 中有 side 的历史记录 - 根据 side 时点后 15/30/60 分钟的 agg_trades 价格计算标签 - 写入 signal_label_events 表(event_id 复用 signal_feature_events.event_id) 标签定义(严格按 Mark Price + 时间序列方向): y_binary_60m = 1 if price_60m_later > price_at_signal (LONG) = 1 if price_60m_later < price_at_signal (SHORT) = 0 otherwise y_return_Xm = (price_Xm_later - price_at_signal) / price_at_signal (方向不翻转,LONG正SHORT正) mfe_r_60m = max favorable excursion / atr_value (需 atr_value 不为0) mae_r_60m = max adverse excursion / atr_value 运行方式: python3 scripts/label_backfill.py python3 scripts/label_backfill.py --symbol BTCUSDT python3 scripts/label_backfill.py --since 1709000000000 # ms timestamp python3 scripts/label_backfill.py --dry-run # 只打印不写入 依赖: PG_PASS / PG_HOST 环境变量(同其他脚本) """ import argparse import logging import os import sys import time import psycopg2 import psycopg2.extras sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend")) from db import get_sync_conn, init_schema logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) logger = logging.getLogger("label_backfill") HORIZONS_MS = { "15m": 15 * 60 * 1000, "30m": 30 * 60 * 1000, "60m": 60 * 60 * 1000, } BATCH_SIZE = 200 LABEL_TABLE = "signal_label_events" def ensure_label_table(conn): """label_events 表由 db.py init_schema 创建,此处仅确认存在""" with conn.cursor() as cur: cur.execute( "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=%s)", (LABEL_TABLE,), ) if not cur.fetchone()[0]: raise RuntimeError(f"表 {LABEL_TABLE} 不存在,请先运行 init_schema()") def fetch_unlabeled_signals(conn, symbol=None, since_ms=None, limit=BATCH_SIZE): """取尚未回填标签的 signal_feature_events(有 side 且 60m 已过期)""" cutoff_ms = int(time.time() * 1000) - HORIZONS_MS["60m"] - 60_000 # 多留1分钟缓冲 with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: params = [] conds = [ "sfe.side IS NOT NULL", "sfe.side != ''", "sfe.ts < %s", "sle.event_id IS NULL", # 尚未回填 ] params.append(cutoff_ms) if symbol: conds.append("sfe.symbol = %s") params.append(symbol.upper()) if since_ms: conds.append("sfe.ts >= %s") params.append(since_ms) where = " AND ".join(conds) params.append(limit) cur.execute( f""" SELECT sfe.event_id AS id, sfe.ts, sfe.symbol, sfe.side AS signal, sfe.price, sfe.atr_value FROM signal_feature_events sfe LEFT JOIN {LABEL_TABLE} sle ON sle.event_id = sfe.event_id WHERE {where} ORDER BY sfe.ts ASC LIMIT %s """, params, ) return cur.fetchall() def fetch_price_at(conn, symbol: str, ts_ms: int) -> float | None: """取 ts_ms 之后最近一笔 agg_trades 成交价""" with conn.cursor() as cur: cur.execute( "SELECT price FROM agg_trades WHERE symbol=%s AND time_ms >= %s ORDER BY time_ms ASC LIMIT 1", (symbol, ts_ms), ) row = cur.fetchone() return float(row[0]) if row else None def fetch_price_range(conn, symbol: str, from_ms: int, to_ms: int): """取区间内最高价和最低价(用于 MFE/MAE 计算)""" with conn.cursor() as cur: cur.execute( "SELECT MAX(price), MIN(price) FROM agg_trades WHERE symbol=%s AND time_ms BETWEEN %s AND %s", (symbol, from_ms, to_ms), ) row = cur.fetchone() if row and row[0] is not None: return float(row[0]), float(row[1]) return None, None def compute_label(signal: str, entry_price: float, future_price: float) -> int: """方向感知二值标签(1=预测正确,0=预测错误)""" if signal == "LONG": return 1 if future_price > entry_price else 0 elif signal == "SHORT": return 1 if future_price < entry_price else 0 return 0 def compute_return(signal: str, entry_price: float, future_price: float) -> float: """方向感知收益率(正值=有利)""" if entry_price == 0: return 0.0 raw = (future_price - entry_price) / entry_price return raw if signal == "LONG" else -raw def compute_mfe_mae(signal: str, entry_price: float, high: float, low: float, atr: float): """MFE/MAE(以R为单位)""" if atr is None or atr <= 0: return None, None if signal == "LONG": mfe = (high - entry_price) / atr mae = (entry_price - low) / atr else: mfe = (entry_price - low) / atr mae = (high - entry_price) / atr return round(mfe, 4), round(mae, 4) def backfill_batch(conn, rows: list, dry_run: bool) -> int: """处理一批信号,返回成功回填数""" if not rows: return 0 records = [] for row in rows: event_id = row["id"] ts = row["ts"] symbol = row["symbol"] signal = row["signal"] entry_price = row["price"] or 0.0 atr_value = row["atr_value"] labels = {} for horizon, delta_ms in HORIZONS_MS.items(): future_ms = ts + delta_ms fp = fetch_price_at(conn, symbol, future_ms) if fp is None: labels[horizon] = None labels[f"ret_{horizon}"] = None else: labels[horizon] = compute_label(signal, entry_price, fp) labels[f"ret_{horizon}"] = round(compute_return(signal, entry_price, fp), 6) # MFE/MAE 在 60m 窗口内 high, low = fetch_price_range(conn, symbol, ts, ts + HORIZONS_MS["60m"]) mfe, mae = compute_mfe_mae(signal, entry_price, high, low, atr_value) if high else (None, None) records.append(( event_id, labels.get("15m"), labels.get("30m"), labels.get("60m"), labels.get("ret_15m"), labels.get("ret_30m"), labels.get("ret_60m"), mfe, mae, int(time.time() * 1000), )) if dry_run: logger.info(f"[dry-run] 准备写入 {len(records)} 条标签") for r in records[:5]: logger.info(f" event_id={r[0]} y15m={r[1]} y30m={r[2]} y60m={r[3]} ret60m={r[6]} mfe={r[7]} mae={r[8]}") return len(records) with conn.cursor() as cur: psycopg2.extras.execute_values( cur, f""" INSERT INTO {LABEL_TABLE} (event_id, y_binary_15m, y_binary_30m, y_binary_60m, y_return_15m, y_return_30m, y_return_60m, mfe_r_60m, mae_r_60m, label_ts) VALUES %s ON CONFLICT (event_id) DO NOTHING """, records, ) conn.commit() return len(records) def main(): parser = argparse.ArgumentParser(description="V5.3 信号标签回填") parser.add_argument("--symbol", help="只回填指定品种(BTCUSDT/ETHUSDT/...)") parser.add_argument("--since", type=int, help="起始时间戳(ms)") parser.add_argument("--dry-run", action="store_true", help="只打印,不写入") parser.add_argument("--loop", action="store_true", help="持续运行,每5分钟跑一次") args = parser.parse_args() init_schema() total = 0 while True: with get_sync_conn() as conn: ensure_label_table(conn) rows = fetch_unlabeled_signals(conn, symbol=args.symbol, since_ms=args.since) if not rows: logger.info("没有待回填的信号,结束。") break n = backfill_batch(conn, rows, dry_run=args.dry_run) total += n logger.info(f"本批回填 {n} 条,累计 {total} 条") time.sleep(1) # 限速:每批间隔1秒,减轻DB压力 if not args.loop: break time.sleep(300) # 5分钟 logger.info(f"回填完成,总计 {total} 条。") if __name__ == "__main__": main()