arbitrage-engine/backend/backtest.py

504 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
backtest.py — V5.1 回测框架逐tick事件回放
架构:
- 从PG读取历史agg_trades按时间顺序逐tick回放
- 复用signal_engine的SymbolState评估逻辑
- 模拟开仓/平仓/止盈止损
- 输出:胜率/盈亏比/夏普/MDD等统计
用法:
python3 backtest.py --symbol BTCUSDT --days 20
python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28
"""
import argparse
import logging
import os
import sys
import time
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Optional
import psycopg2
# 复用signal_engine的核心类
sys.path.insert(0, os.path.dirname(__file__))
from signal_engine import SymbolState, WINDOW_MID
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger("backtest")
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
PG_PORT = int(os.getenv("PG_PORT", "5432"))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
# ─── 仓位管理 ────────────────────────────────────────────────────
@dataclass
class Position:
"""单笔仓位"""
entry_ts: int
entry_price: float
direction: str # "LONG" or "SHORT"
size: float # 仓位大小R单位
tier: str # "light", "standard", "heavy"
score: int
sl_price: float
tp1_price: float
tp2_price: float
tp1_hit: bool = False
status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout
exit_ts: int = 0
exit_price: float = 0.0
pnl_r: float = 0.0
def check_exit(self, time_ms: int, price: float, atr: float) -> bool:
"""检查是否触发止盈止损返回True表示已平仓"""
if self.status != "active" and self.status != "tp1_hit":
return True
if self.direction == "LONG":
# 止损
if price <= self.sl_price:
self.exit_ts = time_ms
self.exit_price = price
if self.tp1_hit:
self.status = "sl_be"
self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本
else:
self.status = "sl"
self.pnl_r = -1.0
return True
# TP1
if not self.tp1_hit and price >= self.tp1_price:
self.tp1_hit = True
self.status = "tp1_hit"
# 移动止损到成本价
self.sl_price = self.entry_price * 1.0005 # 加手续费
# TP2
if self.tp1_hit and price >= self.tp2_price:
self.exit_ts = time_ms
self.exit_price = price
self.status = "tp"
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50%
return True
else: # SHORT
# 止损
if price >= self.sl_price:
self.exit_ts = time_ms
self.exit_price = price
if self.tp1_hit:
self.status = "sl_be"
self.pnl_r = 0.5 * 1.5
else:
self.status = "sl"
self.pnl_r = -1.0
return True
# TP1
if not self.tp1_hit and price <= self.tp1_price:
self.tp1_hit = True
self.status = "tp1_hit"
self.sl_price = self.entry_price * 0.9995
# TP2
if self.tp1_hit and price <= self.tp2_price:
self.exit_ts = time_ms
self.exit_price = price
self.status = "tp"
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0
return True
# 时间止损持仓超过60分钟强平
if time_ms - self.entry_ts > 60 * 60 * 1000:
self.exit_ts = time_ms
self.exit_price = price
self.status = "timeout"
if self.direction == "LONG":
move = (price - self.entry_price) / self.entry_price
else:
move = (self.entry_price - price) / self.entry_price
risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price
self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0
if self.tp1_hit:
self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益
return True
return False
# ─── 回测引擎 ────────────────────────────────────────────────────
class BacktestEngine:
def __init__(self, symbol: str, eval_interval_ms: int = 15000):
self.symbol = symbol
self.state = SymbolState(symbol)
self.state.warmup = False
self.positions: list[Position] = []
self.closed_positions: list[Position] = []
self.cooldown_until: int = 0
self.equity_curve: list[tuple[int, float]] = []
self.cumulative_pnl = 0.0
self.trade_count = 0
self.eval_interval_ms = eval_interval_ms
self.last_eval_ts: int = 0
# 每个eval周期内追踪高低价用于止盈止损检查
self.period_high: float = 0.0
self.period_low: float = float('inf')
self.period_last_price: float = 0.0
self.period_last_ts: int = 0
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
"""处理单个tick — 轻量更新状态 + 追踪高低价"""
# 1. 更新核心状态CVD/ATR/VWAP等
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
# 2. 追踪区间高低价
if price > self.period_high:
self.period_high = price
if price < self.period_low:
self.period_low = price
self.period_last_price = price
self.period_last_ts = time_ms
# 3. 每eval_interval检查一次止盈止损 + 评估信号
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
self._evaluate(time_ms, price)
self.last_eval_ts = time_ms
self.period_high = price
self.period_low = price
def _evaluate(self, time_ms: int, price: float):
"""每15秒执行一次检查止盈止损 + 评估新信号"""
atr = self.state.atr_calc.atr
# 检查持仓止盈止损用区间高低价模拟逐tick精度
for pos in self.positions[:]:
# 用high/low检查是否触发TP/SL
exited = False
if pos.direction == "LONG":
# 先检查止损用low
if self.period_low <= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
# 再检查止盈用high
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
# 用TP1检查可能只触发TP1不退出
pos.check_exit(time_ms, self.period_high, atr)
else: # SHORT
if self.period_high >= pos.sl_price:
exited = pos.check_exit(time_ms, pos.sl_price, atr)
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
exited = pos.check_exit(time_ms, check_price, atr)
else:
pos.check_exit(time_ms, self.period_low, atr)
# 时间止损检查
if not exited and pos.status in ("active", "tp1_hit"):
exited = pos.check_exit(time_ms, price, atr)
if exited or pos.status not in ("active", "tp1_hit"):
if pos in self.positions:
self.positions.remove(pos)
self.closed_positions.append(pos)
self.cumulative_pnl += pos.pnl_r
self.equity_curve.append((time_ms, self.cumulative_pnl))
# 评估新信号
if time_ms < self.cooldown_until:
return
if len(self.positions) > 0:
return
result = self.state.evaluate_signal(time_ms)
signal = result.get("signal")
if not signal:
return
score = result.get("score", 0)
if score < 60:
return # V5.1门槛
# 确定仓位档
if score >= 85:
tier = "heavy"
size_mult = 1.3
elif score >= 75:
tier = "standard"
size_mult = 1.0
else:
tier = "light"
size_mult = 0.5
# 计算TP/SL
risk_atr = 0.7 * self.state.atr_calc.atr # 简化版只用5m ATR
if risk_atr <= 0:
return
if signal == "LONG":
sl = price - 2.0 * risk_atr
tp1 = price + 1.5 * risk_atr
tp2 = price + 3.0 * risk_atr
else:
sl = price + 2.0 * risk_atr
tp1 = price - 1.5 * risk_atr
tp2 = price - 3.0 * risk_atr
pos = Position(
entry_ts=time_ms,
entry_price=price,
direction=signal,
size=size_mult,
tier=tier,
score=score,
sl_price=sl,
tp1_price=tp1,
tp2_price=tp2,
)
self.positions.append(pos)
self.trade_count += 1
# 冷却
self.cooldown_until = time_ms + 10 * 60 * 1000
def report(self) -> dict:
"""生成回测统计报告"""
all_trades = self.closed_positions
if not all_trades:
return {"error": "没有交易记录"}
total = len(all_trades)
wins = [t for t in all_trades if t.pnl_r > 0]
losses = [t for t in all_trades if t.pnl_r <= 0]
win_rate = len(wins) / total * 100 if total > 0 else 0
total_pnl = sum(t.pnl_r for t in all_trades)
gross_profit = sum(t.pnl_r for t in wins) if wins else 0
gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0
profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
# 平均盈亏
avg_win = gross_profit / len(wins) if wins else 0
avg_loss = gross_loss / len(losses) if losses else 0
avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf')
# 最大回撤 (MDD)
peak = 0.0
mdd = 0.0
running = 0.0
for t in all_trades:
running += t.pnl_r
peak = max(peak, running)
dd = peak - running
mdd = max(mdd, dd)
# 平均持仓时间
hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0]
avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0
# 夏普比率简化版用R回报序列
returns = [t.pnl_r for t in all_trades]
if len(returns) > 1:
import statistics
avg_ret = statistics.mean(returns)
std_ret = statistics.stdev(returns)
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化
else:
sharpe = 0
# 按状态分类
status_counts = {}
for t in all_trades:
status_counts[t.status] = status_counts.get(t.status, 0) + 1
# 按tier分类
tier_counts = {}
for t in all_trades:
tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1
# LONG vs SHORT
longs = [t for t in all_trades if t.direction == "LONG"]
shorts = [t for t in all_trades if t.direction == "SHORT"]
long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0
short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0
return {
"总交易数": total,
"胜率": f"{win_rate:.1f}%",
"总盈亏(R)": f"{total_pnl:+.2f}R",
"盈亏比(Profit Factor)": f"{profit_factor:.2f}",
"平均盈利(R)": f"{avg_win:.2f}R",
"平均亏损(R)": f"-{avg_loss:.2f}R",
"盈亏比(Win/Loss)": f"{avg_rr:.2f}",
"夏普比率": f"{sharpe:.2f}",
"最大回撤(MDD)": f"{mdd:.2f}R",
"平均持仓(分钟)": f"{avg_hold_min:.1f}",
"做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)",
"做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)",
"状态分布": status_counts,
"仓位档分布": tier_counts,
}
def print_report(self):
"""打印回测报告"""
report = self.report()
print("\n" + "=" * 60)
print(f" V5.1 回测报告 — {self.symbol}")
print("=" * 60)
for k, v in report.items():
if isinstance(v, dict):
print(f" {k}:")
for kk, vv in v.items():
print(f" {kk}: {vv}")
else:
print(f" {k}: {v}")
print("=" * 60)
def print_trades(self, limit: int = 20):
"""打印最近的交易记录"""
print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:")
print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}")
print("-" * 72)
for t in self.closed_positions[-limit:]:
hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0
pnl_str = f"{t.pnl_r:+.2f}"
print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}")
# ─── 数据加载 ────────────────────────────────────────────────────
def load_trades(symbol: str, start_ms: int, end_ms: int) -> int:
"""返回数据总条数"""
conn = psycopg2.connect(
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
user=PG_USER, password=PG_PASS
)
with conn.cursor() as cur:
cur.execute(
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s",
(symbol, start_ms, end_ms)
)
count = cur.fetchone()[0]
conn.close()
return count
def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID):
"""运行回测"""
engine = BacktestEngine(symbol)
conn = psycopg2.connect(
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
user=PG_USER, password=PG_PASS
)
# named cursor需要在事务内不要设autocommit=True
# 先warmup用回测起点前4h的数据预热指标
warmup_start = start_ms - warmup_ms
logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据")
warmup_count = 0
with conn.cursor("warmup_cursor") as cur:
cur.itersize = 50000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
(symbol, warmup_start, start_ms)
)
for row in cur:
engine.state.process_trade(row[0], row[1], row[2], row[3], row[4])
warmup_count += 1
logger.info(f"预热完成: {warmup_count:,}")
# 回测主循环
total_count = load_trades(symbol, start_ms, end_ms)
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
processed = 0
last_log = time.time()
with conn.cursor("backtest_cursor") as cur:
cur.itersize = 50000
cur.execute(
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
(symbol, start_ms, end_ms)
)
for row in cur:
engine.process_tick(row[0], row[1], row[2], row[3], row[4])
processed += 1
# 每30秒打印进度
if time.time() - last_log > 30:
pct = processed / total_count * 100 if total_count > 0 else 0
logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R")
last_log = time.time()
conn.close()
# 强平所有未平仓位
for pos in engine.positions:
pos.status = "timeout"
pos.exit_ts = end_ms
pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price
engine.closed_positions.append(pos)
logger.info(f"回测完成: 处理 {processed:,} 条tick")
engine.print_report()
engine.print_trades()
return engine
# ─── CLI ─────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="V5.1 回测框架")
parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)")
parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)")
parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)")
parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)")
args = parser.parse_args()
if args.start and args.end:
start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc)
end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = int(end_dt.timestamp() * 1000)
else:
end_ms = int(time.time() * 1000)
start_ms = end_ms - args.days * 24 * 3600 * 1000
start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
logger.info(f"回测参数: {args.symbol} | {start_str}{end_str} | ~{args.days}")
engine = run_backtest(args.symbol, start_ms, end_ms)
# 保存结果到JSON
import json
report = engine.report()
report["symbol"] = args.symbol
report["start"] = start_str
report["end"] = end_str
report["trades_count"] = len(engine.closed_positions)
output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(output_file, "w") as f:
json.dump(report, f, indent=2, ensure_ascii=False, default=str)
logger.info(f"报告已保存: {output_file}")
if __name__ == "__main__":
main()