arbitrage-engine/backend/agg_trades_collector.py
root 4f54e36d1a feat: dual-write agg_trades to local PG + Cloud SQL
- db.py: add Cloud SQL connection pool (CLOUD_PG_ENABLED env toggle)
- agg_trades_collector: flush_buffer writes to both local and cloud
- Cloud SQL write failure is non-fatal (log warning, don't block local)
2026-03-01 07:16:03 +00:00

308 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
agg_trades_collector.py — aggTrades全量采集守护进程PostgreSQL版
架构:
- WebSocket主链路实时推送延迟<100ms
- REST补洞断线重连后从last_agg_id追平
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次
- PG分区表按月自动分区MVCC并发无锁冲突
"""
import asyncio
import json
import logging
import os
import time
from datetime import datetime, timezone
from typing import Optional
import httpx
import psycopg2
import psycopg2.extras
import websockets
from db import get_sync_conn, get_sync_pool, get_cloud_sync_conn, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS, CLOUD_PG_ENABLED
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
],
)
logger = logging.getLogger("collector")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
BATCH_SIZE = 200
BATCH_TIMEOUT = 1.0
# ─── DB helpers ──────────────────────────────────────────────────
def get_last_agg_id(symbol: str) -> Optional[int]:
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
row = cur.fetchone()
return row[0] if row else None
except Exception as e:
logger.error(f"get_last_agg_id error: {e}")
return None
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
with conn.cursor() as cur:
cur.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
VALUES (%s, %s, %s, NOW())
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = EXCLUDED.last_agg_id,
last_time_ms = EXCLUDED.last_time_ms,
updated_at = NOW()
""", (symbol, last_agg_id, last_time_ms))
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades到PG本地+Cloud SQL双写返回实际写入条数"""
if not trades:
return 0
try:
# 确保分区存在
ensure_partitions()
values = []
last_agg_id = 0
last_time_ms = 0
for t in trades:
agg_id = t["a"]
time_ms = t["T"]
values.append((
agg_id, symbol,
float(t["p"]), float(t["q"]),
time_ms,
1 if t["m"] else 0,
))
if agg_id > last_agg_id:
last_agg_id = agg_id
last_time_ms = time_ms
insert_sql = """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
VALUES %s
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING"""
insert_template = "(%s, %s, %s, %s, %s, %s)"
# 写本地PG
inserted = 0
with get_sync_conn() as conn:
with conn.cursor() as cur:
psycopg2.extras.execute_values(
cur, insert_sql, values,
template=insert_template, page_size=1000,
)
inserted = cur.rowcount
if last_agg_id > 0:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
# 双写Cloud SQL失败不影响主流程
if CLOUD_PG_ENABLED:
try:
with get_cloud_sync_conn() as cloud_conn:
if cloud_conn:
with cloud_conn.cursor() as cur:
psycopg2.extras.execute_values(
cur, insert_sql, values,
template=insert_template, page_size=1000,
)
if last_agg_id > 0:
update_meta(cloud_conn, symbol, last_agg_id, last_time_ms)
cloud_conn.commit()
except Exception as e:
logger.warning(f"[{symbol}] Cloud SQL write failed (non-fatal): {e}")
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")
return 0
# ─── REST补洞 ────────────────────────────────────────────────────
async def rest_catchup(symbol: str, from_id: int) -> int:
total = 0
current_id = from_id
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
while True:
try:
resp = await client.get(
f"{BINANCE_FAPI}/aggTrades",
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
)
if resp.status_code != 200:
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
break
data = resp.json()
if not data:
break
count = flush_buffer(symbol, data)
total += count
last = data[-1]["a"]
if last <= current_id:
break
current_id = last + 1
if len(data) < 1000:
break
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"[{symbol}] REST catchup error: {e}")
break
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
return total
# ─── WebSocket采集 ───────────────────────────────────────────────
async def ws_collect(symbol: str):
stream = symbol.lower() + "@aggTrade"
url = f"wss://fstream.binance.com/ws/{stream}"
buffer: list = []
last_flush = time.time()
reconnect_delay = 1.0
while True:
try:
last_id = get_last_agg_id(symbol)
if last_id is not None:
await rest_catchup(symbol, last_id + 1)
logger.info(f"[{symbol}] Connecting WS: {url}")
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
reconnect_delay = 1.0
logger.info(f"[{symbol}] WS connected")
async for raw in ws:
msg = json.loads(raw)
if msg.get("e") != "aggTrade":
continue
buffer.append(msg)
now = time.time()
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
count = flush_buffer(symbol, buffer)
if count > 0:
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
buffer.clear()
last_flush = now
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
except Exception as e:
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
finally:
if buffer:
flush_buffer(symbol, buffer)
buffer.clear()
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30)
# ─── 连续性巡检 ──────────────────────────────────────────────────
async def continuity_check():
while True:
await asyncio.sleep(60)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
for symbol in SYMBOLS:
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
)
row = cur.fetchone()
if not row:
continue
# 检查最近100条是否连续
cur.execute(
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
(symbol,)
)
rows = cur.fetchall()
if len(rows) < 2:
continue
ids = [r[0] for r in rows]
gaps = []
for i in range(len(ids) - 1):
diff = ids[i] - ids[i + 1]
if diff > 1:
gaps.append((ids[i + 1], ids[i], diff - 1))
if gaps:
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
min_gap_id = min(g[0] for g in gaps)
asyncio.create_task(rest_catchup(symbol, min_gap_id))
else:
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
except Exception as e:
logger.error(f"Continuity check error: {e}")
# ─── 每小时报告 ──────────────────────────────────────────────────
async def daily_report():
while True:
await asyncio.sleep(3600)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
report = ["=== AggTrades Integrity Report ==="]
for symbol in SYMBOLS:
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
)
row = cur.fetchone()
if not row:
report.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
cur.execute(
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
)
count = cur.fetchone()[0]
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
logger.info("\n".join(report))
except Exception as e:
logger.error(f"Report error: {e}")
# ─── 入口 ────────────────────────────────────────────────────────
async def main():
logger.info("AggTrades Collector (PG) starting...")
# 确保分区存在
ensure_partitions()
tasks = [
ws_collect(sym) for sym in SYMBOLS
] + [
continuity_check(),
daily_report(),
]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())