arbitrage-engine/backend/main.py

571 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import httpx
from datetime import datetime, timedelta
import asyncio, time, sqlite3, os
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
import datetime as _dt
app = FastAPI(title="Arbitrage Engine API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(auth_router)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
# 简单内存缓存history/stats 60秒rates 3秒
_cache: dict = {}
def get_cache(key: str, ttl: int):
entry = _cache.get(key)
if entry and time.time() - entry["ts"] < ttl:
return entry["data"]
return None
def set_cache(key: str, data):
_cache[key] = {"ts": time.time(), "data": data}
def init_db():
conn = sqlite3.connect(DB_PATH)
conn.execute("""
CREATE TABLE IF NOT EXISTS rate_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
btc_rate REAL NOT NULL,
eth_rate REAL NOT NULL,
btc_price REAL NOT NULL,
eth_price REAL NOT NULL,
btc_index_price REAL,
eth_index_price REAL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
conn.commit()
conn.close()
def save_snapshot(rates: dict):
try:
conn = sqlite3.connect(DB_PATH)
btc = rates.get("BTC", {})
eth = rates.get("ETH", {})
conn.execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
(
int(time.time()),
float(btc.get("lastFundingRate", 0)),
float(eth.get("lastFundingRate", 0)),
float(btc.get("markPrice", 0)),
float(eth.get("markPrice", 0)),
float(btc.get("indexPrice", 0)),
float(eth.get("indexPrice", 0)),
)
)
conn.commit()
conn.close()
except Exception as e:
pass # 落库失败不影响API响应
async def background_snapshot_loop():
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
while True:
try:
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks, return_exceptions=True)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if isinstance(resp, Exception) or resp.status_code != 200:
continue
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"lastFundingRate": float(data["lastFundingRate"]),
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
}
if result:
save_snapshot(result)
except Exception:
pass
await asyncio.sleep(2)
@app.on_event("startup")
async def startup():
init_db()
ensure_auth_tables()
asyncio.create_task(background_snapshot_loop())
@app.get("/api/health")
async def health():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@app.get("/api/rates")
async def get_rates():
cached = get_cache("rates", 3)
if cached: return cached
async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance error for {sym}")
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"symbol": sym,
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
"lastFundingRate": float(data["lastFundingRate"]),
"nextFundingTime": data["nextFundingTime"],
"timestamp": data["time"],
}
set_cache("rates", result)
# 异步落库(不阻塞响应)
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
return result
@app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
"""查询本地落库的实时快照数据"""
since = int(time.time()) - hours * 3600
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
(since, limit)
).fetchall()
conn.close()
return {
"count": len(rows),
"hours": hours,
"data": [dict(r) for r in rows]
}
@app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
"""
从 rate_snapshots 聚合K线数据
symbol: BTC | ETH
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
"""
interval_secs = {
"1m": 60, "5m": 300, "30m": 1800,
"1h": 3600, "4h": 14400, "8h": 28800,
"1d": 86400, "1w": 604800, "1M": 2592000,
}
bar_secs = interval_secs.get(interval, 300)
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
# 查询足够多的原始数据limit根K * bar_secs最多需要的时间范围
since = int(time.time()) - bar_secs * limit
conn = sqlite3.connect(DB_PATH)
rows = conn.execute(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
(since,)
).fetchall()
conn.close()
if not rows:
return {"symbol": symbol, "interval": interval, "data": []}
# 按bar_secs分组聚合OHLC
bars: dict = {}
for ts, rate, price in rows:
bar_ts = (ts // bar_secs) * bar_secs
if bar_ts not in bars:
bars[bar_ts] = {
"time": bar_ts,
"open": rate, "high": rate, "low": rate, "close": rate,
"price_open": price, "price_high": price, "price_low": price, "price_close": price,
}
else:
b = bars[bar_ts]
b["high"] = max(b["high"], rate)
b["low"] = min(b["low"], rate)
b["close"] = rate
b["price_high"] = max(b["price_high"], price)
b["price_low"] = min(b["price_low"], price)
b["price_close"] = price
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
# 转换为万分之(费率 × 10000
for b in data:
for k in ("open", "high", "low", "close"):
b[k] = round(b[k] * 10000, 4)
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
@app.get("/api/stats/ytd")
async def get_stats_ytd(user: dict = Depends(get_current_user)):
"""今年以来(YTD)资金费率年化统计"""
cached = get_cache("stats_ytd", 3600)
if cached: return cached
# 今年1月1日 00:00 UTC
import datetime
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
end_time = int(time.time() * 1000)
async with httpx.AsyncClient(timeout=20, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": year_start, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
result[sym.replace("USDT","")] = {"annualized": 0, "count": 0}
continue
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
result[key] = {"annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = round(mean * 3 * 365 * 100, 2)
result[key] = {"annualized": annualized, "count": len(rates)}
set_cache("stats_ytd", result)
return result
@app.get("/api/signals/history")
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
"""查询信号推送历史"""
try:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?",
(limit,)
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
except Exception as e:
return {"items": [], "error": str(e)}
@app.get("/api/history")
async def get_history(user: dict = Depends(get_current_user)):
cached = get_cache("history", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance history error for {sym}")
key = sym.replace("USDT", "")
result[key] = [
{"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]),
"timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()}
for item in resp.json()
]
set_cache("history", result)
return result
@app.get("/api/stats")
async def get_stats(user: dict = Depends(get_current_user)):
cached = get_cache("stats", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
stats = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}")
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
stats[key] = {"mean7d": 0, "annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = mean * 3 * 365 * 100
stats[key] = {
"mean7d": round(mean * 100, 6),
"annualized": round(annualized, 2),
"count": len(rates),
}
btc_ann = stats.get("BTC", {}).get("annualized", 0)
eth_ann = stats.get("ETH", {}).get("annualized", 0)
btc_mean = stats.get("BTC", {}).get("mean7d", 0)
eth_mean = stats.get("ETH", {}).get("mean7d", 0)
stats["combo"] = {
"mean7d": round((btc_mean + eth_mean) / 2, 6),
"annualized": round((btc_ann + eth_ann) / 2, 2),
}
set_cache("stats", stats)
return stats
# ─── aggTrades 查询接口 ──────────────────────────────────────────
@app.get("/api/trades/meta")
async def get_trades_meta(user: dict = Depends(get_current_user)):
"""aggTrades采集状态各symbol最新agg_id和时间"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
).fetchall()
except Exception:
rows = []
conn.close()
result = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
result[sym] = {
"last_agg_id": r["last_agg_id"],
"last_time_ms": r["last_time_ms"],
"updated_at": r["updated_at"],
}
return result
@app.get("/api/trades/summary")
async def get_trades_summary(
symbol: str = "BTC",
start_ms: int = 0,
end_ms: int = 0,
interval: str = "1m",
user: dict = Depends(get_current_user),
):
"""分钟级聚合买卖delta/成交速率/vwap"""
if end_ms == 0:
end_ms = int(time.time() * 1000)
if start_ms == 0:
start_ms = end_ms - 3600 * 1000 # 默认1小时
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
sym_full = symbol.upper() + "USDT"
# 确定需要查哪些月表
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
months = set()
cur = start_dt.replace(day=1)
while cur <= end_dt:
months.add(cur.strftime("%Y%m"))
if cur.month == 12:
cur = cur.replace(year=cur.year + 1, month=1)
else:
cur = cur.replace(month=cur.month + 1)
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
all_rows = []
for month in sorted(months):
tname = f"agg_trades_{month}"
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
(sym_full, start_ms, end_ms)
).fetchall()
all_rows.extend(rows)
except Exception:
pass
conn.close()
# 按interval聚合
bars: dict = {}
for row in all_rows:
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
if bar_ms not in bars:
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
"trade_count": 0, "vwap_num": 0.0, "vwap_den": 0.0, "max_qty": 0.0}
b = bars[bar_ms]
qty = float(row["qty"])
price = float(row["price"])
if row["is_buyer_maker"] == 0: # 主动买
b["buy_vol"] += qty
else:
b["sell_vol"] += qty
b["trade_count"] += 1
b["vwap_num"] += price * qty
b["vwap_den"] += qty
b["max_qty"] = max(b["max_qty"], qty)
result = []
for b in sorted(bars.values(), key=lambda x: x["time_ms"]):
total = b["buy_vol"] + b["sell_vol"]
result.append({
"time_ms": b["time_ms"],
"buy_vol": round(b["buy_vol"], 4),
"sell_vol": round(b["sell_vol"], 4),
"delta": round(b["buy_vol"] - b["sell_vol"], 4),
"total_vol": round(total, 4),
"trade_count": b["trade_count"],
"vwap": round(b["vwap_num"] / b["vwap_den"], 2) if b["vwap_den"] > 0 else 0,
"max_qty": round(b["max_qty"], 4),
})
return {"symbol": symbol, "interval": interval, "count": len(result), "data": result}
@app.get("/api/trades/latest")
async def get_trades_latest(
symbol: str = "BTC",
limit: int = 30,
user: dict = Depends(get_current_user),
):
"""查最新N条原始成交记录从本地DB实时刷新用"""
sym_full = symbol.upper() + "USDT"
now_month = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? ORDER BY agg_id DESC LIMIT ?",
(sym_full, limit)
).fetchall()
except Exception:
rows = []
conn.close()
return {
"symbol": symbol,
"count": len(rows),
"data": [dict(r) for r in rows],
}
async def collector_health(user: dict = Depends(get_current_user)):
"""采集器健康状态"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
now_ms = int(time.time() * 1000)
status = {}
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
).fetchall()
for r in rows:
sym = r["symbol"].replace("USDT", "")
lag_s = (now_ms - r["last_time_ms"]) / 1000
status[sym] = {
"last_agg_id": r["last_agg_id"],
"lag_seconds": round(lag_s, 1),
"healthy": lag_s < 30, # 30秒内有数据算健康
}
except Exception:
pass
conn.close()
return {"collector": status, "timestamp": now_ms}
# ─── V5 Signal Engine API ────────────────────────────────────────
@app.get("/api/signals/indicators")
async def get_signal_indicators(
symbol: str = "BTC",
minutes: int = 60,
user: dict = Depends(get_current_user),
):
"""获取signal_indicators_1m数据前端图表用"""
sym_full = symbol.upper() + "USDT"
now_ms = int(time.time() * 1000)
start_ms = now_ms - minutes * 60 * 1000
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
"FROM signal_indicators_1m WHERE symbol = ? AND ts >= ? ORDER BY ts ASC",
(sym_full, start_ms)
).fetchall()
except Exception:
rows = []
conn.close()
return {
"symbol": symbol,
"count": len(rows),
"data": [dict(r) for r in rows],
}
@app.get("/api/signals/latest")
async def get_signal_latest(
user: dict = Depends(get_current_user),
):
"""获取最新一条各symbol的指标快照"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
result = {}
for sym in ["BTCUSDT", "ETHUSDT"]:
try:
row = conn.execute(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
"vwap_30m, price, p95_qty, p99_qty, score, signal "
"FROM signal_indicators WHERE symbol = ? ORDER BY ts DESC LIMIT 1",
(sym,)
).fetchone()
if row:
result[sym.replace("USDT", "")] = dict(row)
except Exception:
pass
conn.close()
return result
@app.get("/api/signals/trades")
async def get_signal_trades(
status: str = "all",
limit: int = 50,
user: dict = Depends(get_current_user),
):
"""获取信号交易记录"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
if status == "all":
rows = conn.execute(
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT ?", (limit,)
).fetchall()
else:
rows = conn.execute(
"SELECT * FROM signal_trades WHERE status = ? ORDER BY ts_open DESC LIMIT ?",
(status, limit)
).fetchall()
except Exception:
rows = []
conn.close()
return {"count": len(rows), "data": [dict(r) for r in rows]}