arbitrage-engine/backend/signal_engine.py

555 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
signal_engine.py — V5 短线交易信号引擎
架构:
- 独立PM2进程每5秒循环
- 内存滚动窗口计算指标CVD/ATR/VWAP/大单阈值)
- 启动时回灌历史数据冷启动warmup
- 信号评估核心3条件+加分3条件
- 输出signal_indicators表 + signal_trades表 + Discord推送
指标:
- CVD_fast (30m滚动) / CVD_mid (4h滚动) / CVD_day (UTC日内)
- ATR (5m, 14周期)
- VWAP_30m
- 大单阈值 P95/P99 (24h滚动)
"""
import logging
import os
import sqlite3
import time
import math
import statistics
from collections import deque
from datetime import datetime, timezone
from typing import Optional
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "signal-engine.log")),
],
)
logger = logging.getLogger("signal-engine")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
LOOP_INTERVAL = 5 # 秒
# 窗口大小(毫秒)
WINDOW_FAST = 30 * 60 * 1000 # 30分钟
WINDOW_MID = 4 * 3600 * 1000 # 4小时
WINDOW_DAY = 24 * 3600 * 1000 # 24小时用于P95/P99计算
WINDOW_VWAP = 30 * 60 * 1000 # 30分钟
# ATR参数
ATR_PERIOD_MS = 5 * 60 * 1000 # 5分钟K线
ATR_LENGTH = 14 # 14根
# 信号冷却
COOLDOWN_MS = 10 * 60 * 1000 # 10分钟
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def init_tables(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_indicators (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
cvd_fast REAL,
cvd_mid REAL,
cvd_day REAL,
cvd_fast_slope REAL,
atr_5m REAL,
atr_percentile REAL,
vwap_30m REAL,
price REAL,
p95_qty REAL,
p99_qty REAL,
buy_vol_1m REAL,
sell_vol_1m REAL,
score INTEGER,
signal TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts)")
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
cvd_fast REAL,
cvd_mid REAL,
cvd_day REAL,
atr_5m REAL,
vwap_30m REAL,
price REAL,
score INTEGER,
signal TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts)")
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts_open INTEGER NOT NULL,
ts_close INTEGER,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
entry_price REAL,
exit_price REAL,
qty REAL,
score INTEGER,
pnl REAL,
sl_price REAL,
tp1_price REAL,
tp2_price REAL,
status TEXT DEFAULT 'open'
)
""")
conn.commit()
# ─── 滚动窗口 ───────────────────────────────────────────────────
class TradeWindow:
"""滚动时间窗口,维护买卖量和价格数据"""
def __init__(self, window_ms: int):
self.window_ms = window_ms
self.trades: deque = deque() # (time_ms, qty, price, is_buyer_maker)
self.buy_vol = 0.0
self.sell_vol = 0.0
self.pq_sum = 0.0 # price * qty 累加VWAP用
self.q_sum = 0.0 # qty累加
self.quantities: list = [] # 用于P95/P99仅24h窗口用
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
self.trades.append((time_ms, qty, price, is_buyer_maker))
pq = price * qty
self.pq_sum += pq
self.q_sum += qty
if is_buyer_maker == 0:
self.buy_vol += qty
else:
self.sell_vol += qty
def trim(self, now_ms: int):
cutoff = now_ms - self.window_ms
while self.trades and self.trades[0][0] < cutoff:
t_ms, qty, price, ibm = self.trades.popleft()
pq = price * qty
self.pq_sum -= pq
self.q_sum -= qty
if ibm == 0:
self.buy_vol -= qty
else:
self.sell_vol -= qty
@property
def cvd(self) -> float:
return self.buy_vol - self.sell_vol
@property
def vwap(self) -> float:
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
@property
def total_vol(self) -> float:
return self.buy_vol + self.sell_vol
class ATRCalculator:
"""5分钟K线ATR计算"""
def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH):
self.period_ms = period_ms
self.length = length
self.candles: deque = deque(maxlen=length + 1)
self.current_candle: Optional[dict] = None
self.atr_history: deque = deque(maxlen=288) # 24h of 5m candles for percentile
def update(self, time_ms: int, price: float):
bar_ms = (time_ms // self.period_ms) * self.period_ms
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
if self.current_candle is not None:
self.candles.append(self.current_candle)
self.current_candle = {
"bar": bar_ms, "open": price, "high": price, "low": price, "close": price
}
else:
c = self.current_candle
c["high"] = max(c["high"], price)
c["low"] = min(c["low"], price)
c["close"] = price
@property
def atr(self) -> float:
if len(self.candles) < 2:
return 0.0
trs = []
candles_list = list(self.candles)
for i in range(1, len(candles_list)):
prev_close = candles_list[i-1]["close"]
c = candles_list[i]
tr = max(c["high"] - c["low"], abs(c["high"] - prev_close), abs(c["low"] - prev_close))
trs.append(tr)
if not trs:
return 0.0
# EMA-style ATR
atr_val = trs[0]
for tr in trs[1:]:
atr_val = (atr_val * (self.length - 1) + tr) / self.length
return atr_val
@property
def atr_percentile(self) -> float:
"""当前ATR在最近24h中的分位数"""
current = self.atr
if current == 0:
return 50.0
self.atr_history.append(current)
if len(self.atr_history) < 10:
return 50.0
sorted_hist = sorted(self.atr_history)
rank = sum(1 for x in sorted_hist if x <= current)
return (rank / len(sorted_hist)) * 100
class SymbolState:
"""单个交易对的完整状态"""
def __init__(self, symbol: str):
self.symbol = symbol
self.win_fast = TradeWindow(WINDOW_FAST)
self.win_mid = TradeWindow(WINDOW_MID)
self.win_day = TradeWindow(WINDOW_DAY)
self.win_vwap = TradeWindow(WINDOW_VWAP)
self.atr_calc = ATRCalculator()
self.last_processed_id = 0
self.warmup = True
self.warmup_until = 0
self.prev_cvd_fast = 0.0
self.last_signal_ts = 0
self.last_signal_dir = ""
# P99大单追踪最近15分钟
self.recent_large_trades: deque = deque() # (time_ms, qty, is_buyer_maker)
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
now_ms = time_ms
self.win_fast.add(time_ms, qty, price, is_buyer_maker)
self.win_mid.add(time_ms, qty, price, is_buyer_maker)
self.win_day.add(time_ms, qty, price, is_buyer_maker)
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
self.atr_calc.update(time_ms, price)
self.win_fast.trim(now_ms)
self.win_mid.trim(now_ms)
self.win_day.trim(now_ms)
self.win_vwap.trim(now_ms)
self.last_processed_id = agg_id
def compute_p95_p99(self) -> tuple[float, float]:
"""从24h窗口计算大单阈值"""
if len(self.win_day.trades) < 100:
return 5.0, 10.0 # 默认兜底
qtys = [t[1] for t in self.win_day.trades]
qtys.sort()
n = len(qtys)
p95 = qtys[int(n * 0.95)]
p99 = qtys[int(n * 0.99)]
# BTC兜底5ETH兜底50
if "BTC" in self.symbol:
p95 = max(p95, 5.0)
p99 = max(p99, 10.0)
else:
p95 = max(p95, 50.0)
p99 = max(p99, 100.0)
return p95, p99
def update_large_trades(self, now_ms: int, p99: float):
"""更新最近15分钟的P99大单记录"""
cutoff = now_ms - 15 * 60 * 1000
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
self.recent_large_trades.popleft()
# 从最近处理的trades中找大单
for t in self.win_fast.trades:
if t[1] >= p99 and t[0] > cutoff:
self.recent_large_trades.append((t[0], t[1], t[3]))
def evaluate_signal(self, now_ms: int) -> dict:
"""评估信号核心3条件 + 加分3条件"""
cvd_fast = self.win_fast.cvd
cvd_mid = self.win_mid.cvd
vwap = self.win_vwap.vwap
atr = self.atr_calc.atr
atr_pct = self.atr_calc.atr_percentile
p95, p99 = self.compute_p95_p99()
self.update_large_trades(now_ms, p99)
# 当前价格用VWAP近似
price = vwap if vwap > 0 else 0
# CVD_fast斜率
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
self.prev_cvd_fast = cvd_fast
result = {
"cvd_fast": cvd_fast,
"cvd_mid": cvd_mid,
"cvd_day": self.win_day.cvd,
"cvd_fast_slope": cvd_fast_slope,
"atr": atr,
"atr_pct": atr_pct,
"vwap": vwap,
"price": price,
"p95": p95,
"p99": p99,
"signal": None,
"direction": None,
"score": 0,
}
if self.warmup or price == 0 or atr == 0:
return result
# 冷却期检查
if now_ms - self.last_signal_ts < COOLDOWN_MS:
return result
# === 核心条件 ===
# 做多
long_core = (
cvd_fast > 0 and cvd_fast_slope > 0 and # CVD_fast正且上升
cvd_mid > 0 and # CVD_mid正
price > vwap # 价格在VWAP上方
)
# 做空
short_core = (
cvd_fast < 0 and cvd_fast_slope < 0 and
cvd_mid < 0 and
price < vwap
)
if not long_core and not short_core:
return result
direction = "LONG" if long_core else "SHORT"
# === 加分条件 ===
score = 0
# 1. ATR压缩→扩张 (+25)
if atr_pct > 60:
score += 25
# 2. 无反向P99大单 (+20)
has_adverse_large = False
for lt in self.recent_large_trades:
if direction == "LONG" and lt[2] == 1: # 大卖单对多头不利
has_adverse_large = True
elif direction == "SHORT" and lt[2] == 0: # 大买单对空头不利
has_adverse_large = True
if not has_adverse_large:
score += 20
# 3. 资金费率配合 (+15) — 从rate_snapshots读取
# 暂时跳过,后续接入
# TODO: 接入资金费率条件
score += 0 # placeholder
result["signal"] = direction
result["direction"] = direction
result["score"] = score
# 更新冷却
self.last_signal_ts = now_ms
self.last_signal_dir = direction
return result
# ─── 主循环 ──────────────────────────────────────────────────────
def get_month_tables(conn: sqlite3.Connection) -> list[str]:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
).fetchall()
return [r["name"] for r in rows]
def load_historical(conn: sqlite3.Connection, state: SymbolState, window_ms: int):
"""冷启动:回灌历史数据到内存窗口"""
now_ms = int(time.time() * 1000)
start_ms = now_ms - window_ms
tables = get_month_tables(conn)
count = 0
for tname in tables:
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND time_ms >= ? ORDER BY agg_id ASC",
(state.symbol, start_ms)
).fetchall()
for r in rows:
state.process_trade(r["agg_id"], r["time_ms"], r["price"], r["qty"], r["is_buyer_maker"])
count += 1
except Exception as e:
logger.warning(f"Error loading {tname}: {e}")
logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)")
state.warmup = False
def fetch_new_trades(conn: sqlite3.Connection, symbol: str, last_id: int) -> list:
"""增量读取新aggTrades"""
tables = get_month_tables(conn)
results = []
for tname in tables[-2:]: # 只查最近两个月表
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND agg_id > ? ORDER BY agg_id ASC LIMIT 10000",
(symbol, last_id)
).fetchall()
results.extend(rows)
except Exception:
pass
return results
def save_indicator(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
conn.execute("""
INSERT INTO signal_indicators
(ts, symbol, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile,
vwap_30m, price, p95_qty, p99_qty, score, signal)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
ts, symbol,
result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
result["atr"], result["atr_pct"],
result["vwap"], result["price"],
result["p95"], result["p99"],
result["score"], result.get("signal")
))
conn.commit()
def save_indicator_1m(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
"""每分钟聚合保存"""
bar_ts = (ts // 60000) * 60000
existing = conn.execute(
"SELECT id FROM signal_indicators_1m WHERE ts = ? AND symbol = ?", (bar_ts, symbol)
).fetchone()
if existing:
conn.execute("""
UPDATE signal_indicators_1m SET
cvd_fast=?, cvd_mid=?, cvd_day=?, atr_5m=?, vwap_30m=?, price=?, score=?, signal=?
WHERE ts=? AND symbol=?
""", (
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
result["atr"], result["vwap"], result["price"],
result["score"], result.get("signal"),
bar_ts, symbol
))
else:
conn.execute("""
INSERT INTO signal_indicators_1m
(ts, symbol, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
bar_ts, symbol,
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
result["atr"], result["vwap"], result["price"],
result["score"], result.get("signal")
))
conn.commit()
def main():
conn = get_conn()
init_tables(conn)
states = {sym: SymbolState(sym) for sym in SYMBOLS}
# 冷启动回灌4h数据
for sym, state in states.items():
load_historical(conn, state, WINDOW_MID)
logger.info("=== Signal Engine 启动完成 ===")
last_1m_save = {}
cycle = 0
while True:
try:
now_ms = int(time.time() * 1000)
for sym, state in states.items():
# 增量读取新数据
new_trades = fetch_new_trades(conn, sym, state.last_processed_id)
for t in new_trades:
state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"])
# 评估信号
result = state.evaluate_signal(now_ms)
# 保存5秒指标
save_indicator(conn, now_ms, sym, result)
# 每分钟保存聚合
bar_1m = (now_ms // 60000) * 60000
if last_1m_save.get(sym) != bar_1m:
save_indicator_1m(conn, now_ms, sym, result)
last_1m_save[sym] = bar_1m
# 有信号则记录
if result.get("signal"):
logger.info(
f"[{sym}] 🚨 信号: {result['signal']} "
f"score={result['score']} price={result['price']:.1f} "
f"CVD_fast={result['cvd_fast']:.1f} CVD_mid={result['cvd_mid']:.1f}"
)
# TODO: Discord推送
# TODO: 写入signal_trades
cycle += 1
if cycle % 60 == 0: # 每5分钟打一次状态
for sym, state in states.items():
r = state.evaluate_signal(now_ms)
logger.info(
f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} "
f"ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f} "
f"P95={r['p95']:.4f} P99={r['p99']:.4f}"
)
except Exception as e:
logger.error(f"循环异常: {e}", exc_info=True)
time.sleep(LOOP_INTERVAL)
if __name__ == "__main__":
main()