arbitrage-engine/backend/main.py

280 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import httpx
from datetime import datetime, timedelta
import asyncio, time, sqlite3, os
app = FastAPI(title="Arbitrage Engine API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
# 简单内存缓存history/stats 60秒rates 3秒
_cache: dict = {}
def get_cache(key: str, ttl: int):
entry = _cache.get(key)
if entry and time.time() - entry["ts"] < ttl:
return entry["data"]
return None
def set_cache(key: str, data):
_cache[key] = {"ts": time.time(), "data": data}
def init_db():
conn = sqlite3.connect(DB_PATH)
conn.execute("""
CREATE TABLE IF NOT EXISTS rate_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
btc_rate REAL NOT NULL,
eth_rate REAL NOT NULL,
btc_price REAL NOT NULL,
eth_price REAL NOT NULL,
btc_index_price REAL,
eth_index_price REAL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
conn.commit()
conn.close()
def save_snapshot(rates: dict):
try:
conn = sqlite3.connect(DB_PATH)
btc = rates.get("BTC", {})
eth = rates.get("ETH", {})
conn.execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
(
int(time.time()),
float(btc.get("lastFundingRate", 0)),
float(eth.get("lastFundingRate", 0)),
float(btc.get("markPrice", 0)),
float(eth.get("markPrice", 0)),
float(btc.get("indexPrice", 0)),
float(eth.get("indexPrice", 0)),
)
)
conn.commit()
conn.close()
except Exception as e:
pass # 落库失败不影响API响应
async def background_snapshot_loop():
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
while True:
try:
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks, return_exceptions=True)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if isinstance(resp, Exception) or resp.status_code != 200:
continue
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"lastFundingRate": float(data["lastFundingRate"]),
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
}
if result:
save_snapshot(result)
except Exception:
pass
await asyncio.sleep(2)
@app.on_event("startup")
async def startup():
init_db()
asyncio.create_task(background_snapshot_loop())
@app.get("/api/health")
async def health():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@app.get("/api/rates")
async def get_rates():
cached = get_cache("rates", 3)
if cached: return cached
async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client:
tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance error for {sym}")
data = resp.json()
key = sym.replace("USDT", "")
result[key] = {
"symbol": sym,
"markPrice": float(data["markPrice"]),
"indexPrice": float(data["indexPrice"]),
"lastFundingRate": float(data["lastFundingRate"]),
"nextFundingTime": data["nextFundingTime"],
"timestamp": data["time"],
}
set_cache("rates", result)
# 异步落库(不阻塞响应)
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
return result
@app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000):
"""查询本地落库的实时快照数据"""
since = int(time.time()) - hours * 3600
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
(since, limit)
).fetchall()
conn.close()
return {
"count": len(rows),
"hours": hours,
"data": [dict(r) for r in rows]
}
@app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500):
"""
从 rate_snapshots 聚合K线数据
symbol: BTC | ETH
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
"""
interval_secs = {
"1m": 60, "5m": 300, "30m": 1800,
"1h": 3600, "4h": 14400, "8h": 28800,
"1d": 86400, "1w": 604800, "1M": 2592000,
}
bar_secs = interval_secs.get(interval, 300)
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
# 查询足够多的原始数据limit根K * bar_secs最多需要的时间范围
since = int(time.time()) - bar_secs * limit
conn = sqlite3.connect(DB_PATH)
rows = conn.execute(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
(since,)
).fetchall()
conn.close()
if not rows:
return {"symbol": symbol, "interval": interval, "data": []}
# 按bar_secs分组聚合OHLC
bars: dict = {}
for ts, rate, price in rows:
bar_ts = (ts // bar_secs) * bar_secs
if bar_ts not in bars:
bars[bar_ts] = {
"time": bar_ts,
"open": rate, "high": rate, "low": rate, "close": rate,
"price_open": price, "price_high": price, "price_low": price, "price_close": price,
}
else:
b = bars[bar_ts]
b["high"] = max(b["high"], rate)
b["low"] = min(b["low"], rate)
b["close"] = rate
b["price_high"] = max(b["price_high"], price)
b["price_low"] = min(b["price_low"], price)
b["price_close"] = price
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
# 转换为万分之(费率 × 10000
for b in data:
for k in ("open", "high", "low", "close"):
b[k] = round(b[k] * 10000, 4)
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
@app.get("/api/history")
async def get_history():
cached = get_cache("history", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
result = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance history error for {sym}")
key = sym.replace("USDT", "")
result[key] = [
{"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]),
"timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()}
for item in resp.json()
]
set_cache("history", result)
return result
@app.get("/api/stats")
async def get_stats():
cached = get_cache("stats", 60)
if cached: return cached
end_time = int(datetime.utcnow().timestamp() * 1000)
start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000)
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
tasks = [
client.get(f"{BINANCE_FAPI}/fundingRate",
params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000})
for s in SYMBOLS
]
responses = await asyncio.gather(*tasks)
stats = {}
for sym, resp in zip(SYMBOLS, responses):
if resp.status_code != 200:
raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}")
key = sym.replace("USDT", "")
rates = [float(item["fundingRate"]) for item in resp.json()]
if not rates:
stats[key] = {"mean7d": 0, "annualized": 0, "count": 0}
continue
mean = sum(rates) / len(rates)
annualized = mean * 3 * 365 * 100
stats[key] = {
"mean7d": round(mean * 100, 6),
"annualized": round(annualized, 2),
"count": len(rates),
}
btc_ann = stats.get("BTC", {}).get("annualized", 0)
eth_ann = stats.get("ETH", {}).get("annualized", 0)
btc_mean = stats.get("BTC", {}).get("mean7d", 0)
eth_mean = stats.get("ETH", {}).get("mean7d", 0)
stats["combo"] = {
"mean7d": round((btc_mean + eth_mean) / 2, 6),
"annualized": round((btc_ann + eth_ann) / 2, 2),
}
set_cache("stats", stats)
return stats