291 lines
11 KiB
Python
291 lines
11 KiB
Python
"""
|
||
agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版)
|
||
|
||
架构:
|
||
- WebSocket主链路:实时推送,延迟<100ms
|
||
- REST补洞:断线重连后从last_agg_id追平
|
||
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
||
- 批量写入:攒200条或1秒flush一次
|
||
- PG分区表:按月自动分区,MVCC并发无锁冲突
|
||
- 统一写入 Cloud SQL(双写机制已移除)
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
import websockets
|
||
|
||
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
|
||
],
|
||
)
|
||
logger = logging.getLogger("collector")
|
||
|
||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
|
||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
|
||
|
||
BATCH_SIZE = 200
|
||
BATCH_TIMEOUT = 1.0
|
||
|
||
|
||
# ─── DB helpers ──────────────────────────────────────────────────
|
||
|
||
def get_last_agg_id(symbol: str) -> Optional[int]:
|
||
try:
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||
row = cur.fetchone()
|
||
return row[0] if row else None
|
||
except Exception as e:
|
||
logger.error(f"get_last_agg_id error: {e}")
|
||
return None
|
||
|
||
|
||
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
|
||
with conn.cursor() as cur:
|
||
cur.execute("""
|
||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
|
||
VALUES (%s, %s, %s, NOW())
|
||
ON CONFLICT(symbol) DO UPDATE SET
|
||
last_agg_id = EXCLUDED.last_agg_id,
|
||
last_time_ms = EXCLUDED.last_time_ms,
|
||
updated_at = NOW()
|
||
""", (symbol, last_agg_id, last_time_ms))
|
||
|
||
|
||
def flush_buffer(symbol: str, trades: list) -> int:
|
||
"""写入一批trades到Cloud SQL,返回实际写入条数"""
|
||
if not trades:
|
||
return 0
|
||
try:
|
||
ensure_partitions()
|
||
|
||
values = []
|
||
last_agg_id = 0
|
||
last_time_ms = 0
|
||
|
||
for t in trades:
|
||
agg_id = t["a"]
|
||
time_ms = t["T"]
|
||
values.append((
|
||
agg_id, symbol,
|
||
float(t["p"]), float(t["q"]),
|
||
time_ms,
|
||
1 if t["m"] else 0,
|
||
))
|
||
if agg_id > last_agg_id:
|
||
last_agg_id = agg_id
|
||
last_time_ms = time_ms
|
||
|
||
insert_sql = """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
|
||
VALUES %s
|
||
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING"""
|
||
insert_template = "(%s, %s, %s, %s, %s, %s)"
|
||
|
||
inserted = 0
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
psycopg2.extras.execute_values(
|
||
cur, insert_sql, values,
|
||
template=insert_template, page_size=1000,
|
||
)
|
||
inserted = cur.rowcount
|
||
if last_agg_id > 0:
|
||
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
||
conn.commit()
|
||
|
||
return inserted
|
||
except Exception as e:
|
||
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
||
return 0
|
||
|
||
|
||
# ─── REST补洞 ────────────────────────────────────────────────────
|
||
|
||
async def rest_catchup(symbol: str, from_id: int) -> int:
|
||
total = 0
|
||
current_id = from_id
|
||
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
|
||
|
||
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
|
||
while True:
|
||
try:
|
||
resp = await client.get(
|
||
f"{BINANCE_FAPI}/aggTrades",
|
||
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
|
||
)
|
||
if resp.status_code != 200:
|
||
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
|
||
break
|
||
data = resp.json()
|
||
if not data:
|
||
break
|
||
count = flush_buffer(symbol, data)
|
||
total += count
|
||
last = data[-1]["a"]
|
||
if last <= current_id:
|
||
break
|
||
current_id = last + 1
|
||
if len(data) < 1000:
|
||
break
|
||
await asyncio.sleep(0.1)
|
||
except Exception as e:
|
||
logger.error(f"[{symbol}] REST catchup error: {e}")
|
||
break
|
||
|
||
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
|
||
return total
|
||
|
||
|
||
# ─── WebSocket采集 ───────────────────────────────────────────────
|
||
|
||
async def ws_collect(symbol: str):
|
||
stream = symbol.lower() + "@aggTrade"
|
||
url = f"wss://fstream.binance.com/ws/{stream}"
|
||
buffer: list = []
|
||
last_flush = time.time()
|
||
reconnect_delay = 1.0
|
||
|
||
while True:
|
||
try:
|
||
last_id = get_last_agg_id(symbol)
|
||
if last_id is not None:
|
||
await rest_catchup(symbol, last_id + 1)
|
||
|
||
logger.info(f"[{symbol}] Connecting WS: {url}")
|
||
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
|
||
reconnect_delay = 1.0
|
||
logger.info(f"[{symbol}] WS connected")
|
||
|
||
async for raw in ws:
|
||
msg = json.loads(raw)
|
||
if msg.get("e") != "aggTrade":
|
||
continue
|
||
|
||
buffer.append(msg)
|
||
|
||
now = time.time()
|
||
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
|
||
count = flush_buffer(symbol, buffer)
|
||
if count > 0:
|
||
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
|
||
buffer.clear()
|
||
last_flush = now
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
|
||
except Exception as e:
|
||
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
|
||
finally:
|
||
if buffer:
|
||
flush_buffer(symbol, buffer)
|
||
buffer.clear()
|
||
|
||
await asyncio.sleep(reconnect_delay)
|
||
reconnect_delay = min(reconnect_delay * 2, 30)
|
||
|
||
|
||
# ─── 连续性巡检 ──────────────────────────────────────────────────
|
||
|
||
async def continuity_check():
|
||
while True:
|
||
await asyncio.sleep(60)
|
||
try:
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
for symbol in SYMBOLS:
|
||
cur.execute(
|
||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||
(symbol,)
|
||
)
|
||
row = cur.fetchone()
|
||
if not row:
|
||
continue
|
||
# 检查最近100条是否连续
|
||
cur.execute(
|
||
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
|
||
(symbol,)
|
||
)
|
||
rows = cur.fetchall()
|
||
if len(rows) < 2:
|
||
continue
|
||
ids = [r[0] for r in rows]
|
||
gaps = []
|
||
for i in range(len(ids) - 1):
|
||
diff = ids[i] - ids[i + 1]
|
||
if diff > 1:
|
||
gaps.append((ids[i + 1], ids[i], diff - 1))
|
||
if gaps:
|
||
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
|
||
min_gap_id = min(g[0] for g in gaps)
|
||
asyncio.create_task(rest_catchup(symbol, min_gap_id))
|
||
else:
|
||
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
|
||
except Exception as e:
|
||
logger.error(f"Continuity check error: {e}")
|
||
|
||
|
||
# ─── 每小时报告 ──────────────────────────────────────────────────
|
||
|
||
async def daily_report():
|
||
while True:
|
||
await asyncio.sleep(3600)
|
||
try:
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
report = ["=== AggTrades Integrity Report ==="]
|
||
for symbol in SYMBOLS:
|
||
cur.execute(
|
||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||
(symbol,)
|
||
)
|
||
row = cur.fetchone()
|
||
if not row:
|
||
report.append(f" {symbol}: No data yet")
|
||
continue
|
||
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
|
||
cur.execute(
|
||
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
|
||
)
|
||
count = cur.fetchone()[0]
|
||
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
|
||
logger.info("\n".join(report))
|
||
except Exception as e:
|
||
logger.error(f"Report error: {e}")
|
||
|
||
|
||
# ─── 入口 ────────────────────────────────────────────────────────
|
||
|
||
async def main():
|
||
logger.info("AggTrades Collector (PG) starting...")
|
||
# 确保分区存在
|
||
ensure_partitions()
|
||
|
||
tasks = [
|
||
ws_collect(sym) for sym in SYMBOLS
|
||
] + [
|
||
continuity_check(),
|
||
daily_report(),
|
||
]
|
||
await asyncio.gather(*tasks)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|