arbitrage-engine/backend/main.py

2846 lines
105 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
import httpx
from datetime import datetime, timedelta
import asyncio, time, os, json
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
from db import (
init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute,
close_async_pool,
)
import datetime as _dt
app = FastAPI(title="Arbitrage Engine API")
app.add_middleware(
CORSMiddleware,
allow_origins=["https://arb.zhouyangclaw.com"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth_router)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
# 简单内存缓存history/stats 60秒rates 3秒
_cache: dict = {}
def get_cache(key: str, ttl: int):
entry = _cache.get(key)
if entry and time.time() - entry["ts"] < ttl:
return entry["data"]
return None
def set_cache(key: str, data):
_cache[key] = {"ts": time.time(), "data": data}
async def save_snapshot(rates: dict):
try:
btc = rates.get("BTC", {})
eth = rates.get("ETH", {})
await async_execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) "
"VALUES ($1,$2,$3,$4,$5,$6,$7)",
int(time.time()),
float(btc.get("lastFundingRate", 0)),
float(eth.get("lastFundingRate", 0)),
float(btc.get("markPrice", 0)),
float(eth.get("markPrice", 0)),
float(btc.get("indexPrice", 0)),
float(eth.get("indexPrice", 0)),
)
except Exception as e:
pass # 落库失败不影响API响应
async def background_snapshot_loop():
"""后台每2秒自动拉取费率+价格并落库"""
while True:
try:
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks, return_exceptions=True)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if isinstance(resp, Exception) or resp.status_code != 200:
continue
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"lastFundingRate": float(data["lastFundingRate"]),
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
}
if result:
await save_snapshot(result)
except Exception:
pass
await asyncio.sleep(2)
@app.on_event("startup")
async def startup():
# 初始化PG schema
init_schema()
ensure_auth_tables()
# 初始化asyncpg池
await get_async_pool()
asyncio.create_task(background_snapshot_loop())
@app.on_event("shutdown")
async def shutdown():
await close_async_pool()
@app.get("/api/health")
async def health():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@app.get("/api/rates")
async def get_rates():
cached = get_cache("rates", 3)
if cached: return cached
async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance error for {sym}")
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"symbol": sym,
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
"lastFundingRate": float(data["lastFundingRate"]),
"nextFundingTime": data["nextFundingTime"],
"timestamp": data["time"],
}
set_cache("rates", result)
asyncio.create_task(save_snapshot(result))
return result
@app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
since = int(time.time()) - hours * 3600
rows = await async_fetch(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots "
"WHERE ts >= $1 ORDER BY ts ASC LIMIT $2",
since, limit
)
return {"count": len(rows), "hours": hours, "data": rows}
@app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
symbol_upper = symbol.upper()
supported = {"BTC", "ETH"}
if symbol_upper not in supported:
raise HTTPException(
status_code=400,
detail=f"K线数据仅支持 BTC / ETH暂不支持 {symbol_upper}。XRP/SOL 的 K 线功能在 V5.3 中规划。"
)
interval_secs = {
"1m": 60, "5m": 300, "30m": 1800,
"1h": 3600, "4h": 14400, "8h": 28800,
"1d": 86400, "1w": 604800, "1M": 2592000,
}
bar_secs = interval_secs.get(interval, 300)
rate_col = "btc_rate" if symbol_upper == "BTC" else "eth_rate"
price_col = "btc_price" if symbol_upper == "BTC" else "eth_price"
since = int(time.time()) - bar_secs * limit
rows = await async_fetch(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots "
f"WHERE ts >= $1 ORDER BY ts ASC",
since
)
if not rows:
return {"symbol": symbol, "interval": interval, "data": []}
bars: dict = {}
for r in rows:
ts, rate, price = r["ts"], r["rate"], r["price"]
bar_ts = (ts // bar_secs) * bar_secs
if bar_ts not in bars:
bars[bar_ts] = {
"time": bar_ts,
"open": rate, "high": rate, "low": rate, "close": rate,
"price_open": price, "price_high": price, "price_low": price, "price_close": price,
}
else:
b = bars[bar_ts]
b["high"] = max(b["high"], rate)
b["low"] = min(b["low"], rate)
b["close"] = rate
b["price_high"] = max(b["price_high"], price)
b["price_low"] = min(b["price_low"], price)
b["price_close"] = price
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
for b in data:
for k in ("open", "high", "low", "close"):
b[k] = round(b[k] * 10000, 4)
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
@app.get("/api/stats/ytd")
async def get_stats_ytd(user: dict = Depends(get_current_user)):
cached = get_cache("stats_ytd", 3600)
if cached: return cached
import datetime
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
end_time = int(time.time() * 1000)
async with httpx.AsyncClient(timeout=20, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": year_start, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
result[sym.replace("USDT","")] = {"annualized": 0, "count": 0}
continue
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
result[key] = {"annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = round(mean * 3 * 365 * 100, 2)
result[key] = {"annualized": annualized, "count": len(rates)}
set_cache("stats_ytd", result)
return result
@app.get("/api/signals/history")
async def get_signals_history(
limit: int = 100,
symbol: str = None,
strategy: str = None,
user: dict = Depends(get_current_user)
):
try:
conditions = []
args = []
idx = 1
if symbol:
conditions.append(f"symbol = ${idx}")
args.append(symbol.upper())
idx += 1
if strategy:
conditions.append(f"strategy = ${idx}")
args.append(strategy)
idx += 1
where = f"WHERE {' AND '.join(conditions)}" if conditions else ""
args.append(limit)
rows = await async_fetch(
f"SELECT id, ts, symbol, strategy, score, signal, price, factors "
f"FROM signal_indicators {where} ORDER BY ts DESC LIMIT ${idx}",
*args
)
return {"items": rows}
except Exception as e:
return {"items": [], "error": str(e)}
@app.get("/api/history")
async def get_history(user: dict = Depends(get_current_user)):
cached = get_cache("history", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance history error for {sym}")
key = sym.replace("USDT", "")
result[key] = [
{"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]),
"timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()}
for item in resp.json()
]
set_cache("history", result)
return result
@app.get("/api/stats")
async def get_stats(user: dict = Depends(get_current_user)):
cached = get_cache("stats", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
stats = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}")
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
stats[key] = {"mean7d": 0, "annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = mean * 3 * 365 * 100
stats[key] = {
"mean7d": round(mean * 100, 6),
"annualized": round(annualized, 2),
"count": len(rates),
}
btc_ann = stats.get("BTC", {}).get("annualized", 0)
eth_ann = stats.get("ETH", {}).get("annualized", 0)
btc_mean = stats.get("BTC", {}).get("mean7d", 0)
eth_mean = stats.get("ETH", {}).get("mean7d", 0)
stats["combo"] = {
"mean7d": round((btc_mean + eth_mean) / 2, 6),
"annualized": round((btc_ann + eth_ann) / 2, 2),
}
set_cache("stats", stats)
return stats
# ─── aggTrades 查询接口PG版───────────────────────────────────
@app.get("/api/trades/meta")
async def get_trades_meta(user: dict = Depends(get_current_user)):
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta")
result = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
result[sym] = {
"last_agg_id": r["last_agg_id"],
"last_time_ms": r["last_time_ms"],
"updated_at": r["updated_at"],
}
return result
@app.get("/api/trades/summary")
async def get_trades_summary(
symbol: str = "BTC",
start_ms: int = 0,
end_ms: int = 0,
interval: str = "1m",
user: dict = Depends(get_current_user),
):
if end_ms == 0:
end_ms = int(time.time() * 1000)
if start_ms == 0:
start_ms = end_ms - 3600 * 1000
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
sym_full = symbol.upper() + "USDT"
# PG原生聚合比Python循环快100倍
rows = await async_fetch(
"""
SELECT
(time_ms / $4) * $4 AS bar_ms,
ROUND(SUM(CASE WHEN is_buyer_maker = 0 THEN qty ELSE 0 END)::numeric, 4) AS buy_vol,
ROUND(SUM(CASE WHEN is_buyer_maker = 1 THEN qty ELSE 0 END)::numeric, 4) AS sell_vol,
COUNT(*) AS trade_count,
ROUND((SUM(price * qty) / NULLIF(SUM(qty), 0))::numeric, 2) AS vwap,
ROUND(MAX(qty)::numeric, 4) AS max_qty
FROM agg_trades
WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3
GROUP BY bar_ms
ORDER BY bar_ms ASC
""",
sym_full, start_ms, end_ms, interval_ms
)
result = []
for r in rows:
buy = float(r["buy_vol"])
sell = float(r["sell_vol"])
result.append({
"time_ms": r["bar_ms"],
"buy_vol": buy,
"sell_vol": sell,
"delta": round(buy - sell, 4),
"total_vol": round(buy + sell, 4),
"trade_count": r["trade_count"],
"vwap": float(r["vwap"]) if r["vwap"] else 0,
"max_qty": float(r["max_qty"]),
})
return {"symbol": symbol, "interval": interval, "count": len(result), "data": result}
@app.get("/api/trades/latest")
async def get_trades_latest(
symbol: str = "BTC",
limit: int = 30,
user: dict = Depends(get_current_user),
):
cache_key = f"trades_latest_{symbol}_{limit}"
cached = get_cache(cache_key, 2)
if cached: return cached
sym_full = symbol.upper() + "USDT"
rows = await async_fetch(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2",
sym_full, limit
)
result = {"symbol": symbol, "count": len(rows), "data": rows}
set_cache(cache_key, result)
return result
@app.get("/api/collector/health")
async def collector_health(user: dict = Depends(get_current_user)):
now_ms = int(time.time() * 1000)
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta")
status = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000
status[sym] = {
"last_agg_id": r["last_agg_id"],
"lag_seconds": round(lag_s, 1),
"healthy": lag_s < 30,
}
return {"collector": status, "timestamp": now_ms}
# ─── V5 Signal Engine API ────────────────────────────────────────
@app.get("/api/signals/indicators")
async def get_signal_indicators(
symbol: str = "BTC",
minutes: int = 60,
user: dict = Depends(get_current_user),
):
sym_full = symbol.upper() + "USDT"
now_ms = int(time.time() * 1000)
start_ms = now_ms - minutes * 60 * 1000
rows = await async_fetch(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
"FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC",
sym_full, start_ms
)
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/signals/latest")
async def get_signal_latest(user: dict = Depends(get_current_user), strategy: str = "v53"):
result = {}
for sym in SYMBOLS:
row = await async_fetchrow(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
"vwap_30m, price, p95_qty, p99_qty, score, signal, factors "
"FROM signal_indicators WHERE symbol = $1 AND strategy = $2 ORDER BY ts DESC LIMIT 1",
sym, strategy
)
if row:
data = dict(row)
if isinstance(data.get("factors"), str):
try:
data["factors"] = json.loads(data["factors"])
except Exception:
pass
# 对 v53_btc 或 v53BTC symbol把 alt_score_ref 提升为顶层字段
if strategy.startswith("v53") and isinstance(data.get("factors"), dict):
f = data["factors"]
if data.get("score", 0) == 0 and f.get("alt_score_ref") is not None:
data["display_score"] = f["alt_score_ref"]
else:
data["display_score"] = data.get("score", 0)
data["gate_passed"] = f.get("gate_passed", True)
result[sym.replace("USDT", "")] = data
return result
def _primary_signal_strategy() -> str:
strategy_dir = os.path.join(os.path.dirname(__file__), "strategies")
try:
names = []
for fn in os.listdir(strategy_dir):
if not fn.endswith(".json"):
continue
with open(os.path.join(strategy_dir, fn), "r", encoding="utf-8") as f:
cfg = json.load(f)
if cfg.get("name"):
names.append(cfg["name"])
if "v52_8signals" in names:
return "v52_8signals"
if "v51_baseline" in names:
return "v51_baseline"
except Exception:
pass
return "v51_baseline"
def _normalize_factors(raw):
if not raw:
return {}
if isinstance(raw, str):
try:
return json.loads(raw)
except Exception:
return {}
if isinstance(raw, dict):
return raw
return {}
@app.get("/api/signals/latest-v52")
async def get_signal_latest_v52(user: dict = Depends(get_current_user)):
"""返回V5.1/V5.2并排展示所需的最新信号信息。"""
primary_strategy = _primary_signal_strategy()
result = {}
for sym in SYMBOLS:
base_row = await async_fetchrow(
"SELECT ts, score, signal FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1",
sym,
)
strategy_rows = await async_fetch(
"SELECT strategy, score, direction, entry_ts, score_factors "
"FROM paper_trades WHERE symbol = $1 AND strategy IN ('v51_baseline','v52_8signals') "
"ORDER BY entry_ts DESC",
sym,
)
latest_by_strategy: dict[str, dict] = {}
for row in strategy_rows:
st = (row.get("strategy") or "v51_baseline")
if st not in latest_by_strategy:
latest_by_strategy[st] = row
if "v51_baseline" in latest_by_strategy and "v52_8signals" in latest_by_strategy:
break
def build_strategy_payload(strategy_name: str):
trade_row = latest_by_strategy.get(strategy_name)
if trade_row:
payload = {
"score": trade_row.get("score"),
"signal": trade_row.get("direction"),
"ts": trade_row.get("entry_ts"),
"source": "paper_trade",
}
elif base_row and primary_strategy == strategy_name:
payload = {
"score": base_row.get("score"),
"signal": base_row.get("signal"),
"ts": base_row.get("ts"),
"source": "signal_indicators",
}
else:
payload = {
"score": None,
"signal": None,
"ts": None,
"source": "unavailable",
}
factors = _normalize_factors(trade_row.get("score_factors") if trade_row else None)
payload["funding_rate_score"] = factors.get("funding_rate", {}).get("score")
payload["liquidation_score"] = factors.get("liquidation", {}).get("score")
return payload
result[sym.replace("USDT", "")] = {
"primary_strategy": primary_strategy,
"latest_signal": base_row.get("signal") if base_row else None,
"latest_ts": base_row.get("ts") if base_row else None,
"v51": build_strategy_payload("v51_baseline"),
"v52": build_strategy_payload("v52_8signals"),
}
return result
@app.get("/api/signals/market-indicators")
async def get_market_indicators(user: dict = Depends(get_current_user)):
"""返回最新的market_indicators数据V5.1新增4个数据源"""
result = {}
for sym in SYMBOLS:
indicators = {}
for ind_type in ["long_short_ratio", "top_trader_position", "open_interest_hist", "coinbase_premium", "funding_rate"]:
row = await async_fetchrow(
"SELECT value, timestamp_ms FROM market_indicators WHERE symbol = $1 AND indicator_type = $2 ORDER BY timestamp_ms DESC LIMIT 1",
sym,
ind_type,
)
if row:
val = row["value"]
if isinstance(val, str):
import json as _json
try:
val = _json.loads(val)
except Exception:
pass
indicators[ind_type] = {"value": val, "ts": row["timestamp_ms"]}
result[sym.replace("USDT", "")] = indicators
return result
@app.get("/api/signals/signal-history")
async def get_signal_history(
symbol: str = "BTC",
limit: int = 50,
strategy: str = "v53",
user: dict = Depends(get_current_user),
):
"""返回最近的信号历史(只返回有信号的记录),含各层分数"""
sym_full = symbol.upper() + "USDT"
rows = await async_fetch(
"SELECT ts, score, signal, factors FROM signal_indicators "
"WHERE symbol = $1 AND strategy = $2 AND signal IS NOT NULL "
"ORDER BY ts DESC LIMIT $3",
sym_full, strategy, limit
)
# factors可能是JSON string
for r in rows:
if isinstance(r.get("factors"), str):
try:
r["factors"] = json.loads(r["factors"])
except Exception:
pass
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/signals/trades")
async def get_signal_trades(
status: str = "all",
limit: int = 50,
user: dict = Depends(get_current_user),
):
if status == "all":
rows = await async_fetch(
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit
)
else:
rows = await async_fetch(
"SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2",
status, limit
)
return {"count": len(rows), "data": rows}
# ─── 模拟盘 API ──────────────────────────────────────────────────
# 模拟盘配置状态与signal_engine共享的运行时状态
paper_config = {
"enabled": False,
"enabled_strategies": [], # 分策略开关: ["v51_baseline", "v52_8signals"]
"initial_balance": 10000,
"risk_per_trade": 0.02,
"max_positions": 4,
"tier_multiplier": {"light": 0.5, "standard": 1.0, "heavy": 1.5},
}
# 启动时加载已有配置
_config_path = os.path.join(os.path.dirname(__file__), "paper_config.json")
if os.path.exists(_config_path):
try:
with open(_config_path, "r") as _f:
_saved = json.load(_f)
paper_config.update(_saved)
except Exception:
pass
@app.get("/api/paper/config")
async def paper_get_config(user: dict = Depends(get_current_user)):
"""获取模拟盘配置"""
return paper_config
@app.post("/api/paper/config")
async def paper_set_config(request: Request, user: dict = Depends(get_current_user)):
"""修改模拟盘配置仅admin"""
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="仅管理员可修改")
body = await request.json()
for k in ["enabled", "enabled_strategies", "initial_balance", "risk_per_trade", "max_positions"]:
if k in body:
paper_config[k] = body[k]
# 写入配置文件让signal_engine也能读到
config_path = os.path.join(os.path.dirname(__file__), "paper_config.json")
with open(config_path, "w") as f:
json.dump(paper_config, f, indent=2)
return {"ok": True, "config": paper_config}
@app.get("/api/paper/summary")
async def paper_summary(
strategy: str = "all",
strategy_id: str = "all",
user: dict = Depends(get_current_user),
):
"""模拟盘总览"""
if strategy_id != "all":
closed = await async_fetch(
"SELECT pnl_r, direction FROM paper_trades "
"WHERE status NOT IN ('active','tp1_hit') AND strategy_id = $1",
strategy_id,
)
active = await async_fetch(
"SELECT id FROM paper_trades WHERE status IN ('active','tp1_hit') AND strategy_id = $1",
strategy_id,
)
first = await async_fetchrow(
"SELECT MIN(created_at) as start FROM paper_trades WHERE strategy_id = $1",
strategy_id,
)
# 从 strategies 表取该策略的 initial_balance
strat_row = await async_fetchrow(
"SELECT initial_balance FROM strategies WHERE strategy_id = $1",
strategy_id,
)
initial_balance = float(strat_row["initial_balance"]) if strat_row else paper_config["initial_balance"]
risk_per_trade = paper_config["risk_per_trade"]
elif strategy == "all":
closed = await async_fetch(
"SELECT pnl_r, direction FROM paper_trades WHERE status NOT IN ('active','tp1_hit')"
)
active = await async_fetch(
"SELECT id FROM paper_trades WHERE status IN ('active','tp1_hit')"
)
first = await async_fetchrow("SELECT MIN(created_at) as start FROM paper_trades")
initial_balance = paper_config["initial_balance"]
risk_per_trade = paper_config["risk_per_trade"]
else:
closed = await async_fetch(
"SELECT pnl_r, direction FROM paper_trades "
"WHERE status NOT IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
active = await async_fetch(
"SELECT id FROM paper_trades WHERE status IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
first = await async_fetchrow(
"SELECT MIN(created_at) as start FROM paper_trades WHERE strategy = $1",
strategy,
)
initial_balance = paper_config["initial_balance"]
risk_per_trade = paper_config["risk_per_trade"]
total = len(closed)
wins = len([r for r in closed if r["pnl_r"] > 0])
total_pnl = sum(r["pnl_r"] for r in closed)
paper_1r_usd = initial_balance * risk_per_trade
total_pnl_usdt = total_pnl * paper_1r_usd
balance = initial_balance + total_pnl_usdt
win_rate = (wins / total * 100) if total > 0 else 0
gross_profit = sum(r["pnl_r"] for r in closed if r["pnl_r"] > 0)
gross_loss = abs(sum(r["pnl_r"] for r in closed if r["pnl_r"] <= 0))
profit_factor = (gross_profit / gross_loss) if gross_loss > 0 else 0
return {
"total_trades": total,
"win_rate": round(win_rate, 1),
"total_pnl": round(total_pnl, 2),
"total_pnl_usdt": round(total_pnl_usdt, 2),
"balance": round(balance, 2),
"active_positions": len(active),
"profit_factor": round(profit_factor, 2),
"start_time": str(first["start"]) if first and first["start"] else None,
}
@app.get("/api/paper/positions")
async def paper_positions(
strategy: str = "all",
strategy_id: str = "all",
user: dict = Depends(get_current_user),
):
"""当前活跃持仓(含实时价格和浮动盈亏)"""
if strategy_id != "all":
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, strategy, strategy_id, entry_price, entry_ts, "
"tp1_price, tp2_price, sl_price, tp1_hit, status, atr_at_entry, score_factors, risk_distance "
"FROM paper_trades WHERE status IN ('active','tp1_hit') AND strategy_id = $1 ORDER BY entry_ts DESC",
strategy_id,
)
elif strategy == "all":
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, strategy, strategy_id, entry_price, entry_ts, "
"tp1_price, tp2_price, sl_price, tp1_hit, status, atr_at_entry, score_factors, risk_distance "
"FROM paper_trades WHERE status IN ('active','tp1_hit') ORDER BY entry_ts DESC"
)
else:
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, strategy, strategy_id, entry_price, entry_ts, "
"tp1_price, tp2_price, sl_price, tp1_hit, status, atr_at_entry, score_factors, risk_distance "
"FROM paper_trades WHERE status IN ('active','tp1_hit') AND strategy = $1 ORDER BY entry_ts DESC",
strategy,
)
# 从币安API获取实时价格
prices = {}
symbols_needed = list(set(r["symbol"] for r in rows))
if symbols_needed:
try:
async with httpx.AsyncClient(timeout=5) as client:
resp = await client.get("https://fapi.binance.com/fapi/v1/ticker/price")
if resp.status_code == 200:
for item in resp.json():
if item["symbol"] in symbols_needed:
prices[item["symbol"]] = float(item["price"])
except Exception:
pass
# fallback: 如果币安API失败用signal_indicators
for r in rows:
sym = r["symbol"]
if sym not in prices:
try:
latest = await async_fetchrow(
"SELECT price FROM signal_indicators WHERE symbol=$1 ORDER BY ts DESC LIMIT 1", sym
)
prices[sym] = latest["price"] if latest else 0
except Exception:
prices[sym] = 0
result = []
for r in rows:
d = dict(r)
current_price = prices.get(r["symbol"], 0)
d["current_price"] = current_price
# 浮动盈亏(R)
entry = r["entry_price"]
rd = r.get("risk_distance") or abs(entry - r["sl_price"]) or 1
if rd > 0 and entry > 0:
if r["direction"] == "LONG":
d["unrealized_pnl_r"] = round((current_price - entry) / rd, 2)
else:
d["unrealized_pnl_r"] = round((entry - current_price) / rd, 2)
# 浮动盈亏(USDT) — 假设1R = risk_per_trade
paper_1r = paper_config["initial_balance"] * paper_config["risk_per_trade"]
d["unrealized_pnl_usdt"] = round(d["unrealized_pnl_r"] * paper_1r, 2)
else:
d["unrealized_pnl_r"] = 0
d["unrealized_pnl_usdt"] = 0
result.append(d)
return {"data": result}
@app.get("/api/paper/trades")
async def paper_trades(
symbol: str = "all",
result: str = "all",
strategy: str = "all",
strategy_id: str = "all",
limit: int = 100,
user: dict = Depends(get_current_user),
):
"""历史交易列表"""
conditions = ["status NOT IN ('active','tp1_hit')"]
params = []
idx = 1
if symbol != "all":
conditions.append(f"symbol = ${idx}")
params.append(symbol.upper() + "USDT")
idx += 1
if result == "win":
conditions.append("pnl_r > 0")
elif result == "loss":
conditions.append("pnl_r <= 0")
if strategy_id != "all":
conditions.append(f"strategy_id = ${idx}")
params.append(strategy_id)
idx += 1
elif strategy != "all":
conditions.append(f"strategy = ${idx}")
params.append(strategy)
idx += 1
where = " AND ".join(conditions)
params.append(limit)
rows = await async_fetch(
f"SELECT id, symbol, direction, score, tier, strategy, strategy_id, entry_price, exit_price, "
f"entry_ts, exit_ts, pnl_r, status, tp1_hit, score_factors "
f"FROM paper_trades WHERE {where} ORDER BY exit_ts DESC LIMIT ${idx}",
*params
)
return {"count": len(rows), "data": rows}
@app.get("/api/paper/equity-curve")
async def paper_equity_curve(
strategy: str = "all",
strategy_id: str = "all",
user: dict = Depends(get_current_user),
):
"""权益曲线"""
if strategy_id != "all":
rows = await async_fetch(
"SELECT exit_ts, pnl_r FROM paper_trades "
"WHERE status NOT IN ('active','tp1_hit') AND strategy_id = $1 ORDER BY exit_ts ASC",
strategy_id,
)
elif strategy == "all":
rows = await async_fetch(
"SELECT exit_ts, pnl_r FROM paper_trades "
"WHERE status NOT IN ('active','tp1_hit') ORDER BY exit_ts ASC"
)
else:
rows = await async_fetch(
"SELECT exit_ts, pnl_r FROM paper_trades "
"WHERE status NOT IN ('active','tp1_hit') AND strategy = $1 ORDER BY exit_ts ASC",
strategy,
)
cumulative = 0.0
curve = []
for r in rows:
cumulative += r["pnl_r"]
curve.append({"ts": r["exit_ts"], "pnl": round(cumulative, 2)})
return {"data": curve}
@app.get("/api/paper/stats")
async def paper_stats(
strategy: str = "all",
strategy_id: str = "all",
user: dict = Depends(get_current_user),
):
"""详细统计"""
if strategy_id != "all":
rows = await async_fetch(
"SELECT symbol, direction, pnl_r, tier, entry_ts, exit_ts "
"FROM paper_trades WHERE status NOT IN ('active','tp1_hit') AND strategy_id = $1",
strategy_id,
)
elif strategy == "all":
rows = await async_fetch(
"SELECT symbol, direction, pnl_r, tier, entry_ts, exit_ts "
"FROM paper_trades WHERE status NOT IN ('active','tp1_hit')"
)
else:
rows = await async_fetch(
"SELECT symbol, direction, pnl_r, tier, entry_ts, exit_ts "
"FROM paper_trades WHERE status NOT IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
if not rows:
return {"error": "暂无数据"}
total = len(rows)
wins = [r for r in rows if r["pnl_r"] > 0]
losses = [r for r in rows if r["pnl_r"] <= 0]
# 基础统计
win_rate = len(wins) / total * 100
avg_win = sum(r["pnl_r"] for r in wins) / len(wins) if wins else 0
avg_loss = abs(sum(r["pnl_r"] for r in losses)) / len(losses) if losses else 0
win_loss_ratio = avg_win / avg_loss if avg_loss > 0 else 0
# MDD
peak = 0.0
mdd = 0.0
running = 0.0
for r in sorted(rows, key=lambda x: x["exit_ts"] or 0):
running += r["pnl_r"]
peak = max(peak, running)
mdd = max(mdd, peak - running)
# 夏普
returns = [r["pnl_r"] for r in rows]
if len(returns) > 1:
import statistics
avg_ret = statistics.mean(returns)
std_ret = statistics.stdev(returns)
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0
else:
sharpe = 0
# 按币种 — 完整统计
by_symbol = {}
for r in rows:
s = r["symbol"].replace("USDT", "")
if s not in by_symbol:
by_symbol[s] = []
by_symbol[s].append(r)
def calc_stats(trade_list):
t = len(trade_list)
w = [r for r in trade_list if r["pnl_r"] > 0]
l = [r for r in trade_list if r["pnl_r"] <= 0]
aw = sum(r["pnl_r"] for r in w) / len(w) if w else 0
al = abs(sum(r["pnl_r"] for r in l)) / len(l) if l else 0
wlr = aw / al if al > 0 else 0
# MDD
pk, dd, rn = 0.0, 0.0, 0.0
for r in sorted(trade_list, key=lambda x: x["exit_ts"] or 0):
rn += r["pnl_r"]
pk = max(pk, rn)
dd = max(dd, pk - rn)
# 夏普
rets = [r["pnl_r"] for r in trade_list]
if len(rets) > 1:
import statistics
avg_r = statistics.mean(rets)
std_r = statistics.stdev(rets)
sp = (avg_r / std_r) * (252 ** 0.5) if std_r > 0 else 0
else:
sp = 0
# 方向
lg = [r for r in trade_list if r["direction"] == "LONG"]
sh = [r for r in trade_list if r["direction"] == "SHORT"]
lwr = len([r for r in lg if r["pnl_r"] > 0]) / len(lg) * 100 if lg else 0
swr = len([r for r in sh if r["pnl_r"] > 0]) / len(sh) * 100 if sh else 0
total_pnl = sum(r["pnl_r"] for r in trade_list)
return {
"total": t, "win_rate": round(len(w)/t*100, 1) if t else 0,
"avg_win": round(aw, 2), "avg_loss": round(al, 2),
"win_loss_ratio": round(wlr, 2), "mdd": round(dd, 2),
"sharpe": round(sp, 2), "total_pnl": round(total_pnl, 2),
"long_win_rate": round(lwr, 1), "long_count": len(lg),
"short_win_rate": round(swr, 1), "short_count": len(sh),
}
symbol_stats = {s: calc_stats(tl) for s, tl in by_symbol.items()}
# 按方向
longs = [r for r in rows if r["direction"] == "LONG"]
shorts = [r for r in rows if r["direction"] == "SHORT"]
long_wr = len([r for r in longs if r["pnl_r"] > 0]) / len(longs) * 100 if longs else 0
short_wr = len([r for r in shorts if r["pnl_r"] > 0]) / len(shorts) * 100 if shorts else 0
# 按档位
by_tier = {}
for r in rows:
t = r["tier"]
if t not in by_tier:
by_tier[t] = {"total": 0, "wins": 0}
by_tier[t]["total"] += 1
if r["pnl_r"] > 0:
by_tier[t]["wins"] += 1
tier_stats = {t: {"total": v["total"], "win_rate": round(v["wins"]/v["total"]*100, 1)} for t, v in by_tier.items()}
return {
"total": total,
"win_rate": round(win_rate, 1),
"avg_win": round(avg_win, 2),
"avg_loss": round(avg_loss, 2),
"win_loss_ratio": round(win_loss_ratio, 2),
"mdd": round(mdd, 2),
"sharpe": round(sharpe, 2),
"total_pnl": round(sum(r["pnl_r"] for r in rows), 2),
"long_win_rate": round(long_wr, 1),
"long_count": len(longs),
"short_win_rate": round(short_wr, 1),
"short_count": len(shorts),
"by_symbol": symbol_stats,
"by_tier": tier_stats,
}
@app.get("/api/paper/stats-by-strategy")
async def paper_stats_by_strategy(user: dict = Depends(get_current_user)):
"""按策略聚合模拟盘表现"""
rows = await async_fetch(
"SELECT strategy, pnl_r FROM paper_trades WHERE status NOT IN ('active','tp1_hit')"
)
active_rows = await async_fetch(
"SELECT strategy, COUNT(*) AS active_count FROM paper_trades "
"WHERE status IN ('active','tp1_hit') GROUP BY strategy"
)
if not rows and not active_rows:
return {"data": []}
active_map = {r["strategy"] or "v51_baseline": int(r["active_count"]) for r in active_rows}
by_strategy: dict[str, list[float]] = {}
for row in rows:
strategy = row["strategy"] or "v51_baseline"
by_strategy.setdefault(strategy, []).append(float(row["pnl_r"]))
stats = []
for strategy, pnls in by_strategy.items():
total = len(pnls)
wins = [p for p in pnls if p > 0]
losses = [p for p in pnls if p <= 0]
avg_win = sum(wins) / len(wins) if wins else 0
avg_loss = abs(sum(losses) / len(losses)) if losses else 0
stats.append(
{
"strategy": strategy,
"total": total,
"win_rate": round((len(wins) / total) * 100, 1) if total else 0,
"total_pnl": round(sum(pnls), 2),
"avg_win": round(avg_win, 2),
"avg_loss": round(avg_loss, 2),
"active_positions": active_map.get(strategy, 0),
}
)
for strategy, active_count in active_map.items():
if strategy not in by_strategy:
stats.append(
{
"strategy": strategy,
"total": 0,
"win_rate": 0,
"total_pnl": 0,
"avg_win": 0,
"avg_loss": 0,
"active_positions": active_count,
}
)
stats.sort(key=lambda x: x["strategy"])
return {"data": stats}
# ─── 服务器状态监控 ───────────────────────────────────────────────
import shutil, subprocess, psutil
# 服务器状态缓存(避免重复调用慢操作)
_server_cache: dict = {"data": None, "ts": 0}
_PM2_BIN = None
def _find_pm2_bin():
"""找到pm2二进制路径避免每次走npx"""
global _PM2_BIN
if _PM2_BIN:
return _PM2_BIN
import shutil as _sh
for p in ["/home/fzq1228/.local/bin/pm2", "/usr/local/bin/pm2", "/usr/bin/pm2"]:
if os.path.exists(p):
_PM2_BIN = p
return p
found = _sh.which("pm2")
if found:
_PM2_BIN = found
return found
return "npx pm2"
# 启动时初始化CPU采样首次调用不阻塞
psutil.cpu_percent(interval=None)
@app.get("/api/server/status")
async def get_server_status(user: dict = Depends(get_current_user)):
"""服务器全状态CPU/内存/硬盘/负载/PM2进程/PG数据库/回补进度"""
# 5秒缓存避免频繁调用慢操作
now = time.time()
if _server_cache["data"] and (now - _server_cache["ts"]) < 5:
return _server_cache["data"]
# CPU非阻塞取上次采样间隔的值
cpu_percent = psutil.cpu_percent(interval=None)
cpu_count = psutil.cpu_count()
# 内存
mem = psutil.virtual_memory()
swap = psutil.swap_memory()
# 硬盘
disk = shutil.disk_usage("/")
# 负载
load1, load5, load15 = os.getloadavg()
# Uptime
boot_time = psutil.boot_time()
uptime_s = time.time() - boot_time
# 网络IO
net = psutil.net_io_counters()
# PM2进程状态直接调pm2二进制不走npx
pm2_procs = []
try:
pm2_bin = _find_pm2_bin()
cmd = [pm2_bin, "jlist"] if not pm2_bin.startswith("npx") else ["npx", "pm2", "jlist"]
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=5
)
import json as _json
procs = _json.loads(result.stdout)
for p in procs:
pm2_procs.append({
"name": p.get("name", ""),
"status": p.get("pm2_env", {}).get("status", "unknown"),
"cpu": p.get("monit", {}).get("cpu", 0),
"memory_mb": round(p.get("monit", {}).get("memory", 0) / 1024 / 1024, 1),
"restarts": p.get("pm2_env", {}).get("restart_time", 0),
"uptime_ms": p.get("pm2_env", {}).get("pm_uptime", 0),
"pid": p.get("pid", 0),
})
except Exception:
pm2_procs = []
# PG数据库大小 + agg_trades条数用估算值快1000倍
pg_info = {}
try:
row = await async_fetchrow(
"SELECT pg_database_size(current_database()) as db_size"
)
pg_info["db_size_mb"] = round(row["db_size"] / 1024 / 1024, 1) if row else 0
# 用PG统计信息估算行数毫秒级而非COUNT(*)的秒级全表扫描)
row2 = await async_fetchrow(
"SELECT SUM(n_live_tup)::bigint as cnt FROM pg_stat_user_tables WHERE relname LIKE 'agg_trades%'"
)
pg_info["agg_trades_count"] = row2["cnt"] if row2 and row2["cnt"] else 0
row3 = await async_fetchrow(
"SELECT n_live_tup::bigint as cnt FROM pg_stat_user_tables WHERE relname = 'rate_snapshots'"
)
pg_info["rate_snapshots_count"] = row3["cnt"] if row3 else 0
# 各symbol最新数据时间
meta_rows = await async_fetch("SELECT symbol, last_time_ms, earliest_time_ms FROM agg_trades_meta")
pg_info["symbols"] = {}
for m in meta_rows:
sym = m["symbol"].replace("USDT", "")
pg_info["symbols"][sym] = {
"latest_ms": m["last_time_ms"],
"earliest_ms": m["earliest_time_ms"],
"span_hours": round((m["last_time_ms"] - m["earliest_time_ms"]) / 3600000, 1),
}
except Exception:
pass
# 回补进程
backfill_running = False
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
cmdline = " ".join(proc.info.get("cmdline") or [])
if "backfill_agg_trades" in cmdline:
backfill_running = True
break
except Exception:
pass
result = {
"timestamp": int(time.time() * 1000),
"cpu": {
"percent": cpu_percent,
"cores": cpu_count,
},
"memory": {
"total_gb": round(mem.total / 1024**3, 1),
"used_gb": round(mem.used / 1024**3, 1),
"percent": mem.percent,
"swap_percent": swap.percent,
},
"disk": {
"total_gb": round(disk.total / 1024**3, 1),
"used_gb": round(disk.used / 1024**3, 1),
"free_gb": round(disk.free / 1024**3, 1),
"percent": round(disk.used / disk.total * 100, 1),
},
"load": {
"load1": round(load1, 2),
"load5": round(load5, 2),
"load15": round(load15, 2),
},
"uptime_hours": round(uptime_s / 3600, 1),
"network": {
"bytes_sent_gb": round(net.bytes_sent / 1024**3, 2),
"bytes_recv_gb": round(net.bytes_recv / 1024**3, 2),
},
"pm2": pm2_procs,
"postgres": pg_info,
"backfill_running": backfill_running,
}
_server_cache["data"] = result
_server_cache["ts"] = now
return result
# ============================================================
# 实盘 API/api/live/...
# ============================================================
@app.get("/api/live/summary")
async def live_summary(
strategy: str = "v52_8signals",
user: dict = Depends(get_current_user),
):
"""实盘总览"""
closed = await async_fetch(
"SELECT pnl_r, direction, fee_usdt, funding_fee_usdt, slippage_bps "
"FROM live_trades WHERE status NOT IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
active = await async_fetch(
"SELECT id FROM live_trades WHERE status IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
first = await async_fetchrow(
"SELECT MIN(created_at) as start FROM live_trades WHERE strategy = $1",
strategy,
)
total = len(closed)
wins = len([r for r in closed if r["pnl_r"] and r["pnl_r"] > 0])
total_pnl = sum(r["pnl_r"] for r in closed if r["pnl_r"])
total_fee = sum(r["fee_usdt"] or 0 for r in closed)
total_funding = sum(r["funding_fee_usdt"] or 0 for r in closed)
win_rate = (wins / total * 100) if total > 0 else 0
gross_profit = sum(r["pnl_r"] for r in closed if r["pnl_r"] and r["pnl_r"] > 0)
gross_loss = abs(sum(r["pnl_r"] for r in closed if r["pnl_r"] and r["pnl_r"] <= 0))
profit_factor = (gross_profit / gross_loss) if gross_loss > 0 else 0
# 读风控状态
risk_status = {}
try:
import json as _json
with open("/tmp/risk_guard_state.json") as f:
risk_status = _json.load(f)
except:
risk_status = {"status": "unknown"}
return {
"total_trades": total,
"win_rate": round(win_rate, 1),
"total_pnl_r": round(total_pnl, 2),
"total_pnl_usdt": round(total_pnl * (await _get_risk_usd()), 2),
"active_positions": len(active),
"profit_factor": round(profit_factor, 2),
"total_fee_usdt": round(total_fee, 2),
"total_funding_usdt": round(total_funding, 2),
"start_time": str(first["start"]) if first and first["start"] else None,
"risk_status": risk_status,
}
@app.get("/api/live/positions")
async def live_positions(
strategy: str = "v52_8signals",
user: dict = Depends(get_current_user),
):
"""实盘当前持仓"""
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, strategy, entry_price, entry_ts, "
"tp1_price, tp2_price, sl_price, tp1_hit, status, risk_distance, "
"binance_order_id, fill_price, slippage_bps, protection_gap_ms, "
"signal_to_order_ms, order_to_fill_ms, score_factors "
"FROM live_trades WHERE status IN ('active','tp1_hit') AND strategy = $1 "
"ORDER BY entry_ts DESC",
strategy,
)
# 实时价格
prices = {}
symbols_needed = list(set(r["symbol"] for r in rows))
if symbols_needed:
try:
async with httpx.AsyncClient(timeout=5) as client:
resp = await client.get("https://fapi.binance.com/fapi/v1/ticker/price")
if resp.status_code == 200:
for item in resp.json():
if item["symbol"] in symbols_needed:
prices[item["symbol"]] = float(item["price"])
except:
pass
result = []
for r in rows:
d = dict(r)
current_price = prices.get(r["symbol"], 0)
d["current_price"] = current_price
entry = r["entry_price"] or 0
rd = r.get("risk_distance") or 1
if rd > 0 and entry > 0 and current_price > 0:
if r["direction"] == "LONG":
d["unrealized_pnl_r"] = round((current_price - entry) / rd, 4)
else:
d["unrealized_pnl_r"] = round((entry - current_price) / rd, 4)
d["unrealized_pnl_usdt"] = round(d["unrealized_pnl_r"] * (await _get_risk_usd()), 2)
else:
d["unrealized_pnl_r"] = 0
d["unrealized_pnl_usdt"] = 0
# 持仓时间
if r["entry_ts"]:
import time as _time
d["hold_time_min"] = round((_time.time() * 1000 - r["entry_ts"]) / 60000, 1)
result.append(d)
return {"data": result}
@app.get("/api/live/trades")
async def live_trades(
symbol: str = "all",
result: str = "all",
strategy: str = "v52_8signals",
limit: int = 100,
user: dict = Depends(get_current_user),
):
"""实盘历史交易"""
conditions = ["status NOT IN ('active','tp1_hit')"]
params = []
idx = 1
conditions.append(f"strategy = ${idx}")
params.append(strategy)
idx += 1
if symbol != "all":
conditions.append(f"symbol = ${idx}")
params.append(symbol.upper() + "USDT")
idx += 1
if result == "win":
conditions.append("pnl_r > 0")
elif result == "loss":
conditions.append("pnl_r <= 0")
where = " AND ".join(conditions)
params.append(limit)
rows = await async_fetch(
f"SELECT id, symbol, direction, score, tier, strategy, entry_price, exit_price, "
f"entry_ts, exit_ts, pnl_r, status, tp1_hit, score_factors, "
f"binance_order_id, fill_price, slippage_bps, fee_usdt, funding_fee_usdt, "
f"protection_gap_ms, signal_to_order_ms, order_to_fill_ms, risk_distance, "
f"tp1_price, tp2_price, sl_price "
f"FROM live_trades WHERE {where} ORDER BY exit_ts DESC LIMIT ${idx}",
*params
)
# PnL拆解
result_data = []
for r in rows:
d = dict(r)
entry = r["entry_price"] or 0
exit_p = r["exit_price"] or 0
rd = r["risk_distance"] or 1
direction = r["direction"]
tp1_hit = r["tp1_hit"]
tp1_price = r.get("tp1_price") or 0
# gross_pnl_r不含任何费用
if direction == "LONG":
raw_r = (exit_p - entry) / rd if rd > 0 else 0
else:
raw_r = (entry - exit_p) / rd if rd > 0 else 0
if tp1_hit and tp1_price:
tp1_r = abs(tp1_price - entry) / rd if rd > 0 else 0
gross_r = 0.5 * tp1_r + 0.5 * raw_r
else:
gross_r = raw_r
fee_usdt = r["fee_usdt"] or 0
funding_usdt = r["funding_fee_usdt"] or 0
risk_usd = await _get_risk_usd()
fee_r = fee_usdt / risk_usd if risk_usd > 0 else 0
funding_r = abs(funding_usdt) / risk_usd if funding_usdt < 0 else 0
# slippage_r: 滑点造成的R损失
slippage_bps = r["slippage_bps"] or 0
slippage_usdt = abs(slippage_bps) / 10000 * entry * (risk_usd / rd) if rd > 0 else 0
slippage_r = slippage_usdt / risk_usd if risk_usd > 0 else 0
d["gross_pnl_r"] = round(gross_r, 4)
d["fee_r"] = round(fee_r, 4)
d["funding_r"] = round(funding_r, 4)
d["slippage_r"] = round(slippage_r, 4)
d["net_pnl_r"] = r["pnl_r"] # 已经是net
result_data.append(d)
return {"count": len(result_data), "data": result_data}
@app.get("/api/live/equity-curve")
async def live_equity_curve(
strategy: str = "v52_8signals",
user: dict = Depends(get_current_user),
):
"""实盘权益曲线"""
rows = await async_fetch(
"SELECT exit_ts, pnl_r FROM live_trades "
"WHERE status NOT IN ('active','tp1_hit') AND strategy = $1 ORDER BY exit_ts ASC",
strategy,
)
cumulative = 0.0
curve = []
for r in rows:
cumulative += r["pnl_r"] or 0
curve.append({"ts": r["exit_ts"], "pnl": round(cumulative, 2)})
return {"data": curve}
@app.get("/api/live/stats")
async def live_stats(
strategy: str = "v52_8signals",
user: dict = Depends(get_current_user),
):
"""实盘详细统计"""
rows = await async_fetch(
"SELECT symbol, direction, pnl_r, tier, entry_ts, exit_ts, slippage_bps "
"FROM live_trades WHERE status NOT IN ('active','tp1_hit') AND strategy = $1",
strategy,
)
if not rows:
return {"error": "no data"}
total = len(rows)
wins = [r for r in rows if r["pnl_r"] and r["pnl_r"] > 0]
losses = [r for r in rows if r["pnl_r"] and r["pnl_r"] <= 0]
win_rate = len(wins) / total * 100 if total > 0 else 0
avg_win = sum(r["pnl_r"] for r in wins) / len(wins) if wins else 0
avg_loss = sum(abs(r["pnl_r"]) for r in losses) / len(losses) if losses else 0
win_loss_ratio = avg_win / avg_loss if avg_loss > 0 else 0
total_pnl = sum(r["pnl_r"] for r in rows if r["pnl_r"])
# 滑点统计
slippages = [r["slippage_bps"] for r in rows if r["slippage_bps"] is not None]
avg_slippage = sum(slippages) / len(slippages) if slippages else 0
slippages_sorted = sorted(slippages) if slippages else [0]
p50_slip = slippages_sorted[len(slippages_sorted)//2] if slippages_sorted else 0
p95_idx = min(int(len(slippages_sorted)*0.95), len(slippages_sorted)-1)
p95_slip = slippages_sorted[p95_idx] if slippages_sorted else 0
# MDD
cum = 0
peak = 0
mdd = 0
for r in sorted(rows, key=lambda x: x["exit_ts"] or 0):
cum += r["pnl_r"] or 0
if cum > peak:
peak = cum
dd = peak - cum
if dd > mdd:
mdd = dd
# 按币种
by_symbol = {}
for r in rows:
s = r["symbol"]
if s not in by_symbol:
by_symbol[s] = {"wins": 0, "total": 0, "pnl": 0}
by_symbol[s]["total"] += 1
by_symbol[s]["pnl"] += r["pnl_r"] or 0
if r["pnl_r"] and r["pnl_r"] > 0:
by_symbol[s]["wins"] += 1
for s in by_symbol:
by_symbol[s]["win_rate"] = round(by_symbol[s]["wins"]/by_symbol[s]["total"]*100, 1) if by_symbol[s]["total"] > 0 else 0
by_symbol[s]["total_pnl"] = round(by_symbol[s]["pnl"], 2)
return {
"total": total,
"win_rate": round(win_rate, 1),
"avg_win": round(avg_win, 3),
"avg_loss": round(avg_loss, 3),
"win_loss_ratio": round(win_loss_ratio, 2),
"total_pnl": round(total_pnl, 2),
"mdd": round(mdd, 2),
"avg_slippage_bps": round(avg_slippage, 2),
"p50_slippage_bps": round(p50_slip, 2),
"p95_slippage_bps": round(p95_slip, 2),
"by_symbol": by_symbol,
}
@app.get("/api/live/risk-status")
async def live_risk_status(user: dict = Depends(get_current_user)):
"""风控状态"""
try:
import json as _json
with open("/tmp/risk_guard_state.json") as f:
return _json.load(f)
except:
return {"status": "unknown", "error": "risk_guard_state.json not found"}
_risk_usd_cache = {"v": 2.0, "ts": 0.0}
async def _get_risk_usd() -> float:
"""从live_config读取1R金额缓存60秒"""
now = time.time()
if now - _risk_usd_cache["ts"] < 60:
return _risk_usd_cache["v"]
try:
row = await async_fetchrow("SELECT value FROM live_config WHERE key = $1", "risk_per_trade_usd")
v = float(row["value"]) if row else 2.0
except Exception:
v = 2.0
_risk_usd_cache.update({"v": v, "ts": now})
return v
def _require_admin(user: dict):
"""检查管理员权限"""
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="仅管理员可执行此操作")
@app.post("/api/live/emergency-close")
async def live_emergency_close(user: dict = Depends(get_current_user)):
"""紧急全平写标记文件由risk_guard执行"""
_require_admin(user)
try:
import json as _json
with open("/tmp/risk_guard_emergency.json", "w") as f:
_json.dump({"action": "close_all", "time": time.time(), "user": user.get("email", "unknown")}, f)
return {"ok": True, "message": "紧急平仓指令已发送"}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.post("/api/live/block-new")
async def live_block_new(user: dict = Depends(get_current_user)):
"""禁止新开仓"""
_require_admin(user)
try:
import json as _json
with open("/tmp/risk_guard_emergency.json", "w") as f:
_json.dump({"action": "block_new", "time": time.time(), "user": user.get("email", "unknown")}, f)
return {"ok": True, "message": "已禁止新开仓"}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.post("/api/live/resume")
async def live_resume(user: dict = Depends(get_current_user)):
"""恢复交易"""
_require_admin(user)
try:
import json as _json
with open("/tmp/risk_guard_emergency.json", "w") as f:
_json.dump({"action": "resume", "time": time.time(), "user": user.get("email", "unknown")}, f)
return {"ok": True, "message": "已恢复交易"}
except Exception as e:
return {"ok": False, "error": str(e)}
# ============================================================
# 实盘 API 补充L0-L11
# ============================================================
@app.get("/api/live/account")
async def live_account(user: dict = Depends(get_current_user)):
"""L2: 账户概览 — 权益/保证金/杠杆/今日成交额"""
import httpx
import hashlib, hmac, time as _time
from urllib.parse import urlencode
api_key = os.environ.get("BINANCE_API_KEY", "")
secret_key = os.environ.get("BINANCE_SECRET_KEY", "")
trade_env = os.environ.get("TRADE_ENV", "testnet")
base = "https://testnet.binancefuture.com" if trade_env == "testnet" else "https://fapi.binance.com"
if not api_key or not secret_key:
return {"error": "API keys not configured", "equity": 0, "available_margin": 0, "used_margin": 0, "effective_leverage": 0}
def sign(params):
params["timestamp"] = int(_time.time() * 1000)
qs = urlencode(params)
sig = hmac.new(secret_key.encode(), qs.encode(), hashlib.sha256).hexdigest()
params["signature"] = sig
return params
try:
async with httpx.AsyncClient(timeout=5) as client:
# 账户信息
params = sign({})
resp = await client.get(f"{base}/fapi/v2/account", params=params, headers={"X-MBX-APIKEY": api_key})
if resp.status_code != 200:
return {"error": f"API {resp.status_code}"}
acc = resp.json()
equity = float(acc.get("totalWalletBalance", 0))
available = float(acc.get("availableBalance", 0))
used_margin = float(acc.get("totalInitialMargin", 0))
unrealized = float(acc.get("totalUnrealizedProfit", 0))
effective_leverage = round(used_margin / equity, 2) if equity > 0 else 0
# 今日已实现PnL
today_realized_r = 0
today_fee = 0
today_volume = 0
try:
rows = await async_fetch(
"SELECT pnl_r, fee_usdt FROM live_trades WHERE exit_ts >= $1 AND status NOT IN ('active','tp1_hit')",
int(datetime.now(timezone.utc).replace(hour=0,minute=0,second=0,microsecond=0).timestamp() * 1000)
)
today_realized_r = sum(r["pnl_r"] or 0 for r in rows)
today_fee = sum(r["fee_usdt"] or 0 for r in rows)
except:
pass
return {
"equity": round(equity, 2),
"available_margin": round(available, 2),
"used_margin": round(used_margin, 2),
"unrealized_pnl": round(unrealized, 2),
"effective_leverage": effective_leverage,
"today_realized_r": round(today_realized_r, 2),
"today_realized_usdt": round(today_realized_r * (await _get_risk_usd()), 2),
"today_fee": round(today_fee, 2),
"today_volume": round(today_volume, 2),
}
except Exception as e:
return {"error": str(e)}
@app.get("/api/live/health")
async def live_system_health(user: dict = Depends(get_current_user)):
"""L11: 系统健康 — 各进程心跳、API状态、数据新鲜度"""
import subprocess, time as _time
health = {
"ts": int(_time.time() * 1000),
"processes": {},
"data_freshness": {},
"api_status": "unknown",
}
# PM2进程状态
try:
result = subprocess.run(["pm2", "jlist"], capture_output=True, text=True, timeout=5)
import json as _json
procs = _json.loads(result.stdout) if result.stdout else []
for p in procs:
name = p.get("name", "")
if name in ("live-executor", "position-sync", "risk-guard", "signal-engine", "market-collector", "paper-monitor", "liq-collector"):
health["processes"][name] = {
"status": p.get("pm2_env", {}).get("status", "unknown"),
"uptime_ms": p.get("pm2_env", {}).get("pm_uptime", 0),
"restarts": p.get("pm2_env", {}).get("restart_time", 0),
"memory_mb": round(p.get("monit", {}).get("memory", 0) / 1024 / 1024, 1),
"cpu": p.get("monit", {}).get("cpu", 0),
}
except:
pass
# 数据新鲜度
try:
now_ms = int(_time.time() * 1000)
# 最新行情数据
latest_market = await async_fetchrow("SELECT MAX(ts) as ts FROM signal_indicators")
if latest_market and latest_market["ts"]:
age_sec = (now_ms - latest_market["ts"]) / 1000
health["data_freshness"]["market_data"] = {
"last_ts": latest_market["ts"],
"age_sec": round(age_sec, 1),
"status": "red" if age_sec > 10 else "yellow" if age_sec > 5 else "green",
}
# 最新对账
risk_state = {}
try:
import json as _json2
with open("/tmp/risk_guard_state.json") as f:
risk_state = _json2.load(f)
except:
pass
health["risk_guard"] = risk_state
except:
pass
return health
@app.get("/api/live/reconciliation")
async def live_reconciliation(user: dict = Depends(get_current_user)):
"""L5: 对账状态 — 本地 vs 币安"""
import httpx, hashlib, hmac, time as _time
from urllib.parse import urlencode
api_key = os.environ.get("BINANCE_API_KEY", "")
secret_key = os.environ.get("BINANCE_SECRET_KEY", "")
trade_env = os.environ.get("TRADE_ENV", "testnet")
base = "https://testnet.binancefuture.com" if trade_env == "testnet" else "https://fapi.binance.com"
if not api_key or not secret_key:
return {"error": "API keys not configured"}
def sign(params):
params["timestamp"] = int(_time.time() * 1000)
qs = urlencode(params)
sig = hmac.new(secret_key.encode(), qs.encode(), hashlib.sha256).hexdigest()
params["signature"] = sig
return params
result = {"local_positions": [], "exchange_positions": [], "local_orders": 0, "exchange_orders": 0, "diffs": [], "status": "ok"}
try:
async with httpx.AsyncClient(timeout=5) as client:
# 币安持仓
params = sign({})
resp = await client.get(f"{base}/fapi/v2/positionRisk", params=params, headers={"X-MBX-APIKEY": api_key})
exchange_positions = []
if resp.status_code == 200:
for p in resp.json():
amt = float(p.get("positionAmt", 0))
if amt != 0:
exchange_positions.append({
"symbol": p["symbol"],
"direction": "LONG" if amt > 0 else "SHORT",
"amount": abs(amt),
"entry_price": float(p.get("entryPrice", 0)),
"mark_price": float(p.get("markPrice", 0)),
"liquidation_price": float(p.get("liquidationPrice", 0)),
"unrealized_pnl": float(p.get("unRealizedProfit", 0)),
})
result["exchange_positions"] = exchange_positions
# 币安挂单
params2 = sign({})
resp2 = await client.get(f"{base}/fapi/v1/openOrders", params=params2, headers={"X-MBX-APIKEY": api_key})
exchange_orders = resp2.json() if resp2.status_code == 200 else []
result["exchange_orders"] = len(exchange_orders)
# 本地持仓
local = await async_fetch(
"SELECT id, symbol, direction, entry_price, sl_price, tp1_price, tp2_price, status, tp1_hit "
"FROM live_trades WHERE status IN ('active','tp1_hit')"
)
result["local_positions"] = [dict(r) for r in local]
result["local_orders"] = len(local) * 3 # 预期每仓3挂单(SL+TP1+TP2)
# 对比差异
local_syms = {r["symbol"]: r for r in local}
exchange_syms = {p["symbol"]: p for p in exchange_positions}
for sym, lp in local_syms.items():
if sym not in exchange_syms:
result["diffs"].append({"symbol": sym, "type": "local_only", "severity": "critical", "detail": f"本地有{lp['direction']}仓位但币安无持仓"})
else:
ep = exchange_syms[sym]
if lp["direction"] != ep["direction"]:
result["diffs"].append({"symbol": sym, "type": "direction_mismatch", "severity": "critical", "detail": f"本地={lp['direction']} 币安={ep['direction']}"})
# 清算距离
if ep["liquidation_price"] > 0 and ep["mark_price"] > 0:
if ep["direction"] == "LONG":
dist = (ep["mark_price"] - ep["liquidation_price"]) / ep["mark_price"] * 100
else:
dist = (ep["liquidation_price"] - ep["mark_price"]) / ep["mark_price"] * 100
if dist < 8:
result["diffs"].append({"symbol": sym, "type": "liquidation_critical", "severity": "critical", "detail": f"距清算仅{dist:.1f}%"})
elif dist < 12:
result["diffs"].append({"symbol": sym, "type": "liquidation_warning", "severity": "high", "detail": f"距清算{dist:.1f}%"})
for sym, ep in exchange_syms.items():
if sym not in local_syms:
result["diffs"].append({"symbol": sym, "type": "exchange_only", "severity": "high", "detail": f"币安有{ep['direction']}仓位但本地无记录"})
if result["diffs"]:
result["status"] = "mismatch"
except Exception as e:
result["error"] = str(e)
return result
@app.get("/api/live/execution-quality")
async def live_execution_quality(user: dict = Depends(get_current_user)):
"""L4: 执行质量面板"""
rows = await async_fetch(
"SELECT symbol, slippage_bps, signal_to_order_ms, order_to_fill_ms, protection_gap_ms "
"FROM live_trades WHERE signal_to_order_ms IS NOT NULL ORDER BY created_at DESC LIMIT 200"
)
if not rows:
return {"error": "no data"}
# 按币种分组
by_coin = {}
all_slips = []
all_s2o = []
all_o2f = []
all_prot = []
for r in rows:
sym = r["symbol"]
if sym not in by_coin:
by_coin[sym] = {"slippages": [], "s2o": [], "o2f": [], "protection": [], "count": 0}
by_coin[sym]["count"] += 1
if r["slippage_bps"] is not None:
by_coin[sym]["slippages"].append(r["slippage_bps"])
all_slips.append(r["slippage_bps"])
if r["signal_to_order_ms"] is not None:
by_coin[sym]["s2o"].append(r["signal_to_order_ms"])
all_s2o.append(r["signal_to_order_ms"])
if r["order_to_fill_ms"] is not None:
by_coin[sym]["o2f"].append(r["order_to_fill_ms"])
all_o2f.append(r["order_to_fill_ms"])
if r["protection_gap_ms"] is not None:
by_coin[sym]["protection"].append(r["protection_gap_ms"])
all_prot.append(r["protection_gap_ms"])
def percentile(arr, p):
if not arr: return 0
s = sorted(arr)
idx = min(int(len(s) * p / 100), len(s) - 1)
return s[idx]
def stats(arr):
if not arr: return {"avg": 0, "p50": 0, "p95": 0, "min": 0, "max": 0}
return {
"avg": round(sum(arr)/len(arr), 2),
"p50": round(percentile(arr, 50), 2),
"p95": round(percentile(arr, 95), 2),
"min": round(min(arr), 2),
"max": round(max(arr), 2),
}
result = {
"total_trades": len(rows),
"overall": {
"slippage_bps": stats(all_slips),
"signal_to_order_ms": stats(all_s2o),
"order_to_fill_ms": stats(all_o2f),
"protection_gap_ms": stats(all_prot),
},
"by_symbol": {},
}
for sym, d in by_coin.items():
result["by_symbol"][sym] = {
"count": d["count"],
"slippage_bps": stats(d["slippages"]),
"signal_to_order_ms": stats(d["s2o"]),
"order_to_fill_ms": stats(d["o2f"]),
"protection_gap_ms": stats(d["protection"]),
}
return result
@app.get("/api/live/paper-comparison")
async def live_paper_comparison(
limit: int = 50,
user: dict = Depends(get_current_user),
):
"""L8: 实盘 vs 模拟盘对照"""
# 按signal_id匹配
rows = await async_fetch("""
SELECT lt.symbol, lt.direction, lt.entry_price as live_entry, lt.exit_price as live_exit,
lt.pnl_r as live_pnl, lt.slippage_bps as live_slip, lt.entry_ts as live_entry_ts,
lt.signal_id,
pt.entry_price as paper_entry, pt.exit_price as paper_exit, pt.pnl_r as paper_pnl
FROM live_trades lt
LEFT JOIN paper_trades pt ON lt.signal_id = pt.signal_id AND lt.strategy = pt.strategy
WHERE lt.status NOT IN ('active','tp1_hit')
ORDER BY lt.exit_ts DESC
LIMIT $1
""", limit)
comparisons = []
total_entry_diff = 0
total_pnl_diff = 0
count = 0
for r in rows:
d = dict(r)
if r["paper_entry"] and r["live_entry"]:
d["entry_diff"] = round(r["live_entry"] - r["paper_entry"], 6)
d["entry_diff_bps"] = round(d["entry_diff"] / r["paper_entry"] * 10000, 2) if r["paper_entry"] else 0
if r["paper_pnl"] is not None and r["live_pnl"] is not None:
d["pnl_diff_r"] = round(r["live_pnl"] - r["paper_pnl"], 4)
total_pnl_diff += d["pnl_diff_r"]
count += 1
comparisons.append(d)
return {
"count": len(comparisons),
"avg_pnl_diff_r": round(total_pnl_diff / count, 4) if count > 0 else 0,
"data": comparisons,
}
# ============ Live Events (L7通知流) ============
async def log_live_event(level: str, category: str, message: str, symbol: str = None, detail: dict = None):
"""写入实盘事件日志(供各模块调用)"""
try:
import json as _json
await async_execute(
"INSERT INTO live_events (level, category, symbol, message, detail) VALUES ($1, $2, $3, $4, $5)",
level, category, symbol, message, _json.dumps(detail) if detail else None
)
except Exception:
pass
@app.get("/api/live/events")
async def live_events(
limit: int = 50,
level: str = "all",
category: str = "all",
user: dict = Depends(get_current_user),
):
"""实盘事件流"""
conditions = ["1=1"]
params = []
idx = 1
if level != "all":
conditions.append(f"level = ${idx}")
params.append(level)
idx += 1
if category != "all":
conditions.append(f"category = ${idx}")
params.append(category)
idx += 1
params.append(limit)
where = " AND ".join(conditions)
rows = await async_fetch(
f"SELECT id, ts, level, category, symbol, message, detail "
f"FROM live_events WHERE {where} ORDER BY ts DESC LIMIT ${idx}",
*params
)
return {"count": len(rows), "data": rows}
# ============ Live Config (实盘配置) ============
@app.get("/api/live/config")
async def live_config_get(user: dict = Depends(get_current_user)):
"""获取实盘配置"""
rows = await async_fetch("SELECT key, value, label, updated_at FROM live_config ORDER BY key")
config = {}
for r in rows:
config[r["key"]] = {"value": r["value"], "label": r["label"], "updated_at": str(r["updated_at"])}
return config
@app.put("/api/live/config")
async def live_config_update(request: Request, user: dict = Depends(get_current_user)):
"""更新实盘配置"""
_require_admin(user)
body = await request.json()
updated = []
for key, value in body.items():
await async_execute(
"UPDATE live_config SET value = $1, updated_at = NOW() WHERE key = $2",
str(value), key
)
updated.append(key)
return {"updated": updated}
# ─────────────────────────────────────────────
# 策略广场 API
# ─────────────────────────────────────────────
import json as _json
import os as _os
import statistics as _statistics
_STRATEGY_META = {
"v53": {
"display_name": "V5.3 标准版",
"cvd_windows": "30m / 4h",
"description": "标准版30分钟+4小时CVD双轨适配1小时信号周期",
"initial_balance": 10000,
},
"v53_fast": {
"display_name": "V5.3 Fast版",
"cvd_windows": "5m / 30m",
"description": "快速版5分钟+30分钟CVD双轨捕捉短期动量",
"initial_balance": 10000,
},
"v53_middle": {
"display_name": "V5.3 Middle版",
"cvd_windows": "15m / 1h",
"description": "中速版15分钟+1小时CVD双轨平衡噪音与时效",
"initial_balance": 10000,
},
}
async def _get_strategy_status(strategy_id: str) -> str:
"""根据 paper_config 和最新心跳判断策略状态"""
config_path = _os.path.join(_os.path.dirname(__file__), "paper_config.json")
try:
with open(config_path) as f:
config = _json.load(f)
enabled = strategy_id in config.get("enabled_strategies", [])
except Exception:
enabled = False
if not enabled:
return "paused"
# 检查最近5分钟内是否有心跳
cutoff = int((__import__("time").time() - 300) * 1000)
row = await async_fetch(
"SELECT ts FROM signal_indicators WHERE strategy=$1 AND ts > $2 ORDER BY ts DESC LIMIT 1",
strategy_id, cutoff
)
if row:
return "running"
return "error"
@app.get("/api/strategy-plaza")
async def strategy_plaza(user: dict = Depends(get_current_user)):
"""策略广场总览:返回所有策略卡片数据"""
now_ms = int(__import__("time").time() * 1000)
cutoff_24h = now_ms - 86400000
results = []
for sid, meta in _STRATEGY_META.items():
# 累计统计
rows = await async_fetch(
"SELECT pnl_r, entry_ts, exit_ts FROM paper_trades "
"WHERE strategy=$1 AND exit_ts IS NOT NULL",
sid
)
# 活跃持仓
open_rows = await async_fetch(
"SELECT COUNT(*) as cnt FROM paper_trades WHERE strategy=$1 AND exit_ts IS NULL",
sid
)
open_positions = int(open_rows[0]["cnt"]) if open_rows else 0
# 24h 统计
rows_24h = [r for r in rows if (r["exit_ts"] or 0) >= cutoff_24h]
pnl_rs = [float(r["pnl_r"]) for r in rows]
wins = [p for p in pnl_rs if p > 0]
losses = [p for p in pnl_rs if p <= 0]
net_r = round(sum(pnl_rs), 3)
net_usdt = round(net_r * 200, 0)
pnl_rs_24h = [float(r["pnl_r"]) for r in rows_24h]
pnl_r_24h = round(sum(pnl_rs_24h), 3)
pnl_usdt_24h = round(pnl_r_24h * 200, 0)
std_r = round(_statistics.stdev(pnl_rs), 3) if len(pnl_rs) > 1 else 0.0
started_at = min(r["entry_ts"] for r in rows) if rows else now_ms
last_trade_at = max(r["exit_ts"] for r in rows if r["exit_ts"]) if rows else None
status = await _get_strategy_status(sid)
results.append({
"id": sid,
"display_name": meta["display_name"],
"status": status,
"started_at": started_at,
"initial_balance": meta["initial_balance"],
"current_balance": meta["initial_balance"] + int(net_usdt),
"net_usdt": int(net_usdt),
"net_r": net_r,
"trade_count": len(pnl_rs),
"win_rate": round(len(wins) / len(pnl_rs) * 100, 1) if pnl_rs else 0.0,
"avg_win_r": round(sum(wins) / len(wins), 3) if wins else 0.0,
"avg_loss_r": round(sum(losses) / len(losses), 3) if losses else 0.0,
"open_positions": open_positions,
"pnl_usdt_24h": int(pnl_usdt_24h),
"pnl_r_24h": pnl_r_24h,
"std_r": std_r,
"last_trade_at": last_trade_at,
})
return {"strategies": results}
@app.get("/api/strategy-plaza/{strategy_id}/summary")
async def strategy_plaza_summary(strategy_id: str, user: dict = Depends(get_current_user)):
"""策略详情 summary卡片数据 + 详情字段"""
if strategy_id not in _STRATEGY_META:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Strategy not found")
# 先拿广场数据
plaza_data = await strategy_plaza(user)
card = next((s for s in plaza_data["strategies"] if s["id"] == strategy_id), None)
if not card:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Strategy not found")
meta = _STRATEGY_META[strategy_id]
# 读策略 JSON 获取权重和阈值
strategy_file = _os.path.join(_os.path.dirname(__file__), "strategies", f"{strategy_id}.json")
weights = {}
thresholds = {}
symbols = []
try:
with open(strategy_file) as f:
cfg = _json.load(f)
weights = {
"direction": cfg.get("direction_weight", 55),
"crowding": cfg.get("crowding_weight", 25),
"environment": cfg.get("environment_weight", 15),
"auxiliary": cfg.get("auxiliary_weight", 5),
}
thresholds = {
"signal_threshold": cfg.get("threshold", 75),
"flip_threshold": cfg.get("flip_threshold", 85),
}
symbols = list(cfg.get("symbol_gates", {}).keys())
except Exception:
pass
return {
**card,
"cvd_windows": meta["cvd_windows"],
"description": meta["description"],
"symbols": symbols,
"weights": weights,
"thresholds": thresholds,
}
@app.get("/api/strategy-plaza/{strategy_id}/signals")
async def strategy_plaza_signals(
strategy_id: str,
limit: int = 50,
user: dict = Depends(get_current_user)
):
"""策略信号列表(复用现有逻辑,加 strategy 过滤)"""
if strategy_id not in _STRATEGY_META:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Strategy not found")
rows = await async_fetch(
"SELECT ts, symbol, score, signal, price, factors, cvd_5m, cvd_15m, cvd_30m, cvd_1h, cvd_4h, atr_value "
"FROM signal_indicators WHERE strategy=$1 ORDER BY ts DESC LIMIT $2",
strategy_id, limit
)
return {"signals": [dict(r) for r in rows]}
@app.get("/api/strategy-plaza/{strategy_id}/trades")
async def strategy_plaza_trades(
strategy_id: str,
limit: int = 50,
user: dict = Depends(get_current_user)
):
"""策略交易记录(复用现有逻辑,加 strategy 过滤)"""
if strategy_id not in _STRATEGY_META:
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="Strategy not found")
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, entry_price, exit_price, "
"tp1_price, tp2_price, sl_price, tp1_hit, pnl_r, risk_distance, "
"entry_ts, exit_ts, status, strategy "
"FROM paper_trades WHERE strategy=$1 ORDER BY entry_ts DESC LIMIT $2",
strategy_id, limit
)
return {"trades": [dict(r) for r in rows]}
# ─────────────────────────────────────────────────────────────────────────────
# V5.4 Strategy Factory API
# ─────────────────────────────────────────────────────────────────────────────
import uuid as _uuid
from typing import Optional
from pydantic import BaseModel, Field, field_validator, model_validator
# ── Pydantic Models ──────────────────────────────────────────────────────────
class StrategyCreateRequest(BaseModel):
display_name: str = Field(..., min_length=1, max_length=50)
symbol: str
direction: str = "both"
initial_balance: float = 10000.0
cvd_fast_window: str = "30m"
cvd_slow_window: str = "4h"
weight_direction: int = 55
weight_env: int = 25
weight_aux: int = 15
weight_momentum: int = 5
entry_score: int = 75
# 门1 波动率
gate_vol_enabled: bool = True
vol_atr_pct_min: float = 0.002
# 门2 CVD共振
gate_cvd_enabled: bool = True
# 门3 鲸鱼否决
gate_whale_enabled: bool = True
whale_usd_threshold: float = 50000.0
whale_flow_pct: float = 0.5
# 门4 OBI否决
gate_obi_enabled: bool = True
obi_threshold: float = 0.35
# 门5 期现背离
gate_spot_perp_enabled: bool = False
spot_perp_threshold: float = 0.005
# 风控参数
sl_atr_multiplier: float = 1.5
tp1_ratio: float = 0.75
tp2_ratio: float = 1.5
timeout_minutes: int = 240
flip_threshold: int = 80
description: Optional[str] = None
@field_validator("symbol")
@classmethod
def validate_symbol(cls, v):
allowed = {"BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT"}
if v not in allowed:
raise ValueError(f"symbol must be one of {allowed}")
return v
@field_validator("direction")
@classmethod
def validate_direction(cls, v):
if v not in {"long_only", "short_only", "both"}:
raise ValueError("direction must be long_only, short_only, or both")
return v
@field_validator("cvd_fast_window")
@classmethod
def validate_cvd_fast(cls, v):
if v not in {"5m", "15m", "30m"}:
raise ValueError("cvd_fast_window must be 5m, 15m, or 30m")
return v
@field_validator("cvd_slow_window")
@classmethod
def validate_cvd_slow(cls, v):
if v not in {"30m", "1h", "4h"}:
raise ValueError("cvd_slow_window must be 30m, 1h, or 4h")
return v
@field_validator("weight_direction")
@classmethod
def validate_w_dir(cls, v):
if not 10 <= v <= 80:
raise ValueError("weight_direction must be 10-80")
return v
@field_validator("weight_env")
@classmethod
def validate_w_env(cls, v):
if not 5 <= v <= 60:
raise ValueError("weight_env must be 5-60")
return v
@field_validator("weight_aux")
@classmethod
def validate_w_aux(cls, v):
if not 0 <= v <= 40:
raise ValueError("weight_aux must be 0-40")
return v
@field_validator("weight_momentum")
@classmethod
def validate_w_mom(cls, v):
if not 0 <= v <= 20:
raise ValueError("weight_momentum must be 0-20")
return v
@model_validator(mode="after")
def validate_weights_sum(self):
total = self.weight_direction + self.weight_env + self.weight_aux + self.weight_momentum
if total != 100:
raise ValueError(f"Weights must sum to 100, got {total}")
return self
@field_validator("entry_score")
@classmethod
def validate_entry_score(cls, v):
if not 60 <= v <= 95:
raise ValueError("entry_score must be 60-95")
return v
@field_validator("vol_atr_pct_min")
@classmethod
def validate_vol_atr(cls, v):
if not 0.0001 <= v <= 0.02:
raise ValueError("vol_atr_pct_min must be 0.0001-0.02")
return v
@field_validator("whale_usd_threshold")
@classmethod
def validate_whale_usd(cls, v):
if not 1000 <= v <= 1000000:
raise ValueError("whale_usd_threshold must be 1000-1000000")
return v
@field_validator("whale_flow_pct")
@classmethod
def validate_whale_flow(cls, v):
if not 0.0 <= v <= 1.0:
raise ValueError("whale_flow_pct must be 0.0-1.0")
return v
@field_validator("obi_threshold")
@classmethod
def validate_obi(cls, v):
if not 0.1 <= v <= 0.9:
raise ValueError("obi_threshold must be 0.1-0.9")
return v
@field_validator("spot_perp_threshold")
@classmethod
def validate_spot_perp(cls, v):
if not 0.0005 <= v <= 0.01:
raise ValueError("spot_perp_threshold must be 0.0005-0.01")
return v
@field_validator("sl_atr_multiplier")
@classmethod
def validate_sl(cls, v):
if not 0.5 <= v <= 3.0:
raise ValueError("sl_atr_multiplier must be 0.5-3.0")
return v
@field_validator("tp1_ratio")
@classmethod
def validate_tp1(cls, v):
if not 0.3 <= v <= 2.0:
raise ValueError("tp1_ratio must be 0.3-2.0")
return v
@field_validator("tp2_ratio")
@classmethod
def validate_tp2(cls, v):
if not 0.5 <= v <= 4.0:
raise ValueError("tp2_ratio must be 0.5-4.0")
return v
@field_validator("timeout_minutes")
@classmethod
def validate_timeout(cls, v):
if not 30 <= v <= 1440:
raise ValueError("timeout_minutes must be 30-1440")
return v
@field_validator("flip_threshold")
@classmethod
def validate_flip(cls, v):
if not 60 <= v <= 95:
raise ValueError("flip_threshold must be 60-95")
return v
@field_validator("initial_balance")
@classmethod
def validate_balance(cls, v):
if v < 1000:
raise ValueError("initial_balance must be >= 1000")
return v
class StrategyUpdateRequest(BaseModel):
"""Partial update - all fields optional"""
display_name: Optional[str] = Field(None, min_length=1, max_length=50)
direction: Optional[str] = None
cvd_fast_window: Optional[str] = None
cvd_slow_window: Optional[str] = None
weight_direction: Optional[int] = None
weight_env: Optional[int] = None
weight_aux: Optional[int] = None
weight_momentum: Optional[int] = None
entry_score: Optional[int] = None
# 门1 波动率
gate_vol_enabled: Optional[bool] = None
vol_atr_pct_min: Optional[float] = None
# 门2 CVD共振
gate_cvd_enabled: Optional[bool] = None
# 门3 鲸鱼否决
gate_whale_enabled: Optional[bool] = None
whale_usd_threshold: Optional[float] = None
whale_flow_pct: Optional[float] = None
# 门4 OBI否决
gate_obi_enabled: Optional[bool] = None
obi_threshold: Optional[float] = None
# 门5 期现背离
gate_spot_perp_enabled: Optional[bool] = None
spot_perp_threshold: Optional[float] = None
# 风控
sl_atr_multiplier: Optional[float] = None
tp1_ratio: Optional[float] = None
tp2_ratio: Optional[float] = None
timeout_minutes: Optional[int] = None
flip_threshold: Optional[int] = None
description: Optional[str] = None
class AddBalanceRequest(BaseModel):
amount: float = Field(..., gt=0)
class DeprecateRequest(BaseModel):
confirm: bool
# ── Helper ──────────────────────────────────────────────────────────────────
async def _get_strategy_or_404(strategy_id: str) -> dict:
row = await async_fetchrow(
"SELECT * FROM strategies WHERE strategy_id=$1",
strategy_id
)
if not row:
raise HTTPException(status_code=404, detail="Strategy not found")
return dict(row)
def _strategy_row_to_card(row: dict) -> dict:
"""Convert a strategies row to a card-level response (no config params)"""
return {
"strategy_id": str(row["strategy_id"]),
"display_name": row["display_name"],
"status": row["status"],
"symbol": row["symbol"],
"direction": row["direction"],
"started_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
"initial_balance": row["initial_balance"],
"current_balance": row["current_balance"],
"net_usdt": round(row["current_balance"] - row["initial_balance"], 2),
"deprecated_at": int(row["deprecated_at"].timestamp() * 1000) if row.get("deprecated_at") else None,
"last_run_at": int(row["last_run_at"].timestamp() * 1000) if row.get("last_run_at") else None,
"schema_version": row["schema_version"],
}
def _strategy_row_to_detail(row: dict) -> dict:
"""Full detail including all config params"""
base = _strategy_row_to_card(row)
base.update({
"cvd_fast_window": row["cvd_fast_window"],
"cvd_slow_window": row["cvd_slow_window"],
"weight_direction": row["weight_direction"],
"weight_env": row["weight_env"],
"weight_aux": row["weight_aux"],
"weight_momentum": row["weight_momentum"],
"entry_score": row["entry_score"],
# 门1 波动率
"gate_vol_enabled": row["gate_vol_enabled"],
"vol_atr_pct_min": row["vol_atr_pct_min"],
# 门2 CVD共振
"gate_cvd_enabled": row["gate_cvd_enabled"],
# 门3 鲸鱼否决
"gate_whale_enabled": row["gate_whale_enabled"],
"whale_usd_threshold": row["whale_usd_threshold"],
"whale_flow_pct": row["whale_flow_pct"],
# 门4 OBI否决
"gate_obi_enabled": row["gate_obi_enabled"],
"obi_threshold": row["obi_threshold"],
# 门5 期现背离
"gate_spot_perp_enabled": row["gate_spot_perp_enabled"],
"spot_perp_threshold": row["spot_perp_threshold"],
"sl_atr_multiplier": row["sl_atr_multiplier"],
"tp1_ratio": row["tp1_ratio"],
"tp2_ratio": row["tp2_ratio"],
"timeout_minutes": row["timeout_minutes"],
"flip_threshold": row["flip_threshold"],
"description": row.get("description"),
"created_at": int(row["created_at"].timestamp() * 1000) if row.get("created_at") else 0,
"updated_at": int(row["updated_at"].timestamp() * 1000) if row.get("updated_at") else 0,
})
return base
async def _get_strategy_trade_stats(strategy_id: str) -> dict:
"""Fetch trade statistics for a strategy by strategy_id.
兼容新数据strategy_id列和旧数据strategy文本列
"""
# 固定 UUID → legacy strategy文本名映射迁移时写死的三条策略
LEGACY_NAME_MAP = {
"00000000-0000-0000-0000-000000000053": "v53",
"00000000-0000-0000-0000-000000000054": "v53_middle",
"00000000-0000-0000-0000-000000000055": "v53_fast",
}
legacy_name = LEGACY_NAME_MAP.get(strategy_id)
# 查已关闭的交易记录(同时兼容新旧两种匹配方式)
if legacy_name:
rows = await async_fetch(
"""SELECT status, pnl_r, tp1_hit, entry_ts, exit_ts
FROM paper_trades
WHERE status NOT IN ('active', 'tp1_hit')
AND (strategy_id=$1 OR (strategy_id IS NULL AND strategy=$2))
ORDER BY entry_ts DESC""",
strategy_id, legacy_name
)
else:
rows = await async_fetch(
"""SELECT status, pnl_r, tp1_hit, entry_ts, exit_ts
FROM paper_trades
WHERE strategy_id=$1 AND status NOT IN ('active', 'tp1_hit')
ORDER BY entry_ts DESC""",
strategy_id
)
if not rows:
# 即使没有历史记录也要查持仓
pass
total = len(rows)
wins = [r for r in rows if (r["pnl_r"] or 0) > 0]
losses = [r for r in rows if (r["pnl_r"] or 0) < 0]
win_rate = round(len(wins) / total * 100, 1) if total else 0.0
avg_win = round(sum(r["pnl_r"] for r in wins) / len(wins), 3) if wins else 0.0
avg_loss = round(sum(r["pnl_r"] for r in losses) / len(losses), 3) if losses else 0.0
last_trade_at = rows[0]["exit_ts"] if rows else None
# 24h stats
cutoff_ms = int((datetime.utcnow() - timedelta(hours=24)).timestamp() * 1000)
rows_24h = [r for r in rows if (r["exit_ts"] or 0) >= cutoff_ms]
pnl_r_24h = round(sum(r["pnl_r"] or 0 for r in rows_24h), 3)
pnl_usdt_24h = round(pnl_r_24h * 200, 2)
# Open positions — status IN ('active','tp1_hit'),同时兼容新旧记录
if legacy_name:
open_rows = await async_fetch(
"""SELECT COUNT(*) as cnt FROM paper_trades
WHERE status IN ('active','tp1_hit')
AND (strategy_id=$1 OR (strategy_id IS NULL AND strategy=$2))""",
strategy_id, legacy_name
)
else:
open_rows = await async_fetch(
"""SELECT COUNT(*) as cnt FROM paper_trades
WHERE strategy_id=$1 AND status IN ('active','tp1_hit')""",
strategy_id
)
open_positions = int(open_rows[0]["cnt"]) if open_rows else 0
return {
"trade_count": total,
"win_rate": win_rate,
"avg_win_r": avg_win,
"avg_loss_r": avg_loss,
"open_positions": open_positions,
"pnl_usdt_24h": pnl_usdt_24h,
"pnl_r_24h": pnl_r_24h,
"last_trade_at": last_trade_at,
"net_r": round(sum(r["pnl_r"] or 0 for r in rows), 3),
"net_usdt": round(sum(r["pnl_r"] or 0 for r in rows) * 200, 2),
}
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.post("/api/strategies")
async def create_strategy(body: StrategyCreateRequest, user: dict = Depends(get_current_user)):
"""创建新策略实例"""
new_id = str(_uuid.uuid4())
await async_execute(
"""INSERT INTO strategies (
strategy_id, display_name, schema_version, status,
symbol, direction,
cvd_fast_window, cvd_slow_window,
weight_direction, weight_env, weight_aux, weight_momentum,
entry_score,
gate_vol_enabled, vol_atr_pct_min,
gate_cvd_enabled,
gate_whale_enabled, whale_usd_threshold, whale_flow_pct,
gate_obi_enabled, obi_threshold,
gate_spot_perp_enabled, spot_perp_threshold,
sl_atr_multiplier, tp1_ratio, tp2_ratio,
timeout_minutes, flip_threshold,
initial_balance, current_balance,
description
) VALUES (
$1,$2,1,'running',
$3,$4,$5,$6,
$7,$8,$9,$10,
$11,
$12,$13,
$14,
$15,$16,$17,
$18,$19,
$20,$21,
$22,$23,$24,
$25,$26,
$27,$27,$28
)""",
new_id, body.display_name,
body.symbol, body.direction, body.cvd_fast_window, body.cvd_slow_window,
body.weight_direction, body.weight_env, body.weight_aux, body.weight_momentum,
body.entry_score,
body.gate_vol_enabled, body.vol_atr_pct_min,
body.gate_cvd_enabled,
body.gate_whale_enabled, body.whale_usd_threshold, body.whale_flow_pct,
body.gate_obi_enabled, body.obi_threshold,
body.gate_spot_perp_enabled, body.spot_perp_threshold,
body.sl_atr_multiplier, body.tp1_ratio, body.tp2_ratio,
body.timeout_minutes, body.flip_threshold,
body.initial_balance, body.description
)
row = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", new_id)
return {"ok": True, "strategy": _strategy_row_to_detail(dict(row))}
@app.get("/api/strategies")
async def list_strategies(
include_deprecated: bool = False,
user: dict = Depends(get_current_user)
):
"""获取策略列表"""
if include_deprecated:
rows = await async_fetch("SELECT * FROM strategies ORDER BY created_at ASC")
else:
rows = await async_fetch(
"SELECT * FROM strategies WHERE status != 'deprecated' ORDER BY created_at ASC"
)
result = []
for row in rows:
d = _strategy_row_to_card(dict(row))
stats = await _get_strategy_trade_stats(str(row["strategy_id"]))
d.update(stats)
# 用实时计算的 net_usdt 覆盖 DB 静态的 current_balance
d["current_balance"] = round(row["initial_balance"] + d["net_usdt"], 2)
result.append(d)
return {"strategies": result}
@app.get("/api/strategies/{sid}")
async def get_strategy(sid: str, user: dict = Depends(get_current_user)):
"""获取单个策略详情(含完整参数配置)"""
row = await _get_strategy_or_404(sid)
detail = _strategy_row_to_detail(row)
stats = await _get_strategy_trade_stats(sid)
detail.update(stats)
detail["current_balance"] = round(row["initial_balance"] + detail["net_usdt"], 2)
return {"strategy": detail}
@app.patch("/api/strategies/{sid}")
async def update_strategy(sid: str, body: StrategyUpdateRequest, user: dict = Depends(get_current_user)):
"""更新策略参数Partial Update"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot modify a deprecated strategy")
# Build SET clause dynamically from non-None fields
updates = body.model_dump(exclude_none=True)
if not updates:
raise HTTPException(status_code=400, detail="No fields to update")
# Validate weights sum if any weight is being changed
weight_fields = {"weight_direction", "weight_env", "weight_aux", "weight_momentum"}
if weight_fields & set(updates.keys()):
w_dir = updates.get("weight_direction", row["weight_direction"])
w_env = updates.get("weight_env", row["weight_env"])
w_aux = updates.get("weight_aux", row["weight_aux"])
w_mom = updates.get("weight_momentum", row["weight_momentum"])
if w_dir + w_env + w_aux + w_mom != 100:
raise HTTPException(status_code=400, detail=f"Weights must sum to 100, got {w_dir+w_env+w_aux+w_mom}")
# Validate individual field ranges
validators = {
"direction": lambda v: v in {"long_only", "short_only", "both"},
"cvd_fast_window": lambda v: v in {"5m", "15m", "30m"},
"cvd_slow_window": lambda v: v in {"30m", "1h", "4h"},
"weight_direction": lambda v: 10 <= v <= 80,
"weight_env": lambda v: 5 <= v <= 60,
"weight_aux": lambda v: 0 <= v <= 40,
"weight_momentum": lambda v: 0 <= v <= 20,
"entry_score": lambda v: 60 <= v <= 95,
"obi_threshold": lambda v: 0.1 <= v <= 0.9,
"vol_atr_pct_min": lambda v: 0.0001 <= v <= 0.02,
"whale_usd_threshold": lambda v: 1000 <= v <= 1000000,
"whale_flow_pct": lambda v: 0.0 <= v <= 1.0,
"spot_perp_threshold": lambda v: 0.0005 <= v <= 0.01,
"sl_atr_multiplier": lambda v: 0.5 <= v <= 3.0,
"tp1_ratio": lambda v: 0.3 <= v <= 2.0,
"tp2_ratio": lambda v: 0.5 <= v <= 4.0,
"timeout_minutes": lambda v: 30 <= v <= 1440,
"flip_threshold": lambda v: 60 <= v <= 95,
}
for field, val in updates.items():
if field in validators and not validators[field](val):
raise HTTPException(status_code=400, detail=f"Invalid value for {field}: {val}")
# Execute update
set_parts = [f"{k}=${i+2}" for i, k in enumerate(updates.keys())]
set_parts.append(f"updated_at=NOW()")
sql = f"UPDATE strategies SET {', '.join(set_parts)} WHERE strategy_id=$1"
await async_execute(sql, sid, *updates.values())
updated = await async_fetchrow("SELECT * FROM strategies WHERE strategy_id=$1", sid)
return {"ok": True, "strategy": _strategy_row_to_detail(dict(updated))}
@app.post("/api/strategies/{sid}/pause")
async def pause_strategy(sid: str, user: dict = Depends(get_current_user)):
"""暂停策略(停止开新仓,不影响现有持仓)"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot pause a deprecated strategy")
if row["status"] == "paused":
return {"ok": True, "message": "Already paused"}
await async_execute(
"UPDATE strategies SET status='paused', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
sid
)
return {"ok": True, "status": "paused"}
@app.post("/api/strategies/{sid}/resume")
async def resume_strategy(sid: str, user: dict = Depends(get_current_user)):
"""恢复策略"""
row = await _get_strategy_or_404(sid)
if row["status"] == "running":
return {"ok": True, "message": "Already running"}
await async_execute(
"UPDATE strategies SET status='running', status_changed_at=NOW(), updated_at=NOW() WHERE strategy_id=$1",
sid
)
return {"ok": True, "status": "running"}
@app.post("/api/strategies/{sid}/deprecate")
async def deprecate_strategy(sid: str, body: DeprecateRequest, user: dict = Depends(get_current_user)):
"""废弃策略(数据永久保留,可重新启用)"""
if not body.confirm:
raise HTTPException(status_code=400, detail="Must set confirm=true to deprecate")
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
return {"ok": True, "message": "Already deprecated"}
await async_execute(
"""UPDATE strategies
SET status='deprecated', deprecated_at=NOW(),
status_changed_at=NOW(), updated_at=NOW()
WHERE strategy_id=$1""",
sid
)
return {"ok": True, "status": "deprecated"}
@app.post("/api/strategies/{sid}/restore")
async def restore_strategy(sid: str, user: dict = Depends(get_current_user)):
"""重新启用废弃策略(继续原有余额和历史数据)"""
row = await _get_strategy_or_404(sid)
if row["status"] != "deprecated":
raise HTTPException(status_code=400, detail="Strategy is not deprecated")
await async_execute(
"""UPDATE strategies
SET status='running', deprecated_at=NULL,
status_changed_at=NOW(), updated_at=NOW()
WHERE strategy_id=$1""",
sid
)
return {"ok": True, "status": "running"}
@app.post("/api/strategies/{sid}/add-balance")
async def add_balance(sid: str, body: AddBalanceRequest, user: dict = Depends(get_current_user)):
"""追加余额initial_balance 和 current_balance 同步增加)"""
row = await _get_strategy_or_404(sid)
if row["status"] == "deprecated":
raise HTTPException(status_code=403, detail="Cannot add balance to a deprecated strategy")
new_initial = round(row["initial_balance"] + body.amount, 2)
new_current = round(row["current_balance"] + body.amount, 2)
await async_execute(
"""UPDATE strategies
SET initial_balance=$2, current_balance=$3, updated_at=NOW()
WHERE strategy_id=$1""",
sid, new_initial, new_current
)
return {
"ok": True,
"initial_balance": new_initial,
"current_balance": new_current,
"added": body.amount,
}