feat: V5.1 backtest framework - tick-by-tick replay with TP/SL/position management

This commit is contained in:
root 2026-02-28 05:53:42 +00:00
parent 2e969f68b4
commit 5ba4c7fe98

451
backend/backtest.py Normal file
View File

@ -0,0 +1,451 @@
"""
backtest.py V5.1 回测框架逐tick事件回放
架构
- 从PG读取历史agg_trades按时间顺序逐tick回放
- 复用signal_engine的SymbolState评估逻辑
- 模拟开仓/平仓/止盈止损
- 输出胜率/盈亏比/夏普/MDD等统计
用法
python3 backtest.py --symbol BTCUSDT --days 20
python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28
"""
import argparse
import logging
import os
import sys
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
import psycopg2
# 复用signal_engine的核心类
sys.path.insert(0, os.path.dirname(__file__))
from signal_engine import SymbolState, WINDOW_MID
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("backtest")
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_PORT = int(os.getenv("PG_PORT", "5432"))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
# ─── 仓位管理 ────────────────────────────────────────────────────
@dataclass
class Position:
"""单笔仓位"""
entry_ts: int
entry_price: float
direction: str # "LONG" or "SHORT"
size: float # 仓位大小R单位
tier: str # "light", "standard", "heavy"
score: int
sl_price: float
tp1_price: float
tp2_price: float
tp1_hit: bool = False
status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout
exit_ts: int = 0
exit_price: float = 0.0
pnl_r: float = 0.0
def check_exit(self, time_ms: int, price: float, atr: float) -> bool:
"""检查是否触发止盈止损返回True表示已平仓"""
if self.status != "active" and self.status != "tp1_hit":
return True
if self.direction == "LONG":
# 止损
if price <= self.sl_price:
self.exit_ts = time_ms
self.exit_price = price
if self.tp1_hit:
self.status = "sl_be"
self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本
else:
self.status = "sl"
self.pnl_r = -1.0
return True
# TP1
if not self.tp1_hit and price >= self.tp1_price:
self.tp1_hit = True
self.status = "tp1_hit"
# 移动止损到成本价
self.sl_price = self.entry_price * 1.0005 # 加手续费
# TP2
if self.tp1_hit and price >= self.tp2_price:
self.exit_ts = time_ms
self.exit_price = price
self.status = "tp"
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50%
return True
else: # SHORT
# 止损
if price >= self.sl_price:
self.exit_ts = time_ms
self.exit_price = price
if self.tp1_hit:
self.status = "sl_be"
self.pnl_r = 0.5 * 1.5
else:
self.status = "sl"
self.pnl_r = -1.0
return True
# TP1
if not self.tp1_hit and price <= self.tp1_price:
self.tp1_hit = True
self.status = "tp1_hit"
self.sl_price = self.entry_price * 0.9995
# TP2
if self.tp1_hit and price <= self.tp2_price:
self.exit_ts = time_ms
self.exit_price = price
self.status = "tp"
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0
return True
# 时间止损持仓超过60分钟强平
if time_ms - self.entry_ts > 60 * 60 * 1000:
self.exit_ts = time_ms
self.exit_price = price
self.status = "timeout"
if self.direction == "LONG":
move = (price - self.entry_price) / self.entry_price
else:
move = (self.entry_price - price) / self.entry_price
risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price
self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0
if self.tp1_hit:
self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益
return True
return False
# ─── 回测引擎 ────────────────────────────────────────────────────
class BacktestEngine:
def __init__(self, symbol: str):
self.symbol = symbol
self.state = SymbolState(symbol)
self.state.warmup = False # 回测不需要warmup标记
self.positions: list[Position] = []
self.closed_positions: list[Position] = []
self.cooldown_until: int = 0
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r)
self.cumulative_pnl = 0.0
self.trade_count = 0
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
"""处理单个tick"""
# 1. 更新状态
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
# 2. 检查持仓止盈止损
atr = self.state.atr_calc.atr
for pos in self.positions[:]:
if pos.check_exit(time_ms, price, atr):
self.positions.remove(pos)
self.closed_positions.append(pos)
self.cumulative_pnl += pos.pnl_r
self.equity_curve.append((time_ms, self.cumulative_pnl))
# 3. 评估新信号
if time_ms < self.cooldown_until:
return
if len(self.positions) > 0:
return # 不同时持多个仓位
result = self.state.evaluate_signal(time_ms)
signal = result.get("signal")
if not signal:
return
score = result.get("score", 0)
if score < 60:
return # V5.1门槛
# 确定仓位档
if score >= 85:
tier = "heavy"
size_mult = 1.3
elif score >= 75:
tier = "standard"
size_mult = 1.0
else:
tier = "light"
size_mult = 0.5
# 计算TP/SL
risk_atr = 0.7 * self.state.atr_calc.atr # 简化版只用5m ATR
if risk_atr <= 0:
return
if signal == "LONG":
sl = price - 2.0 * risk_atr
tp1 = price + 1.5 * risk_atr
tp2 = price + 3.0 * risk_atr
else:
sl = price + 2.0 * risk_atr
tp1 = price - 1.5 * risk_atr
tp2 = price - 3.0 * risk_atr
pos = Position(
entry_ts=time_ms,
entry_price=price,
direction=signal,
size=size_mult,
tier=tier,
score=score,
sl_price=sl,
tp1_price=tp1,
tp2_price=tp2,
)
self.positions.append(pos)
self.trade_count += 1
# 冷却
self.cooldown_until = time_ms + 10 * 60 * 1000
def report(self) -> dict:
"""生成回测统计报告"""
all_trades = self.closed_positions
if not all_trades:
return {"error": "没有交易记录"}
total = len(all_trades)
wins = [t for t in all_trades if t.pnl_r > 0]
losses = [t for t in all_trades if t.pnl_r <= 0]
win_rate = len(wins) / total * 100 if total > 0 else 0
total_pnl = sum(t.pnl_r for t in all_trades)
gross_profit = sum(t.pnl_r for t in wins) if wins else 0
gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0
profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
# 平均盈亏
avg_win = gross_profit / len(wins) if wins else 0
avg_loss = gross_loss / len(losses) if losses else 0
avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf')
# 最大回撤 (MDD)
peak = 0.0
mdd = 0.0
running = 0.0
for t in all_trades:
running += t.pnl_r
peak = max(peak, running)
dd = peak - running
mdd = max(mdd, dd)
# 平均持仓时间
hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0]
avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0
# 夏普比率简化版用R回报序列
returns = [t.pnl_r for t in all_trades]
if len(returns) > 1:
import statistics
avg_ret = statistics.mean(returns)
std_ret = statistics.stdev(returns)
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化
else:
sharpe = 0
# 按状态分类
status_counts = {}
for t in all_trades:
status_counts[t.status] = status_counts.get(t.status, 0) + 1
# 按tier分类
tier_counts = {}
for t in all_trades:
tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1
# LONG vs SHORT
longs = [t for t in all_trades if t.direction == "LONG"]
shorts = [t for t in all_trades if t.direction == "SHORT"]
long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0
short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0
return {
"总交易数": total,
"胜率": f"{win_rate:.1f}%",
"总盈亏(R)": f"{total_pnl:+.2f}R",
"盈亏比(Profit Factor)": f"{profit_factor:.2f}",
"平均盈利(R)": f"{avg_win:.2f}R",
"平均亏损(R)": f"-{avg_loss:.2f}R",
"盈亏比(Win/Loss)": f"{avg_rr:.2f}",
"夏普比率": f"{sharpe:.2f}",
"最大回撤(MDD)": f"{mdd:.2f}R",
"平均持仓(分钟)": f"{avg_hold_min:.1f}",
"做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)",
"做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)",
"状态分布": status_counts,
"仓位档分布": tier_counts,
}
def print_report(self):
"""打印回测报告"""
report = self.report()
print("\n" + "=" * 60)
print(f" V5.1 回测报告 — {self.symbol}")
print("=" * 60)
for k, v in report.items():
if isinstance(v, dict):
print(f" {k}:")
for kk, vv in v.items():
print(f" {kk}: {vv}")
else:
print(f" {k}: {v}")
print("=" * 60)
def print_trades(self, limit: int = 20):
"""打印最近的交易记录"""
print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:")
print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}")
print("-" * 72)
for t in self.closed_positions[-limit:]:
hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0
pnl_str = f"{t.pnl_r:+.2f}"
print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}")
# ─── 数据加载 ────────────────────────────────────────────────────
def load_trades(symbol: str, start_ms: int, end_ms: int) -> int:
"""返回数据总条数"""
conn = psycopg2.connect(
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
user=PG_USER, password=PG_PASS
)
with conn.cursor() as cur:
cur.execute(
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s",
(symbol, start_ms, end_ms)
)
count = cur.fetchone()[0]
conn.close()
return count
def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID):
"""运行回测"""
engine = BacktestEngine(symbol)
conn = psycopg2.connect(
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
user=PG_USER, password=PG_PASS
)
conn.set_session(autocommit=True)
# 先warmup用回测起点前4h的数据预热指标
warmup_start = start_ms - warmup_ms
logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据")
warmup_count = 0
with conn.cursor("warmup_cursor") as cur:
cur.itersize = 10000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
(symbol, warmup_start, start_ms)
)
for row in cur:
engine.state.process_trade(row[0], row[1], row[2], row[3], row[4])
warmup_count += 1
logger.info(f"预热完成: {warmup_count:,}")
# 回测主循环
total_count = load_trades(symbol, start_ms, end_ms)
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick")
processed = 0
last_log = time.time()
with conn.cursor("backtest_cursor") as cur:
cur.itersize = 10000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
(symbol, start_ms, end_ms)
)
for row in cur:
engine.process_tick(row[0], row[1], row[2], row[3], row[4])
processed += 1
# 每30秒打印进度
if time.time() - last_log > 30:
pct = processed / total_count * 100 if total_count > 0 else 0
logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R")
last_log = time.time()
conn.close()
# 强平所有未平仓位
for pos in engine.positions:
pos.status = "timeout"
pos.exit_ts = end_ms
pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price
engine.closed_positions.append(pos)
logger.info(f"回测完成: 处理 {processed:,} 条tick")
engine.print_report()
engine.print_trades()
return engine
# ─── CLI ─────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="V5.1 回测框架")
parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)")
parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)")
parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)")
parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)")
args = parser.parse_args()
if args.start and args.end:
start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc)
end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = int(end_dt.timestamp() * 1000)
else:
end_ms = int(time.time() * 1000)
start_ms = end_ms - args.days * 24 * 3600 * 1000
start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
logger.info(f"回测参数: {args.symbol} | {start_str}{end_str} | ~{args.days}")
engine = run_backtest(args.symbol, start_ms, end_ms)
# 保存结果到JSON
import json
report = engine.report()
report["symbol"] = args.symbol
report["start"] = start_str
report["end"] = end_str
report["trades_count"] = len(engine.closed_positions)
output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
json.dump(report, f, indent=2, ensure_ascii=False, default=str)
logger.info(f"报告已保存: {output_file}")
if __name__ == "__main__":
main()