feat: Phase 0 - add signal_feature_events/label_events tables, atr_value snapshot, label_backfill script
This commit is contained in:
parent
8280aaf6ea
commit
0a5222a1de
@ -290,7 +290,61 @@ CREATE TABLE IF NOT EXISTS paper_trades (
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- Live trading tables
|
||||
-- V5.3 Feature Events(每次信号评估快照,无论是否开仓)
|
||||
CREATE TABLE IF NOT EXISTS signal_feature_events (
|
||||
event_id BIGSERIAL PRIMARY KEY,
|
||||
ts BIGINT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
track TEXT NOT NULL DEFAULT 'ALT',
|
||||
side TEXT,
|
||||
strategy TEXT NOT NULL,
|
||||
strategy_version TEXT,
|
||||
config_hash TEXT,
|
||||
-- 原始特征
|
||||
cvd_fast_raw DOUBLE PRECISION,
|
||||
cvd_mid_raw DOUBLE PRECISION,
|
||||
cvd_day_raw DOUBLE PRECISION,
|
||||
cvd_fast_slope_raw DOUBLE PRECISION,
|
||||
p95_qty_raw DOUBLE PRECISION,
|
||||
p99_qty_raw DOUBLE PRECISION,
|
||||
atr_value DOUBLE PRECISION,
|
||||
atr_percentile DOUBLE PRECISION,
|
||||
oi_delta_raw DOUBLE PRECISION,
|
||||
ls_ratio_raw DOUBLE PRECISION,
|
||||
top_pos_raw DOUBLE PRECISION,
|
||||
coinbase_premium_raw DOUBLE PRECISION,
|
||||
obi_raw DOUBLE PRECISION,
|
||||
tiered_cvd_whale_raw DOUBLE PRECISION,
|
||||
-- 分层评分
|
||||
score_direction DOUBLE PRECISION,
|
||||
score_crowding DOUBLE PRECISION,
|
||||
score_environment DOUBLE PRECISION,
|
||||
score_aux DOUBLE PRECISION,
|
||||
score_total DOUBLE PRECISION,
|
||||
-- 决策结果
|
||||
gate_passed BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
block_reason TEXT,
|
||||
price DOUBLE PRECISION,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_sfe_sym_ts ON signal_feature_events(symbol, ts DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_sfe_strategy ON signal_feature_events(strategy, ts DESC);
|
||||
|
||||
-- V5.3 Label Events(延迟回填标签,评估信号预测能力)
|
||||
CREATE TABLE IF NOT EXISTS signal_label_events (
|
||||
event_id BIGINT PRIMARY KEY REFERENCES signal_feature_events(event_id) ON DELETE CASCADE,
|
||||
y_binary_15m SMALLINT,
|
||||
y_binary_30m SMALLINT,
|
||||
y_binary_60m SMALLINT,
|
||||
y_return_15m DOUBLE PRECISION,
|
||||
y_return_30m DOUBLE PRECISION,
|
||||
y_return_60m DOUBLE PRECISION,
|
||||
mfe_r_60m DOUBLE PRECISION,
|
||||
mae_r_60m DOUBLE PRECISION,
|
||||
label_ts BIGINT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS live_config (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
@ -396,6 +450,7 @@ def init_schema():
|
||||
"ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS risk_distance DOUBLE PRECISION",
|
||||
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS strategy TEXT",
|
||||
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS factors JSONB",
|
||||
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS atr_value DOUBLE PRECISION",
|
||||
"ALTER TABLE users ADD COLUMN IF NOT EXISTS discord_id TEXT",
|
||||
"ALTER TABLE users ADD COLUMN IF NOT EXISTS banned BOOLEAN DEFAULT FALSE",
|
||||
]
|
||||
|
||||
@ -375,6 +375,7 @@ class SymbolState:
|
||||
"cvd_day": cvd_day,
|
||||
"vwap": vwap,
|
||||
"atr": atr,
|
||||
"atr_value": atr, # V5.3: ATR绝对值快照(用于feature_events落库)
|
||||
"atr_pct": atr_pct,
|
||||
"p95": p95,
|
||||
"p99": p99,
|
||||
@ -694,10 +695,11 @@ def save_indicator(ts: int, symbol: str, result: dict, strategy: str = "v52_8sig
|
||||
factors_json = _json3.dumps(result.get("factors")) if result.get("factors") else None
|
||||
cur.execute(
|
||||
"INSERT INTO signal_indicators "
|
||||
"(ts,symbol,strategy,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,vwap_30m,price,p95_qty,p99_qty,score,signal,factors) "
|
||||
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||
"(ts,symbol,strategy,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,atr_value,vwap_30m,price,p95_qty,p99_qty,score,signal,factors) "
|
||||
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||
(ts, symbol, strategy, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
|
||||
result["atr"], result["atr_pct"], result["vwap"], result["price"],
|
||||
result["atr"], result["atr_pct"], result.get("atr_value", result["atr"]),
|
||||
result["vwap"], result["price"],
|
||||
result["p95"], result["p99"], result["score"], result.get("signal"), factors_json)
|
||||
)
|
||||
# 有信号时通知live_executor
|
||||
|
||||
242
scripts/label_backfill.py
Normal file
242
scripts/label_backfill.py
Normal file
@ -0,0 +1,242 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
label_backfill.py — V5.3 信号标签回填脚本
|
||||
|
||||
功能:
|
||||
- 遍历 signal_indicators 中有 signal 的历史记录
|
||||
- 根据 signal 时点后 15/30/60 分钟的 agg_trades 价格计算标签
|
||||
- 写入 signal_label_events 表(event_id 复用 signal_indicators.id)
|
||||
|
||||
标签定义(严格按 Mark Price + 时间序列方向):
|
||||
y_binary_60m = 1 if price_60m_later > price_at_signal (LONG)
|
||||
= 1 if price_60m_later < price_at_signal (SHORT)
|
||||
= 0 otherwise
|
||||
y_return_Xm = (price_Xm_later - price_at_signal) / price_at_signal (方向不翻转,LONG正SHORT正)
|
||||
mfe_r_60m = max favorable excursion / atr_value (需 atr_value 不为0)
|
||||
mae_r_60m = max adverse excursion / atr_value
|
||||
|
||||
运行方式:
|
||||
python3 scripts/label_backfill.py
|
||||
python3 scripts/label_backfill.py --symbol BTCUSDT
|
||||
python3 scripts/label_backfill.py --since 1709000000000 # ms timestamp
|
||||
python3 scripts/label_backfill.py --dry-run # 只打印不写入
|
||||
|
||||
依赖:
|
||||
PG_PASS / PG_HOST 环境变量(同其他脚本)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "backend"))
|
||||
from db import get_sync_conn, init_schema
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("label_backfill")
|
||||
|
||||
HORIZONS_MS = {
|
||||
"15m": 15 * 60 * 1000,
|
||||
"30m": 30 * 60 * 1000,
|
||||
"60m": 60 * 60 * 1000,
|
||||
}
|
||||
BATCH_SIZE = 200
|
||||
LABEL_TABLE = "signal_label_events"
|
||||
|
||||
|
||||
def ensure_label_table(conn):
|
||||
"""label_events 表由 db.py init_schema 创建,此处仅确认存在"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=%s)",
|
||||
(LABEL_TABLE,),
|
||||
)
|
||||
if not cur.fetchone()[0]:
|
||||
raise RuntimeError(f"表 {LABEL_TABLE} 不存在,请先运行 init_schema()")
|
||||
|
||||
|
||||
def fetch_unlabeled_signals(conn, symbol=None, since_ms=None, limit=BATCH_SIZE):
|
||||
"""取尚未回填标签的 signal_indicators(有 signal 且 60m 已过期)"""
|
||||
cutoff_ms = int(time.time() * 1000) - HORIZONS_MS["60m"] - 60_000 # 多留1分钟缓冲
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
params = []
|
||||
conds = [
|
||||
"si.signal IS NOT NULL",
|
||||
"si.ts < %s",
|
||||
"sle.event_id IS NULL", # 尚未回填
|
||||
]
|
||||
params.append(cutoff_ms)
|
||||
if symbol:
|
||||
conds.append("si.symbol = %s")
|
||||
params.append(symbol.upper())
|
||||
if since_ms:
|
||||
conds.append("si.ts >= %s")
|
||||
params.append(since_ms)
|
||||
where = " AND ".join(conds)
|
||||
params.append(limit)
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT si.id, si.ts, si.symbol, si.signal, si.price, si.atr_value
|
||||
FROM signal_indicators si
|
||||
LEFT JOIN {LABEL_TABLE} sle ON sle.event_id = si.id
|
||||
WHERE {where}
|
||||
ORDER BY si.ts ASC
|
||||
LIMIT %s
|
||||
""",
|
||||
params,
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
|
||||
def fetch_price_at(conn, symbol: str, ts_ms: int) -> float | None:
|
||||
"""取 ts_ms 之后最近一笔 agg_trades 成交价"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT price FROM agg_trades WHERE symbol=%s AND time_ms >= %s ORDER BY time_ms ASC LIMIT 1",
|
||||
(symbol, ts_ms),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
return float(row[0]) if row else None
|
||||
|
||||
|
||||
def fetch_price_range(conn, symbol: str, from_ms: int, to_ms: int):
|
||||
"""取区间内最高价和最低价(用于 MFE/MAE 计算)"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT MAX(price), MIN(price) FROM agg_trades WHERE symbol=%s AND time_ms BETWEEN %s AND %s",
|
||||
(symbol, from_ms, to_ms),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0] is not None:
|
||||
return float(row[0]), float(row[1])
|
||||
return None, None
|
||||
|
||||
|
||||
def compute_label(signal: str, entry_price: float, future_price: float) -> int:
|
||||
"""方向感知二值标签(1=预测正确,0=预测错误)"""
|
||||
if signal == "LONG":
|
||||
return 1 if future_price > entry_price else 0
|
||||
elif signal == "SHORT":
|
||||
return 1 if future_price < entry_price else 0
|
||||
return 0
|
||||
|
||||
|
||||
def compute_return(signal: str, entry_price: float, future_price: float) -> float:
|
||||
"""方向感知收益率(正值=有利)"""
|
||||
if entry_price == 0:
|
||||
return 0.0
|
||||
raw = (future_price - entry_price) / entry_price
|
||||
return raw if signal == "LONG" else -raw
|
||||
|
||||
|
||||
def compute_mfe_mae(signal: str, entry_price: float, high: float, low: float, atr: float):
|
||||
"""MFE/MAE(以R为单位)"""
|
||||
if atr is None or atr <= 0:
|
||||
return None, None
|
||||
if signal == "LONG":
|
||||
mfe = (high - entry_price) / atr
|
||||
mae = (entry_price - low) / atr
|
||||
else:
|
||||
mfe = (entry_price - low) / atr
|
||||
mae = (high - entry_price) / atr
|
||||
return round(mfe, 4), round(mae, 4)
|
||||
|
||||
|
||||
def backfill_batch(conn, rows: list, dry_run: bool) -> int:
|
||||
"""处理一批信号,返回成功回填数"""
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
records = []
|
||||
for row in rows:
|
||||
event_id = row["id"]
|
||||
ts = row["ts"]
|
||||
symbol = row["symbol"]
|
||||
signal = row["signal"]
|
||||
entry_price = row["price"] or 0.0
|
||||
atr_value = row["atr_value"]
|
||||
|
||||
labels = {}
|
||||
for horizon, delta_ms in HORIZONS_MS.items():
|
||||
future_ms = ts + delta_ms
|
||||
fp = fetch_price_at(conn, symbol, future_ms)
|
||||
if fp is None:
|
||||
labels[horizon] = None
|
||||
labels[f"ret_{horizon}"] = None
|
||||
else:
|
||||
labels[horizon] = compute_label(signal, entry_price, fp)
|
||||
labels[f"ret_{horizon}"] = round(compute_return(signal, entry_price, fp), 6)
|
||||
|
||||
# MFE/MAE 在 60m 窗口内
|
||||
high, low = fetch_price_range(conn, symbol, ts, ts + HORIZONS_MS["60m"])
|
||||
mfe, mae = compute_mfe_mae(signal, entry_price, high, low, atr_value) if high else (None, None)
|
||||
|
||||
records.append((
|
||||
event_id,
|
||||
labels.get("15m"), labels.get("30m"), labels.get("60m"),
|
||||
labels.get("ret_15m"), labels.get("ret_30m"), labels.get("ret_60m"),
|
||||
mfe, mae,
|
||||
int(time.time() * 1000),
|
||||
))
|
||||
|
||||
if dry_run:
|
||||
logger.info(f"[dry-run] 准备写入 {len(records)} 条标签")
|
||||
for r in records[:5]:
|
||||
logger.info(f" event_id={r[0]} y15m={r[1]} y30m={r[2]} y60m={r[3]} ret60m={r[6]} mfe={r[7]} mae={r[8]}")
|
||||
return len(records)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
f"""
|
||||
INSERT INTO {LABEL_TABLE}
|
||||
(event_id, y_binary_15m, y_binary_30m, y_binary_60m,
|
||||
y_return_15m, y_return_30m, y_return_60m, mfe_r_60m, mae_r_60m, label_ts)
|
||||
VALUES %s
|
||||
ON CONFLICT (event_id) DO NOTHING
|
||||
""",
|
||||
records,
|
||||
)
|
||||
conn.commit()
|
||||
return len(records)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="V5.3 信号标签回填")
|
||||
parser.add_argument("--symbol", help="只回填指定品种(BTCUSDT/ETHUSDT/...)")
|
||||
parser.add_argument("--since", type=int, help="起始时间戳(ms)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="只打印,不写入")
|
||||
parser.add_argument("--loop", action="store_true", help="持续运行,每5分钟跑一次")
|
||||
args = parser.parse_args()
|
||||
|
||||
init_schema()
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
with get_sync_conn() as conn:
|
||||
ensure_label_table(conn)
|
||||
rows = fetch_unlabeled_signals(conn, symbol=args.symbol, since_ms=args.since)
|
||||
if not rows:
|
||||
logger.info("没有待回填的信号,结束。")
|
||||
break
|
||||
n = backfill_batch(conn, rows, dry_run=args.dry_run)
|
||||
total += n
|
||||
logger.info(f"本批回填 {n} 条,累计 {total} 条")
|
||||
|
||||
if not args.loop:
|
||||
break
|
||||
time.sleep(300) # 5分钟
|
||||
|
||||
logger.info(f"回填完成,总计 {total} 条。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user