arbitrage-engine/backend/paper_trading.py

203 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
paper_trading.py — 模拟盘开仓/平仓辅助函数
从原来的 signal_engine.py 拆分出的 paper_trades 辅助逻辑:
- paper_open_trade(): 写入 paper_trades 开仓记录;
- paper_has_active_position() / paper_get_active_direction()
- paper_close_by_signal() / paper_active_count()。
行为保持与原实现完全一致,供 signal_engine 调用。
"""
from typing import Optional
from db import get_sync_conn
from signal_engine import PAPER_FEE_RATE # 复用全局配置
def paper_open_trade(
symbol: str,
direction: str,
price: float,
score: int,
tier: str,
atr: float,
now_ms: int,
factors: dict = None,
strategy: str = "v51_baseline",
tp_sl: Optional[dict] = None,
strategy_id: Optional[str] = None,
strategy_name_snapshot: Optional[str] = None,
logger=None,
):
"""模拟开仓:写入 paper_trades"""
import json as _json3
if atr <= 0:
return
tp_sl_cfg = tp_sl or {}
sl_multiplier = float(tp_sl_cfg.get("sl_multiplier", 2.0))
# 支持两种配置方式:
# - 新版 v5.4 strategies 表tp1_ratio / tp2_ratio = 以 R 计的目标(× risk_distance
# - 旧版 v5.2/v5.3 JSON 策略tp1_multiplier / tp2_multiplier = 以 ATR 计的目标
tp1_ratio = tp_sl_cfg.get("tp1_ratio")
tp2_ratio = tp_sl_cfg.get("tp2_ratio")
use_r_based = tp1_ratio is not None and tp2_ratio is not None
if use_r_based:
tp1_ratio = float(tp1_ratio)
tp2_ratio = float(tp2_ratio)
else:
tp1_multiplier = float(tp_sl_cfg.get("tp1_multiplier", 1.5))
tp2_multiplier = float(tp_sl_cfg.get("tp2_multiplier", 3.0))
# 统一定义1R = SL 距离 = sl_multiplier × ATR
risk_distance = sl_multiplier * atr
if direction == "LONG":
sl = price - risk_distance
if use_r_based:
tp1 = price + tp1_ratio * risk_distance
tp2 = price + tp2_ratio * risk_distance
else:
tp1 = price + tp1_multiplier * atr
tp2 = price + tp2_multiplier * atr
else:
sl = price + risk_distance
if use_r_based:
tp1 = price - tp1_ratio * risk_distance
tp2 = price - tp2_ratio * risk_distance
else:
tp1 = price - tp1_multiplier * atr
tp2 = price - tp2_multiplier * atr
# SL 合理性校验:实际距离必须在 risk_distance 的 80%~120% 范围内
actual_sl_dist = abs(sl - price)
if actual_sl_dist < risk_distance * 0.8 or actual_sl_dist > risk_distance * 1.2:
if logger:
logger.error(
f"[{symbol}] ⚠️ SL校验失败拒绝开仓: direction={direction} price={price:.4f} "
f"sl={sl:.4f} actual_dist={actual_sl_dist:.4f} expected={risk_distance:.4f} atr={atr:.4f}"
)
return
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO paper_trades (symbol,direction,score,tier,entry_price,entry_ts,tp1_price,tp2_price,sl_price,atr_at_entry,score_factors,strategy,risk_distance,strategy_id,strategy_name_snapshot) "
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
(
symbol,
direction,
score,
tier,
price,
now_ms,
tp1,
tp2,
sl,
atr,
_json3.dumps(factors) if factors else None,
strategy,
risk_distance,
strategy_id,
strategy_name_snapshot,
),
)
conn.commit()
if logger:
logger.info(
f"[{symbol}] 📝 模拟开仓: {direction} @ {price:.2f} score={score} tier={tier} strategy={strategy} "
f"TP1={tp1:.2f} TP2={tp2:.2f} SL={sl:.2f}"
)
def paper_has_active_position(symbol: str, strategy: Optional[str] = None) -> bool:
"""检查该币种是否有活跃持仓"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
if strategy:
cur.execute(
"SELECT COUNT(*) FROM paper_trades WHERE symbol=%s AND strategy=%s AND status IN ('active','tp1_hit')",
(symbol, strategy),
)
else:
cur.execute(
"SELECT COUNT(*) FROM paper_trades WHERE symbol=%s AND status IN ('active','tp1_hit')",
(symbol,),
)
return cur.fetchone()[0] > 0
def paper_get_active_direction(symbol: str, strategy: Optional[str] = None) -> str | None:
"""获取该币种活跃持仓的方向,无持仓返回 None"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
if strategy:
cur.execute(
"SELECT direction FROM paper_trades WHERE symbol=%s AND strategy=%s AND status IN ('active','tp1_hit') LIMIT 1",
(symbol, strategy),
)
else:
cur.execute(
"SELECT direction FROM paper_trades WHERE symbol=%s AND status IN ('active','tp1_hit') LIMIT 1",
(symbol,),
)
row = cur.fetchone()
return row[0] if row else None
def paper_close_by_signal(symbol: str, current_price: float, now_ms: int, strategy: Optional[str] = None, logger=None):
"""反向信号平仓:按当前价平掉该币种所有活跃仓位"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
if strategy:
cur.execute(
"SELECT id, direction, entry_price, tp1_hit, atr_at_entry, risk_distance "
"FROM paper_trades WHERE symbol=%s AND strategy=%s AND status IN ('active','tp1_hit')",
(symbol, strategy),
)
else:
cur.execute(
"SELECT id, direction, entry_price, tp1_hit, atr_at_entry, risk_distance "
"FROM paper_trades WHERE symbol=%s AND status IN ('active','tp1_hit')",
(symbol,),
)
positions = cur.fetchall()
for pos in positions:
pid, direction, entry_price, tp1_hit, atr_entry, rd_db = pos
risk_distance = rd_db if rd_db and rd_db > 0 else abs(entry_price * 0.01)
if direction == "LONG":
pnl_r = (current_price - entry_price) / risk_distance if risk_distance > 0 else 0
else:
pnl_r = (entry_price - current_price) / risk_distance if risk_distance > 0 else 0
# 扣手续费
fee_r = (2 * PAPER_FEE_RATE * entry_price) / risk_distance if risk_distance > 0 else 0
pnl_r -= fee_r
cur.execute(
"UPDATE paper_trades SET status='signal_flip', exit_price=%s, exit_ts=%s, pnl_r=%s WHERE id=%s",
(current_price, now_ms, round(pnl_r, 4), pid),
)
if logger:
logger.info(
f"[{symbol}] 📝 反向信号平仓: {direction} @ {current_price:.2f} pnl={pnl_r:+.2f}R"
f"{f' strategy={strategy}' if strategy else ''}"
)
conn.commit()
def paper_active_count(strategy: Optional[str] = None) -> int:
"""当前活跃持仓总数(按策略独立计数)"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
if strategy:
cur.execute(
"SELECT COUNT(*) FROM paper_trades WHERE strategy=%s AND status IN ('active','tp1_hit')",
(strategy,),
)
else:
cur.execute(
"SELECT COUNT(*) FROM paper_trades WHERE status IN ('active','tp1_hit')"
)
return cur.fetchone()[0]