arbitrage-engine/backend/agg_trades_collector.py
root 4168c1dd88 refactor: SQLite→PostgreSQL migration - db.py连接层 + main/collector/signal-engine/backfill全部改PG
Phase 1: 核心数据表(agg_trades/rate_snapshots/signal*)迁PG
auth.py暂保留SQLite(低频,不影响性能)
- db.py: psycopg2同步池 + asyncpg异步池 + PG schema + 分区管理
- main.py: 全部改asyncpg查询
- collector: psycopg2 + execute_values批量写入
- signal-engine: psycopg2同步读写
- backfill: psycopg2 + ON CONFLICT DO NOTHING
2026-02-27 16:15:16 +00:00

293 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
agg_trades_collector.py — aggTrades全量采集守护进程PostgreSQL版
架构:
- WebSocket主链路实时推送延迟<100ms
- REST补洞断线重连后从last_agg_id追平
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次
- PG分区表按月自动分区MVCC并发无锁冲突
"""
import asyncio
import json
import logging
import os
import time
from datetime import datetime, timezone
from typing import Optional
import httpx
import psycopg2
import psycopg2.extras
import websockets
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
],
)
logger = logging.getLogger("collector")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
BATCH_SIZE = 200
BATCH_TIMEOUT = 1.0
# ─── DB helpers ──────────────────────────────────────────────────
def get_last_agg_id(symbol: str) -> Optional[int]:
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
row = cur.fetchone()
return row[0] if row else None
except Exception as e:
logger.error(f"get_last_agg_id error: {e}")
return None
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
with conn.cursor() as cur:
cur.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
VALUES (%s, %s, %s, NOW())
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = EXCLUDED.last_agg_id,
last_time_ms = EXCLUDED.last_time_ms,
updated_at = NOW()
""", (symbol, last_agg_id, last_time_ms))
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades到PG返回实际写入条数"""
if not trades:
return 0
try:
# 确保分区存在
ensure_partitions()
with get_sync_conn() as conn:
with conn.cursor() as cur:
# 批量插入ON CONFLICT忽略重复
values = []
last_agg_id = 0
last_time_ms = 0
for t in trades:
agg_id = t["a"]
time_ms = t["T"]
values.append((
agg_id, symbol,
float(t["p"]), float(t["q"]),
time_ms,
1 if t["m"] else 0,
))
if agg_id > last_agg_id:
last_agg_id = agg_id
last_time_ms = time_ms
# 批量INSERT
psycopg2.extras.execute_values(
cur,
"""INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
VALUES %s
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""",
values,
template="(%s, %s, %s, %s, %s, %s)",
page_size=1000,
)
inserted = cur.rowcount
if last_agg_id > 0:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")
return 0
# ─── REST补洞 ────────────────────────────────────────────────────
async def rest_catchup(symbol: str, from_id: int) -> int:
total = 0
current_id = from_id
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
while True:
try:
resp = await client.get(
f"{BINANCE_FAPI}/aggTrades",
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
)
if resp.status_code != 200:
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
break
data = resp.json()
if not data:
break
count = flush_buffer(symbol, data)
total += count
last = data[-1]["a"]
if last <= current_id:
break
current_id = last + 1
if len(data) < 1000:
break
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"[{symbol}] REST catchup error: {e}")
break
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
return total
# ─── WebSocket采集 ───────────────────────────────────────────────
async def ws_collect(symbol: str):
stream = symbol.lower() + "@aggTrade"
url = f"wss://fstream.binance.com/ws/{stream}"
buffer: list = []
last_flush = time.time()
reconnect_delay = 1.0
while True:
try:
last_id = get_last_agg_id(symbol)
if last_id is not None:
await rest_catchup(symbol, last_id + 1)
logger.info(f"[{symbol}] Connecting WS: {url}")
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
reconnect_delay = 1.0
logger.info(f"[{symbol}] WS connected")
async for raw in ws:
msg = json.loads(raw)
if msg.get("e") != "aggTrade":
continue
buffer.append(msg)
now = time.time()
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
count = flush_buffer(symbol, buffer)
if count > 0:
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
buffer.clear()
last_flush = now
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
except Exception as e:
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
finally:
if buffer:
flush_buffer(symbol, buffer)
buffer.clear()
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30)
# ─── 连续性巡检 ──────────────────────────────────────────────────
async def continuity_check():
while True:
await asyncio.sleep(60)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
for symbol in SYMBOLS:
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
)
row = cur.fetchone()
if not row:
continue
# 检查最近100条是否连续
cur.execute(
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
(symbol,)
)
rows = cur.fetchall()
if len(rows) < 2:
continue
ids = [r[0] for r in rows]
gaps = []
for i in range(len(ids) - 1):
diff = ids[i] - ids[i + 1]
if diff > 1:
gaps.append((ids[i + 1], ids[i], diff - 1))
if gaps:
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
min_gap_id = min(g[0] for g in gaps)
asyncio.create_task(rest_catchup(symbol, min_gap_id))
else:
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
except Exception as e:
logger.error(f"Continuity check error: {e}")
# ─── 每小时报告 ──────────────────────────────────────────────────
async def daily_report():
while True:
await asyncio.sleep(3600)
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
report = ["=== AggTrades Integrity Report ==="]
for symbol in SYMBOLS:
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
)
row = cur.fetchone()
if not row:
report.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
cur.execute(
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
)
count = cur.fetchone()[0]
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
logger.info("\n".join(report))
except Exception as e:
logger.error(f"Report error: {e}")
# ─── 入口 ────────────────────────────────────────────────────────
async def main():
logger.info("AggTrades Collector (PG) starting...")
# 确保分区存在
ensure_partitions()
tasks = [
ws_collect(sym) for sym in SYMBOLS
] + [
continuity_check(),
daily_report(),
]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())