arbitrage-engine/backend/main.py
2026-02-28 11:20:05 +00:00

864 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
import httpx
from datetime import datetime, timedelta
import asyncio, time, os
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
from db import (
init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute,
close_async_pool,
)
import datetime as _dt
app = FastAPI(title="Arbitrage Engine API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth_router)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
# 简单内存缓存history/stats 60秒rates 3秒
_cache: dict = {}
def get_cache(key: str, ttl: int):
entry = _cache.get(key)
if entry and time.time() - entry["ts"] < ttl:
return entry["data"]
return None
def set_cache(key: str, data):
_cache[key] = {"ts": time.time(), "data": data}
async def save_snapshot(rates: dict):
try:
btc = rates.get("BTC", {})
eth = rates.get("ETH", {})
await async_execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) "
"VALUES ($1,$2,$3,$4,$5,$6,$7)",
int(time.time()),
float(btc.get("lastFundingRate", 0)),
float(eth.get("lastFundingRate", 0)),
float(btc.get("markPrice", 0)),
float(eth.get("markPrice", 0)),
float(btc.get("indexPrice", 0)),
float(eth.get("indexPrice", 0)),
)
except Exception as e:
pass # 落库失败不影响API响应
async def background_snapshot_loop():
"""后台每2秒自动拉取费率+价格并落库"""
while True:
try:
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks, return_exceptions=True)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if isinstance(resp, Exception) or resp.status_code != 200:
continue
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"lastFundingRate": float(data["lastFundingRate"]),
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
}
if result:
await save_snapshot(result)
except Exception:
pass
await asyncio.sleep(2)
@app.on_event("startup")
async def startup():
# 初始化PG schema
init_schema()
ensure_auth_tables()
# 初始化asyncpg池
await get_async_pool()
asyncio.create_task(background_snapshot_loop())
@app.on_event("shutdown")
async def shutdown():
await close_async_pool()
@app.get("/api/health")
async def health():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@app.get("/api/rates")
async def get_rates():
cached = get_cache("rates", 3)
if cached: return cached
async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance error for {sym}")
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"symbol": sym,
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
"lastFundingRate": float(data["lastFundingRate"]),
"nextFundingTime": data["nextFundingTime"],
"timestamp": data["time"],
}
set_cache("rates", result)
asyncio.create_task(save_snapshot(result))
return result
@app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
since = int(time.time()) - hours * 3600
rows = await async_fetch(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots "
"WHERE ts >= $1 ORDER BY ts ASC LIMIT $2",
since, limit
)
return {"count": len(rows), "hours": hours, "data": rows}
@app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
interval_secs = {
"1m": 60, "5m": 300, "30m": 1800,
"1h": 3600, "4h": 14400, "8h": 28800,
"1d": 86400, "1w": 604800, "1M": 2592000,
}
bar_secs = interval_secs.get(interval, 300)
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
since = int(time.time()) - bar_secs * limit
rows = await async_fetch(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots "
f"WHERE ts >= $1 ORDER BY ts ASC",
since
)
if not rows:
return {"symbol": symbol, "interval": interval, "data": []}
bars: dict = {}
for r in rows:
ts, rate, price = r["ts"], r["rate"], r["price"]
bar_ts = (ts // bar_secs) * bar_secs
if bar_ts not in bars:
bars[bar_ts] = {
"time": bar_ts,
"open": rate, "high": rate, "low": rate, "close": rate,
"price_open": price, "price_high": price, "price_low": price, "price_close": price,
}
else:
b = bars[bar_ts]
b["high"] = max(b["high"], rate)
b["low"] = min(b["low"], rate)
b["close"] = rate
b["price_high"] = max(b["price_high"], price)
b["price_low"] = min(b["price_low"], price)
b["price_close"] = price
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
for b in data:
for k in ("open", "high", "low", "close"):
b[k] = round(b[k] * 10000, 4)
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
@app.get("/api/stats/ytd")
async def get_stats_ytd(user: dict = Depends(get_current_user)):
cached = get_cache("stats_ytd", 3600)
if cached: return cached
import datetime
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
end_time = int(time.time() * 1000)
async with httpx.AsyncClient(timeout=20, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": year_start, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
result[sym.replace("USDT","")] = {"annualized": 0, "count": 0}
continue
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
result[key] = {"annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = round(mean * 3 * 365 * 100, 2)
result[key] = {"annualized": annualized, "count": len(rates)}
set_cache("stats_ytd", result)
return result
@app.get("/api/signals/history")
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
try:
rows = await async_fetch(
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT $1",
limit
)
return {"items": rows}
except Exception as e:
return {"items": [], "error": str(e)}
@app.get("/api/history")
async def get_history(user: dict = Depends(get_current_user)):
cached = get_cache("history", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance history error for {sym}")
key = sym.replace("USDT", "")
result[key] = [
{"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]),
"timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()}
for item in resp.json()
]
set_cache("history", result)
return result
@app.get("/api/stats")
async def get_stats(user: dict = Depends(get_current_user)):
cached = get_cache("stats", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
stats = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}")
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
stats[key] = {"mean7d": 0, "annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = mean * 3 * 365 * 100
stats[key] = {
"mean7d": round(mean * 100, 6),
"annualized": round(annualized, 2),
"count": len(rates),
}
btc_ann = stats.get("BTC", {}).get("annualized", 0)
eth_ann = stats.get("ETH", {}).get("annualized", 0)
btc_mean = stats.get("BTC", {}).get("mean7d", 0)
eth_mean = stats.get("ETH", {}).get("mean7d", 0)
stats["combo"] = {
"mean7d": round((btc_mean + eth_mean) / 2, 6),
"annualized": round((btc_ann + eth_ann) / 2, 2),
}
set_cache("stats", stats)
return stats
# ─── aggTrades 查询接口PG版───────────────────────────────────
@app.get("/api/trades/meta")
async def get_trades_meta(user: dict = Depends(get_current_user)):
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta")
result = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
result[sym] = {
"last_agg_id": r["last_agg_id"],
"last_time_ms": r["last_time_ms"],
"updated_at": r["updated_at"],
}
return result
@app.get("/api/trades/summary")
async def get_trades_summary(
symbol: str = "BTC",
start_ms: int = 0,
end_ms: int = 0,
interval: str = "1m",
user: dict = Depends(get_current_user),
):
if end_ms == 0:
end_ms = int(time.time() * 1000)
if start_ms == 0:
start_ms = end_ms - 3600 * 1000
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
sym_full = symbol.upper() + "USDT"
# PG原生聚合比Python循环快100倍
rows = await async_fetch(
"""
SELECT
(time_ms / $4) * $4 AS bar_ms,
ROUND(SUM(CASE WHEN is_buyer_maker = 0 THEN qty ELSE 0 END)::numeric, 4) AS buy_vol,
ROUND(SUM(CASE WHEN is_buyer_maker = 1 THEN qty ELSE 0 END)::numeric, 4) AS sell_vol,
COUNT(*) AS trade_count,
ROUND((SUM(price * qty) / NULLIF(SUM(qty), 0))::numeric, 2) AS vwap,
ROUND(MAX(qty)::numeric, 4) AS max_qty
FROM agg_trades
WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3
GROUP BY bar_ms
ORDER BY bar_ms ASC
""",
sym_full, start_ms, end_ms, interval_ms
)
result = []
for r in rows:
buy = float(r["buy_vol"])
sell = float(r["sell_vol"])
result.append({
"time_ms": r["bar_ms"],
"buy_vol": buy,
"sell_vol": sell,
"delta": round(buy - sell, 4),
"total_vol": round(buy + sell, 4),
"trade_count": r["trade_count"],
"vwap": float(r["vwap"]) if r["vwap"] else 0,
"max_qty": float(r["max_qty"]),
})
return {"symbol": symbol, "interval": interval, "count": len(result), "data": result}
@app.get("/api/trades/latest")
async def get_trades_latest(
symbol: str = "BTC",
limit: int = 30,
user: dict = Depends(get_current_user),
):
cache_key = f"trades_latest_{symbol}_{limit}"
cached = get_cache(cache_key, 2)
if cached: return cached
sym_full = symbol.upper() + "USDT"
rows = await async_fetch(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2",
sym_full, limit
)
result = {"symbol": symbol, "count": len(rows), "data": rows}
set_cache(cache_key, result)
return result
@app.get("/api/collector/health")
async def collector_health(user: dict = Depends(get_current_user)):
now_ms = int(time.time() * 1000)
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta")
status = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000
status[sym] = {
"last_agg_id": r["last_agg_id"],
"lag_seconds": round(lag_s, 1),
"healthy": lag_s < 30,
}
return {"collector": status, "timestamp": now_ms}
# ─── V5 Signal Engine API ────────────────────────────────────────
@app.get("/api/signals/indicators")
async def get_signal_indicators(
symbol: str = "BTC",
minutes: int = 60,
user: dict = Depends(get_current_user),
):
sym_full = symbol.upper() + "USDT"
now_ms = int(time.time() * 1000)
start_ms = now_ms - minutes * 60 * 1000
rows = await async_fetch(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
"FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC",
sym_full, start_ms
)
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/signals/latest")
async def get_signal_latest(user: dict = Depends(get_current_user)):
result = {}
for sym in SYMBOLS:
row = await async_fetchrow(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
"vwap_30m, price, p95_qty, p99_qty, score, signal "
"FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1",
sym
)
if row:
result[sym.replace("USDT", "")] = row
return result
@app.get("/api/signals/market-indicators")
async def get_market_indicators(user: dict = Depends(get_current_user)):
"""返回最新的market_indicators数据V5.1新增4个数据源"""
result = {}
for sym in SYMBOLS:
indicators = {}
for ind_type in ["long_short_ratio", "top_trader_position", "open_interest_hist", "coinbase_premium"]:
row = await async_fetchrow(
"SELECT value, timestamp_ms FROM market_indicators WHERE symbol = $1 AND indicator_type = $2 ORDER BY timestamp_ms DESC LIMIT 1",
sym,
ind_type,
)
if row:
val = row["value"]
if isinstance(val, str):
import json as _json
try:
val = _json.loads(val)
except Exception:
pass
indicators[ind_type] = {"value": val, "ts": row["timestamp_ms"]}
result[sym.replace("USDT", "")] = indicators
return result
@app.get("/api/signals/signal-history")
async def get_signal_history(
symbol: str = "BTC",
limit: int = 50,
user: dict = Depends(get_current_user),
):
"""返回最近的信号历史(只返回有信号的记录)"""
sym_full = symbol.upper() + "USDT"
rows = await async_fetch(
"SELECT ts, score, signal FROM signal_indicators "
"WHERE symbol = $1 AND signal IS NOT NULL "
"ORDER BY ts DESC LIMIT $2",
sym_full, limit
)
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/signals/trades")
async def get_signal_trades(
status: str = "all",
limit: int = 50,
user: dict = Depends(get_current_user),
):
if status == "all":
rows = await async_fetch(
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit
)
else:
rows = await async_fetch(
"SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2",
status, limit
)
return {"count": len(rows), "data": rows}
# ─── 模拟盘 API ──────────────────────────────────────────────────
# 模拟盘配置状态与signal_engine共享的运行时状态
paper_config = {
"enabled": False,
"initial_balance": 10000,
"risk_per_trade": 0.02,
"max_positions": 4,
"tier_multiplier": {"light": 0.5, "standard": 1.0, "heavy": 1.5},
}
@app.get("/api/paper/config")
async def paper_get_config(user: dict = Depends(get_current_user)):
"""获取模拟盘配置"""
return paper_config
@app.post("/api/paper/config")
async def paper_set_config(request: Request, user: dict = Depends(get_current_user)):
"""修改模拟盘配置仅admin"""
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="仅管理员可修改")
body = await request.json()
for k in ["enabled", "initial_balance", "risk_per_trade", "max_positions"]:
if k in body:
paper_config[k] = body[k]
# 写入配置文件让signal_engine也能读到
import json
config_path = os.path.join(os.path.dirname(__file__), "paper_config.json")
with open(config_path, "w") as f:
json.dump(paper_config, f, indent=2)
return {"ok": True, "config": paper_config}
@app.get("/api/paper/summary")
async def paper_summary(user: dict = Depends(get_current_user)):
"""模拟盘总览"""
closed = await async_fetch(
"SELECT pnl_r, direction FROM paper_trades WHERE status NOT IN ('active','tp1_hit')"
)
active = await async_fetch(
"SELECT id FROM paper_trades WHERE status IN ('active','tp1_hit')"
)
first = await async_fetchrow("SELECT MIN(created_at) as start FROM paper_trades")
total = len(closed)
wins = len([r for r in closed if r["pnl_r"] > 0])
total_pnl = sum(r["pnl_r"] for r in closed)
win_rate = (wins / total * 100) if total > 0 else 0
gross_profit = sum(r["pnl_r"] for r in closed if r["pnl_r"] > 0)
gross_loss = abs(sum(r["pnl_r"] for r in closed if r["pnl_r"] <= 0))
profit_factor = (gross_profit / gross_loss) if gross_loss > 0 else 0
return {
"total_trades": total,
"win_rate": round(win_rate, 1),
"total_pnl": round(total_pnl, 2),
"active_positions": len(active),
"profit_factor": round(profit_factor, 2),
"start_time": str(first["start"]) if first and first["start"] else None,
}
@app.get("/api/paper/positions")
async def paper_positions(user: dict = Depends(get_current_user)):
"""当前活跃持仓"""
rows = await async_fetch(
"SELECT id, symbol, direction, score, tier, entry_price, entry_ts, "
"tp1_price, tp2_price, sl_price, tp1_hit, status, atr_at_entry "
"FROM paper_trades WHERE status IN ('active','tp1_hit') ORDER BY entry_ts DESC"
)
return {"data": rows}
@app.get("/api/paper/trades")
async def paper_trades(
symbol: str = "all",
result: str = "all",
limit: int = 100,
user: dict = Depends(get_current_user),
):
"""历史交易列表"""
conditions = ["status NOT IN ('active','tp1_hit')"]
params = []
idx = 1
if symbol != "all":
conditions.append(f"symbol = ${idx}")
params.append(symbol.upper() + "USDT")
idx += 1
if result == "win":
conditions.append("pnl_r > 0")
elif result == "loss":
conditions.append("pnl_r <= 0")
where = " AND ".join(conditions)
params.append(limit)
rows = await async_fetch(
f"SELECT id, symbol, direction, score, tier, entry_price, exit_price, "
f"entry_ts, exit_ts, pnl_r, status, tp1_hit "
f"FROM paper_trades WHERE {where} ORDER BY exit_ts DESC LIMIT ${idx}",
*params
)
return {"count": len(rows), "data": rows}
@app.get("/api/paper/equity-curve")
async def paper_equity_curve(user: dict = Depends(get_current_user)):
"""权益曲线"""
rows = await async_fetch(
"SELECT exit_ts, pnl_r FROM paper_trades WHERE status NOT IN ('active','tp1_hit') ORDER BY exit_ts ASC"
)
cumulative = 0.0
curve = []
for r in rows:
cumulative += r["pnl_r"]
curve.append({"ts": r["exit_ts"], "pnl": round(cumulative, 2)})
return {"data": curve}
@app.get("/api/paper/stats")
async def paper_stats(user: dict = Depends(get_current_user)):
"""详细统计"""
rows = await async_fetch(
"SELECT symbol, direction, pnl_r, tier, entry_ts, exit_ts "
"FROM paper_trades WHERE status NOT IN ('active','tp1_hit')"
)
if not rows:
return {"error": "暂无数据"}
total = len(rows)
wins = [r for r in rows if r["pnl_r"] > 0]
losses = [r for r in rows if r["pnl_r"] <= 0]
# 基础统计
win_rate = len(wins) / total * 100
avg_win = sum(r["pnl_r"] for r in wins) / len(wins) if wins else 0
avg_loss = abs(sum(r["pnl_r"] for r in losses)) / len(losses) if losses else 0
win_loss_ratio = avg_win / avg_loss if avg_loss > 0 else 0
# MDD
peak = 0.0
mdd = 0.0
running = 0.0
for r in sorted(rows, key=lambda x: x["exit_ts"] or 0):
running += r["pnl_r"]
peak = max(peak, running)
mdd = max(mdd, peak - running)
# 夏普
returns = [r["pnl_r"] for r in rows]
if len(returns) > 1:
import statistics
avg_ret = statistics.mean(returns)
std_ret = statistics.stdev(returns)
sharpe = (avg_ret / std_ret) * (252 ** 0.5) if std_ret > 0 else 0
else:
sharpe = 0
# 按币种
by_symbol = {}
for r in rows:
s = r["symbol"].replace("USDT", "")
if s not in by_symbol:
by_symbol[s] = {"total": 0, "wins": 0}
by_symbol[s]["total"] += 1
if r["pnl_r"] > 0:
by_symbol[s]["wins"] += 1
symbol_stats = {s: {"total": v["total"], "win_rate": round(v["wins"]/v["total"]*100, 1)} for s, v in by_symbol.items()}
# 按方向
longs = [r for r in rows if r["direction"] == "LONG"]
shorts = [r for r in rows if r["direction"] == "SHORT"]
long_wr = len([r for r in longs if r["pnl_r"] > 0]) / len(longs) * 100 if longs else 0
short_wr = len([r for r in shorts if r["pnl_r"] > 0]) / len(shorts) * 100 if shorts else 0
# 按档位
by_tier = {}
for r in rows:
t = r["tier"]
if t not in by_tier:
by_tier[t] = {"total": 0, "wins": 0}
by_tier[t]["total"] += 1
if r["pnl_r"] > 0:
by_tier[t]["wins"] += 1
tier_stats = {t: {"total": v["total"], "win_rate": round(v["wins"]/v["total"]*100, 1)} for t, v in by_tier.items()}
return {
"total": total,
"win_rate": round(win_rate, 1),
"avg_win": round(avg_win, 2),
"avg_loss": round(avg_loss, 2),
"win_loss_ratio": round(win_loss_ratio, 2),
"mdd": round(mdd, 2),
"sharpe": round(sharpe, 2),
"long_win_rate": round(long_wr, 1),
"long_count": len(longs),
"short_win_rate": round(short_wr, 1),
"short_count": len(shorts),
"by_symbol": symbol_stats,
"by_tier": tier_stats,
}
# ─── 服务器状态监控 ───────────────────────────────────────────────
import shutil, subprocess, psutil
# 服务器状态缓存(避免重复调用慢操作)
_server_cache: dict = {"data": None, "ts": 0}
_PM2_BIN = None
def _find_pm2_bin():
"""找到pm2二进制路径避免每次走npx"""
global _PM2_BIN
if _PM2_BIN:
return _PM2_BIN
import shutil as _sh
for p in ["/home/fzq1228/.local/bin/pm2", "/usr/local/bin/pm2", "/usr/bin/pm2"]:
if os.path.exists(p):
_PM2_BIN = p
return p
found = _sh.which("pm2")
if found:
_PM2_BIN = found
return found
return "npx pm2"
# 启动时初始化CPU采样首次调用不阻塞
psutil.cpu_percent(interval=None)
@app.get("/api/server/status")
async def get_server_status(user: dict = Depends(get_current_user)):
"""服务器全状态CPU/内存/硬盘/负载/PM2进程/PG数据库/回补进度"""
# 5秒缓存避免频繁调用慢操作
now = time.time()
if _server_cache["data"] and (now - _server_cache["ts"]) < 5:
return _server_cache["data"]
# CPU非阻塞取上次采样间隔的值
cpu_percent = psutil.cpu_percent(interval=None)
cpu_count = psutil.cpu_count()
# 内存
mem = psutil.virtual_memory()
swap = psutil.swap_memory()
# 硬盘
disk = shutil.disk_usage("/")
# 负载
load1, load5, load15 = os.getloadavg()
# Uptime
boot_time = psutil.boot_time()
uptime_s = time.time() - boot_time
# 网络IO
net = psutil.net_io_counters()
# PM2进程状态直接调pm2二进制不走npx
pm2_procs = []
try:
pm2_bin = _find_pm2_bin()
cmd = [pm2_bin, "jlist"] if not pm2_bin.startswith("npx") else ["npx", "pm2", "jlist"]
result = subprocess.run(
cmd,
capture_output=True, text=True, timeout=5
)
import json as _json
procs = _json.loads(result.stdout)
for p in procs:
pm2_procs.append({
"name": p.get("name", ""),
"status": p.get("pm2_env", {}).get("status", "unknown"),
"cpu": p.get("monit", {}).get("cpu", 0),
"memory_mb": round(p.get("monit", {}).get("memory", 0) / 1024 / 1024, 1),
"restarts": p.get("pm2_env", {}).get("restart_time", 0),
"uptime_ms": p.get("pm2_env", {}).get("pm_uptime", 0),
"pid": p.get("pid", 0),
})
except Exception:
pm2_procs = []
# PG数据库大小 + agg_trades条数用估算值快1000倍
pg_info = {}
try:
row = await async_fetchrow(
"SELECT pg_database_size(current_database()) as db_size"
)
pg_info["db_size_mb"] = round(row["db_size"] / 1024 / 1024, 1) if row else 0
# 用PG统计信息估算行数毫秒级而非COUNT(*)的秒级全表扫描)
row2 = await async_fetchrow(
"SELECT SUM(n_live_tup)::bigint as cnt FROM pg_stat_user_tables WHERE relname LIKE 'agg_trades%'"
)
pg_info["agg_trades_count"] = row2["cnt"] if row2 and row2["cnt"] else 0
row3 = await async_fetchrow(
"SELECT n_live_tup::bigint as cnt FROM pg_stat_user_tables WHERE relname = 'rate_snapshots'"
)
pg_info["rate_snapshots_count"] = row3["cnt"] if row3 else 0
# 各symbol最新数据时间
meta_rows = await async_fetch("SELECT symbol, last_time_ms, earliest_time_ms FROM agg_trades_meta")
pg_info["symbols"] = {}
for m in meta_rows:
sym = m["symbol"].replace("USDT", "")
pg_info["symbols"][sym] = {
"latest_ms": m["last_time_ms"],
"earliest_ms": m["earliest_time_ms"],
"span_hours": round((m["last_time_ms"] - m["earliest_time_ms"]) / 3600000, 1),
}
except Exception:
pass
# 回补进程
backfill_running = False
try:
for proc in psutil.process_iter(["pid", "cmdline"]):
cmdline = " ".join(proc.info.get("cmdline") or [])
if "backfill_agg_trades" in cmdline:
backfill_running = True
break
except Exception:
pass
result = {
"timestamp": int(time.time() * 1000),
"cpu": {
"percent": cpu_percent,
"cores": cpu_count,
},
"memory": {
"total_gb": round(mem.total / 1024**3, 1),
"used_gb": round(mem.used / 1024**3, 1),
"percent": mem.percent,
"swap_percent": swap.percent,
},
"disk": {
"total_gb": round(disk.total / 1024**3, 1),
"used_gb": round(disk.used / 1024**3, 1),
"free_gb": round(disk.free / 1024**3, 1),
"percent": round(disk.used / disk.total * 100, 1),
},
"load": {
"load1": round(load1, 2),
"load5": round(load5, 2),
"load15": round(load15, 2),
},
"uptime_hours": round(uptime_s / 3600, 1),
"network": {
"bytes_sent_gb": round(net.bytes_sent / 1024**3, 2),
"bytes_recv_gb": round(net.bytes_recv / 1024**3, 2),
},
"pm2": pm2_procs,
"postgres": pg_info,
"backfill_running": backfill_running,
}
_server_cache["data"] = result
_server_cache["ts"] = now
return result