arbitrage-engine/scripts/label_backfill.py

246 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
label_backfill.py — V5.3 信号标签回填脚本
功能:
- 遍历 signal_feature_events 中有 side 的历史记录
- 根据 side 时点后 15/30/60 分钟的 agg_trades 价格计算标签
- 写入 signal_label_events 表event_id 复用 signal_feature_events.event_id
标签定义(严格按 Mark Price + 时间序列方向):
y_binary_60m = 1 if price_60m_later > price_at_signal (LONG)
= 1 if price_60m_later < price_at_signal (SHORT)
= 0 otherwise
y_return_Xm = (price_Xm_later - price_at_signal) / price_at_signal (方向不翻转LONG正SHORT正)
mfe_r_60m = max favorable excursion / atr_value (需 atr_value 不为0)
mae_r_60m = max adverse excursion / atr_value
运行方式:
python3 scripts/label_backfill.py
python3 scripts/label_backfill.py --symbol BTCUSDT
python3 scripts/label_backfill.py --since 1709000000000 # ms timestamp
python3 scripts/label_backfill.py --dry-run # 只打印不写入
依赖:
PG_PASS / PG_HOST 环境变量(同其他脚本)
"""
import argparse
import logging
import os
import sys
import time
import psycopg2
import psycopg2.extras
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
from db import get_sync_conn, init_schema
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("label_backfill")
HORIZONS_MS = {
"15m": 15 * 60 * 1000,
"30m": 30 * 60 * 1000,
"60m": 60 * 60 * 1000,
}
BATCH_SIZE = 200
LABEL_TABLE = "signal_label_events"
def ensure_label_table(conn):
"""label_events 表由 db.py init_schema 创建,此处仅确认存在"""
with conn.cursor() as cur:
cur.execute(
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=%s)",
(LABEL_TABLE,),
)
if not cur.fetchone()[0]:
raise RuntimeError(f"{LABEL_TABLE} 不存在,请先运行 init_schema()")
def fetch_unlabeled_signals(conn, symbol=None, since_ms=None, limit=BATCH_SIZE):
"""取尚未回填标签的 signal_feature_events有 side 且 60m 已过期)"""
cutoff_ms = int(time.time() * 1000) - HORIZONS_MS["60m"] - 60_000 # 多留1分钟缓冲
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
params = []
conds = [
"sfe.side IS NOT NULL",
"sfe.side != ''",
"sfe.ts < %s",
"sle.event_id IS NULL", # 尚未回填
]
params.append(cutoff_ms)
if symbol:
conds.append("sfe.symbol = %s")
params.append(symbol.upper())
if since_ms:
conds.append("sfe.ts >= %s")
params.append(since_ms)
where = " AND ".join(conds)
params.append(limit)
cur.execute(
f"""
SELECT sfe.event_id AS id, sfe.ts, sfe.symbol, sfe.side AS signal,
sfe.price, sfe.atr_value
FROM signal_feature_events sfe
LEFT JOIN {LABEL_TABLE} sle ON sle.event_id = sfe.event_id
WHERE {where}
ORDER BY sfe.ts ASC
LIMIT %s
""",
params,
)
return cur.fetchall()
def fetch_price_at(conn, symbol: str, ts_ms: int) -> float | None:
"""取 ts_ms 之后最近一笔 agg_trades 成交价"""
with conn.cursor() as cur:
cur.execute(
"SELECT price FROM agg_trades WHERE symbol=%s AND time_ms >= %s ORDER BY time_ms ASC LIMIT 1",
(symbol, ts_ms),
)
row = cur.fetchone()
return float(row[0]) if row else None
def fetch_price_range(conn, symbol: str, from_ms: int, to_ms: int):
"""取区间内最高价和最低价(用于 MFE/MAE 计算)"""
with conn.cursor() as cur:
cur.execute(
"SELECT MAX(price), MIN(price) FROM agg_trades WHERE symbol=%s AND time_ms BETWEEN %s AND %s",
(symbol, from_ms, to_ms),
)
row = cur.fetchone()
if row and row[0] is not None:
return float(row[0]), float(row[1])
return None, None
def compute_label(signal: str, entry_price: float, future_price: float) -> int:
"""方向感知二值标签1=预测正确0=预测错误)"""
if signal == "LONG":
return 1 if future_price > entry_price else 0
elif signal == "SHORT":
return 1 if future_price < entry_price else 0
return 0
def compute_return(signal: str, entry_price: float, future_price: float) -> float:
"""方向感知收益率(正值=有利)"""
if entry_price == 0:
return 0.0
raw = (future_price - entry_price) / entry_price
return raw if signal == "LONG" else -raw
def compute_mfe_mae(signal: str, entry_price: float, high: float, low: float, atr: float):
"""MFE/MAE以R为单位"""
if atr is None or atr <= 0:
return None, None
if signal == "LONG":
mfe = (high - entry_price) / atr
mae = (entry_price - low) / atr
else:
mfe = (entry_price - low) / atr
mae = (high - entry_price) / atr
return round(mfe, 4), round(mae, 4)
def backfill_batch(conn, rows: list, dry_run: bool) -> int:
"""处理一批信号,返回成功回填数"""
if not rows:
return 0
records = []
for row in rows:
event_id = row["id"]
ts = row["ts"]
symbol = row["symbol"]
signal = row["signal"]
entry_price = row["price"] or 0.0
atr_value = row["atr_value"]
labels = {}
for horizon, delta_ms in HORIZONS_MS.items():
future_ms = ts + delta_ms
fp = fetch_price_at(conn, symbol, future_ms)
if fp is None:
labels[horizon] = None
labels[f"ret_{horizon}"] = None
else:
labels[horizon] = compute_label(signal, entry_price, fp)
labels[f"ret_{horizon}"] = round(compute_return(signal, entry_price, fp), 6)
# MFE/MAE 在 60m 窗口内
high, low = fetch_price_range(conn, symbol, ts, ts + HORIZONS_MS["60m"])
mfe, mae = compute_mfe_mae(signal, entry_price, high, low, atr_value) if high else (None, None)
records.append((
event_id,
labels.get("15m"), labels.get("30m"), labels.get("60m"),
labels.get("ret_15m"), labels.get("ret_30m"), labels.get("ret_60m"),
mfe, mae,
int(time.time() * 1000),
))
if dry_run:
logger.info(f"[dry-run] 准备写入 {len(records)} 条标签")
for r in records[:5]:
logger.info(f" event_id={r[0]} y15m={r[1]} y30m={r[2]} y60m={r[3]} ret60m={r[6]} mfe={r[7]} mae={r[8]}")
return len(records)
with conn.cursor() as cur:
psycopg2.extras.execute_values(
cur,
f"""
INSERT INTO {LABEL_TABLE}
(event_id, y_binary_15m, y_binary_30m, y_binary_60m,
y_return_15m, y_return_30m, y_return_60m, mfe_r_60m, mae_r_60m, label_ts)
VALUES %s
ON CONFLICT (event_id) DO NOTHING
""",
records,
)
conn.commit()
return len(records)
def main():
parser = argparse.ArgumentParser(description="V5.3 信号标签回填")
parser.add_argument("--symbol", help="只回填指定品种BTCUSDT/ETHUSDT/...")
parser.add_argument("--since", type=int, help="起始时间戳ms")
parser.add_argument("--dry-run", action="store_true", help="只打印,不写入")
parser.add_argument("--loop", action="store_true", help="持续运行每5分钟跑一次")
args = parser.parse_args()
init_schema()
total = 0
while True:
with get_sync_conn() as conn:
ensure_label_table(conn)
rows = fetch_unlabeled_signals(conn, symbol=args.symbol, since_ms=args.since)
if not rows:
logger.info("没有待回填的信号,结束。")
break
n = backfill_batch(conn, rows, dry_run=args.dry_run)
total += n
logger.info(f"本批回填 {n} 条,累计 {total}")
time.sleep(1) # 限速每批间隔1秒减轻DB压力
if not args.loop:
break
time.sleep(300) # 5分钟
logger.info(f"回填完成,总计 {total} 条。")
if __name__ == "__main__":
main()