from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware import httpx from datetime import datetime, timedelta import asyncio, time, sqlite3, os from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables import datetime as _dt app = FastAPI(title="Arbitrage Engine API") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) app.include_router(auth_router) BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") # 简单内存缓存(history/stats 60秒,rates 3秒) _cache: dict = {} def get_cache(key: str, ttl: int): entry = _cache.get(key) if entry and time.time() - entry["ts"] < ttl: return entry["data"] return None def set_cache(key: str, data): _cache[key] = {"ts": time.time(), "data": data} def init_db(): conn = sqlite3.connect(DB_PATH) conn.execute(""" CREATE TABLE IF NOT EXISTS rate_snapshots ( id INTEGER PRIMARY KEY AUTOINCREMENT, ts INTEGER NOT NULL, btc_rate REAL NOT NULL, eth_rate REAL NOT NULL, btc_price REAL NOT NULL, eth_price REAL NOT NULL, btc_index_price REAL, eth_index_price REAL ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)") conn.commit() conn.close() def save_snapshot(rates: dict): try: conn = sqlite3.connect(DB_PATH) btc = rates.get("BTC", {}) eth = rates.get("ETH", {}) conn.execute( "INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)", ( int(time.time()), float(btc.get("lastFundingRate", 0)), float(eth.get("lastFundingRate", 0)), float(btc.get("markPrice", 0)), float(eth.get("markPrice", 0)), float(btc.get("indexPrice", 0)), float(eth.get("indexPrice", 0)), ) ) conn.commit() conn.close() except Exception as e: pass # 落库失败不影响API响应 async def background_snapshot_loop(): """后台每2秒自动拉取费率+价格并落库,不依赖前端调用""" while True: try: async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client: tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS] responses = await asyncio.gather(*tasks, return_exceptions=True) result = {} for sym, resp in zip(SYMBOLS, responses): if isinstance(resp, Exception) or resp.status_code != 200: continue data = resp.json() key = sym.replace("USDT", "") result[key] = { "lastFundingRate": float(data["lastFundingRate"]), "markPrice": float(data["markPrice"]), "indexPrice": float(data["indexPrice"]), } if result: save_snapshot(result) except Exception: pass await asyncio.sleep(2) @app.on_event("startup") async def startup(): init_db() ensure_auth_tables() asyncio.create_task(background_snapshot_loop()) @app.get("/api/health") async def health(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} @app.get("/api/rates") async def get_rates(): cached = get_cache("rates", 3) if cached: return cached async with httpx.AsyncClient(timeout=10, headers=HEADERS) as client: tasks = [client.get(f"{BINANCE_FAPI}/premiumIndex", params={"symbol": s}) for s in SYMBOLS] responses = await asyncio.gather(*tasks) result = {} for sym, resp in zip(SYMBOLS, responses): if resp.status_code != 200: raise HTTPException(status_code=502, detail=f"Binance error for {sym}") data = resp.json() key = sym.replace("USDT", "") result[key] = { "symbol": sym, "markPrice": float(data["markPrice"]), "indexPrice": float(data["indexPrice"]), "lastFundingRate": float(data["lastFundingRate"]), "nextFundingTime": data["nextFundingTime"], "timestamp": data["time"], } set_cache("rates", result) # 异步落库(不阻塞响应) asyncio.create_task(asyncio.to_thread(save_snapshot, result)) return result @app.get("/api/snapshots") async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)): """查询本地落库的实时快照数据""" since = int(time.time()) - hours * 3600 conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row rows = conn.execute( "SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?", (since, limit) ).fetchall() conn.close() return { "count": len(rows), "hours": hours, "data": [dict(r) for r in rows] } @app.get("/api/kline") async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)): """ 从 rate_snapshots 聚合K线数据 symbol: BTC | ETH interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M 返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}] """ interval_secs = { "1m": 60, "5m": 300, "30m": 1800, "1h": 3600, "4h": 14400, "8h": 28800, "1d": 86400, "1w": 604800, "1M": 2592000, } bar_secs = interval_secs.get(interval, 300) rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate" price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price" # 查询足够多的原始数据(limit根K * bar_secs最多需要的时间范围) since = int(time.time()) - bar_secs * limit conn = sqlite3.connect(DB_PATH) rows = conn.execute( f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC", (since,) ).fetchall() conn.close() if not rows: return {"symbol": symbol, "interval": interval, "data": []} # 按bar_secs分组聚合OHLC bars: dict = {} for ts, rate, price in rows: bar_ts = (ts // bar_secs) * bar_secs if bar_ts not in bars: bars[bar_ts] = { "time": bar_ts, "open": rate, "high": rate, "low": rate, "close": rate, "price_open": price, "price_high": price, "price_low": price, "price_close": price, } else: b = bars[bar_ts] b["high"] = max(b["high"], rate) b["low"] = min(b["low"], rate) b["close"] = rate b["price_high"] = max(b["price_high"], price) b["price_low"] = min(b["price_low"], price) b["price_close"] = price data = sorted(bars.values(), key=lambda x: x["time"])[-limit:] # 转换为万分之(费率 × 10000) for b in data: for k in ("open", "high", "low", "close"): b[k] = round(b[k] * 10000, 4) return {"symbol": symbol, "interval": interval, "count": len(data), "data": data} @app.get("/api/stats/ytd") async def get_stats_ytd(user: dict = Depends(get_current_user)): """今年以来(YTD)资金费率年化统计""" cached = get_cache("stats_ytd", 3600) if cached: return cached # 今年1月1日 00:00 UTC import datetime year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000) end_time = int(time.time() * 1000) async with httpx.AsyncClient(timeout=20, headers=HEADERS) as client: tasks = [ client.get(f"{BINANCE_FAPI}/fundingRate", params={"symbol": s, "startTime": year_start, "endTime": end_time, "limit": 1000}) for s in SYMBOLS ] responses = await asyncio.gather(*tasks) result = {} for sym, resp in zip(SYMBOLS, responses): if resp.status_code != 200: result[sym.replace("USDT","")] = {"annualized": 0, "count": 0} continue key = sym.replace("USDT", "") rates = [float(item["fundingRate"]) for item in resp.json()] if not rates: result[key] = {"annualized": 0, "count": 0} continue mean = sum(rates) / len(rates) annualized = round(mean * 3 * 365 * 100, 2) result[key] = {"annualized": annualized, "count": len(rates)} set_cache("stats_ytd", result) return result @app.get("/api/signals/history") async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)): """查询信号推送历史""" try: conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row rows = conn.execute( "SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?", (limit,) ).fetchall() conn.close() return {"items": [dict(r) for r in rows]} except Exception as e: return {"items": [], "error": str(e)} @app.get("/api/history") async def get_history(user: dict = Depends(get_current_user)): cached = get_cache("history", 60) if cached: return cached end_time = int(datetime.utcnow().timestamp() * 1000) start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000) async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client: tasks = [ client.get(f"{BINANCE_FAPI}/fundingRate", params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000}) for s in SYMBOLS ] responses = await asyncio.gather(*tasks) result = {} for sym, resp in zip(SYMBOLS, responses): if resp.status_code != 200: raise HTTPException(status_code=502, detail=f"Binance history error for {sym}") key = sym.replace("USDT", "") result[key] = [ {"fundingTime": item["fundingTime"], "fundingRate": float(item["fundingRate"]), "timestamp": datetime.utcfromtimestamp(item["fundingTime"] / 1000).isoformat()} for item in resp.json() ] set_cache("history", result) return result @app.get("/api/stats") async def get_stats(user: dict = Depends(get_current_user)): cached = get_cache("stats", 60) if cached: return cached end_time = int(datetime.utcnow().timestamp() * 1000) start_time = int((datetime.utcnow() - timedelta(days=7)).timestamp() * 1000) async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client: tasks = [ client.get(f"{BINANCE_FAPI}/fundingRate", params={"symbol": s, "startTime": start_time, "endTime": end_time, "limit": 1000}) for s in SYMBOLS ] responses = await asyncio.gather(*tasks) stats = {} for sym, resp in zip(SYMBOLS, responses): if resp.status_code != 200: raise HTTPException(status_code=502, detail=f"Binance stats error for {sym}") key = sym.replace("USDT", "") rates = [float(item["fundingRate"]) for item in resp.json()] if not rates: stats[key] = {"mean7d": 0, "annualized": 0, "count": 0} continue mean = sum(rates) / len(rates) annualized = mean * 3 * 365 * 100 stats[key] = { "mean7d": round(mean * 100, 6), "annualized": round(annualized, 2), "count": len(rates), } btc_ann = stats.get("BTC", {}).get("annualized", 0) eth_ann = stats.get("ETH", {}).get("annualized", 0) btc_mean = stats.get("BTC", {}).get("mean7d", 0) eth_mean = stats.get("ETH", {}).get("mean7d", 0) stats["combo"] = { "mean7d": round((btc_mean + eth_mean) / 2, 6), "annualized": round((btc_ann + eth_ann) / 2, 2), } set_cache("stats", stats) return stats # ─── aggTrades 查询接口 ────────────────────────────────────────── @app.get("/api/trades/meta") async def get_trades_meta(user: dict = Depends(get_current_user)): """aggTrades采集状态:各symbol最新agg_id和时间""" conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row try: rows = conn.execute( "SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta" ).fetchall() except Exception: rows = [] conn.close() result = {} for r in rows: sym = r["symbol"].replace("USDT", "") result[sym] = { "last_agg_id": r["last_agg_id"], "last_time_ms": r["last_time_ms"], "updated_at": r["updated_at"], } return result @app.get("/api/trades/summary") async def get_trades_summary( symbol: str = "BTC", start_ms: int = 0, end_ms: int = 0, interval: str = "1m", user: dict = Depends(get_current_user), ): """分钟级聚合:买卖delta/成交速率/vwap""" if end_ms == 0: end_ms = int(time.time() * 1000) if start_ms == 0: start_ms = end_ms - 3600 * 1000 # 默认1小时 interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000) sym_full = symbol.upper() + "USDT" # 确定需要查哪些月表 start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc) end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc) months = set() cur = start_dt.replace(day=1) while cur <= end_dt: months.add(cur.strftime("%Y%m")) if cur.month == 12: cur = cur.replace(year=cur.year + 1, month=1) else: cur = cur.replace(month=cur.month + 1) conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row all_rows = [] for month in sorted(months): tname = f"agg_trades_{month}" try: rows = conn.execute( f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC", (sym_full, start_ms, end_ms) ).fetchall() all_rows.extend(rows) except Exception: pass conn.close() # 按interval聚合 bars: dict = {} for row in all_rows: bar_ms = (row["time_ms"] // interval_ms) * interval_ms if bar_ms not in bars: bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0, "trade_count": 0, "vwap_num": 0.0, "vwap_den": 0.0, "max_qty": 0.0} b = bars[bar_ms] qty = float(row["qty"]) price = float(row["price"]) if row["is_buyer_maker"] == 0: # 主动买 b["buy_vol"] += qty else: b["sell_vol"] += qty b["trade_count"] += 1 b["vwap_num"] += price * qty b["vwap_den"] += qty b["max_qty"] = max(b["max_qty"], qty) result = [] for b in sorted(bars.values(), key=lambda x: x["time_ms"]): total = b["buy_vol"] + b["sell_vol"] result.append({ "time_ms": b["time_ms"], "buy_vol": round(b["buy_vol"], 4), "sell_vol": round(b["sell_vol"], 4), "delta": round(b["buy_vol"] - b["sell_vol"], 4), "total_vol": round(total, 4), "trade_count": b["trade_count"], "vwap": round(b["vwap_num"] / b["vwap_den"], 2) if b["vwap_den"] > 0 else 0, "max_qty": round(b["max_qty"], 4), }) return {"symbol": symbol, "interval": interval, "count": len(result), "data": result} @app.get("/api/collector/health") async def collector_health(user: dict = Depends(get_current_user)): """采集器健康状态""" conn = sqlite3.connect(DB_PATH) conn.row_factory = sqlite3.Row now_ms = int(time.time() * 1000) status = {} try: rows = conn.execute( "SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta" ).fetchall() for r in rows: sym = r["symbol"].replace("USDT", "") lag_s = (now_ms - r["last_time_ms"]) / 1000 status[sym] = { "last_agg_id": r["last_agg_id"], "lag_seconds": round(lag_s, 1), "healthy": lag_s < 30, # 30秒内有数据算健康 } except Exception: pass conn.close() return {"collector": status, "timestamp": now_ms}