""" agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版) 架构: - WebSocket主链路:实时推送,延迟<100ms - REST补洞:断线重连后从last_agg_id追平 - 每分钟巡检:校验agg_id连续性,发现断档自动补洞 - 批量写入:攒200条或1秒flush一次 - PG分区表:按月自动分区,MVCC并发无锁冲突 """ import asyncio import json import logging import os import time from datetime import datetime, timezone from typing import Optional import httpx import psycopg2 import psycopg2.extras import websockets from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", handlers=[ logging.StreamHandler(), logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")), ], ) logger = logging.getLogger("collector") BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"} BATCH_SIZE = 200 BATCH_TIMEOUT = 1.0 # ─── DB helpers ────────────────────────────────────────────────── def get_last_agg_id(symbol: str) -> Optional[int]: try: with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,)) row = cur.fetchone() return row[0] if row else None except Exception as e: logger.error(f"get_last_agg_id error: {e}") return None def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int): with conn.cursor() as cur: cur.execute(""" INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at) VALUES (%s, %s, %s, NOW()) ON CONFLICT(symbol) DO UPDATE SET last_agg_id = EXCLUDED.last_agg_id, last_time_ms = EXCLUDED.last_time_ms, updated_at = NOW() """, (symbol, last_agg_id, last_time_ms)) def flush_buffer(symbol: str, trades: list) -> int: """写入一批trades到PG,返回实际写入条数""" if not trades: return 0 try: # 确保分区存在 ensure_partitions() with get_sync_conn() as conn: with conn.cursor() as cur: # 批量插入(ON CONFLICT忽略重复) values = [] last_agg_id = 0 last_time_ms = 0 for t in trades: agg_id = t["a"] time_ms = t["T"] values.append(( agg_id, symbol, float(t["p"]), float(t["q"]), time_ms, 1 if t["m"] else 0, )) if agg_id > last_agg_id: last_agg_id = agg_id last_time_ms = time_ms # 批量INSERT psycopg2.extras.execute_values( cur, """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) VALUES %s ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""", values, template="(%s, %s, %s, %s, %s, %s)", page_size=1000, ) inserted = cur.rowcount if last_agg_id > 0: update_meta(conn, symbol, last_agg_id, last_time_ms) conn.commit() return inserted except Exception as e: logger.error(f"flush_buffer [{symbol}] error: {e}") return 0 # ─── REST补洞 ──────────────────────────────────────────────────── async def rest_catchup(symbol: str, from_id: int) -> int: total = 0 current_id = from_id logger.info(f"[{symbol}] REST catchup from agg_id={from_id}") async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client: while True: try: resp = await client.get( f"{BINANCE_FAPI}/aggTrades", params={"symbol": symbol, "fromId": current_id, "limit": 1000}, ) if resp.status_code != 200: logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}") break data = resp.json() if not data: break count = flush_buffer(symbol, data) total += count last = data[-1]["a"] if last <= current_id: break current_id = last + 1 if len(data) < 1000: break await asyncio.sleep(0.1) except Exception as e: logger.error(f"[{symbol}] REST catchup error: {e}") break logger.info(f"[{symbol}] REST catchup done, filled {total} trades") return total # ─── WebSocket采集 ─────────────────────────────────────────────── async def ws_collect(symbol: str): stream = symbol.lower() + "@aggTrade" url = f"wss://fstream.binance.com/ws/{stream}" buffer: list = [] last_flush = time.time() reconnect_delay = 1.0 while True: try: last_id = get_last_agg_id(symbol) if last_id is not None: await rest_catchup(symbol, last_id + 1) logger.info(f"[{symbol}] Connecting WS: {url}") async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws: reconnect_delay = 1.0 logger.info(f"[{symbol}] WS connected") async for raw in ws: msg = json.loads(raw) if msg.get("e") != "aggTrade": continue buffer.append(msg) now = time.time() if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT: count = flush_buffer(symbol, buffer) if count > 0: logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades") buffer.clear() last_flush = now except websockets.exceptions.ConnectionClosed as e: logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s") except Exception as e: logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s") finally: if buffer: flush_buffer(symbol, buffer) buffer.clear() await asyncio.sleep(reconnect_delay) reconnect_delay = min(reconnect_delay * 2, 30) # ─── 连续性巡检 ────────────────────────────────────────────────── async def continuity_check(): while True: await asyncio.sleep(60) try: with get_sync_conn() as conn: with conn.cursor() as cur: for symbol in SYMBOLS: cur.execute( "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s", (symbol,) ) row = cur.fetchone() if not row: continue # 检查最近100条是否连续 cur.execute( "SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100", (symbol,) ) rows = cur.fetchall() if len(rows) < 2: continue ids = [r[0] for r in rows] gaps = [] for i in range(len(ids) - 1): diff = ids[i] - ids[i + 1] if diff > 1: gaps.append((ids[i + 1], ids[i], diff - 1)) if gaps: logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}") min_gap_id = min(g[0] for g in gaps) asyncio.create_task(rest_catchup(symbol, min_gap_id)) else: logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}") except Exception as e: logger.error(f"Continuity check error: {e}") # ─── 每小时报告 ────────────────────────────────────────────────── async def daily_report(): while True: await asyncio.sleep(3600) try: with get_sync_conn() as conn: with conn.cursor() as cur: report = ["=== AggTrades Integrity Report ==="] for symbol in SYMBOLS: cur.execute( "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s", (symbol,) ) row = cur.fetchone() if not row: report.append(f" {symbol}: No data yet") continue last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat() cur.execute( "SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,) ) count = cur.fetchone()[0] report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}") logger.info("\n".join(report)) except Exception as e: logger.error(f"Report error: {e}") # ─── 入口 ──────────────────────────────────────────────────────── async def main(): logger.info("AggTrades Collector (PG) starting...") # 确保分区存在 ensure_partitions() tasks = [ ws_collect(sym) for sym in SYMBOLS ] + [ continuity_check(), daily_report(), ] await asyncio.gather(*tasks) if __name__ == "__main__": asyncio.run(main())