""" backtest.py — V5.1 回测框架(逐tick事件回放) 架构: - 从PG读取历史agg_trades,按时间顺序逐tick回放 - 复用signal_engine的SymbolState评估逻辑 - 模拟开仓/平仓/止盈止损 - 输出:胜率/盈亏比/夏普/MDD等统计 用法: python3 backtest.py --symbol BTCUSDT --days 20 python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28 """ import argparse import logging import os import sys import time from collections import deque from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Optional import psycopg2 # 复用 signal_engine/signal_state 的核心类与评分逻辑 sys.path.insert(0, os.path.dirname(__file__)) from signal_engine import SymbolState, WINDOW_MID from strategy_scoring import evaluate_signal as score_strategy logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", ) logger = logging.getLogger("backtest") PG_HOST = os.getenv("PG_HOST", "10.106.0.3") PG_PORT = int(os.getenv("PG_PORT", "5432")) PG_DB = os.getenv("PG_DB", "arb_engine") PG_USER = os.getenv("PG_USER", "arb") PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") # ─── 仓位管理 ──────────────────────────────────────────────────── @dataclass class Position: """单笔仓位""" entry_ts: int entry_price: float direction: str # "LONG" or "SHORT" size: float # 仓位大小(R单位) tier: str # "light", "standard", "heavy" score: int sl_price: float tp1_price: float tp2_price: float tp1_hit: bool = False status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout exit_ts: int = 0 exit_price: float = 0.0 pnl_r: float = 0.0 def check_exit(self, time_ms: int, price: float, atr: float) -> bool: """检查是否触发止盈止损,返回True表示已平仓""" if self.status != "active" and self.status != "tp1_hit": return True if self.direction == "LONG": # 止损 if price <= self.sl_price: self.exit_ts = time_ms self.exit_price = price if self.tp1_hit: self.status = "sl_be" self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本 else: self.status = "sl" self.pnl_r = -1.0 return True # TP1 if not self.tp1_hit and price >= self.tp1_price: self.tp1_hit = True self.status = "tp1_hit" # 移动止损到成本价 self.sl_price = self.entry_price * 1.0005 # 加手续费 # TP2 if self.tp1_hit and price >= self.tp2_price: self.exit_ts = time_ms self.exit_price = price self.status = "tp" self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50% return True else: # SHORT # 止损 if price >= self.sl_price: self.exit_ts = time_ms self.exit_price = price if self.tp1_hit: self.status = "sl_be" self.pnl_r = 0.5 * 1.5 else: self.status = "sl" self.pnl_r = -1.0 return True # TP1 if not self.tp1_hit and price <= self.tp1_price: self.tp1_hit = True self.status = "tp1_hit" self.sl_price = self.entry_price * 0.9995 # TP2 if self.tp1_hit and price <= self.tp2_price: self.exit_ts = time_ms self.exit_price = price self.status = "tp" self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 return True # 时间止损:持仓超过60分钟强平 if time_ms - self.entry_ts > 60 * 60 * 1000: self.exit_ts = time_ms self.exit_price = price self.status = "timeout" if self.direction == "LONG": move = (price - self.entry_price) / self.entry_price else: move = (self.entry_price - price) / self.entry_price risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0 if self.tp1_hit: self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益 return True return False # ─── 回测引擎 ──────────────────────────────────────────────────── class BacktestEngine: def __init__(self, symbol: str, eval_interval_ms: int = 15000): self.symbol = symbol self.state = SymbolState(symbol) self.state.warmup = False self.positions: list[Position] = [] self.closed_positions: list[Position] = [] self.cooldown_until: int = 0 self.equity_curve: list[tuple[int, float]] = [] self.cumulative_pnl = 0.0 self.trade_count = 0 self.eval_interval_ms = eval_interval_ms self.last_eval_ts: int = 0 # 每个eval周期内追踪高低价(用于止盈止损检查) self.period_high: float = 0.0 self.period_low: float = float('inf') self.period_last_price: float = 0.0 self.period_last_ts: int = 0 def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int): """处理单个tick — 轻量更新状态 + 追踪高低价""" # 1. 更新核心状态(CVD/ATR/VWAP等) self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker) # 2. 追踪区间高低价 if price > self.period_high: self.period_high = price if price < self.period_low: self.period_low = price self.period_last_price = price self.period_last_ts = time_ms # 3. 每eval_interval检查一次止盈止损 + 评估信号 if time_ms - self.last_eval_ts >= self.eval_interval_ms: self._evaluate(time_ms, price) self.last_eval_ts = time_ms self.period_high = price self.period_low = price def _evaluate(self, time_ms: int, price: float): """每15秒执行一次:检查止盈止损 + 评估新信号""" atr = self.state.atr_calc.atr # 检查持仓止盈止损(用区间高低价模拟逐tick精度) for pos in self.positions[:]: # 用high/low检查是否触发TP/SL exited = False if pos.direction == "LONG": # 先检查止损(用low) if self.period_low <= pos.sl_price: exited = pos.check_exit(time_ms, pos.sl_price, atr) # 再检查止盈(用high) elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price): check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price exited = pos.check_exit(time_ms, check_price, atr) else: # 用TP1检查(可能只触发TP1不退出) pos.check_exit(time_ms, self.period_high, atr) else: # SHORT if self.period_high >= pos.sl_price: exited = pos.check_exit(time_ms, pos.sl_price, atr) elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price): check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price exited = pos.check_exit(time_ms, check_price, atr) else: pos.check_exit(time_ms, self.period_low, atr) # 时间止损检查 if not exited and pos.status in ("active", "tp1_hit"): exited = pos.check_exit(time_ms, price, atr) if exited or pos.status not in ("active", "tp1_hit"): if pos in self.positions: self.positions.remove(pos) self.closed_positions.append(pos) self.cumulative_pnl += pos.pnl_r self.equity_curve.append((time_ms, self.cumulative_pnl)) # 评估新信号 if time_ms < self.cooldown_until: return if len(self.positions) > 0: return # 使用统一评分入口(V5.1 baseline 配置) strategy_cfg = { "name": "v51_baseline", "threshold": 75, "signals": ["cvd", "p99", "accel", "ls_ratio", "oi", "coinbase_premium"], } result = score_strategy(self.state, time_ms, strategy_cfg=strategy_cfg) signal = result.get("signal") if not signal: return score = result.get("score", 0) if score < 60: return # V5.1门槛 # 确定仓位档 if score >= 85: tier = "heavy" size_mult = 1.3 elif score >= 75: tier = "standard" size_mult = 1.0 else: tier = "light" size_mult = 0.5 # 计算TP/SL risk_atr = 0.7 * self.state.atr_calc.atr # 简化版,只用5m ATR if risk_atr <= 0: return if signal == "LONG": sl = price - 2.0 * risk_atr tp1 = price + 1.5 * risk_atr tp2 = price + 3.0 * risk_atr else: sl = price + 2.0 * risk_atr tp1 = price - 1.5 * risk_atr tp2 = price - 3.0 * risk_atr pos = Position( entry_ts=time_ms, entry_price=price, direction=signal, size=size_mult, tier=tier, score=score, sl_price=sl, tp1_price=tp1, tp2_price=tp2, ) self.positions.append(pos) self.trade_count += 1 # 冷却 self.cooldown_until = time_ms + 10 * 60 * 1000 def report(self) -> dict: """生成回测统计报告""" all_trades = self.closed_positions if not all_trades: return {"error": "没有交易记录"} total = len(all_trades) wins = [t for t in all_trades if t.pnl_r > 0] losses = [t for t in all_trades if t.pnl_r <= 0] win_rate = len(wins) / total * 100 if total > 0 else 0 total_pnl = sum(t.pnl_r for t in all_trades) gross_profit = sum(t.pnl_r for t in wins) if wins else 0 gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0 profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf') # 平均盈亏 avg_win = gross_profit / len(wins) if wins else 0 avg_loss = gross_loss / len(losses) if losses else 0 avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf') # 最大回撤 (MDD) peak = 0.0 mdd = 0.0 running = 0.0 for t in all_trades: running += t.pnl_r peak = max(peak, running) dd = peak - running mdd = max(mdd, dd) # 平均持仓时间 hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0] avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0 # 夏普比率(简化版,用R回报序列) returns = [t.pnl_r for t in all_trades] if len(returns) > 1: import statistics avg_ret = statistics.mean(returns) std_ret = statistics.stdev(returns) sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化 else: sharpe = 0 # 按状态分类 status_counts = {} for t in all_trades: status_counts[t.status] = status_counts.get(t.status, 0) + 1 # 按tier分类 tier_counts = {} for t in all_trades: tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1 # LONG vs SHORT longs = [t for t in all_trades if t.direction == "LONG"] shorts = [t for t in all_trades if t.direction == "SHORT"] long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0 short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0 return { "总交易数": total, "胜率": f"{win_rate:.1f}%", "总盈亏(R)": f"{total_pnl:+.2f}R", "盈亏比(Profit Factor)": f"{profit_factor:.2f}", "平均盈利(R)": f"{avg_win:.2f}R", "平均亏损(R)": f"-{avg_loss:.2f}R", "盈亏比(Win/Loss)": f"{avg_rr:.2f}", "夏普比率": f"{sharpe:.2f}", "最大回撤(MDD)": f"{mdd:.2f}R", "平均持仓(分钟)": f"{avg_hold_min:.1f}", "做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)", "做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)", "状态分布": status_counts, "仓位档分布": tier_counts, } def print_report(self): """打印回测报告""" report = self.report() print("\n" + "=" * 60) print(f" V5.1 回测报告 — {self.symbol}") print("=" * 60) for k, v in report.items(): if isinstance(v, dict): print(f" {k}:") for kk, vv in v.items(): print(f" {kk}: {vv}") else: print(f" {k}: {v}") print("=" * 60) def print_trades(self, limit: int = 20): """打印最近的交易记录""" print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:") print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}") print("-" * 72) for t in self.closed_positions[-limit:]: hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0 pnl_str = f"{t.pnl_r:+.2f}" print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}") # ─── 数据加载 ──────────────────────────────────────────────────── def load_trades(symbol: str, start_ms: int, end_ms: int) -> int: """返回数据总条数""" conn = psycopg2.connect( host=PG_HOST, port=PG_PORT, dbname=PG_DB, user=PG_USER, password=PG_PASS ) with conn.cursor() as cur: cur.execute( "SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s", (symbol, start_ms, end_ms) ) count = cur.fetchone()[0] conn.close() return count def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID): """运行回测""" engine = BacktestEngine(symbol) conn = psycopg2.connect( host=PG_HOST, port=PG_PORT, dbname=PG_DB, user=PG_USER, password=PG_PASS ) # named cursor需要在事务内(不要设autocommit=True) # 先warmup(用回测起点前4h的数据预热指标) warmup_start = start_ms - warmup_ms logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据") warmup_count = 0 with conn.cursor("warmup_cursor") as cur: cur.itersize = 50000 cur.execute( "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC", (symbol, warmup_start, start_ms) ) for row in cur: engine.state.process_trade(row[0], row[1], row[2], row[3], row[4]) warmup_count += 1 logger.info(f"预热完成: {warmup_count:,} 条") # 回测主循环 total_count = load_trades(symbol, start_ms, end_ms) logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms") processed = 0 last_log = time.time() with conn.cursor("backtest_cursor") as cur: cur.itersize = 50000 cur.execute( "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " "WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC", (symbol, start_ms, end_ms) ) for row in cur: engine.process_tick(row[0], row[1], row[2], row[3], row[4]) processed += 1 # 每30秒打印进度 if time.time() - last_log > 30: pct = processed / total_count * 100 if total_count > 0 else 0 logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R") last_log = time.time() conn.close() # 强平所有未平仓位 for pos in engine.positions: pos.status = "timeout" pos.exit_ts = end_ms pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price engine.closed_positions.append(pos) logger.info(f"回测完成: 处理 {processed:,} 条tick") engine.print_report() engine.print_trades() return engine # ─── CLI ───────────────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser(description="V5.1 回测框架") parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)") parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)") parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)") parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)") args = parser.parse_args() if args.start and args.end: start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc) end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc) start_ms = int(start_dt.timestamp() * 1000) end_ms = int(end_dt.timestamp() * 1000) else: end_ms = int(time.time() * 1000) start_ms = end_ms - args.days * 24 * 3600 * 1000 start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M") end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M") logger.info(f"回测参数: {args.symbol} | {start_str} → {end_str} | ~{args.days}天") engine = run_backtest(args.symbol, start_ms, end_ms) # 保存结果到JSON import json report = engine.report() report["symbol"] = args.symbol report["start"] = start_str report["end"] = end_str report["trades_count"] = len(engine.closed_positions) output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(output_file, "w") as f: json.dump(report, f, indent=2, ensure_ascii=False, default=str) logger.info(f"报告已保存: {output_file}") if __name__ == "__main__": main()