perf: backtest optimization - 15s eval interval + 50k batch + OHLC TP/SL check

This commit is contained in:
root 2026-02-28 07:34:48 +00:00
parent ec6a8fc64d
commit 3155e8848b

View File

@ -137,36 +137,88 @@ class Position:
# ─── 回测引擎 ────────────────────────────────────────────────────
class BacktestEngine:
def __init__(self, symbol: str):
def __init__(self, symbol: str, eval_interval_ms: int = 15000):
self.symbol = symbol
self.state = SymbolState(symbol)
self.state.warmup = False # 回测不需要warmup标记
self.state.warmup = False
self.positions: list[Position] = []
self.closed_positions: list[Position] = []
self.cooldown_until: int = 0
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r)
self.equity_curve: list[tuple[int, float]] = []
self.cumulative_pnl = 0.0
self.trade_count = 0
self.eval_interval_ms = eval_interval_ms
self.last_eval_ts: int = 0
# 每个eval周期内追踪高低价用于止盈止损检查
self.period_high: float = 0.0
self.period_low: float = float('inf')
self.period_last_price: float = 0.0
self.period_last_ts: int = 0
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
"""处理单个tick"""
# 1. 更新状态
"""处理单个tick — 轻量更新状态 + 追踪高低价"""
# 1. 更新核心状态CVD/ATR/VWAP等
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
# 2. 检查持仓止盈止损
atr = self.state.atr_calc.atr
for pos in self.positions[:]:
if pos.check_exit(time_ms, price, atr):
self.positions.remove(pos)
self.closed_positions.append(pos)
self.cumulative_pnl += pos.pnl_r
self.equity_curve.append((time_ms, self.cumulative_pnl))
# 2. 追踪区间高低价
if price > self.period_high:
self.period_high = price
if price < self.period_low:
self.period_low = price
self.period_last_price = price
self.period_last_ts = time_ms
# 3. 评估新信号
# 3. 每eval_interval检查一次止盈止损 + 评估信号
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
self._evaluate(time_ms, price)
self.last_eval_ts = time_ms
self.period_high = price
self.period_low = price
def _evaluate(self, time_ms: int, price: float):
"""每15秒执行一次检查止盈止损 + 评估新信号"""
atr = self.state.atr_calc.atr
# 检查持仓止盈止损用区间高低价模拟逐tick精度
for pos in self.positions[:]:
# 用high/low检查是否触发TP/SL
exited = False
if pos.direction == "LONG":
# 先检查止损用low
if self.period_low <= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
# 再检查止盈用high
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
# 用TP1检查可能只触发TP1不退出
pos.check_exit(time_ms, self.period_high, atr)
else: # SHORT
if self.period_high >= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
pos.check_exit(time_ms, self.period_low, atr)
# 时间止损检查
if not exited and pos.status in ("active", "tp1_hit"):
exited = pos.check_exit(time_ms, price, atr)
if exited or pos.status not in ("active", "tp1_hit"):
if pos in self.positions:
self.positions.remove(pos)
self.closed_positions.append(pos)
self.cumulative_pnl += pos.pnl_r
self.equity_curve.append((time_ms, self.cumulative_pnl))
# 评估新信号
if time_ms < self.cooldown_until:
return
if len(self.positions) > 0:
return # 不同时持多个仓位
return
result = self.state.evaluate_signal(time_ms)
signal = result.get("signal")
@ -356,7 +408,7 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
warmup_count = 0
with conn.cursor("warmup_cursor") as cur:
cur.itersize = 10000
cur.itersize = 50000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
@ -370,13 +422,13 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
# 回测主循环
total_count = load_trades(symbol, start_ms, end_ms)
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick")
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
processed = 0
last_log = time.time()
with conn.cursor("backtest_cursor") as cur:
cur.itersize = 10000
cur.itersize = 50000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",