perf: backtest optimization - 15s eval interval + 50k batch + OHLC TP/SL check
This commit is contained in:
parent
ec6a8fc64d
commit
3155e8848b
@ -137,36 +137,88 @@ class Position:
|
||||
# ─── 回测引擎 ────────────────────────────────────────────────────
|
||||
|
||||
class BacktestEngine:
|
||||
def __init__(self, symbol: str):
|
||||
def __init__(self, symbol: str, eval_interval_ms: int = 15000):
|
||||
self.symbol = symbol
|
||||
self.state = SymbolState(symbol)
|
||||
self.state.warmup = False # 回测不需要warmup标记
|
||||
self.state.warmup = False
|
||||
self.positions: list[Position] = []
|
||||
self.closed_positions: list[Position] = []
|
||||
self.cooldown_until: int = 0
|
||||
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r)
|
||||
self.equity_curve: list[tuple[int, float]] = []
|
||||
self.cumulative_pnl = 0.0
|
||||
self.trade_count = 0
|
||||
self.eval_interval_ms = eval_interval_ms
|
||||
self.last_eval_ts: int = 0
|
||||
# 每个eval周期内追踪高低价(用于止盈止损检查)
|
||||
self.period_high: float = 0.0
|
||||
self.period_low: float = float('inf')
|
||||
self.period_last_price: float = 0.0
|
||||
self.period_last_ts: int = 0
|
||||
|
||||
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||||
"""处理单个tick"""
|
||||
# 1. 更新状态
|
||||
"""处理单个tick — 轻量更新状态 + 追踪高低价"""
|
||||
# 1. 更新核心状态(CVD/ATR/VWAP等)
|
||||
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
|
||||
|
||||
# 2. 检查持仓止盈止损
|
||||
# 2. 追踪区间高低价
|
||||
if price > self.period_high:
|
||||
self.period_high = price
|
||||
if price < self.period_low:
|
||||
self.period_low = price
|
||||
self.period_last_price = price
|
||||
self.period_last_ts = time_ms
|
||||
|
||||
# 3. 每eval_interval检查一次止盈止损 + 评估信号
|
||||
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
|
||||
self._evaluate(time_ms, price)
|
||||
self.last_eval_ts = time_ms
|
||||
self.period_high = price
|
||||
self.period_low = price
|
||||
|
||||
def _evaluate(self, time_ms: int, price: float):
|
||||
"""每15秒执行一次:检查止盈止损 + 评估新信号"""
|
||||
atr = self.state.atr_calc.atr
|
||||
|
||||
# 检查持仓止盈止损(用区间高低价模拟逐tick精度)
|
||||
for pos in self.positions[:]:
|
||||
if pos.check_exit(time_ms, price, atr):
|
||||
# 用high/low检查是否触发TP/SL
|
||||
exited = False
|
||||
if pos.direction == "LONG":
|
||||
# 先检查止损(用low)
|
||||
if self.period_low <= pos.sl_price:
|
||||
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||||
# 再检查止盈(用high)
|
||||
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
|
||||
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
|
||||
exited = pos.check_exit(time_ms, check_price, atr)
|
||||
else:
|
||||
# 用TP1检查(可能只触发TP1不退出)
|
||||
pos.check_exit(time_ms, self.period_high, atr)
|
||||
else: # SHORT
|
||||
if self.period_high >= pos.sl_price:
|
||||
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||||
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
|
||||
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
|
||||
exited = pos.check_exit(time_ms, check_price, atr)
|
||||
else:
|
||||
pos.check_exit(time_ms, self.period_low, atr)
|
||||
|
||||
# 时间止损检查
|
||||
if not exited and pos.status in ("active", "tp1_hit"):
|
||||
exited = pos.check_exit(time_ms, price, atr)
|
||||
|
||||
if exited or pos.status not in ("active", "tp1_hit"):
|
||||
if pos in self.positions:
|
||||
self.positions.remove(pos)
|
||||
self.closed_positions.append(pos)
|
||||
self.cumulative_pnl += pos.pnl_r
|
||||
self.equity_curve.append((time_ms, self.cumulative_pnl))
|
||||
|
||||
# 3. 评估新信号
|
||||
# 评估新信号
|
||||
if time_ms < self.cooldown_until:
|
||||
return
|
||||
if len(self.positions) > 0:
|
||||
return # 不同时持多个仓位
|
||||
return
|
||||
|
||||
result = self.state.evaluate_signal(time_ms)
|
||||
signal = result.get("signal")
|
||||
@ -356,7 +408,7 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
|
||||
warmup_count = 0
|
||||
|
||||
with conn.cursor("warmup_cursor") as cur:
|
||||
cur.itersize = 10000
|
||||
cur.itersize = 50000
|
||||
cur.execute(
|
||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
|
||||
@ -370,13 +422,13 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
|
||||
|
||||
# 回测主循环
|
||||
total_count = load_trades(symbol, start_ms, end_ms)
|
||||
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick")
|
||||
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
|
||||
|
||||
processed = 0
|
||||
last_log = time.time()
|
||||
|
||||
with conn.cursor("backtest_cursor") as cur:
|
||||
cur.itersize = 10000
|
||||
cur.itersize = 50000
|
||||
cur.execute(
|
||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user