368 lines
13 KiB
Python
368 lines
13 KiB
Python
"""
|
||
agg_trades_collector.py — aggTrades全量采集守护进程
|
||
|
||
架构:
|
||
- WebSocket主链路:实时推送,延迟<100ms
|
||
- REST补洞:断线重连后从last_agg_id追平
|
||
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
||
- 批量写入:攒200条或1秒flush一次,减少WAL压力
|
||
- 按月分表:agg_trades_YYYYMM,单表千万行内查询快
|
||
- 健康接口:GET /collector/health 可监控
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
import time
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
import websockets
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||
handlers=[
|
||
logging.StreamHandler(),
|
||
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
|
||
],
|
||
)
|
||
logger = logging.getLogger("collector")
|
||
|
||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
|
||
|
||
# 批量写入缓冲
|
||
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
|
||
BATCH_SIZE = 200
|
||
BATCH_TIMEOUT = 1.0 # seconds
|
||
|
||
|
||
# ─── DB helpers ──────────────────────────────────────────────────
|
||
|
||
def get_conn() -> sqlite3.Connection:
|
||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
||
conn.row_factory = sqlite3.Row
|
||
conn.execute("PRAGMA journal_mode=WAL")
|
||
conn.execute("PRAGMA synchronous=NORMAL")
|
||
return conn
|
||
|
||
|
||
def table_name(ts_ms: int) -> str:
|
||
"""按月分表:agg_trades_202602"""
|
||
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||
return f"agg_trades_{dt.strftime('%Y%m')}"
|
||
|
||
|
||
def ensure_table(conn: sqlite3.Connection, tname: str):
|
||
conn.execute(f"""
|
||
CREATE TABLE IF NOT EXISTS {tname} (
|
||
agg_id INTEGER PRIMARY KEY,
|
||
symbol TEXT NOT NULL,
|
||
price REAL NOT NULL,
|
||
qty REAL NOT NULL,
|
||
first_trade_id INTEGER,
|
||
last_trade_id INTEGER,
|
||
time_ms INTEGER NOT NULL,
|
||
is_buyer_maker INTEGER NOT NULL
|
||
)
|
||
""")
|
||
conn.execute(f"""
|
||
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
|
||
ON {tname}(symbol, time_ms)
|
||
""")
|
||
|
||
|
||
def ensure_meta_table(conn: sqlite3.Connection):
|
||
conn.execute("""
|
||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||
symbol TEXT PRIMARY KEY,
|
||
last_agg_id INTEGER NOT NULL,
|
||
last_time_ms INTEGER NOT NULL,
|
||
updated_at TEXT DEFAULT (datetime('now'))
|
||
)
|
||
""")
|
||
|
||
|
||
def get_last_agg_id(symbol: str) -> Optional[int]:
|
||
try:
|
||
conn = get_conn()
|
||
ensure_meta_table(conn)
|
||
row = conn.execute(
|
||
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
|
||
).fetchone()
|
||
conn.close()
|
||
return row["last_agg_id"] if row else None
|
||
except Exception as e:
|
||
logger.error(f"get_last_agg_id error: {e}")
|
||
return None
|
||
|
||
|
||
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
|
||
conn.execute("""
|
||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
|
||
VALUES (?, ?, ?, datetime('now'))
|
||
ON CONFLICT(symbol) DO UPDATE SET
|
||
last_agg_id = excluded.last_agg_id,
|
||
last_time_ms = excluded.last_time_ms,
|
||
updated_at = excluded.updated_at
|
||
""", (symbol, last_agg_id, last_time_ms))
|
||
|
||
|
||
def flush_buffer(symbol: str, trades: list) -> int:
|
||
"""写入一批trades,返回实际写入条数(去重后)"""
|
||
if not trades:
|
||
return 0
|
||
try:
|
||
conn = get_conn()
|
||
ensure_meta_table(conn)
|
||
# 按月分组
|
||
by_month: dict[str, list] = {}
|
||
for t in trades:
|
||
tname = table_name(t["T"])
|
||
if tname not in by_month:
|
||
by_month[tname] = []
|
||
by_month[tname].append(t)
|
||
|
||
inserted = 0
|
||
last_agg_id = 0
|
||
last_time_ms = 0
|
||
|
||
for tname, batch in by_month.items():
|
||
ensure_table(conn, tname)
|
||
for t in batch:
|
||
cur = conn.execute(
|
||
f"""INSERT OR IGNORE INTO {tname}
|
||
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
|
||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||
(
|
||
t["a"], # agg_id
|
||
symbol,
|
||
float(t["p"]),
|
||
float(t["q"]),
|
||
t.get("f"), # first_trade_id
|
||
t.get("l"), # last_trade_id
|
||
t["T"], # time_ms
|
||
1 if t["m"] else 0, # is_buyer_maker
|
||
)
|
||
)
|
||
if cur.rowcount > 0:
|
||
inserted += 1
|
||
if t["a"] > last_agg_id:
|
||
last_agg_id = t["a"]
|
||
last_time_ms = t["T"]
|
||
|
||
if last_agg_id > 0:
|
||
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
return inserted
|
||
except Exception as e:
|
||
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
||
return 0
|
||
|
||
|
||
# ─── REST补洞 ────────────────────────────────────────────────────
|
||
|
||
async def rest_catchup(symbol: str, from_id: int) -> int:
|
||
"""从from_id开始REST拉取,追平到最新,返回补洞条数"""
|
||
total = 0
|
||
current_id = from_id
|
||
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
|
||
|
||
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
|
||
while True:
|
||
try:
|
||
resp = await client.get(
|
||
f"{BINANCE_FAPI}/aggTrades",
|
||
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
|
||
)
|
||
if resp.status_code != 200:
|
||
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
|
||
break
|
||
data = resp.json()
|
||
if not data:
|
||
break
|
||
count = flush_buffer(symbol, data)
|
||
total += count
|
||
last = data[-1]["a"]
|
||
if last <= current_id:
|
||
break
|
||
current_id = last + 1
|
||
# 如果拉到的比最新少1000条,说明追平了
|
||
if len(data) < 1000:
|
||
break
|
||
await asyncio.sleep(0.1) # rate limit友好
|
||
except Exception as e:
|
||
logger.error(f"[{symbol}] REST catchup error: {e}")
|
||
break
|
||
|
||
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
|
||
return total
|
||
|
||
|
||
# ─── WebSocket采集 ───────────────────────────────────────────────
|
||
|
||
async def ws_collect(symbol: str):
|
||
"""单Symbol的WS采集循环,自动断线重连+REST补洞"""
|
||
stream = symbol.lower() + "@aggTrade"
|
||
url = f"wss://fstream.binance.com/ws/{stream}"
|
||
buffer: list = []
|
||
last_flush = time.time()
|
||
reconnect_delay = 1.0
|
||
|
||
while True:
|
||
try:
|
||
# 断线重连前先REST补洞
|
||
last_id = get_last_agg_id(symbol)
|
||
if last_id is not None:
|
||
await rest_catchup(symbol, last_id + 1)
|
||
|
||
logger.info(f"[{symbol}] Connecting WS: {url}")
|
||
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
|
||
reconnect_delay = 1.0 # 连上了就重置
|
||
logger.info(f"[{symbol}] WS connected")
|
||
|
||
async for raw in ws:
|
||
msg = json.loads(raw)
|
||
if msg.get("e") != "aggTrade":
|
||
continue
|
||
|
||
buffer.append(msg)
|
||
|
||
# 批量flush:满200条或超1秒
|
||
now = time.time()
|
||
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
|
||
count = flush_buffer(symbol, buffer)
|
||
if count > 0:
|
||
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
|
||
buffer.clear()
|
||
last_flush = now
|
||
|
||
except websockets.exceptions.ConnectionClosed as e:
|
||
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
|
||
except Exception as e:
|
||
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
|
||
finally:
|
||
# flush剩余buffer
|
||
if buffer:
|
||
flush_buffer(symbol, buffer)
|
||
buffer.clear()
|
||
|
||
await asyncio.sleep(reconnect_delay)
|
||
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
|
||
|
||
|
||
# ─── 连续性巡检 ──────────────────────────────────────────────────
|
||
|
||
async def continuity_check():
|
||
"""每60秒巡检一次:检查各symbol最近的agg_id是否有断档"""
|
||
while True:
|
||
await asyncio.sleep(60)
|
||
try:
|
||
conn = get_conn()
|
||
ensure_meta_table(conn)
|
||
for symbol in SYMBOLS:
|
||
row = conn.execute(
|
||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
||
(symbol,)
|
||
).fetchone()
|
||
if not row:
|
||
continue
|
||
# 检查最近1000条是否连续
|
||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
||
tname = f"agg_trades_{now_month}"
|
||
try:
|
||
rows = conn.execute(
|
||
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
|
||
(symbol,)
|
||
).fetchall()
|
||
if len(rows) < 2:
|
||
continue
|
||
ids = [r["agg_id"] for r in rows]
|
||
# 检查是否连续(降序)
|
||
gaps = []
|
||
for i in range(len(ids) - 1):
|
||
diff = ids[i] - ids[i + 1]
|
||
if diff > 1:
|
||
gaps.append((ids[i + 1], ids[i], diff - 1))
|
||
if gaps:
|
||
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
|
||
# 触发补洞
|
||
min_gap_id = min(g[0] for g in gaps)
|
||
asyncio.create_task(rest_catchup(symbol, min_gap_id))
|
||
else:
|
||
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
|
||
except Exception:
|
||
pass # 表可能还不存在
|
||
conn.close()
|
||
except Exception as e:
|
||
logger.error(f"Continuity check error: {e}")
|
||
|
||
|
||
# ─── 每日完整性报告 ──────────────────────────────────────────────
|
||
|
||
async def daily_report():
|
||
"""每小时生成一次完整性摘要日志"""
|
||
while True:
|
||
await asyncio.sleep(3600)
|
||
try:
|
||
conn = get_conn()
|
||
ensure_meta_table(conn)
|
||
report_lines = ["=== AggTrades Integrity Report ==="]
|
||
for symbol in SYMBOLS:
|
||
row = conn.execute(
|
||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
||
(symbol,)
|
||
).fetchone()
|
||
if not row:
|
||
report_lines.append(f" {symbol}: No data yet")
|
||
continue
|
||
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
|
||
# 统计本月总量
|
||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
||
tname = f"agg_trades_{now_month}"
|
||
try:
|
||
count = conn.execute(
|
||
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
|
||
).fetchone()["c"]
|
||
except Exception:
|
||
count = 0
|
||
report_lines.append(
|
||
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
|
||
)
|
||
conn.close()
|
||
logger.info("\n".join(report_lines))
|
||
except Exception as e:
|
||
logger.error(f"Daily report error: {e}")
|
||
|
||
|
||
# ─── 入口 ────────────────────────────────────────────────────────
|
||
|
||
async def main():
|
||
logger.info("AggTrades Collector starting...")
|
||
# 确保基础表存在
|
||
conn = get_conn()
|
||
ensure_meta_table(conn)
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
# 并行启动所有symbol的WS + 巡检 + 报告
|
||
tasks = [
|
||
ws_collect(sym) for sym in SYMBOLS
|
||
] + [
|
||
continuity_check(),
|
||
daily_report(),
|
||
]
|
||
await asyncio.gather(*tasks)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|