245 lines
8.2 KiB
Python
245 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
label_backfill.py — V5.3 信号标签回填脚本
|
||
|
||
功能:
|
||
- 遍历 signal_feature_events 中有 side 的历史记录
|
||
- 根据 side 时点后 15/30/60 分钟的 agg_trades 价格计算标签
|
||
- 写入 signal_label_events 表(event_id 复用 signal_feature_events.event_id)
|
||
|
||
标签定义(严格按 Mark Price + 时间序列方向):
|
||
y_binary_60m = 1 if price_60m_later > price_at_signal (LONG)
|
||
= 1 if price_60m_later < price_at_signal (SHORT)
|
||
= 0 otherwise
|
||
y_return_Xm = (price_Xm_later - price_at_signal) / price_at_signal (方向不翻转,LONG正SHORT正)
|
||
mfe_r_60m = max favorable excursion / atr_value (需 atr_value 不为0)
|
||
mae_r_60m = max adverse excursion / atr_value
|
||
|
||
运行方式:
|
||
python3 scripts/label_backfill.py
|
||
python3 scripts/label_backfill.py --symbol BTCUSDT
|
||
python3 scripts/label_backfill.py --since 1709000000000 # ms timestamp
|
||
python3 scripts/label_backfill.py --dry-run # 只打印不写入
|
||
|
||
依赖:
|
||
PG_PASS / PG_HOST 环境变量(同其他脚本)
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
import os
|
||
import sys
|
||
import time
|
||
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
||
from db import get_sync_conn, init_schema
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||
)
|
||
logger = logging.getLogger("label_backfill")
|
||
|
||
HORIZONS_MS = {
|
||
"15m": 15 * 60 * 1000,
|
||
"30m": 30 * 60 * 1000,
|
||
"60m": 60 * 60 * 1000,
|
||
}
|
||
BATCH_SIZE = 200
|
||
LABEL_TABLE = "signal_label_events"
|
||
|
||
|
||
def ensure_label_table(conn):
|
||
"""label_events 表由 db.py init_schema 创建,此处仅确认存在"""
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=%s)",
|
||
(LABEL_TABLE,),
|
||
)
|
||
if not cur.fetchone()[0]:
|
||
raise RuntimeError(f"表 {LABEL_TABLE} 不存在,请先运行 init_schema()")
|
||
|
||
|
||
def fetch_unlabeled_signals(conn, symbol=None, since_ms=None, limit=BATCH_SIZE):
|
||
"""取尚未回填标签的 signal_feature_events(有 side 且 60m 已过期)"""
|
||
cutoff_ms = int(time.time() * 1000) - HORIZONS_MS["60m"] - 60_000 # 多留1分钟缓冲
|
||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||
params = []
|
||
conds = [
|
||
"sfe.side IS NOT NULL",
|
||
"sfe.side != ''",
|
||
"sfe.ts < %s",
|
||
"sle.event_id IS NULL", # 尚未回填
|
||
]
|
||
params.append(cutoff_ms)
|
||
if symbol:
|
||
conds.append("sfe.symbol = %s")
|
||
params.append(symbol.upper())
|
||
if since_ms:
|
||
conds.append("sfe.ts >= %s")
|
||
params.append(since_ms)
|
||
where = " AND ".join(conds)
|
||
params.append(limit)
|
||
cur.execute(
|
||
f"""
|
||
SELECT sfe.event_id AS id, sfe.ts, sfe.symbol, sfe.side AS signal,
|
||
sfe.price, sfe.atr_value
|
||
FROM signal_feature_events sfe
|
||
LEFT JOIN {LABEL_TABLE} sle ON sle.event_id = sfe.event_id
|
||
WHERE {where}
|
||
ORDER BY sfe.ts ASC
|
||
LIMIT %s
|
||
""",
|
||
params,
|
||
)
|
||
return cur.fetchall()
|
||
|
||
|
||
def fetch_price_at(conn, symbol: str, ts_ms: int) -> float | None:
|
||
"""取 ts_ms 之后最近一笔 agg_trades 成交价"""
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT price FROM agg_trades WHERE symbol=%s AND time_ms >= %s ORDER BY time_ms ASC LIMIT 1",
|
||
(symbol, ts_ms),
|
||
)
|
||
row = cur.fetchone()
|
||
return float(row[0]) if row else None
|
||
|
||
|
||
def fetch_price_range(conn, symbol: str, from_ms: int, to_ms: int):
|
||
"""取区间内最高价和最低价(用于 MFE/MAE 计算)"""
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT MAX(price), MIN(price) FROM agg_trades WHERE symbol=%s AND time_ms BETWEEN %s AND %s",
|
||
(symbol, from_ms, to_ms),
|
||
)
|
||
row = cur.fetchone()
|
||
if row and row[0] is not None:
|
||
return float(row[0]), float(row[1])
|
||
return None, None
|
||
|
||
|
||
def compute_label(signal: str, entry_price: float, future_price: float) -> int:
|
||
"""方向感知二值标签(1=预测正确,0=预测错误)"""
|
||
if signal == "LONG":
|
||
return 1 if future_price > entry_price else 0
|
||
elif signal == "SHORT":
|
||
return 1 if future_price < entry_price else 0
|
||
return 0
|
||
|
||
|
||
def compute_return(signal: str, entry_price: float, future_price: float) -> float:
|
||
"""方向感知收益率(正值=有利)"""
|
||
if entry_price == 0:
|
||
return 0.0
|
||
raw = (future_price - entry_price) / entry_price
|
||
return raw if signal == "LONG" else -raw
|
||
|
||
|
||
def compute_mfe_mae(signal: str, entry_price: float, high: float, low: float, atr: float):
|
||
"""MFE/MAE(以R为单位)"""
|
||
if atr is None or atr <= 0:
|
||
return None, None
|
||
if signal == "LONG":
|
||
mfe = (high - entry_price) / atr
|
||
mae = (entry_price - low) / atr
|
||
else:
|
||
mfe = (entry_price - low) / atr
|
||
mae = (high - entry_price) / atr
|
||
return round(mfe, 4), round(mae, 4)
|
||
|
||
|
||
def backfill_batch(conn, rows: list, dry_run: bool) -> int:
|
||
"""处理一批信号,返回成功回填数"""
|
||
if not rows:
|
||
return 0
|
||
|
||
records = []
|
||
for row in rows:
|
||
event_id = row["id"]
|
||
ts = row["ts"]
|
||
symbol = row["symbol"]
|
||
signal = row["signal"]
|
||
entry_price = row["price"] or 0.0
|
||
atr_value = row["atr_value"]
|
||
|
||
labels = {}
|
||
for horizon, delta_ms in HORIZONS_MS.items():
|
||
future_ms = ts + delta_ms
|
||
fp = fetch_price_at(conn, symbol, future_ms)
|
||
if fp is None:
|
||
labels[horizon] = None
|
||
labels[f"ret_{horizon}"] = None
|
||
else:
|
||
labels[horizon] = compute_label(signal, entry_price, fp)
|
||
labels[f"ret_{horizon}"] = round(compute_return(signal, entry_price, fp), 6)
|
||
|
||
# MFE/MAE 在 60m 窗口内
|
||
high, low = fetch_price_range(conn, symbol, ts, ts + HORIZONS_MS["60m"])
|
||
mfe, mae = compute_mfe_mae(signal, entry_price, high, low, atr_value) if high else (None, None)
|
||
|
||
records.append((
|
||
event_id,
|
||
labels.get("15m"), labels.get("30m"), labels.get("60m"),
|
||
labels.get("ret_15m"), labels.get("ret_30m"), labels.get("ret_60m"),
|
||
mfe, mae,
|
||
int(time.time() * 1000),
|
||
))
|
||
|
||
if dry_run:
|
||
logger.info(f"[dry-run] 准备写入 {len(records)} 条标签")
|
||
for r in records[:5]:
|
||
logger.info(f" event_id={r[0]} y15m={r[1]} y30m={r[2]} y60m={r[3]} ret60m={r[6]} mfe={r[7]} mae={r[8]}")
|
||
return len(records)
|
||
|
||
with conn.cursor() as cur:
|
||
psycopg2.extras.execute_values(
|
||
cur,
|
||
f"""
|
||
INSERT INTO {LABEL_TABLE}
|
||
(event_id, y_binary_15m, y_binary_30m, y_binary_60m,
|
||
y_return_15m, y_return_30m, y_return_60m, mfe_r_60m, mae_r_60m, label_ts)
|
||
VALUES %s
|
||
ON CONFLICT (event_id) DO NOTHING
|
||
""",
|
||
records,
|
||
)
|
||
conn.commit()
|
||
return len(records)
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="V5.3 信号标签回填")
|
||
parser.add_argument("--symbol", help="只回填指定品种(BTCUSDT/ETHUSDT/...)")
|
||
parser.add_argument("--since", type=int, help="起始时间戳(ms)")
|
||
parser.add_argument("--dry-run", action="store_true", help="只打印,不写入")
|
||
parser.add_argument("--loop", action="store_true", help="持续运行,每5分钟跑一次")
|
||
args = parser.parse_args()
|
||
|
||
init_schema()
|
||
|
||
total = 0
|
||
while True:
|
||
with get_sync_conn() as conn:
|
||
ensure_label_table(conn)
|
||
rows = fetch_unlabeled_signals(conn, symbol=args.symbol, since_ms=args.since)
|
||
if not rows:
|
||
logger.info("没有待回填的信号,结束。")
|
||
break
|
||
n = backfill_batch(conn, rows, dry_run=args.dry_run)
|
||
total += n
|
||
logger.info(f"本批回填 {n} 条,累计 {total} 条")
|
||
|
||
if not args.loop:
|
||
break
|
||
time.sleep(300) # 5分钟
|
||
|
||
logger.info(f"回填完成,总计 {total} 条。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|