From 5ba4c7fe9819d418d26a554ee3f721bb06ccc8d8 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 28 Feb 2026 05:53:42 +0000 Subject: [PATCH] feat: V5.1 backtest framework - tick-by-tick replay with TP/SL/position management --- backend/backtest.py | 451 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 451 insertions(+) create mode 100644 backend/backtest.py diff --git a/backend/backtest.py b/backend/backtest.py new file mode 100644 index 0000000..e094d30 --- /dev/null +++ b/backend/backtest.py @@ -0,0 +1,451 @@ +""" +backtest.py — V5.1 回测框架(逐tick事件回放) + +架构: + - 从PG读取历史agg_trades,按时间顺序逐tick回放 + - 复用signal_engine的SymbolState评估逻辑 + - 模拟开仓/平仓/止盈止损 + - 输出:胜率/盈亏比/夏普/MDD等统计 + +用法: + python3 backtest.py --symbol BTCUSDT --days 20 + python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28 +""" + +import argparse +import logging +import os +import sys +import time +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +import psycopg2 + +# 复用signal_engine的核心类 +sys.path.insert(0, os.path.dirname(__file__)) +from signal_engine import SymbolState, WINDOW_MID + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", +) +logger = logging.getLogger("backtest") + +PG_HOST = os.getenv("PG_HOST", "127.0.0.1") +PG_PORT = int(os.getenv("PG_PORT", "5432")) +PG_DB = os.getenv("PG_DB", "arb_engine") +PG_USER = os.getenv("PG_USER", "arb") +PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") + + +# ─── 仓位管理 ──────────────────────────────────────────────────── + +@dataclass +class Position: + """单笔仓位""" + entry_ts: int + entry_price: float + direction: str # "LONG" or "SHORT" + size: float # 仓位大小(R单位) + tier: str # "light", "standard", "heavy" + score: int + sl_price: float + tp1_price: float + tp2_price: float + tp1_hit: bool = False + status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout + exit_ts: int = 0 + exit_price: float = 0.0 + pnl_r: float = 0.0 + + def check_exit(self, time_ms: int, price: float, atr: float) -> bool: + """检查是否触发止盈止损,返回True表示已平仓""" + if self.status != "active" and self.status != "tp1_hit": + return True + + if self.direction == "LONG": + # 止损 + if price <= self.sl_price: + self.exit_ts = time_ms + self.exit_price = price + if self.tp1_hit: + self.status = "sl_be" + self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本 + else: + self.status = "sl" + self.pnl_r = -1.0 + return True + # TP1 + if not self.tp1_hit and price >= self.tp1_price: + self.tp1_hit = True + self.status = "tp1_hit" + # 移动止损到成本价 + self.sl_price = self.entry_price * 1.0005 # 加手续费 + # TP2 + if self.tp1_hit and price >= self.tp2_price: + self.exit_ts = time_ms + self.exit_price = price + self.status = "tp" + self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50% + return True + else: # SHORT + # 止损 + if price >= self.sl_price: + self.exit_ts = time_ms + self.exit_price = price + if self.tp1_hit: + self.status = "sl_be" + self.pnl_r = 0.5 * 1.5 + else: + self.status = "sl" + self.pnl_r = -1.0 + return True + # TP1 + if not self.tp1_hit and price <= self.tp1_price: + self.tp1_hit = True + self.status = "tp1_hit" + self.sl_price = self.entry_price * 0.9995 + # TP2 + if self.tp1_hit and price <= self.tp2_price: + self.exit_ts = time_ms + self.exit_price = price + self.status = "tp" + self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 + return True + + # 时间止损:持仓超过60分钟强平 + if time_ms - self.entry_ts > 60 * 60 * 1000: + self.exit_ts = time_ms + self.exit_price = price + self.status = "timeout" + if self.direction == "LONG": + move = (price - self.entry_price) / self.entry_price + else: + move = (self.entry_price - price) / self.entry_price + risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price + self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0 + if self.tp1_hit: + self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益 + return True + + return False + + +# ─── 回测引擎 ──────────────────────────────────────────────────── + +class BacktestEngine: + def __init__(self, symbol: str): + self.symbol = symbol + self.state = SymbolState(symbol) + self.state.warmup = False # 回测不需要warmup标记 + self.positions: list[Position] = [] + self.closed_positions: list[Position] = [] + self.cooldown_until: int = 0 + self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r) + self.cumulative_pnl = 0.0 + self.trade_count = 0 + + def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int): + """处理单个tick""" + # 1. 更新状态 + self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker) + + # 2. 检查持仓止盈止损 + atr = self.state.atr_calc.atr + for pos in self.positions[:]: + if pos.check_exit(time_ms, price, atr): + self.positions.remove(pos) + self.closed_positions.append(pos) + self.cumulative_pnl += pos.pnl_r + self.equity_curve.append((time_ms, self.cumulative_pnl)) + + # 3. 评估新信号 + if time_ms < self.cooldown_until: + return + if len(self.positions) > 0: + return # 不同时持多个仓位 + + result = self.state.evaluate_signal(time_ms) + signal = result.get("signal") + if not signal: + return + + score = result.get("score", 0) + if score < 60: + return # V5.1门槛 + + # 确定仓位档 + if score >= 85: + tier = "heavy" + size_mult = 1.3 + elif score >= 75: + tier = "standard" + size_mult = 1.0 + else: + tier = "light" + size_mult = 0.5 + + # 计算TP/SL + risk_atr = 0.7 * self.state.atr_calc.atr # 简化版,只用5m ATR + if risk_atr <= 0: + return + + if signal == "LONG": + sl = price - 2.0 * risk_atr + tp1 = price + 1.5 * risk_atr + tp2 = price + 3.0 * risk_atr + else: + sl = price + 2.0 * risk_atr + tp1 = price - 1.5 * risk_atr + tp2 = price - 3.0 * risk_atr + + pos = Position( + entry_ts=time_ms, + entry_price=price, + direction=signal, + size=size_mult, + tier=tier, + score=score, + sl_price=sl, + tp1_price=tp1, + tp2_price=tp2, + ) + self.positions.append(pos) + self.trade_count += 1 + + # 冷却 + self.cooldown_until = time_ms + 10 * 60 * 1000 + + def report(self) -> dict: + """生成回测统计报告""" + all_trades = self.closed_positions + if not all_trades: + return {"error": "没有交易记录"} + + total = len(all_trades) + wins = [t for t in all_trades if t.pnl_r > 0] + losses = [t for t in all_trades if t.pnl_r <= 0] + win_rate = len(wins) / total * 100 if total > 0 else 0 + + total_pnl = sum(t.pnl_r for t in all_trades) + gross_profit = sum(t.pnl_r for t in wins) if wins else 0 + gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0 + profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf') + + # 平均盈亏 + avg_win = gross_profit / len(wins) if wins else 0 + avg_loss = gross_loss / len(losses) if losses else 0 + avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf') + + # 最大回撤 (MDD) + peak = 0.0 + mdd = 0.0 + running = 0.0 + for t in all_trades: + running += t.pnl_r + peak = max(peak, running) + dd = peak - running + mdd = max(mdd, dd) + + # 平均持仓时间 + hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0] + avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0 + + # 夏普比率(简化版,用R回报序列) + returns = [t.pnl_r for t in all_trades] + if len(returns) > 1: + import statistics + avg_ret = statistics.mean(returns) + std_ret = statistics.stdev(returns) + sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化 + else: + sharpe = 0 + + # 按状态分类 + status_counts = {} + for t in all_trades: + status_counts[t.status] = status_counts.get(t.status, 0) + 1 + + # 按tier分类 + tier_counts = {} + for t in all_trades: + tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1 + + # LONG vs SHORT + longs = [t for t in all_trades if t.direction == "LONG"] + shorts = [t for t in all_trades if t.direction == "SHORT"] + long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0 + short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0 + + return { + "总交易数": total, + "胜率": f"{win_rate:.1f}%", + "总盈亏(R)": f"{total_pnl:+.2f}R", + "盈亏比(Profit Factor)": f"{profit_factor:.2f}", + "平均盈利(R)": f"{avg_win:.2f}R", + "平均亏损(R)": f"-{avg_loss:.2f}R", + "盈亏比(Win/Loss)": f"{avg_rr:.2f}", + "夏普比率": f"{sharpe:.2f}", + "最大回撤(MDD)": f"{mdd:.2f}R", + "平均持仓(分钟)": f"{avg_hold_min:.1f}", + "做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)", + "做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)", + "状态分布": status_counts, + "仓位档分布": tier_counts, + } + + def print_report(self): + """打印回测报告""" + report = self.report() + print("\n" + "=" * 60) + print(f" V5.1 回测报告 — {self.symbol}") + print("=" * 60) + for k, v in report.items(): + if isinstance(v, dict): + print(f" {k}:") + for kk, vv in v.items(): + print(f" {kk}: {vv}") + else: + print(f" {k}: {v}") + print("=" * 60) + + def print_trades(self, limit: int = 20): + """打印最近的交易记录""" + print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:") + print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}") + print("-" * 72) + for t in self.closed_positions[-limit:]: + hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0 + pnl_str = f"{t.pnl_r:+.2f}" + print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}") + + +# ─── 数据加载 ──────────────────────────────────────────────────── + +def load_trades(symbol: str, start_ms: int, end_ms: int) -> int: + """返回数据总条数""" + conn = psycopg2.connect( + host=PG_HOST, port=PG_PORT, dbname=PG_DB, + user=PG_USER, password=PG_PASS + ) + with conn.cursor() as cur: + cur.execute( + "SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s", + (symbol, start_ms, end_ms) + ) + count = cur.fetchone()[0] + conn.close() + return count + + +def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID): + """运行回测""" + engine = BacktestEngine(symbol) + conn = psycopg2.connect( + host=PG_HOST, port=PG_PORT, dbname=PG_DB, + user=PG_USER, password=PG_PASS + ) + conn.set_session(autocommit=True) + + # 先warmup(用回测起点前4h的数据预热指标) + warmup_start = start_ms - warmup_ms + logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据") + warmup_count = 0 + + with conn.cursor("warmup_cursor") as cur: + cur.itersize = 10000 + cur.execute( + "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " + "WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC", + (symbol, warmup_start, start_ms) + ) + for row in cur: + engine.state.process_trade(row[0], row[1], row[2], row[3], row[4]) + warmup_count += 1 + + logger.info(f"预热完成: {warmup_count:,} 条") + + # 回测主循环 + total_count = load_trades(symbol, start_ms, end_ms) + logger.info(f"开始回测: {symbol}, {total_count:,} 条tick") + + processed = 0 + last_log = time.time() + + with conn.cursor("backtest_cursor") as cur: + cur.itersize = 10000 + cur.execute( + "SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades " + "WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC", + (symbol, start_ms, end_ms) + ) + for row in cur: + engine.process_tick(row[0], row[1], row[2], row[3], row[4]) + processed += 1 + + # 每30秒打印进度 + if time.time() - last_log > 30: + pct = processed / total_count * 100 if total_count > 0 else 0 + logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R") + last_log = time.time() + + conn.close() + + # 强平所有未平仓位 + for pos in engine.positions: + pos.status = "timeout" + pos.exit_ts = end_ms + pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price + engine.closed_positions.append(pos) + + logger.info(f"回测完成: 处理 {processed:,} 条tick") + engine.print_report() + engine.print_trades() + + return engine + + +# ─── CLI ───────────────────────────────────────────────────────── + +def main(): + parser = argparse.ArgumentParser(description="V5.1 回测框架") + parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)") + parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)") + parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)") + parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)") + args = parser.parse_args() + + if args.start and args.end: + start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc) + end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc) + start_ms = int(start_dt.timestamp() * 1000) + end_ms = int(end_dt.timestamp() * 1000) + else: + end_ms = int(time.time() * 1000) + start_ms = end_ms - args.days * 24 * 3600 * 1000 + + start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M") + end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M") + logger.info(f"回测参数: {args.symbol} | {start_str} → {end_str} | ~{args.days}天") + + engine = run_backtest(args.symbol, start_ms, end_ms) + + # 保存结果到JSON + import json + report = engine.report() + report["symbol"] = args.symbol + report["start"] = start_str + report["end"] = end_str + report["trades_count"] = len(engine.closed_positions) + + output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + with open(output_file, "w") as f: + json.dump(report, f, indent=2, ensure_ascii=False, default=str) + logger.info(f"报告已保存: {output_file}") + + +if __name__ == "__main__": + main()