""" agg_trades_collector.py — aggTrades全量采集守护进程 架构: - WebSocket主链路:实时推送,延迟<100ms - REST补洞:断线重连后从last_agg_id追平 - 每分钟巡检:校验agg_id连续性,发现断档自动补洞 - 批量写入:攒200条或1秒flush一次,减少WAL压力 - 按月分表:agg_trades_YYYYMM,单表千万行内查询快 - 健康接口:GET /collector/health 可监控 """ import asyncio import json import logging import os import sqlite3 import time from datetime import datetime, timezone from typing import Optional import httpx import websockets logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")), ], ) logger = logging.getLogger("collector") DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"} # 批量写入缓冲 _buffer: dict[str, list] = {s: [] for s in SYMBOLS} BATCH_SIZE = 200 BATCH_TIMEOUT = 1.0 # seconds # ─── DB helpers ────────────────────────────────────────────────── def get_conn() -> sqlite3.Connection: conn = sqlite3.connect(DB_PATH, timeout=30) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") return conn def table_name(ts_ms: int) -> str: """按月分表:agg_trades_202602""" dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) return f"agg_trades_{dt.strftime('%Y%m')}" def ensure_table(conn: sqlite3.Connection, tname: str): conn.execute(f""" CREATE TABLE IF NOT EXISTS {tname} ( agg_id INTEGER PRIMARY KEY, symbol TEXT NOT NULL, price REAL NOT NULL, qty REAL NOT NULL, first_trade_id INTEGER, last_trade_id INTEGER, time_ms INTEGER NOT NULL, is_buyer_maker INTEGER NOT NULL ) """) conn.execute(f""" CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time ON {tname}(symbol, time_ms) """) def ensure_meta_table(conn: sqlite3.Connection): conn.execute(""" CREATE TABLE IF NOT EXISTS agg_trades_meta ( symbol TEXT PRIMARY KEY, last_agg_id INTEGER NOT NULL, last_time_ms INTEGER NOT NULL, updated_at TEXT DEFAULT (datetime('now')) ) """) def get_last_agg_id(symbol: str) -> Optional[int]: try: conn = get_conn() ensure_meta_table(conn) row = conn.execute( "SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,) ).fetchone() conn.close() return row["last_agg_id"] if row else None except Exception as e: logger.error(f"get_last_agg_id error: {e}") return None def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int): conn.execute(""" INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at) VALUES (?, ?, ?, datetime('now')) ON CONFLICT(symbol) DO UPDATE SET last_agg_id = excluded.last_agg_id, last_time_ms = excluded.last_time_ms, updated_at = excluded.updated_at """, (symbol, last_agg_id, last_time_ms)) def flush_buffer(symbol: str, trades: list) -> int: """写入一批trades,返回实际写入条数(去重后)""" if not trades: return 0 try: conn = get_conn() ensure_meta_table(conn) # 按月分组 by_month: dict[str, list] = {} for t in trades: tname = table_name(t["T"]) if tname not in by_month: by_month[tname] = [] by_month[tname].append(t) inserted = 0 last_agg_id = 0 last_time_ms = 0 for tname, batch in by_month.items(): ensure_table(conn, tname) for t in batch: cur = conn.execute( f"""INSERT OR IGNORE INTO {tname} (agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker) VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", ( t["a"], # agg_id symbol, float(t["p"]), float(t["q"]), t.get("f"), # first_trade_id t.get("l"), # last_trade_id t["T"], # time_ms 1 if t["m"] else 0, # is_buyer_maker ) ) if cur.rowcount > 0: inserted += 1 if t["a"] > last_agg_id: last_agg_id = t["a"] last_time_ms = t["T"] if last_agg_id > 0: update_meta(conn, symbol, last_agg_id, last_time_ms) conn.commit() conn.close() return inserted except Exception as e: logger.error(f"flush_buffer [{symbol}] error: {e}") return 0 # ─── REST补洞 ──────────────────────────────────────────────────── async def rest_catchup(symbol: str, from_id: int) -> int: """从from_id开始REST拉取,追平到最新,返回补洞条数""" total = 0 current_id = from_id logger.info(f"[{symbol}] REST catchup from agg_id={from_id}") async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client: while True: try: resp = await client.get( f"{BINANCE_FAPI}/aggTrades", params={"symbol": symbol, "fromId": current_id, "limit": 1000}, ) if resp.status_code != 200: logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}") break data = resp.json() if not data: break count = flush_buffer(symbol, data) total += count last = data[-1]["a"] if last <= current_id: break current_id = last + 1 # 如果拉到的比最新少1000条,说明追平了 if len(data) < 1000: break await asyncio.sleep(0.1) # rate limit友好 except Exception as e: logger.error(f"[{symbol}] REST catchup error: {e}") break logger.info(f"[{symbol}] REST catchup done, filled {total} trades") return total # ─── WebSocket采集 ─────────────────────────────────────────────── async def ws_collect(symbol: str): """单Symbol的WS采集循环,自动断线重连+REST补洞""" stream = symbol.lower() + "@aggTrade" url = f"wss://fstream.binance.com/ws/{stream}" buffer: list = [] last_flush = time.time() reconnect_delay = 1.0 while True: try: # 断线重连前先REST补洞 last_id = get_last_agg_id(symbol) if last_id is not None: await rest_catchup(symbol, last_id + 1) logger.info(f"[{symbol}] Connecting WS: {url}") async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws: reconnect_delay = 1.0 # 连上了就重置 logger.info(f"[{symbol}] WS connected") async for raw in ws: msg = json.loads(raw) if msg.get("e") != "aggTrade": continue buffer.append(msg) # 批量flush:满200条或超1秒 now = time.time() if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT: count = flush_buffer(symbol, buffer) if count > 0: logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades") buffer.clear() last_flush = now except websockets.exceptions.ConnectionClosed as e: logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s") except Exception as e: logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s") finally: # flush剩余buffer if buffer: flush_buffer(symbol, buffer) buffer.clear() await asyncio.sleep(reconnect_delay) reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s # ─── 连续性巡检 ────────────────────────────────────────────────── async def continuity_check(): """每60秒巡检一次:检查各symbol最近的agg_id是否有断档""" while True: await asyncio.sleep(60) try: conn = get_conn() ensure_meta_table(conn) for symbol in SYMBOLS: row = conn.execute( "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", (symbol,) ).fetchone() if not row: continue # 检查最近1000条是否连续 now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") tname = f"agg_trades_{now_month}" try: rows = conn.execute( f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100", (symbol,) ).fetchall() if len(rows) < 2: continue ids = [r["agg_id"] for r in rows] # 检查是否连续(降序) gaps = [] for i in range(len(ids) - 1): diff = ids[i] - ids[i + 1] if diff > 1: gaps.append((ids[i + 1], ids[i], diff - 1)) if gaps: logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}") # 触发补洞 min_gap_id = min(g[0] for g in gaps) asyncio.create_task(rest_catchup(symbol, min_gap_id)) else: logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}") except Exception: pass # 表可能还不存在 conn.close() except Exception as e: logger.error(f"Continuity check error: {e}") # ─── 每日完整性报告 ────────────────────────────────────────────── async def daily_report(): """每小时生成一次完整性摘要日志""" while True: await asyncio.sleep(3600) try: conn = get_conn() ensure_meta_table(conn) report_lines = ["=== AggTrades Integrity Report ==="] for symbol in SYMBOLS: row = conn.execute( "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", (symbol,) ).fetchone() if not row: report_lines.append(f" {symbol}: No data yet") continue last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat() # 统计本月总量 now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") tname = f"agg_trades_{now_month}" try: count = conn.execute( f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,) ).fetchone()["c"] except Exception: count = 0 report_lines.append( f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}" ) conn.close() logger.info("\n".join(report_lines)) except Exception as e: logger.error(f"Daily report error: {e}") # ─── 入口 ──────────────────────────────────────────────────────── async def main(): logger.info("AggTrades Collector starting...") # 确保基础表存在 conn = get_conn() ensure_meta_table(conn) conn.commit() conn.close() # 并行启动所有symbol的WS + 巡检 + 报告 tasks = [ ws_collect(sym) for sym in SYMBOLS ] + [ continuity_check(), daily_report(), ] await asyncio.gather(*tasks) if __name__ == "__main__": asyncio.run(main())