feat: V3.0 aggTrades collector - WS+REST补洞+巡检+按月分表+查询API

This commit is contained in:
root 2026-02-27 11:29:16 +00:00
parent 1ab228286c
commit 7e38b24fa8
2 changed files with 499 additions and 0 deletions

View File

@ -0,0 +1,367 @@
"""
agg_trades_collector.py aggTrades全量采集守护进程
架构
- WebSocket主链路实时推送延迟<100ms
- REST补洞断线重连后从last_agg_id追平
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次减少WAL压力
- 按月分表agg_trades_YYYYMM单表千万行内查询快
- 健康接口GET /collector/health 可监控
"""
import asyncio
import json
import logging
import os
import sqlite3
import time
from datetime import datetime, timezone
from typing import Optional
import httpx
import websockets
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler(os.path.join(os.path.dirname(__file__), "..", "collector.log")),
],
)
logger = logging.getLogger("collector")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
# 批量写入缓冲
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
BATCH_SIZE = 200
BATCH_TIMEOUT = 1.0 # seconds
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def table_name(ts_ms: int) -> str:
"""按月分表agg_trades_202602"""
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
return f"agg_trades_{dt.strftime('%Y%m')}"
def ensure_table(conn: sqlite3.Connection, tname: str):
conn.execute(f"""
CREATE TABLE IF NOT EXISTS {tname} (
agg_id INTEGER PRIMARY KEY,
symbol TEXT NOT NULL,
price REAL NOT NULL,
qty REAL NOT NULL,
first_trade_id INTEGER,
last_trade_id INTEGER,
time_ms INTEGER NOT NULL,
is_buyer_maker INTEGER NOT NULL
)
""")
conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
ON {tname}(symbol, time_ms)
""")
def ensure_meta_table(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id INTEGER NOT NULL,
last_time_ms INTEGER NOT NULL,
updated_at TEXT DEFAULT (datetime('now'))
)
""")
def get_last_agg_id(symbol: str) -> Optional[int]:
try:
conn = get_conn()
ensure_meta_table(conn)
row = conn.execute(
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
).fetchone()
conn.close()
return row["last_agg_id"] if row else None
except Exception as e:
logger.error(f"get_last_agg_id error: {e}")
return None
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
conn.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
VALUES (?, ?, ?, datetime('now'))
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = excluded.last_agg_id,
last_time_ms = excluded.last_time_ms,
updated_at = excluded.updated_at
""", (symbol, last_agg_id, last_time_ms))
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades返回实际写入条数去重后"""
if not trades:
return 0
try:
conn = get_conn()
ensure_meta_table(conn)
# 按月分组
by_month: dict[str, list] = {}
for t in trades:
tname = table_name(t["T"])
if tname not in by_month:
by_month[tname] = []
by_month[tname].append(t)
inserted = 0
last_agg_id = 0
last_time_ms = 0
for tname, batch in by_month.items():
ensure_table(conn, tname)
for t in batch:
cur = conn.execute(
f"""INSERT OR IGNORE INTO {tname}
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
t["a"], # agg_id
symbol,
float(t["p"]),
float(t["q"]),
t.get("f"), # first_trade_id
t.get("l"), # last_trade_id
t["T"], # time_ms
1 if t["m"] else 0, # is_buyer_maker
)
)
if cur.rowcount > 0:
inserted += 1
if t["a"] > last_agg_id:
last_agg_id = t["a"]
last_time_ms = t["T"]
if last_agg_id > 0:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
conn.close()
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")
return 0
# ─── REST补洞 ────────────────────────────────────────────────────
async def rest_catchup(symbol: str, from_id: int) -> int:
"""从from_id开始REST拉取追平到最新返回补洞条数"""
total = 0
current_id = from_id
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
async with httpx.AsyncClient(timeout=15, headers=HEADERS) as client:
while True:
try:
resp = await client.get(
f"{BINANCE_FAPI}/aggTrades",
params={"symbol": symbol, "fromId": current_id, "limit": 1000},
)
if resp.status_code != 200:
logger.warning(f"[{symbol}] REST catchup HTTP {resp.status_code}")
break
data = resp.json()
if not data:
break
count = flush_buffer(symbol, data)
total += count
last = data[-1]["a"]
if last <= current_id:
break
current_id = last + 1
# 如果拉到的比最新少1000条说明追平了
if len(data) < 1000:
break
await asyncio.sleep(0.1) # rate limit友好
except Exception as e:
logger.error(f"[{symbol}] REST catchup error: {e}")
break
logger.info(f"[{symbol}] REST catchup done, filled {total} trades")
return total
# ─── WebSocket采集 ───────────────────────────────────────────────
async def ws_collect(symbol: str):
"""单Symbol的WS采集循环自动断线重连+REST补洞"""
stream = symbol.lower() + "@aggTrade"
url = f"wss://fstream.binance.com/ws/{stream}"
buffer: list = []
last_flush = time.time()
reconnect_delay = 1.0
while True:
try:
# 断线重连前先REST补洞
last_id = get_last_agg_id(symbol)
if last_id is not None:
await rest_catchup(symbol, last_id + 1)
logger.info(f"[{symbol}] Connecting WS: {url}")
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
reconnect_delay = 1.0 # 连上了就重置
logger.info(f"[{symbol}] WS connected")
async for raw in ws:
msg = json.loads(raw)
if msg.get("e") != "aggTrade":
continue
buffer.append(msg)
# 批量flush满200条或超1秒
now = time.time()
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
count = flush_buffer(symbol, buffer)
if count > 0:
logger.debug(f"[{symbol}] flushed {count}/{len(buffer)} trades")
buffer.clear()
last_flush = now
except websockets.exceptions.ConnectionClosed as e:
logger.warning(f"[{symbol}] WS closed: {e}, reconnecting in {reconnect_delay}s")
except Exception as e:
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
finally:
# flush剩余buffer
if buffer:
flush_buffer(symbol, buffer)
buffer.clear()
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
# ─── 连续性巡检 ──────────────────────────────────────────────────
async def continuity_check():
"""每60秒巡检一次检查各symbol最近的agg_id是否有断档"""
while True:
await asyncio.sleep(60)
try:
conn = get_conn()
ensure_meta_table(conn)
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
(symbol,)
).fetchone()
if not row:
continue
# 检查最近1000条是否连续
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
rows = conn.execute(
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
(symbol,)
).fetchall()
if len(rows) < 2:
continue
ids = [r["agg_id"] for r in rows]
# 检查是否连续(降序)
gaps = []
for i in range(len(ids) - 1):
diff = ids[i] - ids[i + 1]
if diff > 1:
gaps.append((ids[i + 1], ids[i], diff - 1))
if gaps:
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
# 触发补洞
min_gap_id = min(g[0] for g in gaps)
asyncio.create_task(rest_catchup(symbol, min_gap_id))
else:
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
except Exception:
pass # 表可能还不存在
conn.close()
except Exception as e:
logger.error(f"Continuity check error: {e}")
# ─── 每日完整性报告 ──────────────────────────────────────────────
async def daily_report():
"""每小时生成一次完整性摘要日志"""
while True:
await asyncio.sleep(3600)
try:
conn = get_conn()
ensure_meta_table(conn)
report_lines = ["=== AggTrades Integrity Report ==="]
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
(symbol,)
).fetchone()
if not row:
report_lines.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
# 统计本月总量
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
count = conn.execute(
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
).fetchone()["c"]
except Exception:
count = 0
report_lines.append(
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
)
conn.close()
logger.info("\n".join(report_lines))
except Exception as e:
logger.error(f"Daily report error: {e}")
# ─── 入口 ────────────────────────────────────────────────────────
async def main():
logger.info("AggTrades Collector starting...")
# 确保基础表存在
conn = get_conn()
ensure_meta_table(conn)
conn.commit()
conn.close()
# 并行启动所有symbol的WS + 巡检 + 报告
tasks = [
ws_collect(sym) for sym in SYMBOLS
] + [
continuity_check(),
daily_report(),
]
await asyncio.gather(*tasks)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -5,6 +5,7 @@ from datetime import datetime, timedelta
import asyncio, time, sqlite3, os import asyncio, time, sqlite3, os
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
import datetime as _dt
app = FastAPI(title="Arbitrage Engine API") app = FastAPI(title="Arbitrage Engine API")
@ -331,3 +332,134 @@ async def get_stats(user: dict = Depends(get_current_user)):
} }
set_cache("stats", stats) set_cache("stats", stats)
return stats return stats
# ─── aggTrades 查询接口 ──────────────────────────────────────────
@app.get("/api/trades/meta")
async def get_trades_meta(user: dict = Depends(get_current_user)):
"""aggTrades采集状态各symbol最新agg_id和时间"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
).fetchall()
except Exception:
rows = []
conn.close()
result = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
result[sym] = {
"last_agg_id": r["last_agg_id"],
"last_time_ms": r["last_time_ms"],
"updated_at": r["updated_at"],
}
return result
@app.get("/api/trades/summary")
async def get_trades_summary(
symbol: str = "BTC",
start_ms: int = 0,
end_ms: int = 0,
interval: str = "1m",
user: dict = Depends(get_current_user),
):
"""分钟级聚合买卖delta/成交速率/vwap"""
if end_ms == 0:
end_ms = int(time.time() * 1000)
if start_ms == 0:
start_ms = end_ms - 3600 * 1000 # 默认1小时
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
sym_full = symbol.upper() + "USDT"
# 确定需要查哪些月表
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
months = set()
cur = start_dt.replace(day=1)
while cur <= end_dt:
months.add(cur.strftime("%Y%m"))
if cur.month == 12:
cur = cur.replace(year=cur.year + 1, month=1)
else:
cur = cur.replace(month=cur.month + 1)
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
all_rows = []
for month in sorted(months):
tname = f"agg_trades_{month}"
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
(sym_full, start_ms, end_ms)
).fetchall()
all_rows.extend(rows)
except Exception:
pass
conn.close()
# 按interval聚合
bars: dict = {}
for row in all_rows:
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
if bar_ms not in bars:
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
"trade_count": 0, "vwap_num": 0.0, "vwap_den": 0.0, "max_qty": 0.0}
b = bars[bar_ms]
qty = float(row["qty"])
price = float(row["price"])
if row["is_buyer_maker"] == 0: # 主动买
b["buy_vol"] += qty
else:
b["sell_vol"] += qty
b["trade_count"] += 1
b["vwap_num"] += price * qty
b["vwap_den"] += qty
b["max_qty"] = max(b["max_qty"], qty)
result = []
for b in sorted(bars.values(), key=lambda x: x["time_ms"]):
total = b["buy_vol"] + b["sell_vol"]
result.append({
"time_ms": b["time_ms"],
"buy_vol": round(b["buy_vol"], 4),
"sell_vol": round(b["sell_vol"], 4),
"delta": round(b["buy_vol"] - b["sell_vol"], 4),
"total_vol": round(total, 4),
"trade_count": b["trade_count"],
"vwap": round(b["vwap_num"] / b["vwap_den"], 2) if b["vwap_den"] > 0 else 0,
"max_qty": round(b["max_qty"], 4),
})
return {"symbol": symbol, "interval": interval, "count": len(result), "data": result}
@app.get("/api/collector/health")
async def collector_health(user: dict = Depends(get_current_user)):
"""采集器健康状态"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
now_ms = int(time.time() * 1000)
status = {}
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
).fetchall()
for r in rows:
sym = r["symbol"].replace("USDT", "")
lag_s = (now_ms - r["last_time_ms"]) / 1000
status[sym] = {
"last_agg_id": r["last_agg_id"],
"lag_seconds": round(lag_s, 1),
"healthy": lag_s < 30, # 30秒内有数据算健康
}
except Exception:
pass
conn.close()
return {"collector": status, "timestamp": now_ms}