arbitrage-engine/scripts/replay_paper_trades.py

322 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
全量重放脚本 v2用真实成交价(agg_trades) + 原始信号重算所有paper_trades的结果
修复点xiaofan审阅后
1. 事件判定按时间戳最先发生,不是固定优先级
2. TP1后进入半仓状态机pnl_r不重复计算全仓
3. flip与价格触发冲突谁时间早用谁同时间优先价格触发
4. 保本价显式常量,区分净保本/毛保本
"""
import psycopg2
from datetime import datetime, timezone, timedelta
BJ = timezone(timedelta(hours=8))
TIMEOUT_MS = 60 * 60 * 1000 # 60分钟
# 保本价偏移毛保本不含手续费仅防止SL=entry被1tick打掉
BE_OFFSET_LONG = 1.0005 # LONG: SL移到entry*1.0005
BE_OFFSET_SHORT = 0.9995 # SHORT: SL移到entry*0.9995
STRATEGY_CONFIG = {
'v52_8signals': {'sl': 2.1, 'tp1': 1.4, 'tp2': 3.15},
'v51_baseline': {'sl': 1.4, 'tp1': 1.05, 'tp2': 2.1},
}
def ts_bj(ts_ms):
return datetime.fromtimestamp(ts_ms / 1000, BJ).strftime('%m-%d %H:%M:%S')
def replay_trade(cur, tid, symbol, direction, strategy, entry_ts, atr):
cfg = STRATEGY_CONFIG.get(strategy)
if not cfg:
return None, f"未知策略: {strategy}"
# 1. 真实入场价entry_ts时刻最新成交价
cur.execute("""
SELECT price FROM agg_trades
WHERE symbol=%s AND time_ms <= %s
ORDER BY time_ms DESC LIMIT 1
""", (symbol, entry_ts))
row = cur.fetchone()
if not row:
return None, "找不到entry时刻的agg_trade"
entry = row[0]
# 2. 计算原始TP/SL
rd = cfg['sl'] * atr
if rd <= 0:
return None, f"risk_distance={rd}ATR={atr}无效"
if direction == 'LONG':
sl_orig = entry - rd
tp1 = entry + cfg['tp1'] * atr
tp2 = entry + cfg['tp2'] * atr
else:
sl_orig = entry + rd
tp1 = entry - cfg['tp1'] * atr
tp2 = entry - cfg['tp2'] * atr
timeout_ts = entry_ts + TIMEOUT_MS
# 3. 加载该时段内所有agg_trades价格按时间顺序
cur.execute("""
SELECT time_ms, price FROM agg_trades
WHERE symbol=%s AND time_ms > %s AND time_ms <= %s
ORDER BY time_ms ASC
""", (symbol, entry_ts, timeout_ts))
price_rows = cur.fetchall()
# 4. 加载该时段内第一个反向信号signal_flip检测
cur.execute("""
SELECT ts FROM signal_indicators
WHERE symbol=%s AND strategy=%s AND ts > %s AND ts <= %s
AND signal IS NOT NULL AND signal != ''
AND signal != %s
ORDER BY ts ASC
LIMIT 1
""", (symbol, strategy, entry_ts, timeout_ts, direction))
flip_row = cur.fetchone()
flip_ts = flip_row[0] if flip_row else None
# 5. 状态机:按时间顺序处理事件
sl_current = sl_orig # 当前有效SL可能移到保本价
tp1_hit = False
tp1_hit_ts = None
tp1_r = abs(tp1 - entry) / rd # 预计算TP1触发后固定
result_status = None
result_exit_ts = None
result_exit_price = None
result_pnl_r = None
for time_ms, price in price_rows:
# 关键先检查flip_ts是否比当前tick更早tie-break同时间优先价格触发
if flip_ts and flip_ts < time_ms:
# flip发生在这笔tick之前flip优先
cur.execute("""
SELECT price FROM agg_trades
WHERE symbol=%s AND time_ms <= %s
ORDER BY time_ms DESC LIMIT 1
""", (symbol, flip_ts))
fp = cur.fetchone()
flip_price = fp[0] if fp else price
if direction == 'LONG':
flip_pnl_half = (flip_price - entry) / rd
else:
flip_pnl_half = (entry - flip_price) / rd
if tp1_hit:
# 已TP1半仓已在tp1出剩余半仓在flip_price出
result_pnl_r = 0.5 * tp1_r + 0.5 * flip_pnl_half
else:
result_pnl_r = flip_pnl_half
result_status = 'signal_flip'
result_exit_ts = flip_ts
result_exit_price = flip_price
break
if direction == 'LONG':
if not tp1_hit:
if price <= sl_current:
result_status = 'sl'
result_exit_ts = time_ms
result_exit_price = sl_orig # 按挂单价成交
result_pnl_r = -1.0
break
if price >= tp1:
# TP1触发半仓止盈SL移保本
tp1_hit = True
tp1_hit_ts = time_ms
sl_current = entry * BE_OFFSET_LONG
else:
# 半仓状态机只剩50%仓位
if price <= sl_current:
# 保本SL触发
result_status = 'sl_be'
result_exit_ts = time_ms
result_exit_price = sl_current
result_pnl_r = 0.5 * tp1_r # 半仓TP1已实现
break
if price >= tp2:
# TP2触发
tp2_r = (tp2 - entry) / rd
result_status = 'tp'
result_exit_ts = time_ms
result_exit_price = tp2
result_pnl_r = 0.5 * tp1_r + 0.5 * tp2_r
break
else: # SHORT
if not tp1_hit:
if price >= sl_current:
result_status = 'sl'
result_exit_ts = time_ms
result_exit_price = sl_orig
result_pnl_r = -1.0
break
if price <= tp1:
tp1_hit = True
tp1_hit_ts = time_ms
sl_current = entry * BE_OFFSET_SHORT
else:
if price >= sl_current:
result_status = 'sl_be'
result_exit_ts = time_ms
result_exit_price = sl_current
result_pnl_r = 0.5 * tp1_r
break
if price <= tp2:
tp2_r = (entry - tp2) / rd
result_status = 'tp'
result_exit_ts = time_ms
result_exit_price = tp2
result_pnl_r = 0.5 * tp1_r + 0.5 * tp2_r
break
# Timeout或循环结束未触发
if not result_status:
# 检查flip_ts是否在timeout范围内但没被price_rows覆盖到
if flip_ts and flip_ts <= timeout_ts:
cur.execute("""
SELECT price FROM agg_trades
WHERE symbol=%s AND time_ms <= %s
ORDER BY time_ms DESC LIMIT 1
""", (symbol, flip_ts))
fp = cur.fetchone()
flip_price = fp[0] if fp else entry
if direction == 'LONG':
flip_pnl_half = (flip_price - entry) / rd
else:
flip_pnl_half = (entry - flip_price) / rd
if tp1_hit:
result_pnl_r = 0.5 * tp1_r + 0.5 * flip_pnl_half
else:
result_pnl_r = flip_pnl_half
result_status = 'signal_flip'
result_exit_ts = flip_ts
result_exit_price = flip_price
else:
# 真正的timeout
cur.execute("""
SELECT price FROM agg_trades
WHERE symbol=%s AND time_ms <= %s
ORDER BY time_ms DESC LIMIT 1
""", (symbol, timeout_ts))
lp = cur.fetchone()
exit_price = lp[0] if lp else entry
if direction == 'LONG':
timeout_half_pnl = (exit_price - entry) / rd
else:
timeout_half_pnl = (entry - exit_price) / rd
if tp1_hit:
result_pnl_r = 0.5 * tp1_r + 0.5 * timeout_half_pnl
else:
result_pnl_r = timeout_half_pnl
result_status = 'timeout'
result_exit_ts = timeout_ts
result_exit_price = exit_price
# 扣手续费: fee_r = 2 * taker_rate * entry / rd开仓+平仓各一次)
PAPER_FEE_RATE = 0.0005
fee_r = 2 * PAPER_FEE_RATE * entry / rd if rd > 0 else 0
result_pnl_r -= fee_r
return {
'id': tid,
'entry': entry,
'rd': rd,
'tp1': tp1,
'tp2': tp2,
'sl_orig': sl_orig,
'tp1_hit': tp1_hit,
'tp1_hit_ts': tp1_hit_ts,
'status': result_status,
'exit_ts': result_exit_ts,
'exit_price': result_exit_price,
'pnl_r': round(result_pnl_r, 4),
}, None
def main(dry_run=False):
conn = psycopg2.connect(host='10.106.0.3', dbname='arb_engine', user='arb', password='arb_engine_2026')
cur = conn.cursor()
cur.execute("""
SELECT id, symbol, direction, strategy, entry_ts, atr_at_entry
FROM paper_trades
WHERE atr_at_entry > 0
AND status NOT IN ('active', 'tp1_hit')
AND COALESCE(calc_version, 0) < 2
ORDER BY id ASC
""")
trades = cur.fetchall()
print(f"总计 {len(trades)} 笔待重放")
results = []
errors = []
for tid, symbol, direction, strategy, entry_ts, atr in trades:
res, err = replay_trade(cur, tid, symbol, direction, strategy, entry_ts, atr)
if err:
errors.append((tid, err))
else:
results.append((res, strategy))
print(f"成功重放: {len(results)}, 错误: {len(errors)}")
for e in errors[:5]:
print(f" 错误: {e}")
if not dry_run and results:
for res, _ in results:
r = res
cur.execute("""
UPDATE paper_trades SET
entry_price = %s,
tp1_price = %s,
tp2_price = %s,
sl_price = %s,
tp1_hit = %s,
status = %s,
exit_price = %s,
exit_ts = %s,
pnl_r = %s,
risk_distance = %s,
price_source = 'last_trade',
calc_version = 2
WHERE id = %s
""", (
r['entry'], r['tp1'], r['tp2'], r['sl_orig'],
r['tp1_hit'], r['status'], r['exit_price'], r['exit_ts'],
r['pnl_r'], r['rd'], r['id']
))
conn.commit()
print(f"已写入 {len(results)}")
# 统计
by_strategy = {}
for res, strat in results:
if strat not in by_strategy:
by_strategy[strat] = {'n': 0, 'wins': 0, 'total_r': 0.0, 'status': {}}
by_strategy[strat]['n'] += 1
by_strategy[strat]['total_r'] += res['pnl_r']
if res['pnl_r'] > 0:
by_strategy[strat]['wins'] += 1
s = res['status']
by_strategy[strat]['status'][s] = by_strategy[strat]['status'].get(s, 0) + 1
print("\n===== 重放统计(真实价口径)=====")
for strat, s in sorted(by_strategy.items()):
win_pct = s['wins'] / s['n'] * 100 if s['n'] > 0 else 0
print(f" {strat}: {s['n']}笔, 胜率{win_pct:.1f}%, 总R={s['total_r']:+.2f}")
for st, c in sorted(s['status'].items()):
print(f" {st}: {c}")
conn.close()
if __name__ == '__main__':
import sys
dry = '--dry-run' in sys.argv
print(f"模式: {'DRY RUN不写入' if dry else '正式写入'}")
main(dry_run=dry)