arbitrage-engine/backend/signal_state.py

299 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
signal_state.py — CVD/ATR 滚动窗口与 SymbolState 抽象
从原来的 signal_engine.py 中拆分出的纯数据结构与计算逻辑:
- TradeWindow逐笔成交滚动窗口负责 CVD/VWAP 计算;
- ATRCalculator按固定周期聚合 K 线并计算 ATR/百分位;
- get_max_fr资金费率历史最大值缓存
- SymbolState单币种的内存状态窗口、指标快照、巨鲸数据等
这些类保持与原实现完全一致,只是搬迁到独立模块,便于维护与测试。
"""
import time
from collections import deque
from typing import Any, Optional
from db import get_sync_conn
def to_float(value: Any) -> Optional[float]:
try:
return float(value) if value is not None else None
except (TypeError, ValueError):
return None
class TradeWindow:
def __init__(self, window_ms: int):
self.window_ms = window_ms
self.trades: deque = deque()
self.buy_vol = 0.0
self.sell_vol = 0.0
self.pq_sum = 0.0
self.q_sum = 0.0
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
self.trades.append((time_ms, qty, price, is_buyer_maker))
pq = price * qty
self.pq_sum += pq
self.q_sum += qty
if is_buyer_maker == 0:
self.buy_vol += qty
else:
self.sell_vol += qty
def trim(self, now_ms: int):
cutoff = now_ms - self.window_ms
while self.trades and self.trades[0][0] < cutoff:
t_ms, qty, price, ibm = self.trades.popleft()
self.pq_sum -= price * qty
self.q_sum -= qty
if ibm == 0:
self.buy_vol -= qty
else:
self.sell_vol -= qty
@property
def cvd(self) -> float:
return self.buy_vol - self.sell_vol
@property
def vwap(self) -> float:
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
class ATRCalculator:
def __init__(self, period_ms: int, length: int):
self.period_ms = period_ms
self.length = length
self.candles: deque = deque(maxlen=length + 1)
self.current_candle: Optional[dict] = None
self.atr_history: deque = deque(maxlen=288)
def update(self, time_ms: int, price: float):
bar_ms = (time_ms // self.period_ms) * self.period_ms
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
if self.current_candle is not None:
self.candles.append(self.current_candle)
self.current_candle = {
"bar": bar_ms,
"open": price,
"high": price,
"low": price,
"close": price,
}
else:
c = self.current_candle
c["high"] = max(c["high"], price)
c["low"] = min(c["low"], price)
c["close"] = price
@property
def atr(self) -> float:
if len(self.candles) < 2:
return 0.0
trs = []
candles_list = list(self.candles)
for i in range(1, len(candles_list)):
prev_close = candles_list[i - 1]["close"]
c = candles_list[i]
tr = max(
c["high"] - c["low"],
abs(c["high"] - prev_close),
abs(c["low"] - prev_close),
)
trs.append(tr)
if not trs:
return 0.0
atr_val = trs[0]
for tr in trs[1:]:
atr_val = (atr_val * (self.length - 1) + tr) / self.length
return atr_val
@property
def atr_percentile(self) -> float:
current = self.atr
if current == 0:
return 50.0
self.atr_history.append(current)
if len(self.atr_history) < 10:
return 50.0
sorted_hist = sorted(self.atr_history)
rank = sum(1 for x in sorted_hist if x <= current)
return (rank / len(sorted_hist)) * 100
# ─── FR 历史最大值缓存(每小时更新)───────────────────────────────
_max_fr_cache: dict = {} # {symbol: max_abs_fr}
_max_fr_updated: float = 0
def get_max_fr(symbol: str) -> float:
"""获取该币种历史最大|FR|,每小时刷新一次"""
global _max_fr_cache, _max_fr_updated
now = time.time()
if now - _max_fr_updated > 3600 or symbol not in _max_fr_cache:
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT symbol, MAX(ABS((value->>'fundingRate')::float)) as max_fr "
"FROM market_indicators WHERE indicator_type='funding_rate' "
"GROUP BY symbol"
)
for row in cur.fetchall():
_max_fr_cache[row[0]] = row[1] if row[1] else 0.0001
_max_fr_updated = now
except Exception:
# 读取失败时保持旧缓存,返回一个小的默认值防除零
pass
return _max_fr_cache.get(symbol, 0.0001) # 默认0.01%防除零
class SymbolState:
def __init__(
self,
symbol: str,
window_fast_ms: int,
window_mid_ms: int,
window_day_ms: int,
window_vwap_ms: int,
atr_period_ms: int,
atr_length: int,
fetch_market_indicators_fn,
):
self.symbol = symbol
self.win_fast = TradeWindow(window_fast_ms)
self.win_mid = TradeWindow(window_mid_ms)
self.win_day = TradeWindow(window_day_ms)
self.win_vwap = TradeWindow(window_vwap_ms)
self.atr_calc = ATRCalculator(atr_period_ms, atr_length)
self.last_processed_id = 0
self.last_trade_price = 0.0
self.warmup = True
self.prev_cvd_fast = 0.0
self.prev_cvd_fast_slope = 0.0
self.prev_oi_value = 0.0
# 从外部函数获取最新 market_indicators
self.market_indicators = fetch_market_indicators_fn(symbol)
self.last_signal_ts: dict[str, int] = {}
self.last_signal_dir: dict[str, str] = {}
self.recent_large_trades: deque = deque()
# ── Phase 2 实时内存字段(由后台 WebSocket 协程更新)──────────
self.rt_obi: float = 0.0
self.rt_spot_perp_div: float = 0.0
# tiered_cvd_whale按成交额分档实时累计最近15分钟窗口
self._whale_trades: deque = deque() # (time_ms, usd_val, is_sell)
self.WHALE_WINDOW_MS: int = 15 * 60 * 1000 # 15 分钟
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
now_ms = time_ms
self.win_fast.add(time_ms, qty, price, is_buyer_maker)
self.win_mid.add(time_ms, qty, price, is_buyer_maker)
self.win_day.add(time_ms, qty, price, is_buyer_maker)
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
self.atr_calc.update(time_ms, price)
self.win_fast.trim(now_ms)
self.win_mid.trim(now_ms)
self.win_day.trim(now_ms)
self.win_vwap.trim(now_ms)
self.last_processed_id = agg_id
self.last_trade_price = price # 最新成交价,用于 entry_price
# tiered_cvd_whale 实时累计(>$100k 为巨鲸)
usd_val = price * qty
if usd_val >= 100_000:
self._whale_trades.append((time_ms, usd_val, bool(is_buyer_maker)))
# 修剪 15 分钟窗口
cutoff = now_ms - self.WHALE_WINDOW_MS
while self._whale_trades and self._whale_trades[0][0] < cutoff:
self._whale_trades.popleft()
@property
def whale_cvd_ratio(self) -> float:
"""巨鲸净 CVD 比率[-1,1],基于最近 15 分钟 >$100k 成交"""
buy_usd = sum(t[1] for t in self._whale_trades if not t[2])
sell_usd = sum(t[1] for t in self._whale_trades if t[2])
total = buy_usd + sell_usd
return (buy_usd - sell_usd) / total if total > 0 else 0.0
def compute_p95_p99(self) -> tuple:
if len(self.win_day.trades) < 100:
return 5.0, 10.0
qtys = sorted([t[1] for t in self.win_day.trades])
n = len(qtys)
p95 = qtys[int(n * 0.95)]
p99 = qtys[int(n * 0.99)]
if "BTC" in self.symbol:
p95 = max(p95, 5.0)
p99 = max(p99, 10.0)
else:
p95 = max(p95, 50.0)
p99 = max(p99, 100.0)
return p95, p99
def update_large_trades(self, now_ms: int, p99: float):
cutoff = now_ms - 15 * 60 * 1000
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
self.recent_large_trades.popleft()
# 只检查新 trade避免重复添加
seen = set(t[0] for t in self.recent_large_trades) # time_ms 作为去重 key
for t in self.win_fast.trades:
if t[1] >= p99 and t[0] > cutoff and t[0] not in seen:
self.recent_large_trades.append((t[0], t[1], t[3]))
seen.add(t[0])
def build_evaluation_snapshot(self, now_ms: int) -> dict:
cvd_fast = self.win_fast.cvd
cvd_mid = self.win_mid.cvd
cvd_day = self.win_day.cvd
vwap = self.win_vwap.vwap
atr = self.atr_calc.atr
atr_pct = self.atr_calc.atr_percentile
p95, p99 = self.compute_p95_p99()
self.update_large_trades(now_ms, p99)
price = self.last_trade_price if self.last_trade_price > 0 else vwap # 用最新成交价,非 VWAP
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
cvd_fast_accel = cvd_fast_slope - self.prev_cvd_fast_slope
self.prev_cvd_fast = cvd_fast
self.prev_cvd_fast_slope = cvd_fast_slope
oi_value = to_float(self.market_indicators.get("open_interest_hist"))
if oi_value is None or self.prev_oi_value == 0:
oi_change = 0.0
environment_score = 10
else:
oi_change = (
(oi_value - self.prev_oi_value) / self.prev_oi_value
if self.prev_oi_value > 0
else 0.0
)
if oi_change >= 0.03:
environment_score = 15
elif oi_change > 0:
environment_score = 10
else:
environment_score = 5
if oi_value is not None and oi_value > 0:
self.prev_oi_value = oi_value
return {
"cvd_fast": cvd_fast,
"cvd_mid": cvd_mid,
"cvd_day": cvd_day,
"vwap": vwap,
"atr": atr,
"atr_value": atr,
"atr_pct": atr_pct,
"p95": p95,
"p99": p99,
"price": price,
"cvd_fast_slope": cvd_fast_slope,
"cvd_fast_accel": cvd_fast_accel,
"oi_change": oi_change,
"environment_score": environment_score,
"oi_value": oi_value,
}