perf: backtest optimization - 15s eval interval + 50k batch + OHLC TP/SL check

This commit is contained in:
root 2026-02-28 07:34:48 +00:00
parent ec6a8fc64d
commit 3155e8848b

View File

@ -137,36 +137,88 @@ class Position:
# ─── 回测引擎 ──────────────────────────────────────────────────── # ─── 回测引擎 ────────────────────────────────────────────────────
class BacktestEngine: class BacktestEngine:
def __init__(self, symbol: str): def __init__(self, symbol: str, eval_interval_ms: int = 15000):
self.symbol = symbol self.symbol = symbol
self.state = SymbolState(symbol) self.state = SymbolState(symbol)
self.state.warmup = False # 回测不需要warmup标记 self.state.warmup = False
self.positions: list[Position] = [] self.positions: list[Position] = []
self.closed_positions: list[Position] = [] self.closed_positions: list[Position] = []
self.cooldown_until: int = 0 self.cooldown_until: int = 0
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r) self.equity_curve: list[tuple[int, float]] = []
self.cumulative_pnl = 0.0 self.cumulative_pnl = 0.0
self.trade_count = 0 self.trade_count = 0
self.eval_interval_ms = eval_interval_ms
self.last_eval_ts: int = 0
# 每个eval周期内追踪高低价用于止盈止损检查
self.period_high: float = 0.0
self.period_low: float = float('inf')
self.period_last_price: float = 0.0
self.period_last_ts: int = 0
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int): def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
"""处理单个tick""" """处理单个tick — 轻量更新状态 + 追踪高低价"""
# 1. 更新状态 # 1. 更新核心状态CVD/ATR/VWAP等
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker) self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
# 2. 检查持仓止盈止损 # 2. 追踪区间高低价
if price > self.period_high:
self.period_high = price
if price < self.period_low:
self.period_low = price
self.period_last_price = price
self.period_last_ts = time_ms
# 3. 每eval_interval检查一次止盈止损 + 评估信号
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
self._evaluate(time_ms, price)
self.last_eval_ts = time_ms
self.period_high = price
self.period_low = price
def _evaluate(self, time_ms: int, price: float):
"""每15秒执行一次检查止盈止损 + 评估新信号"""
atr = self.state.atr_calc.atr atr = self.state.atr_calc.atr
# 检查持仓止盈止损用区间高低价模拟逐tick精度
for pos in self.positions[:]: for pos in self.positions[:]:
if pos.check_exit(time_ms, price, atr): # 用high/low检查是否触发TP/SL
exited = False
if pos.direction == "LONG":
# 先检查止损用low
if self.period_low <= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
# 再检查止盈用high
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
# 用TP1检查可能只触发TP1不退出
pos.check_exit(time_ms, self.period_high, atr)
else: # SHORT
if self.period_high >= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
pos.check_exit(time_ms, self.period_low, atr)
# 时间止损检查
if not exited and pos.status in ("active", "tp1_hit"):
exited = pos.check_exit(time_ms, price, atr)
if exited or pos.status not in ("active", "tp1_hit"):
if pos in self.positions:
self.positions.remove(pos) self.positions.remove(pos)
self.closed_positions.append(pos) self.closed_positions.append(pos)
self.cumulative_pnl += pos.pnl_r self.cumulative_pnl += pos.pnl_r
self.equity_curve.append((time_ms, self.cumulative_pnl)) self.equity_curve.append((time_ms, self.cumulative_pnl))
# 3. 评估新信号 # 评估新信号
if time_ms < self.cooldown_until: if time_ms < self.cooldown_until:
return return
if len(self.positions) > 0: if len(self.positions) > 0:
return # 不同时持多个仓位 return
result = self.state.evaluate_signal(time_ms) result = self.state.evaluate_signal(time_ms)
signal = result.get("signal") signal = result.get("signal")
@ -356,7 +408,7 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
warmup_count = 0 warmup_count = 0
with conn.cursor("warmup_cursor") as cur: with conn.cursor("warmup_cursor") as cur:
cur.itersize = 10000 cur.itersize = 50000
cur.execute( cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC", "WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
@ -370,13 +422,13 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
# 回测主循环 # 回测主循环
total_count = load_trades(symbol, start_ms, end_ms) total_count = load_trades(symbol, start_ms, end_ms)
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick") logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
processed = 0 processed = 0
last_log = time.time() last_log = time.time()
with conn.cursor("backtest_cursor") as cur: with conn.cursor("backtest_cursor") as cur:
cur.itersize = 10000 cur.itersize = 50000
cur.execute( cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC", "WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",