511 lines
19 KiB
Python
511 lines
19 KiB
Python
"""
|
||
backtest.py — V5.1 回测框架(逐tick事件回放)
|
||
|
||
架构:
|
||
- 从PG读取历史agg_trades,按时间顺序逐tick回放
|
||
- 复用signal_engine的SymbolState评估逻辑
|
||
- 模拟开仓/平仓/止盈止损
|
||
- 输出:胜率/盈亏比/夏普/MDD等统计
|
||
|
||
用法:
|
||
python3 backtest.py --symbol BTCUSDT --days 20
|
||
python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
import os
|
||
import sys
|
||
import time
|
||
from collections import deque
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
import psycopg2
|
||
|
||
# 复用 signal_engine/signal_state 的核心类与评分逻辑
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from signal_engine import SymbolState, WINDOW_MID
|
||
from strategy_scoring import evaluate_signal as score_strategy
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||
)
|
||
logger = logging.getLogger("backtest")
|
||
|
||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
||
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||
PG_USER = os.getenv("PG_USER", "arb")
|
||
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
||
|
||
|
||
# ─── 仓位管理 ────────────────────────────────────────────────────
|
||
|
||
@dataclass
|
||
class Position:
|
||
"""单笔仓位"""
|
||
entry_ts: int
|
||
entry_price: float
|
||
direction: str # "LONG" or "SHORT"
|
||
size: float # 仓位大小(R单位)
|
||
tier: str # "light", "standard", "heavy"
|
||
score: int
|
||
sl_price: float
|
||
tp1_price: float
|
||
tp2_price: float
|
||
tp1_hit: bool = False
|
||
status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout
|
||
exit_ts: int = 0
|
||
exit_price: float = 0.0
|
||
pnl_r: float = 0.0
|
||
|
||
def check_exit(self, time_ms: int, price: float, atr: float) -> bool:
|
||
"""检查是否触发止盈止损,返回True表示已平仓"""
|
||
if self.status != "active" and self.status != "tp1_hit":
|
||
return True
|
||
|
||
if self.direction == "LONG":
|
||
# 止损
|
||
if price <= self.sl_price:
|
||
self.exit_ts = time_ms
|
||
self.exit_price = price
|
||
if self.tp1_hit:
|
||
self.status = "sl_be"
|
||
self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本
|
||
else:
|
||
self.status = "sl"
|
||
self.pnl_r = -1.0
|
||
return True
|
||
# TP1
|
||
if not self.tp1_hit and price >= self.tp1_price:
|
||
self.tp1_hit = True
|
||
self.status = "tp1_hit"
|
||
# 移动止损到成本价
|
||
self.sl_price = self.entry_price * 1.0005 # 加手续费
|
||
# TP2
|
||
if self.tp1_hit and price >= self.tp2_price:
|
||
self.exit_ts = time_ms
|
||
self.exit_price = price
|
||
self.status = "tp"
|
||
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50%
|
||
return True
|
||
else: # SHORT
|
||
# 止损
|
||
if price >= self.sl_price:
|
||
self.exit_ts = time_ms
|
||
self.exit_price = price
|
||
if self.tp1_hit:
|
||
self.status = "sl_be"
|
||
self.pnl_r = 0.5 * 1.5
|
||
else:
|
||
self.status = "sl"
|
||
self.pnl_r = -1.0
|
||
return True
|
||
# TP1
|
||
if not self.tp1_hit and price <= self.tp1_price:
|
||
self.tp1_hit = True
|
||
self.status = "tp1_hit"
|
||
self.sl_price = self.entry_price * 0.9995
|
||
# TP2
|
||
if self.tp1_hit and price <= self.tp2_price:
|
||
self.exit_ts = time_ms
|
||
self.exit_price = price
|
||
self.status = "tp"
|
||
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0
|
||
return True
|
||
|
||
# 时间止损:持仓超过60分钟强平
|
||
if time_ms - self.entry_ts > 60 * 60 * 1000:
|
||
self.exit_ts = time_ms
|
||
self.exit_price = price
|
||
self.status = "timeout"
|
||
if self.direction == "LONG":
|
||
move = (price - self.entry_price) / self.entry_price
|
||
else:
|
||
move = (self.entry_price - price) / self.entry_price
|
||
risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price
|
||
self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0
|
||
if self.tp1_hit:
|
||
self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益
|
||
return True
|
||
|
||
return False
|
||
|
||
|
||
# ─── 回测引擎 ────────────────────────────────────────────────────
|
||
|
||
class BacktestEngine:
|
||
def __init__(self, symbol: str, eval_interval_ms: int = 15000):
|
||
self.symbol = symbol
|
||
self.state = SymbolState(symbol)
|
||
self.state.warmup = False
|
||
self.positions: list[Position] = []
|
||
self.closed_positions: list[Position] = []
|
||
self.cooldown_until: int = 0
|
||
self.equity_curve: list[tuple[int, float]] = []
|
||
self.cumulative_pnl = 0.0
|
||
self.trade_count = 0
|
||
self.eval_interval_ms = eval_interval_ms
|
||
self.last_eval_ts: int = 0
|
||
# 每个eval周期内追踪高低价(用于止盈止损检查)
|
||
self.period_high: float = 0.0
|
||
self.period_low: float = float('inf')
|
||
self.period_last_price: float = 0.0
|
||
self.period_last_ts: int = 0
|
||
|
||
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||
"""处理单个tick — 轻量更新状态 + 追踪高低价"""
|
||
# 1. 更新核心状态(CVD/ATR/VWAP等)
|
||
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
|
||
|
||
# 2. 追踪区间高低价
|
||
if price > self.period_high:
|
||
self.period_high = price
|
||
if price < self.period_low:
|
||
self.period_low = price
|
||
self.period_last_price = price
|
||
self.period_last_ts = time_ms
|
||
|
||
# 3. 每eval_interval检查一次止盈止损 + 评估信号
|
||
if time_ms - self.last_eval_ts >= self.eval_interval_ms:
|
||
self._evaluate(time_ms, price)
|
||
self.last_eval_ts = time_ms
|
||
self.period_high = price
|
||
self.period_low = price
|
||
|
||
def _evaluate(self, time_ms: int, price: float):
|
||
"""每15秒执行一次:检查止盈止损 + 评估新信号"""
|
||
atr = self.state.atr_calc.atr
|
||
|
||
# 检查持仓止盈止损(用区间高低价模拟逐tick精度)
|
||
for pos in self.positions[:]:
|
||
# 用high/low检查是否触发TP/SL
|
||
exited = False
|
||
if pos.direction == "LONG":
|
||
# 先检查止损(用low)
|
||
if self.period_low <= pos.sl_price:
|
||
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||
# 再检查止盈(用high)
|
||
elif self.period_high >= pos.tp1_price or (pos.tp1_hit and self.period_high >= pos.tp2_price):
|
||
check_price = pos.tp2_price if pos.tp1_hit and self.period_high >= pos.tp2_price else pos.tp1_price
|
||
exited = pos.check_exit(time_ms, check_price, atr)
|
||
else:
|
||
# 用TP1检查(可能只触发TP1不退出)
|
||
pos.check_exit(time_ms, self.period_high, atr)
|
||
else: # SHORT
|
||
if self.period_high >= pos.sl_price:
|
||
exited = pos.check_exit(time_ms, pos.sl_price, atr)
|
||
elif self.period_low <= pos.tp1_price or (pos.tp1_hit and self.period_low <= pos.tp2_price):
|
||
check_price = pos.tp2_price if pos.tp1_hit and self.period_low <= pos.tp2_price else pos.tp1_price
|
||
exited = pos.check_exit(time_ms, check_price, atr)
|
||
else:
|
||
pos.check_exit(time_ms, self.period_low, atr)
|
||
|
||
# 时间止损检查
|
||
if not exited and pos.status in ("active", "tp1_hit"):
|
||
exited = pos.check_exit(time_ms, price, atr)
|
||
|
||
if exited or pos.status not in ("active", "tp1_hit"):
|
||
if pos in self.positions:
|
||
self.positions.remove(pos)
|
||
self.closed_positions.append(pos)
|
||
self.cumulative_pnl += pos.pnl_r
|
||
self.equity_curve.append((time_ms, self.cumulative_pnl))
|
||
|
||
# 评估新信号
|
||
if time_ms < self.cooldown_until:
|
||
return
|
||
if len(self.positions) > 0:
|
||
return
|
||
|
||
# 使用统一评分入口(V5.1 baseline 配置)
|
||
strategy_cfg = {
|
||
"name": "v51_baseline",
|
||
"threshold": 75,
|
||
"signals": ["cvd", "p99", "accel", "ls_ratio", "oi", "coinbase_premium"],
|
||
}
|
||
result = score_strategy(self.state, time_ms, strategy_cfg=strategy_cfg)
|
||
signal = result.get("signal")
|
||
if not signal:
|
||
return
|
||
|
||
score = result.get("score", 0)
|
||
if score < 60:
|
||
return # V5.1门槛
|
||
|
||
# 确定仓位档
|
||
if score >= 85:
|
||
tier = "heavy"
|
||
size_mult = 1.3
|
||
elif score >= 75:
|
||
tier = "standard"
|
||
size_mult = 1.0
|
||
else:
|
||
tier = "light"
|
||
size_mult = 0.5
|
||
|
||
# 计算TP/SL
|
||
risk_atr = 0.7 * self.state.atr_calc.atr # 简化版,只用5m ATR
|
||
if risk_atr <= 0:
|
||
return
|
||
|
||
if signal == "LONG":
|
||
sl = price - 2.0 * risk_atr
|
||
tp1 = price + 1.5 * risk_atr
|
||
tp2 = price + 3.0 * risk_atr
|
||
else:
|
||
sl = price + 2.0 * risk_atr
|
||
tp1 = price - 1.5 * risk_atr
|
||
tp2 = price - 3.0 * risk_atr
|
||
|
||
pos = Position(
|
||
entry_ts=time_ms,
|
||
entry_price=price,
|
||
direction=signal,
|
||
size=size_mult,
|
||
tier=tier,
|
||
score=score,
|
||
sl_price=sl,
|
||
tp1_price=tp1,
|
||
tp2_price=tp2,
|
||
)
|
||
self.positions.append(pos)
|
||
self.trade_count += 1
|
||
|
||
# 冷却
|
||
self.cooldown_until = time_ms + 10 * 60 * 1000
|
||
|
||
def report(self) -> dict:
|
||
"""生成回测统计报告"""
|
||
all_trades = self.closed_positions
|
||
if not all_trades:
|
||
return {"error": "没有交易记录"}
|
||
|
||
total = len(all_trades)
|
||
wins = [t for t in all_trades if t.pnl_r > 0]
|
||
losses = [t for t in all_trades if t.pnl_r <= 0]
|
||
win_rate = len(wins) / total * 100 if total > 0 else 0
|
||
|
||
total_pnl = sum(t.pnl_r for t in all_trades)
|
||
gross_profit = sum(t.pnl_r for t in wins) if wins else 0
|
||
gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0
|
||
profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
|
||
|
||
# 平均盈亏
|
||
avg_win = gross_profit / len(wins) if wins else 0
|
||
avg_loss = gross_loss / len(losses) if losses else 0
|
||
avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf')
|
||
|
||
# 最大回撤 (MDD)
|
||
peak = 0.0
|
||
mdd = 0.0
|
||
running = 0.0
|
||
for t in all_trades:
|
||
running += t.pnl_r
|
||
peak = max(peak, running)
|
||
dd = peak - running
|
||
mdd = max(mdd, dd)
|
||
|
||
# 平均持仓时间
|
||
hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0]
|
||
avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0
|
||
|
||
# 夏普比率(简化版,用R回报序列)
|
||
returns = [t.pnl_r for t in all_trades]
|
||
if len(returns) > 1:
|
||
import statistics
|
||
avg_ret = statistics.mean(returns)
|
||
std_ret = statistics.stdev(returns)
|
||
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化
|
||
else:
|
||
sharpe = 0
|
||
|
||
# 按状态分类
|
||
status_counts = {}
|
||
for t in all_trades:
|
||
status_counts[t.status] = status_counts.get(t.status, 0) + 1
|
||
|
||
# 按tier分类
|
||
tier_counts = {}
|
||
for t in all_trades:
|
||
tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1
|
||
|
||
# LONG vs SHORT
|
||
longs = [t for t in all_trades if t.direction == "LONG"]
|
||
shorts = [t for t in all_trades if t.direction == "SHORT"]
|
||
long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0
|
||
short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0
|
||
|
||
return {
|
||
"总交易数": total,
|
||
"胜率": f"{win_rate:.1f}%",
|
||
"总盈亏(R)": f"{total_pnl:+.2f}R",
|
||
"盈亏比(Profit Factor)": f"{profit_factor:.2f}",
|
||
"平均盈利(R)": f"{avg_win:.2f}R",
|
||
"平均亏损(R)": f"-{avg_loss:.2f}R",
|
||
"盈亏比(Win/Loss)": f"{avg_rr:.2f}",
|
||
"夏普比率": f"{sharpe:.2f}",
|
||
"最大回撤(MDD)": f"{mdd:.2f}R",
|
||
"平均持仓(分钟)": f"{avg_hold_min:.1f}",
|
||
"做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)",
|
||
"做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)",
|
||
"状态分布": status_counts,
|
||
"仓位档分布": tier_counts,
|
||
}
|
||
|
||
def print_report(self):
|
||
"""打印回测报告"""
|
||
report = self.report()
|
||
print("\n" + "=" * 60)
|
||
print(f" V5.1 回测报告 — {self.symbol}")
|
||
print("=" * 60)
|
||
for k, v in report.items():
|
||
if isinstance(v, dict):
|
||
print(f" {k}:")
|
||
for kk, vv in v.items():
|
||
print(f" {kk}: {vv}")
|
||
else:
|
||
print(f" {k}: {v}")
|
||
print("=" * 60)
|
||
|
||
def print_trades(self, limit: int = 20):
|
||
"""打印最近的交易记录"""
|
||
print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:")
|
||
print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}")
|
||
print("-" * 72)
|
||
for t in self.closed_positions[-limit:]:
|
||
hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0
|
||
pnl_str = f"{t.pnl_r:+.2f}"
|
||
print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}")
|
||
|
||
|
||
# ─── 数据加载 ────────────────────────────────────────────────────
|
||
|
||
def load_trades(symbol: str, start_ms: int, end_ms: int) -> int:
|
||
"""返回数据总条数"""
|
||
conn = psycopg2.connect(
|
||
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
|
||
user=PG_USER, password=PG_PASS
|
||
)
|
||
with conn.cursor() as cur:
|
||
cur.execute(
|
||
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s",
|
||
(symbol, start_ms, end_ms)
|
||
)
|
||
count = cur.fetchone()[0]
|
||
conn.close()
|
||
return count
|
||
|
||
|
||
def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID):
|
||
"""运行回测"""
|
||
engine = BacktestEngine(symbol)
|
||
conn = psycopg2.connect(
|
||
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
|
||
user=PG_USER, password=PG_PASS
|
||
)
|
||
# named cursor需要在事务内(不要设autocommit=True)
|
||
|
||
# 先warmup(用回测起点前4h的数据预热指标)
|
||
warmup_start = start_ms - warmup_ms
|
||
logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据")
|
||
warmup_count = 0
|
||
|
||
with conn.cursor("warmup_cursor") as cur:
|
||
cur.itersize = 50000
|
||
cur.execute(
|
||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
|
||
(symbol, warmup_start, start_ms)
|
||
)
|
||
for row in cur:
|
||
engine.state.process_trade(row[0], row[1], row[2], row[3], row[4])
|
||
warmup_count += 1
|
||
|
||
logger.info(f"预热完成: {warmup_count:,} 条")
|
||
|
||
# 回测主循环
|
||
total_count = load_trades(symbol, start_ms, end_ms)
|
||
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick, 评估间隔: {engine.eval_interval_ms}ms")
|
||
|
||
processed = 0
|
||
last_log = time.time()
|
||
|
||
with conn.cursor("backtest_cursor") as cur:
|
||
cur.itersize = 50000
|
||
cur.execute(
|
||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
|
||
(symbol, start_ms, end_ms)
|
||
)
|
||
for row in cur:
|
||
engine.process_tick(row[0], row[1], row[2], row[3], row[4])
|
||
processed += 1
|
||
|
||
# 每30秒打印进度
|
||
if time.time() - last_log > 30:
|
||
pct = processed / total_count * 100 if total_count > 0 else 0
|
||
logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R")
|
||
last_log = time.time()
|
||
|
||
conn.close()
|
||
|
||
# 强平所有未平仓位
|
||
for pos in engine.positions:
|
||
pos.status = "timeout"
|
||
pos.exit_ts = end_ms
|
||
pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price
|
||
engine.closed_positions.append(pos)
|
||
|
||
logger.info(f"回测完成: 处理 {processed:,} 条tick")
|
||
engine.print_report()
|
||
engine.print_trades()
|
||
|
||
return engine
|
||
|
||
|
||
# ─── CLI ─────────────────────────────────────────────────────────
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="V5.1 回测框架")
|
||
parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)")
|
||
parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)")
|
||
parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)")
|
||
parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)")
|
||
args = parser.parse_args()
|
||
|
||
if args.start and args.end:
|
||
start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||
end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||
start_ms = int(start_dt.timestamp() * 1000)
|
||
end_ms = int(end_dt.timestamp() * 1000)
|
||
else:
|
||
end_ms = int(time.time() * 1000)
|
||
start_ms = end_ms - args.days * 24 * 3600 * 1000
|
||
|
||
start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||
end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||
logger.info(f"回测参数: {args.symbol} | {start_str} → {end_str} | ~{args.days}天")
|
||
|
||
engine = run_backtest(args.symbol, start_ms, end_ms)
|
||
|
||
# 保存结果到JSON
|
||
import json
|
||
report = engine.report()
|
||
report["symbol"] = args.symbol
|
||
report["start"] = start_str
|
||
report["end"] = end_str
|
||
report["trades_count"] = len(engine.closed_positions)
|
||
|
||
output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||
with open(output_file, "w") as f:
|
||
json.dump(report, f, indent=2, ensure_ascii=False, default=str)
|
||
logger.info(f"报告已保存: {output_file}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|