From 3155e8848b9024f01d79a4fb9649aac95ee17f35 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 28 Feb 2026 07:34:48 +0000 Subject: [PATCH] perf: backtest optimization - 15s eval interval + 50k batch + OHLC TP/SL check --- backend/backtest.py | 88 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 70 insertions(+), 18 deletions(-) diff --git a/backend/backtest.py b/backend/backtest.py index d007745..7968509 100644 --- a/backend/backtest.py +++ b/backend/backtest.py @@ -137,36 +137,88 @@ class Position: # ─── 回测引擎 ──────────────────────────────────────────────────── class BacktestEngine: - def __init__(self, symbol: str): + def __init__(self, symbol: str, eval_interval_ms: int = 15000): self.symbol = symbol self.state = SymbolState(symbol) - self.state.warmup = False # 回测不需要warmup标记 + self.state.warmup = False self.positions: list[Position] = [] self.closed_positions: list[Position] = [] self.cooldown_until: int = 0 - self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r) + self.equity_curve: list[tuple[int, float]] = [] self.cumulative_pnl = 0.0 self.trade_count = 0 + self.eval_interval_ms = eval_interval_ms + self.last_eval_ts: int = 0 + # 每个eval周期内追踪高低价(用于止盈止损检查) + self.period_high: float = 0.0 + self.period_low: float = float('inf') + self.period_last_price: float = 0.0 + self.period_last_ts: int = 0 def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int): - """处理单个tick""" - # 1. 更新状态 + """处理单个tick — 轻量更新状态 + 追踪高低价""" + # 1. 更新核心状态(CVD/ATR/VWAP等) self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker) - # 2. 检查持仓止盈止损 - atr = self.state.atr_calc.atr - for pos in self.positions[:]: - if pos.check_exit(time_ms, price, atr): - self.positions.remove(pos) - self.closed_positions.append(pos) - self.cumulative_pnl += pos.pnl_r - self.equity_curve.append((time_ms, self.cumulative_pnl)) + # 2. 追踪区间高低价 + if price > self.period_high: + self.period_high = price + if price < self.period_low: + self.period_low = price + self.period_last_price = price + self.period_last_ts = time_ms - # 3. 评估新信号 + # 3. 每eval_interval检查一次止盈止损 + 评估信号 + if time_ms - self.last_eval_ts >= self.eval_interval_ms: + self._evaluate(time_ms, price) + self.last_eval_ts = time_ms + self.period_high = price + self.period_low = price + + def _evaluate(self, time_ms: int, price: float): + """每15秒执行一次:检查止盈止损 + 评估新信号""" + atr = self.state.atr_calc.atr + + # 检查持仓止盈止损(用区间高低价模拟逐tick精度) + for pos in self.positions[:]: + # 用high/low检查是否触发TP/SL + exited = False + if pos.direction == "LONG": + # 先检查止损(用low) + if self.period_low <= pos.sl_price: + exited = pos.check_exit(time_ms, pos.sl_price, atr) + # 再检查止盈(用high) + elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price): + check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price + exited = pos.check_exit(time_ms, check_price, atr) + else: + # 用TP1检查(可能只触发TP1不退出) + pos.check_exit(time_ms, self.period_high, atr) + else: # SHORT + if self.period_high >= pos.sl_price: + exited = pos.check_exit(time_ms, pos.sl_price, atr) + elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price): + check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price + exited = pos.check_exit(time_ms, check_price, atr) + else: + pos.check_exit(time_ms, self.period_low, atr) + + # 时间止损检查 + if not exited and pos.status in ("active", "tp1_hit"): + exited = pos.check_exit(time_ms, price, atr) + + if exited or pos.status not in ("active", "tp1_hit"): + if pos in self.positions: + self.positions.remove(pos) + self.closed_positions.append(pos) + self.cumulative_pnl += pos.pnl_r + self.equity_curve.append((time_ms, self.cumulative_pnl)) + + # 评估新信号 if time_ms < self.cooldown_until: return if len(self.positions) > 0: - return # 不同时持多个仓位 + return result = self.state.evaluate_signal(time_ms) signal = result.get("signal") @@ -356,7 +408,7 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO warmup_count = 0 with conn.cursor("warmup_cursor") as cur: - cur.itersize = 10000 + cur.itersize = 50000 cur.execute( "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC", @@ -370,13 +422,13 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO # 回测主循环 total_count = load_trades(symbol, start_ms, end_ms) - logger.info(f"开始回测: {symbol}, {total_count:,} 条tick") + logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms") processed = 0 last_log = time.time() with conn.cursor("backtest_cursor") as cur: - cur.itersize = 10000 + cur.itersize = 50000 cur.execute( "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",