arbitrage-engine/backend/agg_trades_collector.py

368 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
agg_trades_collector.py — aggTrades全量采集守护进程
架构:
- WebSocket主链路实时推送延迟<100ms
- REST补洞断线重连后从last_agg_id追平
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次减少WAL压力
- 按月分表agg_trades_YYYYMM单表千万行内查询快
- 健康接口GET /collector/health 可监控
"""
import asyncio
import json
import logging
import os
import sqlite3
import time
from datetime import datetime, timezone
from typing import Optional
import httpx
import websockets
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
],
)
logger = logging.getLogger("collector")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
# 批量写入缓冲
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
BATCH_SIZE = 200
BATCH_TIMEOUT = 1.0 # seconds
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def table_name(ts_ms: int) -> str:
"""按月分表agg_trades_202602"""
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
return f"agg_trades_{dt.strftime('%Y%m')}"
def ensure_table(conn: sqlite3.Connection, tname: str):
conn.execute(f"""
CREATE TABLE IF NOT EXISTS {tname} (
agg_id INTEGER PRIMARY KEY,
symbol TEXT NOT NULL,
price REAL NOT NULL,
qty REAL NOT NULL,
first_trade_id INTEGER,
last_trade_id INTEGER,
time_ms INTEGER NOT NULL,
is_buyer_maker INTEGER NOT NULL
)
""")
conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
ON {tname}(symbol, time_ms)
""")
def ensure_meta_table(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id INTEGER NOT NULL,
last_time_ms INTEGER NOT NULL,
updated_at TEXT DEFAULT (datetime('now'))
)
""")
def get_last_agg_id(symbol: str) -> Optional[int]:
try:
conn = get_conn()
ensure_meta_table(conn)
row = conn.execute(
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
).fetchone()
conn.close()
return row["last_agg_id"] if row else None
except Exception as e:
logger.error(f"get_last_agg_id error: {e}")
return None
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
conn.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
VALUES (?, ?, ?, datetime('now'))
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = excluded.last_agg_id,
last_time_ms = excluded.last_time_ms,
updated_at = excluded.updated_at
""", (symbol, last_agg_id, last_time_ms))
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades返回实际写入条数去重后"""
if not trades:
return 0
try:
conn = get_conn()
ensure_meta_table(conn)
# 按月分组
by_month: dict[str, list] = {}
for t in trades:
tname = table_name(t["T"])
if tname not in by_month:
by_month[tname] = []
by_month[tname].append(t)
inserted = 0
last_agg_id = 0
last_time_ms = 0
for tname, batch in by_month.items():
ensure_table(conn, tname)
for t in batch:
cur = conn.execute(
f"""INSERT OR IGNORE INTO {tname}
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
t["a"], # agg_id
symbol,
float(t["p"]),
float(t["q"]),
t.get("f"), # first_trade_id
t.get("l"), # last_trade_id
t["T"], # time_ms
1 if t["m"] else 0, # is_buyer_maker
)
)
if cur.rowcount > 0:
inserted += 1
if t["a"] > last_agg_id:
last_agg_id = t["a"]
last_time_ms = t["T"]
if last_agg_id > 0:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
conn.close()
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")
return 0
# ─── REST补洞 ────────────────────────────────────────────────────
async def rest_catchup(symbol: str, from_id: int) -> int:
"""从from_id开始REST拉取追平到最新返回补洞条数"""
total = 0
current_id = from_id
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
while True:
try:
resp = await client.get(
f"{BINANCE_FAPI}/aggTrades",
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
)
if resp.status_code != 200:
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
break
data = resp.json()
if not data:
break
count = flush_buffer(symbol, data)
total += count
last = data[-1]["a"]
if last <= current_id:
break
current_id = last + 1
# 如果拉到的比最新少1000条说明追平了
if len(data) < 1000:
break
await asyncio.sleep(0.1) # rate limit友好
except Exception as e:
logger.error(f"[{symbol}] REST catchup error: {e}")
break
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
return total
# ─── WebSocket采集 ───────────────────────────────────────────────
async def ws_collect(symbol: str):
"""单Symbol的WS采集循环自动断线重连+REST补洞"""
stream = symbol.lower() + "@aggTrade"
url = f"wss://fstream.binance.com/ws/{stream}"
buffer: list = []
last_flush = time.time()
reconnect_delay = 1.0
while True:
try:
# 断线重连前先REST补洞
last_id = get_last_agg_id(symbol)
if last_id is not None:
await rest_catchup(symbol, last_id + 1)
logger.info(f"[{symbol}] Connecting WS: {url}")
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
reconnect_delay = 1.0 # 连上了就重置
logger.info(f"[{symbol}] WS connected")
async for raw in ws:
msg = json.loads(raw)
if msg.get("e") != "aggTrade":
continue
buffer.append(msg)
# 批量flush满200条或超1秒
now = time.time()
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
count = flush_buffer(symbol, buffer)
if count > 0:
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
buffer.clear()
last_flush = now
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
except Exception as e:
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
finally:
# flush剩余buffer
if buffer:
flush_buffer(symbol, buffer)
buffer.clear()
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
# ─── 连续性巡检 ──────────────────────────────────────────────────
async def continuity_check():
"""每60秒巡检一次检查各symbol最近的agg_id是否有断档"""
while True:
await asyncio.sleep(60)
try:
conn = get_conn()
ensure_meta_table(conn)
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
(symbol,)
).fetchone()
if not row:
continue
# 检查最近1000条是否连续
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
rows = conn.execute(
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
(symbol,)
).fetchall()
if len(rows) < 2:
continue
ids = [r["agg_id"] for r in rows]
# 检查是否连续(降序)
gaps = []
for i in range(len(ids) - 1):
diff = ids[i] - ids[i + 1]
if diff > 1:
gaps.append((ids[i + 1], ids[i], diff - 1))
if gaps:
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
# 触发补洞
min_gap_id = min(g[0] for g in gaps)
asyncio.create_task(rest_catchup(symbol, min_gap_id))
else:
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
except Exception:
pass # 表可能还不存在
conn.close()
except Exception as e:
logger.error(f"Continuity check error: {e}")
# ─── 每日完整性报告 ──────────────────────────────────────────────
async def daily_report():
"""每小时生成一次完整性摘要日志"""
while True:
await asyncio.sleep(3600)
try:
conn = get_conn()
ensure_meta_table(conn)
report_lines = ["=== AggTrades Integrity Report ==="]
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
(symbol,)
).fetchone()
if not row:
report_lines.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
# 统计本月总量
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
count = conn.execute(
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
).fetchone()["c"]
except Exception:
count = 0
report_lines.append(
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
)
conn.close()
logger.info("\n".join(report_lines))
except Exception as e:
logger.error(f"Daily report error: {e}")
# ─── 入口 ────────────────────────────────────────────────────────
async def main():
logger.info("AggTrades Collector starting...")
# 确保基础表存在
conn = get_conn()
ensure_meta_table(conn)
conn.commit()
conn.close()
# 并行启动所有symbol的WS + 巡检 + 报告
tasks = [
ws_collect(sym) for sym in SYMBOLS
] + [
continuity_check(),
daily_report(),
]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())