feat: V5.1 backtest framework - tick-by-tick replay with TP/SL/position management
This commit is contained in:
parent
2e969f68b4
commit
5ba4c7fe98
451
backend/backtest.py
Normal file
451
backend/backtest.py
Normal file
@ -0,0 +1,451 @@
|
||||
"""
|
||||
backtest.py — V5.1 回测框架(逐tick事件回放)
|
||||
|
||||
架构:
|
||||
- 从PG读取历史agg_trades,按时间顺序逐tick回放
|
||||
- 复用signal_engine的SymbolState评估逻辑
|
||||
- 模拟开仓/平仓/止盈止损
|
||||
- 输出:胜率/盈亏比/夏普/MDD等统计
|
||||
|
||||
用法:
|
||||
python3 backtest.py --symbol BTCUSDT --days 20
|
||||
python3 backtest.py --symbol BTCUSDT --start 2026-02-08 --end 2026-02-28
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import psycopg2
|
||||
|
||||
# 复用signal_engine的核心类
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
from signal_engine import SymbolState, WINDOW_MID
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("backtest")
|
||||
|
||||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
||||
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
||||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||
PG_USER = os.getenv("PG_USER", "arb")
|
||||
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
||||
|
||||
|
||||
# ─── 仓位管理 ────────────────────────────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
"""单笔仓位"""
|
||||
entry_ts: int
|
||||
entry_price: float
|
||||
direction: str # "LONG" or "SHORT"
|
||||
size: float # 仓位大小(R单位)
|
||||
tier: str # "light", "standard", "heavy"
|
||||
score: int
|
||||
sl_price: float
|
||||
tp1_price: float
|
||||
tp2_price: float
|
||||
tp1_hit: bool = False
|
||||
status: str = "active" # active, tp1_hit, tp, sl, sl_be, timeout
|
||||
exit_ts: int = 0
|
||||
exit_price: float = 0.0
|
||||
pnl_r: float = 0.0
|
||||
|
||||
def check_exit(self, time_ms: int, price: float, atr: float) -> bool:
|
||||
"""检查是否触发止盈止损,返回True表示已平仓"""
|
||||
if self.status != "active" and self.status != "tp1_hit":
|
||||
return True
|
||||
|
||||
if self.direction == "LONG":
|
||||
# 止损
|
||||
if price <= self.sl_price:
|
||||
self.exit_ts = time_ms
|
||||
self.exit_price = price
|
||||
if self.tp1_hit:
|
||||
self.status = "sl_be"
|
||||
self.pnl_r = 0.5 * 1.5 # TP1已平一半,剩余保本
|
||||
else:
|
||||
self.status = "sl"
|
||||
self.pnl_r = -1.0
|
||||
return True
|
||||
# TP1
|
||||
if not self.tp1_hit and price >= self.tp1_price:
|
||||
self.tp1_hit = True
|
||||
self.status = "tp1_hit"
|
||||
# 移动止损到成本价
|
||||
self.sl_price = self.entry_price * 1.0005 # 加手续费
|
||||
# TP2
|
||||
if self.tp1_hit and price >= self.tp2_price:
|
||||
self.exit_ts = time_ms
|
||||
self.exit_price = price
|
||||
self.status = "tp"
|
||||
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0 # TP1=1.5R的50% + TP2=3.0R的50%
|
||||
return True
|
||||
else: # SHORT
|
||||
# 止损
|
||||
if price >= self.sl_price:
|
||||
self.exit_ts = time_ms
|
||||
self.exit_price = price
|
||||
if self.tp1_hit:
|
||||
self.status = "sl_be"
|
||||
self.pnl_r = 0.5 * 1.5
|
||||
else:
|
||||
self.status = "sl"
|
||||
self.pnl_r = -1.0
|
||||
return True
|
||||
# TP1
|
||||
if not self.tp1_hit and price <= self.tp1_price:
|
||||
self.tp1_hit = True
|
||||
self.status = "tp1_hit"
|
||||
self.sl_price = self.entry_price * 0.9995
|
||||
# TP2
|
||||
if self.tp1_hit and price <= self.tp2_price:
|
||||
self.exit_ts = time_ms
|
||||
self.exit_price = price
|
||||
self.status = "tp"
|
||||
self.pnl_r = 0.5 * 1.5 + 0.5 * 3.0
|
||||
return True
|
||||
|
||||
# 时间止损:持仓超过60分钟强平
|
||||
if time_ms - self.entry_ts > 60 * 60 * 1000:
|
||||
self.exit_ts = time_ms
|
||||
self.exit_price = price
|
||||
self.status = "timeout"
|
||||
if self.direction == "LONG":
|
||||
move = (price - self.entry_price) / self.entry_price
|
||||
else:
|
||||
move = (self.entry_price - price) / self.entry_price
|
||||
risk_atr_pct = abs(self.sl_price - self.entry_price) / self.entry_price
|
||||
self.pnl_r = move / risk_atr_pct if risk_atr_pct > 0 else 0
|
||||
if self.tp1_hit:
|
||||
self.pnl_r = max(self.pnl_r, 0.5 * 1.5) # 至少保TP1的收益
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# ─── 回测引擎 ────────────────────────────────────────────────────
|
||||
|
||||
class BacktestEngine:
|
||||
def __init__(self, symbol: str):
|
||||
self.symbol = symbol
|
||||
self.state = SymbolState(symbol)
|
||||
self.state.warmup = False # 回测不需要warmup标记
|
||||
self.positions: list[Position] = []
|
||||
self.closed_positions: list[Position] = []
|
||||
self.cooldown_until: int = 0
|
||||
self.equity_curve: list[tuple[int, float]] = [] # (ts, cumulative_pnl_r)
|
||||
self.cumulative_pnl = 0.0
|
||||
self.trade_count = 0
|
||||
|
||||
def process_tick(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||||
"""处理单个tick"""
|
||||
# 1. 更新状态
|
||||
self.state.process_trade(agg_id, time_ms, price, qty, is_buyer_maker)
|
||||
|
||||
# 2. 检查持仓止盈止损
|
||||
atr = self.state.atr_calc.atr
|
||||
for pos in self.positions[:]:
|
||||
if pos.check_exit(time_ms, price, atr):
|
||||
self.positions.remove(pos)
|
||||
self.closed_positions.append(pos)
|
||||
self.cumulative_pnl += pos.pnl_r
|
||||
self.equity_curve.append((time_ms, self.cumulative_pnl))
|
||||
|
||||
# 3. 评估新信号
|
||||
if time_ms < self.cooldown_until:
|
||||
return
|
||||
if len(self.positions) > 0:
|
||||
return # 不同时持多个仓位
|
||||
|
||||
result = self.state.evaluate_signal(time_ms)
|
||||
signal = result.get("signal")
|
||||
if not signal:
|
||||
return
|
||||
|
||||
score = result.get("score", 0)
|
||||
if score < 60:
|
||||
return # V5.1门槛
|
||||
|
||||
# 确定仓位档
|
||||
if score >= 85:
|
||||
tier = "heavy"
|
||||
size_mult = 1.3
|
||||
elif score >= 75:
|
||||
tier = "standard"
|
||||
size_mult = 1.0
|
||||
else:
|
||||
tier = "light"
|
||||
size_mult = 0.5
|
||||
|
||||
# 计算TP/SL
|
||||
risk_atr = 0.7 * self.state.atr_calc.atr # 简化版,只用5m ATR
|
||||
if risk_atr <= 0:
|
||||
return
|
||||
|
||||
if signal == "LONG":
|
||||
sl = price - 2.0 * risk_atr
|
||||
tp1 = price + 1.5 * risk_atr
|
||||
tp2 = price + 3.0 * risk_atr
|
||||
else:
|
||||
sl = price + 2.0 * risk_atr
|
||||
tp1 = price - 1.5 * risk_atr
|
||||
tp2 = price - 3.0 * risk_atr
|
||||
|
||||
pos = Position(
|
||||
entry_ts=time_ms,
|
||||
entry_price=price,
|
||||
direction=signal,
|
||||
size=size_mult,
|
||||
tier=tier,
|
||||
score=score,
|
||||
sl_price=sl,
|
||||
tp1_price=tp1,
|
||||
tp2_price=tp2,
|
||||
)
|
||||
self.positions.append(pos)
|
||||
self.trade_count += 1
|
||||
|
||||
# 冷却
|
||||
self.cooldown_until = time_ms + 10 * 60 * 1000
|
||||
|
||||
def report(self) -> dict:
|
||||
"""生成回测统计报告"""
|
||||
all_trades = self.closed_positions
|
||||
if not all_trades:
|
||||
return {"error": "没有交易记录"}
|
||||
|
||||
total = len(all_trades)
|
||||
wins = [t for t in all_trades if t.pnl_r > 0]
|
||||
losses = [t for t in all_trades if t.pnl_r <= 0]
|
||||
win_rate = len(wins) / total * 100 if total > 0 else 0
|
||||
|
||||
total_pnl = sum(t.pnl_r for t in all_trades)
|
||||
gross_profit = sum(t.pnl_r for t in wins) if wins else 0
|
||||
gross_loss = abs(sum(t.pnl_r for t in losses)) if losses else 0
|
||||
profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
|
||||
|
||||
# 平均盈亏
|
||||
avg_win = gross_profit / len(wins) if wins else 0
|
||||
avg_loss = gross_loss / len(losses) if losses else 0
|
||||
avg_rr = avg_win / avg_loss if avg_loss > 0 else float('inf')
|
||||
|
||||
# 最大回撤 (MDD)
|
||||
peak = 0.0
|
||||
mdd = 0.0
|
||||
running = 0.0
|
||||
for t in all_trades:
|
||||
running += t.pnl_r
|
||||
peak = max(peak, running)
|
||||
dd = peak - running
|
||||
mdd = max(mdd, dd)
|
||||
|
||||
# 平均持仓时间
|
||||
hold_times = [(t.exit_ts - t.entry_ts) / 60000 for t in all_trades if t.exit_ts > 0]
|
||||
avg_hold_min = sum(hold_times) / len(hold_times) if hold_times else 0
|
||||
|
||||
# 夏普比率(简化版,用R回报序列)
|
||||
returns = [t.pnl_r for t in all_trades]
|
||||
if len(returns) > 1:
|
||||
import statistics
|
||||
avg_ret = statistics.mean(returns)
|
||||
std_ret = statistics.stdev(returns)
|
||||
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0 # 年化
|
||||
else:
|
||||
sharpe = 0
|
||||
|
||||
# 按状态分类
|
||||
status_counts = {}
|
||||
for t in all_trades:
|
||||
status_counts[t.status] = status_counts.get(t.status, 0) + 1
|
||||
|
||||
# 按tier分类
|
||||
tier_counts = {}
|
||||
for t in all_trades:
|
||||
tier_counts[t.tier] = tier_counts.get(t.tier, 0) + 1
|
||||
|
||||
# LONG vs SHORT
|
||||
longs = [t for t in all_trades if t.direction == "LONG"]
|
||||
shorts = [t for t in all_trades if t.direction == "SHORT"]
|
||||
long_wr = len([t for t in longs if t.pnl_r > 0]) / len(longs) * 100 if longs else 0
|
||||
short_wr = len([t for t in shorts if t.pnl_r > 0]) / len(shorts) * 100 if shorts else 0
|
||||
|
||||
return {
|
||||
"总交易数": total,
|
||||
"胜率": f"{win_rate:.1f}%",
|
||||
"总盈亏(R)": f"{total_pnl:+.2f}R",
|
||||
"盈亏比(Profit Factor)": f"{profit_factor:.2f}",
|
||||
"平均盈利(R)": f"{avg_win:.2f}R",
|
||||
"平均亏损(R)": f"-{avg_loss:.2f}R",
|
||||
"盈亏比(Win/Loss)": f"{avg_rr:.2f}",
|
||||
"夏普比率": f"{sharpe:.2f}",
|
||||
"最大回撤(MDD)": f"{mdd:.2f}R",
|
||||
"平均持仓(分钟)": f"{avg_hold_min:.1f}",
|
||||
"做多胜率": f"{long_wr:.1f}% ({len(longs)}笔)",
|
||||
"做空胜率": f"{short_wr:.1f}% ({len(shorts)}笔)",
|
||||
"状态分布": status_counts,
|
||||
"仓位档分布": tier_counts,
|
||||
}
|
||||
|
||||
def print_report(self):
|
||||
"""打印回测报告"""
|
||||
report = self.report()
|
||||
print("\n" + "=" * 60)
|
||||
print(f" V5.1 回测报告 — {self.symbol}")
|
||||
print("=" * 60)
|
||||
for k, v in report.items():
|
||||
if isinstance(v, dict):
|
||||
print(f" {k}:")
|
||||
for kk, vv in v.items():
|
||||
print(f" {kk}: {vv}")
|
||||
else:
|
||||
print(f" {k}: {v}")
|
||||
print("=" * 60)
|
||||
|
||||
def print_trades(self, limit: int = 20):
|
||||
"""打印最近的交易记录"""
|
||||
print(f"\n最近 {min(limit, len(self.closed_positions))} 笔交易:")
|
||||
print(f"{'方向':<6} {'分数':<5} {'档位':<8} {'入场价':<12} {'出场价':<12} {'PnL(R)':<8} {'状态':<8} {'持仓(分)':<8}")
|
||||
print("-" * 72)
|
||||
for t in self.closed_positions[-limit:]:
|
||||
hold_min = (t.exit_ts - t.entry_ts) / 60000 if t.exit_ts > 0 else 0
|
||||
pnl_str = f"{t.pnl_r:+.2f}"
|
||||
print(f"{t.direction:<6} {t.score:<5} {t.tier:<8} {t.entry_price:<12.2f} {t.exit_price:<12.2f} {pnl_str:<8} {t.status:<8} {hold_min:<8.1f}")
|
||||
|
||||
|
||||
# ─── 数据加载 ────────────────────────────────────────────────────
|
||||
|
||||
def load_trades(symbol: str, start_ms: int, end_ms: int) -> int:
|
||||
"""返回数据总条数"""
|
||||
conn = psycopg2.connect(
|
||||
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
|
||||
user=PG_USER, password=PG_PASS
|
||||
)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s",
|
||||
(symbol, start_ms, end_ms)
|
||||
)
|
||||
count = cur.fetchone()[0]
|
||||
conn.close()
|
||||
return count
|
||||
|
||||
|
||||
def run_backtest(symbol: str, start_ms: int, end_ms: int, warmup_ms: int = WINDOW_MID):
|
||||
"""运行回测"""
|
||||
engine = BacktestEngine(symbol)
|
||||
conn = psycopg2.connect(
|
||||
host=PG_HOST, port=PG_PORT, dbname=PG_DB,
|
||||
user=PG_USER, password=PG_PASS
|
||||
)
|
||||
conn.set_session(autocommit=True)
|
||||
|
||||
# 先warmup(用回测起点前4h的数据预热指标)
|
||||
warmup_start = start_ms - warmup_ms
|
||||
logger.info(f"预热中... 加载 {warmup_ms // 3600000}h 历史数据")
|
||||
warmup_count = 0
|
||||
|
||||
with conn.cursor("warmup_cursor") as cur:
|
||||
cur.itersize = 10000
|
||||
cur.execute(
|
||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms < %s ORDER BY agg_id ASC",
|
||||
(symbol, warmup_start, start_ms)
|
||||
)
|
||||
for row in cur:
|
||||
engine.state.process_trade(row[0], row[1], row[2], row[3], row[4])
|
||||
warmup_count += 1
|
||||
|
||||
logger.info(f"预热完成: {warmup_count:,} 条")
|
||||
|
||||
# 回测主循环
|
||||
total_count = load_trades(symbol, start_ms, end_ms)
|
||||
logger.info(f"开始回测: {symbol}, {total_count:,} 条tick")
|
||||
|
||||
processed = 0
|
||||
last_log = time.time()
|
||||
|
||||
with conn.cursor("backtest_cursor") as cur:
|
||||
cur.itersize = 10000
|
||||
cur.execute(
|
||||
"SELECT agg_id, time_ms, price, qty, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND time_ms >= %s AND time_ms <= %s ORDER BY agg_id ASC",
|
||||
(symbol, start_ms, end_ms)
|
||||
)
|
||||
for row in cur:
|
||||
engine.process_tick(row[0], row[1], row[2], row[3], row[4])
|
||||
processed += 1
|
||||
|
||||
# 每30秒打印进度
|
||||
if time.time() - last_log > 30:
|
||||
pct = processed / total_count * 100 if total_count > 0 else 0
|
||||
logger.info(f"进度: {processed:,}/{total_count:,} ({pct:.1f}%) | 交易: {engine.trade_count} | PnL: {engine.cumulative_pnl:+.2f}R")
|
||||
last_log = time.time()
|
||||
|
||||
conn.close()
|
||||
|
||||
# 强平所有未平仓位
|
||||
for pos in engine.positions:
|
||||
pos.status = "timeout"
|
||||
pos.exit_ts = end_ms
|
||||
pos.exit_price = engine.state.win_fast.trades[-1][2] if engine.state.win_fast.trades else pos.entry_price
|
||||
engine.closed_positions.append(pos)
|
||||
|
||||
logger.info(f"回测完成: 处理 {processed:,} 条tick")
|
||||
engine.print_report()
|
||||
engine.print_trades()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
# ─── CLI ─────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="V5.1 回测框架")
|
||||
parser.add_argument("--symbol", default="BTCUSDT", help="交易对 (default: BTCUSDT)")
|
||||
parser.add_argument("--days", type=int, default=7, help="回测天数 (default: 7)")
|
||||
parser.add_argument("--start", type=str, help="开始日期 (YYYY-MM-DD)")
|
||||
parser.add_argument("--end", type=str, help="结束日期 (YYYY-MM-DD)")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.start and args.end:
|
||||
start_dt = datetime.strptime(args.start, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
end_dt = datetime.strptime(args.end, "%Y-%m-%d").replace(tzinfo=timezone.utc)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = int(end_dt.timestamp() * 1000)
|
||||
else:
|
||||
end_ms = int(time.time() * 1000)
|
||||
start_ms = end_ms - args.days * 24 * 3600 * 1000
|
||||
|
||||
start_str = datetime.fromtimestamp(start_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||||
end_str = datetime.fromtimestamp(end_ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||||
logger.info(f"回测参数: {args.symbol} | {start_str} → {end_str} | ~{args.days}天")
|
||||
|
||||
engine = run_backtest(args.symbol, start_ms, end_ms)
|
||||
|
||||
# 保存结果到JSON
|
||||
import json
|
||||
report = engine.report()
|
||||
report["symbol"] = args.symbol
|
||||
report["start"] = start_str
|
||||
report["end"] = end_str
|
||||
report["trades_count"] = len(engine.closed_positions)
|
||||
|
||||
output_file = f"backtest_{args.symbol}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False, default=str)
|
||||
logger.info(f"报告已保存: {output_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user