arbitrage-engine/backend/strategy_loader.py

191 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
strategy_loader.py — 从 JSON 文件 / DB 加载策略配置
从原来的 signal_engine.py 拆分出的策略加载逻辑:
- load_strategy_configs(): 从 backend/strategies/*.json 读取配置;
- load_strategy_configs_from_db(): 从 strategies 表读取 running 策略并映射到 cfg dict。
行为保持与原实现完全一致,用于给 signal_engine 等调用方复用。
"""
import json
import logging
import os
from typing import Any
from db import get_sync_conn
logger = logging.getLogger("strategy-loader")
STRATEGY_DIR = os.path.join(os.path.dirname(__file__), "strategies")
DEFAULT_STRATEGY_FILES = [
# 仅保留 V5.3 系列作为本地默认策略
"v53.json",
"v53_fast.json",
"v53_middle.json",
]
def load_strategy_configs() -> list[dict]:
"""从本地 JSON 文件加载默认策略配置"""
configs: list[dict[str, Any]] = []
for filename in DEFAULT_STRATEGY_FILES:
path = os.path.join(STRATEGY_DIR, filename)
try:
with open(path, "r", encoding="utf-8") as f:
cfg = json.load(f)
if isinstance(cfg, dict) and cfg.get("name"):
configs.append(cfg)
except FileNotFoundError:
logger.warning(f"策略配置缺失: {path}")
except Exception as e:
logger.error(f"策略配置加载失败 {path}: {e}")
if not configs:
logger.warning("未加载到策略配置,回退到 v53 默认配置")
configs.append(
{
"name": "v53",
"threshold": 75,
"flip_threshold": 85,
"tp_sl": {
"sl_multiplier": 2.0,
"tp1_multiplier": 1.5,
"tp2_multiplier": 3.0,
},
# 默认支持四个主交易对其他细节gates/symbol_gates
# 在 evaluate_factory_strategy 内部有安全的默认值。
"symbols": ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"],
}
)
return configs
def load_strategy_configs_from_db() -> list[dict]:
"""
V5.4: 从 strategies 表读取 running 状态的策略配置。
把 DB 字段映射成现有 JSON 格式(保持与 JSON 文件完全兼容)。
失败时返回空列表,调用方应 fallback 到 JSON。
内存安全:每次读取只返回配置列表,无缓存,无大对象。
"""
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"""
SELECT
strategy_id::text, display_name, symbol,
cvd_fast_window, cvd_slow_window,
weight_direction, weight_env, weight_aux, weight_momentum,
entry_score,
gate_obi_enabled, obi_threshold,
gate_whale_enabled, whale_usd_threshold, whale_flow_pct,
gate_vol_enabled, vol_atr_pct_min,
gate_cvd_enabled,
gate_spot_perp_enabled, spot_perp_threshold,
sl_atr_multiplier, tp1_ratio, tp2_ratio,
timeout_minutes, flip_threshold, direction
FROM strategies
WHERE status = 'running'
ORDER BY created_at ASC
"""
)
rows = cur.fetchall()
configs: list[dict[str, Any]] = []
for row in rows:
(
sid,
display_name,
symbol,
cvd_fast,
cvd_slow,
w_dir,
w_env,
w_aux,
w_mom,
entry_score,
gate_obi,
obi_thr,
gate_whale,
whale_usd_thr,
whale_flow_pct_val,
gate_vol,
vol_atr_pct,
gate_cvd,
gate_spot,
spot_thr,
sl_mult,
tp1_r,
tp2_r,
timeout_min,
flip_thr,
direction,
) = row
# 把 display_name 映射回 legacy strategy name用于兼容评分逻辑
# legacy 策略用固定 UUID 识别
LEGACY_UUID_MAP = {
"00000000-0000-0000-0000-000000000053": "v53",
"00000000-0000-0000-0000-000000000054": "v53_middle",
"00000000-0000-0000-0000-000000000055": "v53_fast",
}
strategy_name = LEGACY_UUID_MAP.get(sid, f"custom_{sid[:8]}")
cfg: dict[str, Any] = {
"name": strategy_name,
"strategy_id": sid, # V5.4 新增:用于写 strategy_id 到 DB
"strategy_name_snapshot": display_name,
"symbol": symbol,
"direction": direction,
"cvd_fast_window": cvd_fast,
"cvd_slow_window": cvd_slow,
"threshold": entry_score,
"weights": {
"direction": w_dir,
"env": w_env,
"aux": w_aux,
"momentum": w_mom,
},
"gates": {
"vol": {
"enabled": gate_vol,
"vol_atr_pct_min": float(vol_atr_pct or 0.002),
},
"cvd": {"enabled": gate_cvd},
"whale": {
"enabled": gate_whale,
"whale_usd_threshold": float(whale_usd_thr or 50000),
"whale_flow_pct": float(whale_flow_pct_val or 0.5),
},
"obi": {
"enabled": gate_obi,
"threshold": float(obi_thr or 0.35),
},
"spot_perp": {
"enabled": gate_spot,
"threshold": float(spot_thr or 0.005),
},
},
"tp_sl": {
# V5.4: 统一采用“以 R 计”的配置:
# risk_distance = sl_atr_multiplier × ATR = 1R
# TP1 = entry ± tp1_ratio × risk_distance
# TP2 = entry ± tp2_ratio × risk_distance
"sl_multiplier": sl_mult,
"tp1_ratio": tp1_r,
"tp2_ratio": tp2_r,
},
"timeout_minutes": timeout_min,
"flip_threshold": flip_thr,
}
configs.append(cfg)
logger.info(
f"[DB] 已加载 {len(configs)} 个策略配置: {[c['name'] for c in configs]}"
)
return configs
except Exception as e:
logger.warning(f"[DB] load_strategy_configs_from_db 失败,将 fallback 到 JSON: {e}")
return []