466 lines
17 KiB
Python
466 lines
17 KiB
Python
from fastapi import FastAPI, HTTPException, Depends
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
import httpx
|
||
from datetime import datetime, timedelta
|
||
import asyncio, time, sqlite3, os
|
||
|
||
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
||
import datetime as _dt
|
||
|
||
app = FastAPI(title="Arbitrage Engine API")
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.include_router(auth_router)
|
||
|
||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||
|
||
# 简单内存缓存(history/stats 60秒,rates 3秒)
|
||
_cache: dict = {}
|
||
|
||
def get_cache(key: str, ttl: int):
|
||
entry = _cache.get(key)
|
||
if entry and time.time() - entry["ts"] < ttl:
|
||
return entry["data"]
|
||
return None
|
||
|
||
def set_cache(key: str, data):
|
||
_cache[key] = {"ts": time.time(), "data": data}
|
||
|
||
|
||
def init_db():
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
ts INTEGER NOT NULL,
|
||
btc_rate REAL NOT NULL,
|
||
eth_rate REAL NOT NULL,
|
||
btc_price REAL NOT NULL,
|
||
eth_price REAL NOT NULL,
|
||
btc_index_price REAL,
|
||
eth_index_price REAL
|
||
)
|
||
""")
|
||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
|
||
def save_snapshot(rates: dict):
|
||
try:
|
||
conn = sqlite3.connect(DB_PATH)
|
||
btc = rates.get("BTC", {})
|
||
eth = rates.get("ETH", {})
|
||
conn.execute(
|
||
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
|
||
(
|
||
int(time.time()),
|
||
float(btc.get("lastFundingRate", 0)),
|
||
float(eth.get("lastFundingRate", 0)),
|
||
float(btc.get("markPrice", 0)),
|
||
float(eth.get("markPrice", 0)),
|
||
float(btc.get("indexPrice", 0)),
|
||
float(eth.get("indexPrice", 0)),
|
||
)
|
||
)
|
||
conn.commit()
|
||
conn.close()
|
||
except Exception as e:
|
||
pass # 落库失败不影响API响应
|
||
|
||
|
||
async def background_snapshot_loop():
|
||
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
|
||
while True:
|
||
try:
|
||
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
|
||
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
|
||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||
result = {}
|
||
for sym, resp in zip(SYMBOLS, responses):
|
||
if isinstance(resp, Exception) or resp.status_code != 200:
|
||
continue
|
||
data = resp.json()
|
||
key = sym.replace("USDT", "")
|
||
result[key] = {
|
||
"lastFundingRate": float(data["lastFundingRate"]),
|
||
"markPrice": float(data["markPrice"]),
|
||
"indexPrice": float(data["indexPrice"]),
|
||
}
|
||
if result:
|
||
save_snapshot(result)
|
||
except Exception:
|
||
pass
|
||
await asyncio.sleep(2)
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
init_db()
|
||
ensure_auth_tables()
|
||
asyncio.create_task(background_snapshot_loop())
|
||
|
||
|
||
@app.get("/api/health")
|
||
async def health():
|
||
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
||
|
||
|
||
@app.get("/api/rates")
|
||
async def get_rates():
|
||
cached = get_cache("rates", 3)
|
||
if cached: return cached
|
||
async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client:
|
||
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
|
||
responses = await asyncio.gather(*tasks)
|
||
result = {}
|
||
for sym, resp in zip(SYMBOLS, responses):
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=502, detail=f"Binance error for {sym}")
|
||
data = resp.json()
|
||
key = sym.replace("USDT", "")
|
||
result[key] = {
|
||
"symbol": sym,
|
||
"markPrice": float(data["markPrice"]),
|
||
"indexPrice": float(data["indexPrice"]),
|
||
"lastFundingRate": float(data["lastFundingRate"]),
|
||
"nextFundingTime": data["nextFundingTime"],
|
||
"timestamp": data["time"],
|
||
}
|
||
set_cache("rates", result)
|
||
# 异步落库(不阻塞响应)
|
||
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
|
||
return result
|
||
|
||
|
||
@app.get("/api/snapshots")
|
||
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
||
"""查询本地落库的实时快照数据"""
|
||
since = int(time.time()) - hours * 3600
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
rows = conn.execute(
|
||
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
|
||
(since, limit)
|
||
).fetchall()
|
||
conn.close()
|
||
return {
|
||
"count": len(rows),
|
||
"hours": hours,
|
||
"data": [dict(r) for r in rows]
|
||
}
|
||
|
||
|
||
@app.get("/api/kline")
|
||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
||
"""
|
||
从 rate_snapshots 聚合K线数据
|
||
symbol: BTC | ETH
|
||
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
|
||
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
|
||
"""
|
||
interval_secs = {
|
||
"1m": 60, "5m": 300, "30m": 1800,
|
||
"1h": 3600, "4h": 14400, "8h": 28800,
|
||
"1d": 86400, "1w": 604800, "1M": 2592000,
|
||
}
|
||
bar_secs = interval_secs.get(interval, 300)
|
||
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
|
||
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
|
||
|
||
# 查询足够多的原始数据(limit根K * bar_secs最多需要的时间范围)
|
||
since = int(time.time()) - bar_secs * limit
|
||
conn = sqlite3.connect(DB_PATH)
|
||
rows = conn.execute(
|
||
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
|
||
(since,)
|
||
).fetchall()
|
||
conn.close()
|
||
|
||
if not rows:
|
||
return {"symbol": symbol, "interval": interval, "data": []}
|
||
|
||
# 按bar_secs分组聚合OHLC
|
||
bars: dict = {}
|
||
for ts, rate, price in rows:
|
||
bar_ts = (ts // bar_secs) * bar_secs
|
||
if bar_ts not in bars:
|
||
bars[bar_ts] = {
|
||
"time": bar_ts,
|
||
"open": rate, "high": rate, "low": rate, "close": rate,
|
||
"price_open": price, "price_high": price, "price_low": price, "price_close": price,
|
||
}
|
||
else:
|
||
b = bars[bar_ts]
|
||
b["high"] = max(b["high"], rate)
|
||
b["low"] = min(b["low"], rate)
|
||
b["close"] = rate
|
||
b["price_high"] = max(b["price_high"], price)
|
||
b["price_low"] = min(b["price_low"], price)
|
||
b["price_close"] = price
|
||
|
||
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
|
||
# 转换为万分之(费率 × 10000)
|
||
for b in data:
|
||
for k in ("open", "high", "low", "close"):
|
||
b[k] = round(b[k] * 10000, 4)
|
||
|
||
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
|
||
|
||
|
||
@app.get("/api/stats/ytd")
|
||
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||
"""今年以来(YTD)资金费率年化统计"""
|
||
cached = get_cache("stats_ytd", 3600)
|
||
if cached: return cached
|
||
# 今年1月1日 00:00 UTC
|
||
import datetime
|
||
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
|
||
end_time = int(time.time() * 1000)
|
||
async with httpx.AsyncClient(timeout=20, headers=HEADERS) as client:
|
||
tasks = [
|
||
client.get(f"{BINANCE_FAPI}/fundingRate",
|
||
params={"symbol": s, "startTime": year_start, "endTime": end_time, "limit": 1000})
|
||
for s in SYMBOLS
|
||
]
|
||
responses = await asyncio.gather(*tasks)
|
||
result = {}
|
||
for sym, resp in zip(SYMBOLS, responses):
|
||
if resp.status_code != 200:
|
||
result[sym.replace("USDT","")] = {"annualized": 0, "count": 0}
|
||
continue
|
||
key = sym.replace("USDT", "")
|
||
rates = [float(item["fundingRate"]) for item in resp.json()]
|
||
if not rates:
|
||
result[key] = {"annualized": 0, "count": 0}
|
||
continue
|
||
mean = sum(rates) / len(rates)
|
||
annualized = round(mean * 3 * 365 * 100, 2)
|
||
result[key] = {"annualized": annualized, "count": len(rates)}
|
||
set_cache("stats_ytd", result)
|
||
return result
|
||
|
||
|
||
@app.get("/api/signals/history")
|
||
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
||
"""查询信号推送历史"""
|
||
try:
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
rows = conn.execute(
|
||
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?",
|
||
(limit,)
|
||
).fetchall()
|
||
conn.close()
|
||
return {"items": [dict(r) for r in rows]}
|
||
except Exception as e:
|
||
return {"items": [], "error": str(e)}
|
||
|
||
|
||
@app.get("/api/history")
|
||
async def get_history(user: dict = Depends(get_current_user)):
|
||
cached = get_cache("history", 60)
|
||
if cached: return cached
|
||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
|
||
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
|
||
tasks = [
|
||
client.get(f"{BINANCE_FAPI}/fundingRate",
|
||
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
|
||
for s in SYMBOLS
|
||
]
|
||
responses = await asyncio.gather(*tasks)
|
||
result = {}
|
||
for sym, resp in zip(SYMBOLS, responses):
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=502, detail=f"Binance history error for {sym}")
|
||
key = sym.replace("USDT", "")
|
||
result[key] = [
|
||
{"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]),
|
||
"timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()}
|
||
for item in resp.json()
|
||
]
|
||
set_cache("history", result)
|
||
return result
|
||
|
||
|
||
@app.get("/api/stats")
|
||
async def get_stats(user: dict = Depends(get_current_user)):
|
||
cached = get_cache("stats", 60)
|
||
if cached: return cached
|
||
end_time = int(datetime.utcnow().timestamp() * 1000)
|
||
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
|
||
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
|
||
tasks = [
|
||
client.get(f"{BINANCE_FAPI}/fundingRate",
|
||
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
|
||
for s in SYMBOLS
|
||
]
|
||
responses = await asyncio.gather(*tasks)
|
||
stats = {}
|
||
for sym, resp in zip(SYMBOLS, responses):
|
||
if resp.status_code != 200:
|
||
raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}")
|
||
key = sym.replace("USDT", "")
|
||
rates = [float(item["fundingRate"]) for item in resp.json()]
|
||
if not rates:
|
||
stats[key] = {"mean7d": 0, "annualized": 0, "count": 0}
|
||
continue
|
||
mean = sum(rates) / len(rates)
|
||
annualized = mean * 3 * 365 * 100
|
||
stats[key] = {
|
||
"mean7d": round(mean * 100, 6),
|
||
"annualized": round(annualized, 2),
|
||
"count": len(rates),
|
||
}
|
||
btc_ann = stats.get("BTC", {}).get("annualized", 0)
|
||
eth_ann = stats.get("ETH", {}).get("annualized", 0)
|
||
btc_mean = stats.get("BTC", {}).get("mean7d", 0)
|
||
eth_mean = stats.get("ETH", {}).get("mean7d", 0)
|
||
stats["combo"] = {
|
||
"mean7d": round((btc_mean + eth_mean) / 2, 6),
|
||
"annualized": round((btc_ann + eth_ann) / 2, 2),
|
||
}
|
||
set_cache("stats", stats)
|
||
return stats
|
||
|
||
|
||
# ─── aggTrades 查询接口 ──────────────────────────────────────────
|
||
|
||
@app.get("/api/trades/meta")
|
||
async def get_trades_meta(user: dict = Depends(get_current_user)):
|
||
"""aggTrades采集状态:各symbol最新agg_id和时间"""
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
try:
|
||
rows = conn.execute(
|
||
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
|
||
).fetchall()
|
||
except Exception:
|
||
rows = []
|
||
conn.close()
|
||
result = {}
|
||
for r in rows:
|
||
sym = r["symbol"].replace("USDT", "")
|
||
result[sym] = {
|
||
"last_agg_id": r["last_agg_id"],
|
||
"last_time_ms": r["last_time_ms"],
|
||
"updated_at": r["updated_at"],
|
||
}
|
||
return result
|
||
|
||
|
||
@app.get("/api/trades/summary")
|
||
async def get_trades_summary(
|
||
symbol: str = "BTC",
|
||
start_ms: int = 0,
|
||
end_ms: int = 0,
|
||
interval: str = "1m",
|
||
user: dict = Depends(get_current_user),
|
||
):
|
||
"""分钟级聚合:买卖delta/成交速率/vwap"""
|
||
if end_ms == 0:
|
||
end_ms = int(time.time() * 1000)
|
||
if start_ms == 0:
|
||
start_ms = end_ms - 3600 * 1000 # 默认1小时
|
||
|
||
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
|
||
sym_full = symbol.upper() + "USDT"
|
||
|
||
# 确定需要查哪些月表
|
||
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
|
||
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
|
||
months = set()
|
||
cur = start_dt.replace(day=1)
|
||
while cur <= end_dt:
|
||
months.add(cur.strftime("%Y%m"))
|
||
if cur.month == 12:
|
||
cur = cur.replace(year=cur.year + 1, month=1)
|
||
else:
|
||
cur = cur.replace(month=cur.month + 1)
|
||
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
all_rows = []
|
||
for month in sorted(months):
|
||
tname = f"agg_trades_{month}"
|
||
try:
|
||
rows = conn.execute(
|
||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
|
||
(sym_full, start_ms, end_ms)
|
||
).fetchall()
|
||
all_rows.extend(rows)
|
||
except Exception:
|
||
pass
|
||
conn.close()
|
||
|
||
# 按interval聚合
|
||
bars: dict = {}
|
||
for row in all_rows:
|
||
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
|
||
if bar_ms not in bars:
|
||
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
|
||
"trade_count": 0, "vwap_num": 0.0, "vwap_den": 0.0, "max_qty": 0.0}
|
||
b = bars[bar_ms]
|
||
qty = float(row["qty"])
|
||
price = float(row["price"])
|
||
if row["is_buyer_maker"] == 0: # 主动买
|
||
b["buy_vol"] += qty
|
||
else:
|
||
b["sell_vol"] += qty
|
||
b["trade_count"] += 1
|
||
b["vwap_num"] += price * qty
|
||
b["vwap_den"] += qty
|
||
b["max_qty"] = max(b["max_qty"], qty)
|
||
|
||
result = []
|
||
for b in sorted(bars.values(), key=lambda x: x["time_ms"]):
|
||
total = b["buy_vol"] + b["sell_vol"]
|
||
result.append({
|
||
"time_ms": b["time_ms"],
|
||
"buy_vol": round(b["buy_vol"], 4),
|
||
"sell_vol": round(b["sell_vol"], 4),
|
||
"delta": round(b["buy_vol"] - b["sell_vol"], 4),
|
||
"total_vol": round(total, 4),
|
||
"trade_count": b["trade_count"],
|
||
"vwap": round(b["vwap_num"] / b["vwap_den"], 2) if b["vwap_den"] > 0 else 0,
|
||
"max_qty": round(b["max_qty"], 4),
|
||
})
|
||
|
||
return {"symbol": symbol, "interval": interval, "count": len(result), "data": result}
|
||
|
||
|
||
@app.get("/api/collector/health")
|
||
async def collector_health(user: dict = Depends(get_current_user)):
|
||
"""采集器健康状态"""
|
||
conn = sqlite3.connect(DB_PATH)
|
||
conn.row_factory = sqlite3.Row
|
||
now_ms = int(time.time() * 1000)
|
||
status = {}
|
||
try:
|
||
rows = conn.execute(
|
||
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
|
||
).fetchall()
|
||
for r in rows:
|
||
sym = r["symbol"].replace("USDT", "")
|
||
lag_s = (now_ms - r["last_time_ms"]) / 1000
|
||
status[sym] = {
|
||
"last_agg_id": r["last_agg_id"],
|
||
"lag_seconds": round(lag_s, 1),
|
||
"healthy": lag_s < 30, # 30秒内有数据算健康
|
||
}
|
||
except Exception:
|
||
pass
|
||
conn.close()
|
||
return {"collector": status, "timestamp": now_ms}
|