perf: backtest optimization - 15s eval interval + 50k batch + OHLC TP/SL check
This commit is contained in:
parent
ec6a8fc64d
commit
3155e8848b
@ -137,36 +137,88 @@ class Position:
|
|||||||
# ─── 回测引擎 ────────────────────────────────────────────────────
|
# ─── 回测引擎 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
class BacktestEngine:
|
class BacktestEngine:
|
||||||
def __init__(self, symbol: str):
|
def __init__(self, symbol: str, eval_interval_ms: int = 15000):
|
||||||
self.symbol = symbol
|
self.symbol = symbol
|
||||||
self.state = SymbolState(symbol)
|
self.state = SymbolState(symbol)
|
||||||
self.state.warmup = False # 回测不需要warmup标记
|
self.state.warmup = False
|
||||||
self.positions: list[Position] = []
|
self.positions: list[Position] = []
|
||||||
self.closed_positions: list[Position] = []
|
self.closed_positions: list[Position] = []
|
||||||
self.cooldown_until: int = 0
|
self.cooldown_until: int = 0
|
||||||
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r)
|
self.equity_curve: list[tuple[int, float]] = []
|
||||||
self.cumulative_pnl = 0.0
|
self.cumulative_pnl = 0.0
|
||||||
self.trade_count = 0
|
self.trade_count = 0
|
||||||
|
self.eval_interval_ms = eval_interval_ms
|
||||||
|
self.last_eval_ts: int = 0
|
||||||
|
# 每个eval周期内追踪高低价(用于止盈止损检查)
|
||||||
|
self.period_high: float = 0.0
|
||||||
|
self.period_low: float = float('inf')
|
||||||
|
self.period_last_price: float = 0.0
|
||||||
|
self.period_last_ts: int = 0
|
||||||
|
|
||||||
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||||||
"""处理单个tick"""
|
"""处理单个tick — 轻量更新状态 + 追踪高低价"""
|
||||||
# 1. 更新状态
|
# 1. 更新核心状态(CVD/ATR/VWAP等)
|
||||||
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
|
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
|
||||||
|
|
||||||
# 2. 检查持仓止盈止损
|
# 2. 追踪区间高低价
|
||||||
|
if price > self.period_high:
|
||||||
|
self.period_high = price
|
||||||
|
if price < self.period_low:
|
||||||
|
self.period_low = price
|
||||||
|
self.period_last_price = price
|
||||||
|
self.period_last_ts = time_ms
|
||||||
|
|
||||||
|
# 3. 每eval_interval检查一次止盈止损 + 评估信号
|
||||||
|
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
|
||||||
|
self._evaluate(time_ms, price)
|
||||||
|
self.last_eval_ts = time_ms
|
||||||
|
self.period_high = price
|
||||||
|
self.period_low = price
|
||||||
|
|
||||||
|
def _evaluate(self, time_ms: int, price: float):
|
||||||
|
"""每15秒执行一次:检查止盈止损 + 评估新信号"""
|
||||||
atr = self.state.atr_calc.atr
|
atr = self.state.atr_calc.atr
|
||||||
|
|
||||||
|
# 检查持仓止盈止损(用区间高低价模拟逐tick精度)
|
||||||
for pos in self.positions[:]:
|
for pos in self.positions[:]:
|
||||||
if pos.check_exit(time_ms, price, atr):
|
# 用high/low检查是否触发TP/SL
|
||||||
|
exited = False
|
||||||
|
if pos.direction == "LONG":
|
||||||
|
# 先检查止损(用low)
|
||||||
|
if self.period_low <= pos.sl_price:
|
||||||
|
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||||||
|
# 再检查止盈(用high)
|
||||||
|
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
|
||||||
|
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
|
||||||
|
exited = pos.check_exit(time_ms, check_price, atr)
|
||||||
|
else:
|
||||||
|
# 用TP1检查(可能只触发TP1不退出)
|
||||||
|
pos.check_exit(time_ms, self.period_high, atr)
|
||||||
|
else: # SHORT
|
||||||
|
if self.period_high >= pos.sl_price:
|
||||||
|
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||||||
|
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
|
||||||
|
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
|
||||||
|
exited = pos.check_exit(time_ms, check_price, atr)
|
||||||
|
else:
|
||||||
|
pos.check_exit(time_ms, self.period_low, atr)
|
||||||
|
|
||||||
|
# 时间止损检查
|
||||||
|
if not exited and pos.status in ("active", "tp1_hit"):
|
||||||
|
exited = pos.check_exit(time_ms, price, atr)
|
||||||
|
|
||||||
|
if exited or pos.status not in ("active", "tp1_hit"):
|
||||||
|
if pos in self.positions:
|
||||||
self.positions.remove(pos)
|
self.positions.remove(pos)
|
||||||
self.closed_positions.append(pos)
|
self.closed_positions.append(pos)
|
||||||
self.cumulative_pnl += pos.pnl_r
|
self.cumulative_pnl += pos.pnl_r
|
||||||
self.equity_curve.append((time_ms, self.cumulative_pnl))
|
self.equity_curve.append((time_ms, self.cumulative_pnl))
|
||||||
|
|
||||||
# 3. 评估新信号
|
# 评估新信号
|
||||||
if time_ms < self.cooldown_until:
|
if time_ms < self.cooldown_until:
|
||||||
return
|
return
|
||||||
if len(self.positions) > 0:
|
if len(self.positions) > 0:
|
||||||
return # 不同时持多个仓位
|
return
|
||||||
|
|
||||||
result = self.state.evaluate_signal(time_ms)
|
result = self.state.evaluate_signal(time_ms)
|
||||||
signal = result.get("signal")
|
signal = result.get("signal")
|
||||||
@ -356,7 +408,7 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
|
|||||||
warmup_count = 0
|
warmup_count = 0
|
||||||
|
|
||||||
with conn.cursor("warmup_cursor") as cur:
|
with conn.cursor("warmup_cursor") as cur:
|
||||||
cur.itersize = 10000
|
cur.itersize = 50000
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
|
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
|
||||||
@ -370,13 +422,13 @@ def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDO
|
|||||||
|
|
||||||
# 回测主循环
|
# 回测主循环
|
||||||
total_count = load_trades(symbol, start_ms, end_ms)
|
total_count = load_trades(symbol, start_ms, end_ms)
|
||||||
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick")
|
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
|
||||||
|
|
||||||
processed = 0
|
processed = 0
|
||||||
last_log = time.time()
|
last_log = time.time()
|
||||||
|
|
||||||
with conn.cursor("backtest_cursor") as cur:
|
with conn.cursor("backtest_cursor") as cur:
|
||||||
cur.itersize = 10000
|
cur.itersize = 50000
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
|
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user