From 7e38b24fa82ada551d38e01ce4aeb633665efcf2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 27 Feb 2026 11:29:16 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20V3.0=20aggTrades=20collector=20-=20WS+R?= =?UTF-8?q?EST=E8=A1=A5=E6=B4=9E+=E5=B7=A1=E6=A3=80+=E6=8C=89=E6=9C=88?= =?UTF-8?q?=E5=88=86=E8=A1=A8+=E6=9F=A5=E8=AF=A2API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/agg_trades_collector.py | 367 ++++++++++++++++++++++++++++++++ backend/main.py | 132 ++++++++++++ 2 files changed, 499 insertions(+) create mode 100644 backend/agg_trades_collector.py diff --git a/backend/agg_trades_collector.py b/backend/agg_trades_collector.py new file mode 100644 index 0000000..cfe3dfe --- /dev/null +++ b/backend/agg_trades_collector.py @@ -0,0 +1,367 @@ +""" +agg_trades_collector.py — aggTrades全量采集守护进程 + +架构: + - WebSocket主链路:实时推送,延迟<100ms + - REST补洞:断线重连后从last_agg_id追平 + - 每分钟巡检:校验agg_id连续性,发现断档自动补洞 + - 批量写入:攒200条或1秒flush一次,减少WAL压力 + - 按月分表:agg_trades_YYYYMM,单表千万行内查询快 + - 健康接口:GET /collector/health 可监控 +""" + +import asyncio +import json +import logging +import os +import sqlite3 +import time +from datetime import datetime, timezone +from typing import Optional + +import httpx +import websockets + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + handlers=[ + logging.StreamHandler(), + logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")), + ], +) +logger = logging.getLogger("collector") + +DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") +BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" +SYMBOLS = ["BTCUSDT", "ETHUSDT"] +HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"} + +# 批量写入缓冲 +_buffer: dict[str, list] = {s: [] for s in SYMBOLS} +BATCH_SIZE = 200 +BATCH_TIMEOUT = 1.0 # seconds + + +# ─── DB helpers ────────────────────────────────────────────────── + +def get_conn() -> sqlite3.Connection: + conn = sqlite3.connect(DB_PATH, timeout=30) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA synchronous=NORMAL") + return conn + + +def table_name(ts_ms: int) -> str: + """按月分表:agg_trades_202602""" + dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) + return f"agg_trades_{dt.strftime('%Y%m')}" + + +def ensure_table(conn: sqlite3.Connection, tname: str): + conn.execute(f""" + CREATE TABLE IF NOT EXISTS {tname} ( + agg_id INTEGER PRIMARY KEY, + symbol TEXT NOT NULL, + price REAL NOT NULL, + qty REAL NOT NULL, + first_trade_id INTEGER, + last_trade_id INTEGER, + time_ms INTEGER NOT NULL, + is_buyer_maker INTEGER NOT NULL + ) + """) + conn.execute(f""" + CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time + ON {tname}(symbol, time_ms) + """) + + +def ensure_meta_table(conn: sqlite3.Connection): + conn.execute(""" + CREATE TABLE IF NOT EXISTS agg_trades_meta ( + symbol TEXT PRIMARY KEY, + last_agg_id INTEGER NOT NULL, + last_time_ms INTEGER NOT NULL, + updated_at TEXT DEFAULT (datetime('now')) + ) + """) + + +def get_last_agg_id(symbol: str) -> Optional[int]: + try: + conn = get_conn() + ensure_meta_table(conn) + row = conn.execute( + "SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,) + ).fetchone() + conn.close() + return row["last_agg_id"] if row else None + except Exception as e: + logger.error(f"get_last_agg_id error: {e}") + return None + + +def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int): + conn.execute(""" + INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at) + VALUES (?, ?, ?, datetime('now')) + ON CONFLICT(symbol) DO UPDATE SET + last_agg_id = excluded.last_agg_id, + last_time_ms = excluded.last_time_ms, + updated_at = excluded.updated_at + """, (symbol, last_agg_id, last_time_ms)) + + +def flush_buffer(symbol: str, trades: list) -> int: + """写入一批trades,返回实际写入条数(去重后)""" + if not trades: + return 0 + try: + conn = get_conn() + ensure_meta_table(conn) + # 按月分组 + by_month: dict[str, list] = {} + for t in trades: + tname = table_name(t["T"]) + if tname not in by_month: + by_month[tname] = [] + by_month[tname].append(t) + + inserted = 0 + last_agg_id = 0 + last_time_ms = 0 + + for tname, batch in by_month.items(): + ensure_table(conn, tname) + for t in batch: + cur = conn.execute( + f"""INSERT OR IGNORE INTO {tname} + (agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ( + t["a"], # agg_id + symbol, + float(t["p"]), + float(t["q"]), + t.get("f"), # first_trade_id + t.get("l"), # last_trade_id + t["T"], # time_ms + 1 if t["m"] else 0, # is_buyer_maker + ) + ) + if cur.rowcount > 0: + inserted += 1 + if t["a"] > last_agg_id: + last_agg_id = t["a"] + last_time_ms = t["T"] + + if last_agg_id > 0: + update_meta(conn, symbol, last_agg_id, last_time_ms) + + conn.commit() + conn.close() + return inserted + except Exception as e: + logger.error(f"flush_buffer [{symbol}] error: {e}") + return 0 + + +# ─── REST补洞 ──────────────────────────────────────────────────── + +async def rest_catchup(symbol: str, from_id: int) -> int: + """从from_id开始REST拉取,追平到最新,返回补洞条数""" + total = 0 + current_id = from_id + logger.info(f"[{symbol}] REST catchup from agg_id={from_id}") + + async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client: + while True: + try: + resp = await client.get( + f"{BINANCE_FAPI}/aggTrades", + params={"symbol": symbol, "fromId": current_id, "limit": 1000}, + ) + if resp.status_code != 200: + logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}") + break + data = resp.json() + if not data: + break + count = flush_buffer(symbol, data) + total += count + last = data[-1]["a"] + if last <= current_id: + break + current_id = last + 1 + # 如果拉到的比最新少1000条,说明追平了 + if len(data) < 1000: + break + await asyncio.sleep(0.1) # rate limit友好 + except Exception as e: + logger.error(f"[{symbol}] REST catchup error: {e}") + break + + logger.info(f"[{symbol}] REST catchup done, filled {total} trades") + return total + + +# ─── WebSocket采集 ─────────────────────────────────────────────── + +async def ws_collect(symbol: str): + """单Symbol的WS采集循环,自动断线重连+REST补洞""" + stream = symbol.lower() + "@aggTrade" + url = f"wss://fstream.binance.com/ws/{stream}" + buffer: list = [] + last_flush = time.time() + reconnect_delay = 1.0 + + while True: + try: + # 断线重连前先REST补洞 + last_id = get_last_agg_id(symbol) + if last_id is not None: + await rest_catchup(symbol, last_id + 1) + + logger.info(f"[{symbol}] Connecting WS: {url}") + async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws: + reconnect_delay = 1.0 # 连上了就重置 + logger.info(f"[{symbol}] WS connected") + + async for raw in ws: + msg = json.loads(raw) + if msg.get("e") != "aggTrade": + continue + + buffer.append(msg) + + # 批量flush:满200条或超1秒 + now = time.time() + if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT: + count = flush_buffer(symbol, buffer) + if count > 0: + logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades") + buffer.clear() + last_flush = now + + except websockets.exceptions.ConnectionClosed as e: + logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s") + except Exception as e: + logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s") + finally: + # flush剩余buffer + if buffer: + flush_buffer(symbol, buffer) + buffer.clear() + + await asyncio.sleep(reconnect_delay) + reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s + + +# ─── 连续性巡检 ────────────────────────────────────────────────── + +async def continuity_check(): + """每60秒巡检一次:检查各symbol最近的agg_id是否有断档""" + while True: + await asyncio.sleep(60) + try: + conn = get_conn() + ensure_meta_table(conn) + for symbol in SYMBOLS: + row = conn.execute( + "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", + (symbol,) + ).fetchone() + if not row: + continue + # 检查最近1000条是否连续 + now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") + tname = f"agg_trades_{now_month}" + try: + rows = conn.execute( + f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100", + (symbol,) + ).fetchall() + if len(rows) < 2: + continue + ids = [r["agg_id"] for r in rows] + # 检查是否连续(降序) + gaps = [] + for i in range(len(ids) - 1): + diff = ids[i] - ids[i + 1] + if diff > 1: + gaps.append((ids[i + 1], ids[i], diff - 1)) + if gaps: + logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}") + # 触发补洞 + min_gap_id = min(g[0] for g in gaps) + asyncio.create_task(rest_catchup(symbol, min_gap_id)) + else: + logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}") + except Exception: + pass # 表可能还不存在 + conn.close() + except Exception as e: + logger.error(f"Continuity check error: {e}") + + +# ─── 每日完整性报告 ────────────────────────────────────────────── + +async def daily_report(): + """每小时生成一次完整性摘要日志""" + while True: + await asyncio.sleep(3600) + try: + conn = get_conn() + ensure_meta_table(conn) + report_lines = ["=== AggTrades Integrity Report ==="] + for symbol in SYMBOLS: + row = conn.execute( + "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", + (symbol,) + ).fetchone() + if not row: + report_lines.append(f" {symbol}: No data yet") + continue + last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat() + # 统计本月总量 + now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") + tname = f"agg_trades_{now_month}" + try: + count = conn.execute( + f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,) + ).fetchone()["c"] + except Exception: + count = 0 + report_lines.append( + f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}" + ) + conn.close() + logger.info("\n".join(report_lines)) + except Exception as e: + logger.error(f"Daily report error: {e}") + + +# ─── 入口 ──────────────────────────────────────────────────────── + +async def main(): + logger.info("AggTrades Collector starting...") + # 确保基础表存在 + conn = get_conn() + ensure_meta_table(conn) + conn.commit() + conn.close() + + # 并行启动所有symbol的WS + 巡检 + 报告 + tasks = [ + ws_collect(sym) for sym in SYMBOLS + ] + [ + continuity_check(), + daily_report(), + ] + await asyncio.gather(*tasks) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backend/main.py b/backend/main.py index 342b2fb..2cb21b5 100644 --- a/backend/main.py +++ b/backend/main.py @@ -5,6 +5,7 @@ from datetime import datetime, timedelta import asyncio, time, sqlite3, os from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables +import datetime as _dt app = FastAPI(title="Arbitrage Engine API") @@ -331,3 +332,134 @@ async def get_stats(user: dict = Depends(get_current_user)): } set_cache("stats", stats) return stats + + +# ─── aggTrades 查询接口 ────────────────────────────────────────── + +@app.get("/api/trades/meta") +async def get_trades_meta(user: dict = Depends(get_current_user)): + """aggTrades采集状态:各symbol最新agg_id和时间""" + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + "SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta" + ).fetchall() + except Exception: + rows = [] + conn.close() + result = {} + for r in rows: + sym = r["symbol"].replace("USDT", "") + result[sym] = { + "last_agg_id": r["last_agg_id"], + "last_time_ms": r["last_time_ms"], + "updated_at": r["updated_at"], + } + return result + + +@app.get("/api/trades/summary") +async def get_trades_summary( + symbol: str = "BTC", + start_ms: int = 0, + end_ms: int = 0, + interval: str = "1m", + user: dict = Depends(get_current_user), +): + """分钟级聚合:买卖delta/成交速率/vwap""" + if end_ms == 0: + end_ms = int(time.time() * 1000) + if start_ms == 0: + start_ms = end_ms - 3600 * 1000 # 默认1小时 + + interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000) + sym_full = symbol.upper() + "USDT" + + # 确定需要查哪些月表 + start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc) + end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc) + months = set() + cur = start_dt.replace(day=1) + while cur <= end_dt: + months.add(cur.strftime("%Y%m")) + if cur.month == 12: + cur = cur.replace(year=cur.year + 1, month=1) + else: + cur = cur.replace(month=cur.month + 1) + + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + all_rows = [] + for month in sorted(months): + tname = f"agg_trades_{month}" + try: + rows = conn.execute( + f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " + f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC", + (sym_full, start_ms, end_ms) + ).fetchall() + all_rows.extend(rows) + except Exception: + pass + conn.close() + + # 按interval聚合 + bars: dict = {} + for row in all_rows: + bar_ms = (row["time_ms"] // interval_ms) * interval_ms + if bar_ms not in bars: + bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0, + "trade_count": 0, "vwap_num": 0.0, "vwap_den": 0.0, "max_qty": 0.0} + b = bars[bar_ms] + qty = float(row["qty"]) + price = float(row["price"]) + if row["is_buyer_maker"] == 0: # 主动买 + b["buy_vol"] += qty + else: + b["sell_vol"] += qty + b["trade_count"] += 1 + b["vwap_num"] += price * qty + b["vwap_den"] += qty + b["max_qty"] = max(b["max_qty"], qty) + + result = [] + for b in sorted(bars.values(), key=lambda x: x["time_ms"]): + total = b["buy_vol"] + b["sell_vol"] + result.append({ + "time_ms": b["time_ms"], + "buy_vol": round(b["buy_vol"], 4), + "sell_vol": round(b["sell_vol"], 4), + "delta": round(b["buy_vol"] - b["sell_vol"], 4), + "total_vol": round(total, 4), + "trade_count": b["trade_count"], + "vwap": round(b["vwap_num"] / b["vwap_den"], 2) if b["vwap_den"] > 0 else 0, + "max_qty": round(b["max_qty"], 4), + }) + + return {"symbol": symbol, "interval": interval, "count": len(result), "data": result} + + +@app.get("/api/collector/health") +async def collector_health(user: dict = Depends(get_current_user)): + """采集器健康状态""" + conn = sqlite3.connect(DB_PATH) + conn.row_factory = sqlite3.Row + now_ms = int(time.time() * 1000) + status = {} + try: + rows = conn.execute( + "SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta" + ).fetchall() + for r in rows: + sym = r["symbol"].replace("USDT", "") + lag_s = (now_ms - r["last_time_ms"]) / 1000 + status[sym] = { + "last_agg_id": r["last_agg_id"], + "lag_seconds": round(lag_s, 1), + "healthy": lag_s < 30, # 30秒内有数据算健康 + } + except Exception: + pass + conn.close() + return {"collector": status, "timestamp": now_ms}