refactor: SQLite→PostgreSQL migration - db.py连接层 + main/collector/signal-engine/backfill全部改PG
Phase 1: 核心数据表(agg_trades/rate_snapshots/signal*)迁PG auth.py暂保留SQLite(低频,不影响性能) - db.py: psycopg2同步池 + asyncpg异步池 + PG schema + 分区管理 - main.py: 全部改asyncpg查询 - collector: psycopg2 + execute_values批量写入 - signal-engine: psycopg2同步读写 - backfill: psycopg2 + ON CONFLICT DO NOTHING
This commit is contained in:
parent
23c7597a40
commit
4168c1dd88
@ -1,27 +1,29 @@
|
|||||||
"""
|
"""
|
||||||
agg_trades_collector.py — aggTrades全量采集守护进程
|
agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版)
|
||||||
|
|
||||||
架构:
|
架构:
|
||||||
- WebSocket主链路:实时推送,延迟<100ms
|
- WebSocket主链路:实时推送,延迟<100ms
|
||||||
- REST补洞:断线重连后从last_agg_id追平
|
- REST补洞:断线重连后从last_agg_id追平
|
||||||
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
||||||
- 批量写入:攒200条或1秒flush一次,减少WAL压力
|
- 批量写入:攒200条或1秒flush一次
|
||||||
- 按月分表:agg_trades_YYYYMM,单表千万行内查询快
|
- PG分区表:按月自动分区,MVCC并发无锁冲突
|
||||||
- 健康接口:GET /collector/health 可监控
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import psycopg2
|
||||||
|
import psycopg2.extras
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
|
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
@ -32,137 +34,85 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("collector")
|
logger = logging.getLogger("collector")
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
|
||||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
|
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
|
||||||
|
|
||||||
# 批量写入缓冲
|
|
||||||
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
|
|
||||||
BATCH_SIZE = 200
|
BATCH_SIZE = 200
|
||||||
BATCH_TIMEOUT = 1.0 # seconds
|
BATCH_TIMEOUT = 1.0
|
||||||
|
|
||||||
|
|
||||||
# ─── DB helpers ──────────────────────────────────────────────────
|
# ─── DB helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_conn() -> sqlite3.Connection:
|
|
||||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
conn.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
|
|
||||||
def table_name(ts_ms: int) -> str:
|
|
||||||
"""按月分表:agg_trades_202602"""
|
|
||||||
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
|
||||||
return f"agg_trades_{dt.strftime('%Y%m')}"
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_table(conn: sqlite3.Connection, tname: str):
|
|
||||||
conn.execute(f"""
|
|
||||||
CREATE TABLE IF NOT EXISTS {tname} (
|
|
||||||
agg_id INTEGER PRIMARY KEY,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
price REAL NOT NULL,
|
|
||||||
qty REAL NOT NULL,
|
|
||||||
first_trade_id INTEGER,
|
|
||||||
last_trade_id INTEGER,
|
|
||||||
time_ms INTEGER NOT NULL,
|
|
||||||
is_buyer_maker INTEGER NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.execute(f"""
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
|
|
||||||
ON {tname}(symbol, time_ms)
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_meta_table(conn: sqlite3.Connection):
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
|
||||||
symbol TEXT PRIMARY KEY,
|
|
||||||
last_agg_id INTEGER NOT NULL,
|
|
||||||
last_time_ms INTEGER NOT NULL,
|
|
||||||
updated_at TEXT DEFAULT (datetime('now'))
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
def get_last_agg_id(symbol: str) -> Optional[int]:
|
def get_last_agg_id(symbol: str) -> Optional[int]:
|
||||||
try:
|
try:
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
ensure_meta_table(conn)
|
with conn.cursor() as cur:
|
||||||
row = conn.execute(
|
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||||
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
|
row = cur.fetchone()
|
||||||
).fetchone()
|
return row[0] if row else None
|
||||||
conn.close()
|
|
||||||
return row["last_agg_id"] if row else None
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"get_last_agg_id error: {e}")
|
logger.error(f"get_last_agg_id error: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
|
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
|
||||||
conn.execute("""
|
with conn.cursor() as cur:
|
||||||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
|
cur.execute("""
|
||||||
VALUES (?, ?, ?, datetime('now'))
|
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
|
||||||
ON CONFLICT(symbol) DO UPDATE SET
|
VALUES (%s, %s, %s, NOW())
|
||||||
last_agg_id = excluded.last_agg_id,
|
ON CONFLICT(symbol) DO UPDATE SET
|
||||||
last_time_ms = excluded.last_time_ms,
|
last_agg_id = EXCLUDED.last_agg_id,
|
||||||
updated_at = excluded.updated_at
|
last_time_ms = EXCLUDED.last_time_ms,
|
||||||
""", (symbol, last_agg_id, last_time_ms))
|
updated_at = NOW()
|
||||||
|
""", (symbol, last_agg_id, last_time_ms))
|
||||||
|
|
||||||
|
|
||||||
def flush_buffer(symbol: str, trades: list) -> int:
|
def flush_buffer(symbol: str, trades: list) -> int:
|
||||||
"""写入一批trades,返回实际写入条数(去重后)"""
|
"""写入一批trades到PG,返回实际写入条数"""
|
||||||
if not trades:
|
if not trades:
|
||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
conn = get_conn()
|
# 确保分区存在
|
||||||
ensure_meta_table(conn)
|
ensure_partitions()
|
||||||
# 按月分组
|
|
||||||
by_month: dict[str, list] = {}
|
|
||||||
for t in trades:
|
|
||||||
tname = table_name(t["T"])
|
|
||||||
if tname not in by_month:
|
|
||||||
by_month[tname] = []
|
|
||||||
by_month[tname].append(t)
|
|
||||||
|
|
||||||
inserted = 0
|
with get_sync_conn() as conn:
|
||||||
last_agg_id = 0
|
with conn.cursor() as cur:
|
||||||
last_time_ms = 0
|
# 批量插入(ON CONFLICT忽略重复)
|
||||||
|
values = []
|
||||||
|
last_agg_id = 0
|
||||||
|
last_time_ms = 0
|
||||||
|
|
||||||
for tname, batch in by_month.items():
|
for t in trades:
|
||||||
ensure_table(conn, tname)
|
agg_id = t["a"]
|
||||||
for t in batch:
|
time_ms = t["T"]
|
||||||
cur = conn.execute(
|
values.append((
|
||||||
f"""INSERT OR IGNORE INTO {tname}
|
agg_id, symbol,
|
||||||
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
|
float(t["p"]), float(t["q"]),
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
time_ms,
|
||||||
(
|
1 if t["m"] else 0,
|
||||||
t["a"], # agg_id
|
))
|
||||||
symbol,
|
if agg_id > last_agg_id:
|
||||||
float(t["p"]),
|
last_agg_id = agg_id
|
||||||
float(t["q"]),
|
last_time_ms = time_ms
|
||||||
t.get("f"), # first_trade_id
|
|
||||||
t.get("l"), # last_trade_id
|
# 批量INSERT
|
||||||
t["T"], # time_ms
|
psycopg2.extras.execute_values(
|
||||||
1 if t["m"] else 0, # is_buyer_maker
|
cur,
|
||||||
)
|
"""INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
|
||||||
|
VALUES %s
|
||||||
|
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""",
|
||||||
|
values,
|
||||||
|
template="(%s, %s, %s, %s, %s, %s)",
|
||||||
|
page_size=1000,
|
||||||
)
|
)
|
||||||
if cur.rowcount > 0:
|
inserted = cur.rowcount
|
||||||
inserted += 1
|
|
||||||
if t["a"] > last_agg_id:
|
|
||||||
last_agg_id = t["a"]
|
|
||||||
last_time_ms = t["T"]
|
|
||||||
|
|
||||||
if last_agg_id > 0:
|
if last_agg_id > 0:
|
||||||
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
||||||
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
return inserted
|
||||||
return inserted
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
||||||
return 0
|
return 0
|
||||||
@ -171,7 +121,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
|
|||||||
# ─── REST补洞 ────────────────────────────────────────────────────
|
# ─── REST补洞 ────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def rest_catchup(symbol: str, from_id: int) -> int:
|
async def rest_catchup(symbol: str, from_id: int) -> int:
|
||||||
"""从from_id开始REST拉取,追平到最新,返回补洞条数"""
|
|
||||||
total = 0
|
total = 0
|
||||||
current_id = from_id
|
current_id = from_id
|
||||||
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
|
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
|
||||||
@ -195,10 +144,9 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
|
|||||||
if last <= current_id:
|
if last <= current_id:
|
||||||
break
|
break
|
||||||
current_id = last + 1
|
current_id = last + 1
|
||||||
# 如果拉到的比最新少1000条,说明追平了
|
|
||||||
if len(data) < 1000:
|
if len(data) < 1000:
|
||||||
break
|
break
|
||||||
await asyncio.sleep(0.1) # rate limit友好
|
await asyncio.sleep(0.1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{symbol}] REST catchup error: {e}")
|
logger.error(f"[{symbol}] REST catchup error: {e}")
|
||||||
break
|
break
|
||||||
@ -210,7 +158,6 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
|
|||||||
# ─── WebSocket采集 ───────────────────────────────────────────────
|
# ─── WebSocket采集 ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def ws_collect(symbol: str):
|
async def ws_collect(symbol: str):
|
||||||
"""单Symbol的WS采集循环,自动断线重连+REST补洞"""
|
|
||||||
stream = symbol.lower() + "@aggTrade"
|
stream = symbol.lower() + "@aggTrade"
|
||||||
url = f"wss://fstream.binance.com/ws/{stream}"
|
url = f"wss://fstream.binance.com/ws/{stream}"
|
||||||
buffer: list = []
|
buffer: list = []
|
||||||
@ -219,14 +166,13 @@ async def ws_collect(symbol: str):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# 断线重连前先REST补洞
|
|
||||||
last_id = get_last_agg_id(symbol)
|
last_id = get_last_agg_id(symbol)
|
||||||
if last_id is not None:
|
if last_id is not None:
|
||||||
await rest_catchup(symbol, last_id + 1)
|
await rest_catchup(symbol, last_id + 1)
|
||||||
|
|
||||||
logger.info(f"[{symbol}] Connecting WS: {url}")
|
logger.info(f"[{symbol}] Connecting WS: {url}")
|
||||||
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
|
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
|
||||||
reconnect_delay = 1.0 # 连上了就重置
|
reconnect_delay = 1.0
|
||||||
logger.info(f"[{symbol}] WS connected")
|
logger.info(f"[{symbol}] WS connected")
|
||||||
|
|
||||||
async for raw in ws:
|
async for raw in ws:
|
||||||
@ -236,7 +182,6 @@ async def ws_collect(symbol: str):
|
|||||||
|
|
||||||
buffer.append(msg)
|
buffer.append(msg)
|
||||||
|
|
||||||
# 批量flush:满200条或超1秒
|
|
||||||
now = time.time()
|
now = time.time()
|
||||||
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
|
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
|
||||||
count = flush_buffer(symbol, buffer)
|
count = flush_buffer(symbol, buffer)
|
||||||
@ -250,110 +195,90 @@ async def ws_collect(symbol: str):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
|
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
|
||||||
finally:
|
finally:
|
||||||
# flush剩余buffer
|
|
||||||
if buffer:
|
if buffer:
|
||||||
flush_buffer(symbol, buffer)
|
flush_buffer(symbol, buffer)
|
||||||
buffer.clear()
|
buffer.clear()
|
||||||
|
|
||||||
await asyncio.sleep(reconnect_delay)
|
await asyncio.sleep(reconnect_delay)
|
||||||
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
|
reconnect_delay = min(reconnect_delay * 2, 30)
|
||||||
|
|
||||||
|
|
||||||
# ─── 连续性巡检 ──────────────────────────────────────────────────
|
# ─── 连续性巡检 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
async def continuity_check():
|
async def continuity_check():
|
||||||
"""每60秒巡检一次:检查各symbol最近的agg_id是否有断档"""
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
try:
|
try:
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
ensure_meta_table(conn)
|
with conn.cursor() as cur:
|
||||||
for symbol in SYMBOLS:
|
for symbol in SYMBOLS:
|
||||||
row = conn.execute(
|
cur.execute(
|
||||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||||||
(symbol,)
|
(symbol,)
|
||||||
).fetchone()
|
)
|
||||||
if not row:
|
row = cur.fetchone()
|
||||||
continue
|
if not row:
|
||||||
# 检查最近1000条是否连续
|
continue
|
||||||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
# 检查最近100条是否连续
|
||||||
tname = f"agg_trades_{now_month}"
|
cur.execute(
|
||||||
try:
|
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
|
||||||
rows = conn.execute(
|
(symbol,)
|
||||||
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
|
)
|
||||||
(symbol,)
|
rows = cur.fetchall()
|
||||||
).fetchall()
|
if len(rows) < 2:
|
||||||
if len(rows) < 2:
|
continue
|
||||||
continue
|
ids = [r[0] for r in rows]
|
||||||
ids = [r["agg_id"] for r in rows]
|
gaps = []
|
||||||
# 检查是否连续(降序)
|
for i in range(len(ids) - 1):
|
||||||
gaps = []
|
diff = ids[i] - ids[i + 1]
|
||||||
for i in range(len(ids) - 1):
|
if diff > 1:
|
||||||
diff = ids[i] - ids[i + 1]
|
gaps.append((ids[i + 1], ids[i], diff - 1))
|
||||||
if diff > 1:
|
if gaps:
|
||||||
gaps.append((ids[i + 1], ids[i], diff - 1))
|
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
|
||||||
if gaps:
|
min_gap_id = min(g[0] for g in gaps)
|
||||||
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
|
asyncio.create_task(rest_catchup(symbol, min_gap_id))
|
||||||
# 触发补洞
|
else:
|
||||||
min_gap_id = min(g[0] for g in gaps)
|
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
|
||||||
asyncio.create_task(rest_catchup(symbol, min_gap_id))
|
|
||||||
else:
|
|
||||||
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
|
|
||||||
except Exception:
|
|
||||||
pass # 表可能还不存在
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Continuity check error: {e}")
|
logger.error(f"Continuity check error: {e}")
|
||||||
|
|
||||||
|
|
||||||
# ─── 每日完整性报告 ──────────────────────────────────────────────
|
# ─── 每小时报告 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
async def daily_report():
|
async def daily_report():
|
||||||
"""每小时生成一次完整性摘要日志"""
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(3600)
|
await asyncio.sleep(3600)
|
||||||
try:
|
try:
|
||||||
conn = get_conn()
|
with get_sync_conn() as conn:
|
||||||
ensure_meta_table(conn)
|
with conn.cursor() as cur:
|
||||||
report_lines = ["=== AggTrades Integrity Report ==="]
|
report = ["=== AggTrades Integrity Report ==="]
|
||||||
for symbol in SYMBOLS:
|
for symbol in SYMBOLS:
|
||||||
row = conn.execute(
|
cur.execute(
|
||||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||||||
(symbol,)
|
(symbol,)
|
||||||
).fetchone()
|
)
|
||||||
if not row:
|
row = cur.fetchone()
|
||||||
report_lines.append(f" {symbol}: No data yet")
|
if not row:
|
||||||
continue
|
report.append(f" {symbol}: No data yet")
|
||||||
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
|
continue
|
||||||
# 统计本月总量
|
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
|
||||||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
cur.execute(
|
||||||
tname = f"agg_trades_{now_month}"
|
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
|
||||||
try:
|
)
|
||||||
count = conn.execute(
|
count = cur.fetchone()[0]
|
||||||
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
|
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
|
||||||
).fetchone()["c"]
|
logger.info("\n".join(report))
|
||||||
except Exception:
|
|
||||||
count = 0
|
|
||||||
report_lines.append(
|
|
||||||
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
|
|
||||||
)
|
|
||||||
conn.close()
|
|
||||||
logger.info("\n".join(report_lines))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Daily report error: {e}")
|
logger.error(f"Report error: {e}")
|
||||||
|
|
||||||
|
|
||||||
# ─── 入口 ────────────────────────────────────────────────────────
|
# ─── 入口 ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
logger.info("AggTrades Collector starting...")
|
logger.info("AggTrades Collector (PG) starting...")
|
||||||
# 确保基础表存在
|
# 确保分区存在
|
||||||
conn = get_conn()
|
ensure_partitions()
|
||||||
ensure_meta_table(conn)
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# 并行启动所有symbol的WS + 巡检 + 报告
|
|
||||||
tasks = [
|
tasks = [
|
||||||
ws_collect(sym) for sym in SYMBOLS
|
ws_collect(sym) for sym in SYMBOLS
|
||||||
] + [
|
] + [
|
||||||
|
|||||||
@ -1,27 +1,18 @@
|
|||||||
"""
|
"""
|
||||||
backfill_agg_trades.py — 历史aggTrades回补脚本
|
backfill_agg_trades.py — 历史aggTrades回补脚本(PostgreSQL版)
|
||||||
|
|
||||||
功能:
|
|
||||||
- 从当前DB最早agg_id向历史方向回补
|
|
||||||
- Binance REST API分页拉取,每次1000条
|
|
||||||
- INSERT OR IGNORE写入按月分表
|
|
||||||
- 断点续传:记录进度到agg_trades_meta
|
|
||||||
- 速率控制:sleep 200ms/请求,429自动退避
|
|
||||||
- BTC+ETH并行回补
|
|
||||||
|
|
||||||
用法:
|
|
||||||
python3 backfill_agg_trades.py [--days 30] [--symbol BTCUSDT]
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
import psycopg2.extras
|
||||||
|
|
||||||
|
from db import get_sync_conn, init_schema, ensure_partitions
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@ -33,119 +24,56 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("backfill")
|
logger = logging.getLogger("backfill")
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
|
||||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/backfill"}
|
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/backfill"}
|
||||||
BATCH_SIZE = 1000
|
BATCH_SIZE = 1000
|
||||||
SLEEP_MS = 2000 # 低优先级:2秒/请求,让出锁给实时服务
|
SLEEP_MS = 2000
|
||||||
|
|
||||||
|
|
||||||
# ─── DB helpers ──────────────────────────────────────────────────
|
# ─── DB helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_conn() -> sqlite3.Connection:
|
def get_earliest_agg_id(symbol: str) -> int | None:
|
||||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
with get_sync_conn() as conn:
|
||||||
conn.row_factory = sqlite3.Row
|
with conn.cursor() as cur:
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
cur.execute("SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||||
conn.execute("PRAGMA synchronous=NORMAL")
|
row = cur.fetchone()
|
||||||
return conn
|
if row and row[0]:
|
||||||
|
return row[0]
|
||||||
|
# fallback: scan agg_trades
|
||||||
|
cur.execute("SELECT MIN(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
|
||||||
|
row = cur.fetchone()
|
||||||
|
return row[0] if row else None
|
||||||
|
|
||||||
|
|
||||||
def table_name(ts_ms: int) -> str:
|
def update_earliest_meta(symbol: str, agg_id: int, time_ms: int):
|
||||||
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
with get_sync_conn() as conn:
|
||||||
return f"agg_trades_{dt.strftime('%Y%m')}"
|
with conn.cursor() as cur:
|
||||||
|
cur.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||||
|
if cur.fetchone():
|
||||||
def ensure_table(conn: sqlite3.Connection, tname: str):
|
cur.execute("""
|
||||||
conn.execute(f"""
|
UPDATE agg_trades_meta SET
|
||||||
CREATE TABLE IF NOT EXISTS {tname} (
|
earliest_agg_id = LEAST(%s, COALESCE(earliest_agg_id, %s)),
|
||||||
agg_id INTEGER PRIMARY KEY,
|
earliest_time_ms = LEAST(%s, COALESCE(earliest_time_ms, %s))
|
||||||
symbol TEXT NOT NULL,
|
WHERE symbol = %s
|
||||||
price REAL NOT NULL,
|
""", (agg_id, agg_id, time_ms, time_ms, symbol))
|
||||||
qty REAL NOT NULL,
|
else:
|
||||||
time_ms INTEGER NOT NULL,
|
cur.execute("""
|
||||||
is_buyer_maker INTEGER NOT NULL
|
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms)
|
||||||
)
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
""")
|
""", (symbol, agg_id, time_ms, agg_id, time_ms))
|
||||||
conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{tname}_symbol_time ON {tname}(symbol, time_ms)")
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
def ensure_meta(conn: sqlite3.Connection):
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
|
||||||
symbol TEXT PRIMARY KEY,
|
|
||||||
last_agg_id INTEGER,
|
|
||||||
last_time_ms INTEGER,
|
|
||||||
earliest_agg_id INTEGER,
|
|
||||||
earliest_time_ms INTEGER
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
# 添加earliest字段(如果旧表没有)
|
|
||||||
try:
|
|
||||||
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_agg_id INTEGER")
|
|
||||||
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_time_ms INTEGER")
|
|
||||||
except Exception:
|
|
||||||
pass # 已存在
|
|
||||||
|
|
||||||
|
|
||||||
def get_earliest_agg_id(conn: sqlite3.Connection, symbol: str) -> int | None:
|
|
||||||
"""查找DB中该symbol的最小agg_id"""
|
|
||||||
# 先查meta表
|
|
||||||
row = conn.execute(
|
|
||||||
"SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
|
|
||||||
).fetchone()
|
|
||||||
if row and row["earliest_agg_id"]:
|
|
||||||
return row["earliest_agg_id"]
|
|
||||||
|
|
||||||
# meta表没有,扫所有月表
|
|
||||||
tables = conn.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'"
|
|
||||||
).fetchall()
|
|
||||||
min_id = None
|
|
||||||
for t in tables:
|
|
||||||
r = conn.execute(
|
|
||||||
f"SELECT MIN(agg_id) as mid FROM {t['name']} WHERE symbol = ?", (symbol,)
|
|
||||||
).fetchone()
|
|
||||||
if r and r["mid"] is not None:
|
|
||||||
if min_id is None or r["mid"] < min_id:
|
|
||||||
min_id = r["mid"]
|
|
||||||
return min_id
|
|
||||||
|
|
||||||
|
|
||||||
def update_earliest_meta(conn: sqlite3.Connection, symbol: str, agg_id: int, time_ms: int):
|
|
||||||
# 先检查是否已有该symbol的记录
|
|
||||||
row = conn.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = ?", (symbol,)).fetchone()
|
|
||||||
if row:
|
|
||||||
conn.execute("""
|
|
||||||
UPDATE agg_trades_meta SET
|
|
||||||
earliest_agg_id = MIN(?, COALESCE(earliest_agg_id, ?)),
|
|
||||||
earliest_time_ms = MIN(?, COALESCE(earliest_time_ms, ?))
|
|
||||||
WHERE symbol = ?
|
|
||||||
""", (agg_id, agg_id, time_ms, time_ms, symbol))
|
|
||||||
else:
|
|
||||||
conn.execute("""
|
|
||||||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
""", (symbol, agg_id, time_ms, agg_id, time_ms))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# ─── REST API ────────────────────────────────────────────────────
|
# ─── REST API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def fetch_agg_trades(symbol: str, from_id: int | None = None,
|
def fetch_agg_trades(symbol: str, from_id: int | None = None, limit: int = 1000) -> list:
|
||||||
start_time: int | None = None, limit: int = 1000) -> list[dict]:
|
|
||||||
"""从Binance拉取aggTrades,支持fromId和startTime"""
|
|
||||||
params = {"symbol": symbol, "limit": limit}
|
params = {"symbol": symbol, "limit": limit}
|
||||||
if from_id is not None:
|
if from_id is not None:
|
||||||
params["fromId"] = from_id
|
params["fromId"] = from_id
|
||||||
elif start_time is not None:
|
|
||||||
params["startTime"] = start_time
|
|
||||||
|
|
||||||
for attempt in range(5):
|
for attempt in range(5):
|
||||||
try:
|
try:
|
||||||
r = requests.get(
|
r = requests.get(f"{BINANCE_FAPI}/aggTrades", params=params, headers=HEADERS, timeout=15)
|
||||||
f"{BINANCE_FAPI}/aggTrades",
|
|
||||||
params=params, headers=HEADERS, timeout=15
|
|
||||||
)
|
|
||||||
if r.status_code == 429:
|
if r.status_code == 429:
|
||||||
wait = min(60 * (2 ** attempt), 300)
|
wait = min(60 * (2 ** attempt), 300)
|
||||||
logger.warning(f"Rate limited (429), waiting {wait}s (attempt {attempt+1})")
|
logger.warning(f"Rate limited (429), waiting {wait}s (attempt {attempt+1})")
|
||||||
@ -162,67 +90,47 @@ def fetch_agg_trades(symbol: str, from_id: int | None = None,
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> int:
|
def write_batch(symbol: str, trades: list) -> int:
|
||||||
"""批量写入,返回实际插入条数"""
|
if not trades:
|
||||||
inserted = 0
|
return 0
|
||||||
tables_seen = set()
|
ensure_partitions()
|
||||||
for t in trades:
|
values = [(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0) for t in trades]
|
||||||
tname = table_name(t["T"])
|
with get_sync_conn() as conn:
|
||||||
if tname not in tables_seen:
|
with conn.cursor() as cur:
|
||||||
ensure_table(conn, tname)
|
psycopg2.extras.execute_values(
|
||||||
tables_seen.add(tname)
|
cur,
|
||||||
try:
|
"INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
|
||||||
conn.execute(
|
"VALUES %s ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING",
|
||||||
f"INSERT OR IGNORE INTO {tname} (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
|
values,
|
||||||
f"VALUES (?, ?, ?, ?, ?, ?)",
|
template="(%s, %s, %s, %s, %s, %s)",
|
||||||
(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0)
|
page_size=1000,
|
||||||
)
|
)
|
||||||
inserted += 1
|
inserted = cur.rowcount
|
||||||
except sqlite3.IntegrityError:
|
conn.commit()
|
||||||
pass
|
|
||||||
conn.commit()
|
|
||||||
return inserted
|
return inserted
|
||||||
|
|
||||||
|
|
||||||
# ─── 主逻辑 ──────────────────────────────────────────────────────
|
# ─── 主逻辑 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
def backfill_symbol(symbol: str, target_days: int | None = None):
|
def backfill_symbol(symbol: str, target_days: int | None = None):
|
||||||
"""回补单个symbol的历史数据"""
|
earliest = get_earliest_agg_id(symbol)
|
||||||
conn = get_conn()
|
|
||||||
ensure_meta(conn)
|
|
||||||
|
|
||||||
earliest = get_earliest_agg_id(conn, symbol)
|
|
||||||
if earliest is None:
|
if earliest is None:
|
||||||
logger.info(f"[{symbol}] DB无数据,从最新开始拉最近的数据确定起点")
|
logger.info(f"[{symbol}] DB无数据,拉最新确定起点")
|
||||||
trades = fetch_agg_trades(symbol, limit=1)
|
trades = fetch_agg_trades(symbol, limit=1)
|
||||||
if not trades:
|
if not trades:
|
||||||
logger.error(f"[{symbol}] 无法获取起始数据")
|
logger.error(f"[{symbol}] 无法获取起始数据")
|
||||||
return
|
return
|
||||||
earliest = trades[0]["a"]
|
earliest = trades[0]["a"]
|
||||||
logger.info(f"[{symbol}] 当前最新agg_id: {earliest}")
|
|
||||||
|
|
||||||
# 计算目标:往前补到多少天前
|
target_ts = int((time.time() - target_days * 86400) * 1000) if target_days else 0
|
||||||
if target_days:
|
|
||||||
target_ts = int((time.time() - target_days * 86400) * 1000)
|
|
||||||
else:
|
|
||||||
target_ts = 0 # 拉到最早
|
|
||||||
|
|
||||||
logger.info(f"[{symbol}] 开始回补,当前最早agg_id={earliest}")
|
logger.info(f"[{symbol}] 开始回补,当前最早agg_id={earliest}")
|
||||||
if target_days:
|
|
||||||
target_dt = datetime.fromtimestamp(target_ts / 1000, tz=timezone.utc)
|
|
||||||
logger.info(f"[{symbol}] 目标:回补到 {target_dt.strftime('%Y-%m-%d')} ({target_days}天前)")
|
|
||||||
|
|
||||||
total_inserted = 0
|
total_inserted = 0
|
||||||
total_requests = 0
|
total_requests = 0
|
||||||
rate_limit_hits = 0
|
|
||||||
current_from_id = earliest
|
current_from_id = earliest
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# 从current_from_id向前拉:先拉current_from_id - BATCH_SIZE*2的位置
|
|
||||||
# Binance fromId是从该id开始正向拉,所以我们用endTime方式
|
|
||||||
# 更可靠:用fromId = current_from_id - BATCH_SIZE,然后正向拉
|
|
||||||
fetch_from = max(0, current_from_id - BATCH_SIZE)
|
fetch_from = max(0, current_from_id - BATCH_SIZE)
|
||||||
|
|
||||||
trades = fetch_agg_trades(symbol, from_id=fetch_from, limit=BATCH_SIZE)
|
trades = fetch_agg_trades(symbol, from_id=fetch_from, limit=BATCH_SIZE)
|
||||||
total_requests += 1
|
total_requests += 1
|
||||||
|
|
||||||
@ -230,104 +138,56 @@ def backfill_symbol(symbol: str, target_days: int | None = None):
|
|||||||
logger.info(f"[{symbol}] 无更多数据,回补结束")
|
logger.info(f"[{symbol}] 无更多数据,回补结束")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 过滤掉已有的(>= earliest的数据实时采集器已覆盖)
|
|
||||||
new_trades = [t for t in trades if t["a"] < current_from_id]
|
new_trades = [t for t in trades if t["a"] < current_from_id]
|
||||||
|
|
||||||
if not new_trades:
|
if not new_trades:
|
||||||
# 没有更早的数据了
|
|
||||||
logger.info(f"[{symbol}] 已到达最早数据,回补结束")
|
logger.info(f"[{symbol}] 已到达最早数据,回补结束")
|
||||||
break
|
break
|
||||||
|
|
||||||
inserted = write_batch(conn, symbol, new_trades)
|
inserted = write_batch(symbol, new_trades)
|
||||||
total_inserted += inserted
|
total_inserted += inserted
|
||||||
|
|
||||||
oldest = min(new_trades, key=lambda x: x["a"])
|
oldest = min(new_trades, key=lambda x: x["a"])
|
||||||
oldest_time = datetime.fromtimestamp(oldest["T"] / 1000, tz=timezone.utc)
|
oldest_time = datetime.fromtimestamp(oldest["T"] / 1000, tz=timezone.utc)
|
||||||
current_from_id = oldest["a"]
|
current_from_id = oldest["a"]
|
||||||
|
update_earliest_meta(symbol, oldest["a"], oldest["T"])
|
||||||
|
|
||||||
# 更新meta
|
|
||||||
update_earliest_meta(conn, symbol, oldest["a"], oldest["T"])
|
|
||||||
|
|
||||||
# 进度日志(每50批打一次)
|
|
||||||
if total_requests % 50 == 0:
|
if total_requests % 50 == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{symbol}] 进度: {total_inserted:,} 条已插入, "
|
f"[{symbol}] 进度: {total_inserted:,} 条已插入, "
|
||||||
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, "
|
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, agg_id={current_from_id:,}, 请求数={total_requests}"
|
||||||
f"agg_id={current_from_id:,}, "
|
|
||||||
f"请求数={total_requests}, 429次数={rate_limit_hits}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否到达目标时间
|
|
||||||
if target_ts and oldest["T"] <= target_ts:
|
if target_ts and oldest["T"] <= target_ts:
|
||||||
logger.info(f"[{symbol}] 已达到目标时间,回补结束")
|
logger.info(f"[{symbol}] 已达到目标时间,回补结束")
|
||||||
break
|
break
|
||||||
|
|
||||||
time.sleep(SLEEP_MS / 1000)
|
time.sleep(SLEEP_MS / 1000)
|
||||||
|
|
||||||
conn.close()
|
logger.info(f"[{symbol}] 回补完成: 总插入={total_inserted:,}, 总请求={total_requests}")
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[{symbol}] 回补完成: "
|
|
||||||
f"总插入={total_inserted:,}, 总请求={total_requests}, "
|
|
||||||
f"429次数={rate_limit_hits}"
|
|
||||||
)
|
|
||||||
return total_inserted, total_requests, rate_limit_hits
|
|
||||||
|
|
||||||
|
|
||||||
def check_continuity(symbol: str):
|
def check_continuity(symbol: str):
|
||||||
"""检查agg_id连续性"""
|
with get_sync_conn() as conn:
|
||||||
conn = get_conn()
|
with conn.cursor() as cur:
|
||||||
tables = conn.execute(
|
cur.execute("SELECT COUNT(*), MIN(agg_id), MAX(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
|
row = cur.fetchone()
|
||||||
).fetchall()
|
total, min_id, max_id = row
|
||||||
|
if not total:
|
||||||
|
logger.info(f"[{symbol}] 无数据")
|
||||||
|
return
|
||||||
|
span = max_id - min_id + 1
|
||||||
|
coverage = total / span * 100 if span > 0 else 0
|
||||||
|
logger.info(f"[{symbol}] 总条数={total:,}, ID范围={min_id:,}~{max_id:,}, 覆盖率={coverage:.2f}%")
|
||||||
|
|
||||||
all_ids = []
|
|
||||||
for t in tables:
|
|
||||||
rows = conn.execute(
|
|
||||||
f"SELECT agg_id FROM {t['name']} WHERE symbol = ? ORDER BY agg_id",
|
|
||||||
(symbol,)
|
|
||||||
).fetchall()
|
|
||||||
all_ids.extend(r["agg_id"] for r in rows)
|
|
||||||
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
if not all_ids:
|
|
||||||
logger.info(f"[{symbol}] 无数据")
|
|
||||||
return
|
|
||||||
|
|
||||||
all_ids.sort()
|
|
||||||
gaps = 0
|
|
||||||
gap_ranges = []
|
|
||||||
for i in range(1, len(all_ids)):
|
|
||||||
diff = all_ids[i] - all_ids[i-1]
|
|
||||||
if diff > 1:
|
|
||||||
gaps += 1
|
|
||||||
if len(gap_ranges) < 10:
|
|
||||||
gap_ranges.append((all_ids[i-1], all_ids[i], diff - 1))
|
|
||||||
|
|
||||||
total = len(all_ids)
|
|
||||||
span = all_ids[-1] - all_ids[0] + 1
|
|
||||||
coverage = total / span * 100 if span > 0 else 0
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[{symbol}] 连续性检查: "
|
|
||||||
f"总条数={total:,}, ID范围={all_ids[0]:,}~{all_ids[-1]:,}, "
|
|
||||||
f"理论条数={span:,}, 覆盖率={coverage:.2f}%, 缺口数={gaps}"
|
|
||||||
)
|
|
||||||
if gap_ranges:
|
|
||||||
for start, end, missing in gap_ranges[:5]:
|
|
||||||
logger.info(f" 缺口: {start} → {end} (缺{missing}条)")
|
|
||||||
|
|
||||||
|
|
||||||
# ─── CLI ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Backfill aggTrades from Binance")
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--days", type=int, default=None, help="回补天数(默认全量)")
|
parser.add_argument("--days", type=int, default=None)
|
||||||
parser.add_argument("--symbol", type=str, default=None, help="指定symbol(默认BTC+ETH)")
|
parser.add_argument("--symbol", type=str, default=None)
|
||||||
parser.add_argument("--check", action="store_true", help="仅检查连续性")
|
parser.add_argument("--check", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
init_schema()
|
||||||
symbols = [args.symbol] if args.symbol else ["BTCUSDT", "ETHUSDT"]
|
symbols = [args.symbol] if args.symbol else ["BTCUSDT", "ETHUSDT"]
|
||||||
|
|
||||||
if args.check:
|
if args.check:
|
||||||
@ -336,18 +196,12 @@ def main():
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"=== 开始回补 === symbols={symbols}, days={args.days or '全量'}")
|
logger.info(f"=== 开始回补 === symbols={symbols}, days={args.days or '全量'}")
|
||||||
|
|
||||||
for sym in symbols:
|
for sym in symbols:
|
||||||
logger.info(f"--- 回补 {sym} ---")
|
backfill_symbol(sym, target_days=args.days)
|
||||||
result = backfill_symbol(sym, target_days=args.days)
|
|
||||||
if result:
|
|
||||||
inserted, reqs, limits = result
|
|
||||||
logger.info(f"[{sym}] 结果: 插入{inserted:,}条, {reqs}次请求, {limits}次429")
|
|
||||||
|
|
||||||
logger.info("=== 回补完成,开始连续性检查 ===")
|
logger.info("=== 连续性检查 ===")
|
||||||
for sym in symbols:
|
for sym in symbols:
|
||||||
check_continuity(sym)
|
check_continuity(sym)
|
||||||
|
|
||||||
logger.info("=== 全部完成 ===")
|
logger.info("=== 全部完成 ===")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
282
backend/db.py
Normal file
282
backend/db.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
"""
|
||||||
|
db.py — PostgreSQL 数据库连接层
|
||||||
|
同步连接池(psycopg2)供脚本类使用
|
||||||
|
异步连接池(asyncpg)供FastAPI使用
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
import psycopg2
|
||||||
|
import psycopg2.pool
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
# PG连接参数
|
||||||
|
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
||||||
|
PG_PORT = int(os.getenv("PG_PORT", 5432))
|
||||||
|
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||||
|
PG_USER = os.getenv("PG_USER", "arb")
|
||||||
|
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
||||||
|
|
||||||
|
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
||||||
|
|
||||||
|
# ─── 同步连接池(psycopg2)─────────────────────────────────────
|
||||||
|
|
||||||
|
_sync_pool = None
|
||||||
|
|
||||||
|
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
|
||||||
|
global _sync_pool
|
||||||
|
if _sync_pool is None:
|
||||||
|
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
|
||||||
|
minconn=1, maxconn=5,
|
||||||
|
host=PG_HOST, port=PG_PORT,
|
||||||
|
dbname=PG_DB, user=PG_USER, password=PG_PASS,
|
||||||
|
)
|
||||||
|
return _sync_pool
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_sync_conn():
|
||||||
|
"""获取同步PG连接(自动归还到池)"""
|
||||||
|
pool = get_sync_pool()
|
||||||
|
conn = pool.getconn()
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
pool.putconn(conn)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_execute(sql: str, params=None, fetch=False):
|
||||||
|
"""简便执行SQL"""
|
||||||
|
with get_sync_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(sql, params)
|
||||||
|
if fetch:
|
||||||
|
cols = [desc[0] for desc in cur.description]
|
||||||
|
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def sync_executemany(sql: str, params_list: list):
|
||||||
|
"""批量执行"""
|
||||||
|
with get_sync_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.executemany(sql, params_list)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 异步连接池(asyncpg)─────────────────────────────────────
|
||||||
|
|
||||||
|
_async_pool: asyncpg.Pool | None = None
|
||||||
|
|
||||||
|
async def get_async_pool() -> asyncpg.Pool:
|
||||||
|
global _async_pool
|
||||||
|
if _async_pool is None:
|
||||||
|
_async_pool = await asyncpg.create_pool(
|
||||||
|
dsn=PG_DSN, min_size=2, max_size=10,
|
||||||
|
)
|
||||||
|
return _async_pool
|
||||||
|
|
||||||
|
|
||||||
|
async def async_fetch(sql: str, *args) -> list[dict]:
|
||||||
|
"""异步查询,返回dict列表"""
|
||||||
|
pool = await get_async_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
rows = await conn.fetch(sql, *args)
|
||||||
|
return [dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def async_fetchrow(sql: str, *args) -> dict | None:
|
||||||
|
"""异步查询单行"""
|
||||||
|
pool = await get_async_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
row = await conn.fetchrow(sql, *args)
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
async def async_execute(sql: str, *args):
|
||||||
|
"""异步执行"""
|
||||||
|
pool = await get_async_pool()
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
await conn.execute(sql, *args)
|
||||||
|
|
||||||
|
|
||||||
|
async def close_async_pool():
|
||||||
|
global _async_pool
|
||||||
|
if _async_pool:
|
||||||
|
await _async_pool.close()
|
||||||
|
_async_pool = None
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 建表(PG Schema)──────────────────────────────────────────
|
||||||
|
|
||||||
|
SCHEMA_SQL = """
|
||||||
|
-- 费率快照
|
||||||
|
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
btc_rate DOUBLE PRECISION NOT NULL,
|
||||||
|
eth_rate DOUBLE PRECISION NOT NULL,
|
||||||
|
btc_price DOUBLE PRECISION NOT NULL,
|
||||||
|
eth_price DOUBLE PRECISION NOT NULL,
|
||||||
|
btc_index_price DOUBLE PRECISION,
|
||||||
|
eth_index_price DOUBLE PRECISION
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
|
||||||
|
|
||||||
|
-- aggTrades元数据
|
||||||
|
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||||||
|
symbol TEXT PRIMARY KEY,
|
||||||
|
last_agg_id BIGINT NOT NULL,
|
||||||
|
last_time_ms BIGINT,
|
||||||
|
earliest_agg_id BIGINT,
|
||||||
|
earliest_time_ms BIGINT,
|
||||||
|
updated_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- aggTrades主表(分区父表)
|
||||||
|
CREATE TABLE IF NOT EXISTS agg_trades (
|
||||||
|
agg_id BIGINT NOT NULL,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
price DOUBLE PRECISION NOT NULL,
|
||||||
|
qty DOUBLE PRECISION NOT NULL,
|
||||||
|
time_ms BIGINT NOT NULL,
|
||||||
|
is_buyer_maker SMALLINT NOT NULL,
|
||||||
|
PRIMARY KEY (time_ms, symbol, agg_id)
|
||||||
|
) PARTITION BY RANGE (time_ms);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
|
||||||
|
|
||||||
|
-- Signal Engine表
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_indicators (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
cvd_fast DOUBLE PRECISION,
|
||||||
|
cvd_mid DOUBLE PRECISION,
|
||||||
|
cvd_day DOUBLE PRECISION,
|
||||||
|
cvd_fast_slope DOUBLE PRECISION,
|
||||||
|
atr_5m DOUBLE PRECISION,
|
||||||
|
atr_percentile DOUBLE PRECISION,
|
||||||
|
vwap_30m DOUBLE PRECISION,
|
||||||
|
price DOUBLE PRECISION,
|
||||||
|
p95_qty DOUBLE PRECISION,
|
||||||
|
p99_qty DOUBLE PRECISION,
|
||||||
|
buy_vol_1m DOUBLE PRECISION,
|
||||||
|
sell_vol_1m DOUBLE PRECISION,
|
||||||
|
score INTEGER,
|
||||||
|
signal TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
cvd_fast DOUBLE PRECISION,
|
||||||
|
cvd_mid DOUBLE PRECISION,
|
||||||
|
cvd_day DOUBLE PRECISION,
|
||||||
|
atr_5m DOUBLE PRECISION,
|
||||||
|
vwap_30m DOUBLE PRECISION,
|
||||||
|
price DOUBLE PRECISION,
|
||||||
|
score INTEGER,
|
||||||
|
signal TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_trades (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts_open BIGINT NOT NULL,
|
||||||
|
ts_close BIGINT,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
direction TEXT NOT NULL,
|
||||||
|
entry_price DOUBLE PRECISION,
|
||||||
|
exit_price DOUBLE PRECISION,
|
||||||
|
qty DOUBLE PRECISION,
|
||||||
|
score INTEGER,
|
||||||
|
pnl DOUBLE PRECISION,
|
||||||
|
sl_price DOUBLE PRECISION,
|
||||||
|
tp1_price DOUBLE PRECISION,
|
||||||
|
tp2_price DOUBLE PRECISION,
|
||||||
|
status TEXT DEFAULT 'open'
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 信号日志(旧表兼容)
|
||||||
|
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
symbol TEXT,
|
||||||
|
rate DOUBLE PRECISION,
|
||||||
|
annualized DOUBLE PRECISION,
|
||||||
|
sent_at TEXT,
|
||||||
|
message TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 用户表(auth)
|
||||||
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
email TEXT UNIQUE NOT NULL,
|
||||||
|
password_hash TEXT NOT NULL,
|
||||||
|
role TEXT DEFAULT 'user',
|
||||||
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
code TEXT UNIQUE NOT NULL,
|
||||||
|
used_by BIGINT REFERENCES users(id),
|
||||||
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_partitions():
|
||||||
|
"""创建当月和下月的分区表"""
|
||||||
|
import datetime
|
||||||
|
now = datetime.datetime.utcnow()
|
||||||
|
months = []
|
||||||
|
for delta in range(0, 3): # 当月+下2个月
|
||||||
|
d = now + datetime.timedelta(days=delta * 30)
|
||||||
|
months.append(d.strftime("%Y%m"))
|
||||||
|
|
||||||
|
with get_sync_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
for m in set(months):
|
||||||
|
year = int(m[:4])
|
||||||
|
month = int(m[4:])
|
||||||
|
# 计算分区范围(毫秒时间戳)
|
||||||
|
start = datetime.datetime(year, month, 1)
|
||||||
|
if month == 12:
|
||||||
|
end = datetime.datetime(year + 1, 1, 1)
|
||||||
|
else:
|
||||||
|
end = datetime.datetime(year, month + 1, 1)
|
||||||
|
start_ms = int(start.timestamp() * 1000)
|
||||||
|
end_ms = int(end.timestamp() * 1000)
|
||||||
|
part_name = f"agg_trades_{m}"
|
||||||
|
try:
|
||||||
|
cur.execute(f"""
|
||||||
|
CREATE TABLE IF NOT EXISTS {part_name}
|
||||||
|
PARTITION OF agg_trades
|
||||||
|
FOR VALUES FROM ({start_ms}) TO ({end_ms})
|
||||||
|
""")
|
||||||
|
except Exception:
|
||||||
|
pass # 分区已存在
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def init_schema():
|
||||||
|
"""初始化全部PG表结构"""
|
||||||
|
with get_sync_conn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
for stmt in SCHEMA_SQL.split(";"):
|
||||||
|
stmt = stmt.strip()
|
||||||
|
if stmt:
|
||||||
|
try:
|
||||||
|
cur.execute(stmt)
|
||||||
|
except Exception as e:
|
||||||
|
conn.rollback()
|
||||||
|
# 忽略已存在错误
|
||||||
|
continue
|
||||||
|
conn.commit()
|
||||||
|
ensure_partitions()
|
||||||
317
backend/main.py
317
backend/main.py
@ -2,9 +2,13 @@ from fastapi import FastAPI, HTTPException, Depends
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import httpx
|
import httpx
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
import asyncio, time, sqlite3, os
|
import asyncio, time, os
|
||||||
|
|
||||||
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
||||||
|
from db import (
|
||||||
|
init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute,
|
||||||
|
close_async_pool,
|
||||||
|
)
|
||||||
import datetime as _dt
|
import datetime as _dt
|
||||||
|
|
||||||
app = FastAPI(title="Arbitrage Engine API")
|
app = FastAPI(title="Arbitrage Engine API")
|
||||||
@ -21,7 +25,6 @@ app.include_router(auth_router)
|
|||||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||||
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
|
||||||
|
|
||||||
# 简单内存缓存(history/stats 60秒,rates 3秒)
|
# 简单内存缓存(history/stats 60秒,rates 3秒)
|
||||||
_cache: dict = {}
|
_cache: dict = {}
|
||||||
@ -36,50 +39,27 @@ def set_cache(key: str, data):
|
|||||||
_cache[key] = {"ts": time.time(), "data": data}
|
_cache[key] = {"ts": time.time(), "data": data}
|
||||||
|
|
||||||
|
|
||||||
def init_db():
|
async def save_snapshot(rates: dict):
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
ts INTEGER NOT NULL,
|
|
||||||
btc_rate REAL NOT NULL,
|
|
||||||
eth_rate REAL NOT NULL,
|
|
||||||
btc_price REAL NOT NULL,
|
|
||||||
eth_price REAL NOT NULL,
|
|
||||||
btc_index_price REAL,
|
|
||||||
eth_index_price REAL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
|
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def save_snapshot(rates: dict):
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
btc = rates.get("BTC", {})
|
btc = rates.get("BTC", {})
|
||||||
eth = rates.get("ETH", {})
|
eth = rates.get("ETH", {})
|
||||||
conn.execute(
|
await async_execute(
|
||||||
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
|
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) "
|
||||||
(
|
"VALUES ($1,$2,$3,$4,$5,$6,$7)",
|
||||||
int(time.time()),
|
int(time.time()),
|
||||||
float(btc.get("lastFundingRate", 0)),
|
float(btc.get("lastFundingRate", 0)),
|
||||||
float(eth.get("lastFundingRate", 0)),
|
float(eth.get("lastFundingRate", 0)),
|
||||||
float(btc.get("markPrice", 0)),
|
float(btc.get("markPrice", 0)),
|
||||||
float(eth.get("markPrice", 0)),
|
float(eth.get("markPrice", 0)),
|
||||||
float(btc.get("indexPrice", 0)),
|
float(btc.get("indexPrice", 0)),
|
||||||
float(eth.get("indexPrice", 0)),
|
float(eth.get("indexPrice", 0)),
|
||||||
)
|
|
||||||
)
|
)
|
||||||
conn.commit()
|
|
||||||
conn.close()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass # 落库失败不影响API响应
|
pass # 落库失败不影响API响应
|
||||||
|
|
||||||
|
|
||||||
async def background_snapshot_loop():
|
async def background_snapshot_loop():
|
||||||
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
|
"""后台每2秒自动拉取费率+价格并落库"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
|
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
|
||||||
@ -97,7 +77,7 @@ async def background_snapshot_loop():
|
|||||||
"indexPrice": float(data["indexPrice"]),
|
"indexPrice": float(data["indexPrice"]),
|
||||||
}
|
}
|
||||||
if result:
|
if result:
|
||||||
save_snapshot(result)
|
await save_snapshot(result)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
@ -105,11 +85,19 @@ async def background_snapshot_loop():
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
init_db()
|
# 初始化PG schema
|
||||||
|
init_schema()
|
||||||
ensure_auth_tables()
|
ensure_auth_tables()
|
||||||
|
# 初始化asyncpg池
|
||||||
|
await get_async_pool()
|
||||||
asyncio.create_task(background_snapshot_loop())
|
asyncio.create_task(background_snapshot_loop())
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown():
|
||||||
|
await close_async_pool()
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/health")
|
@app.get("/api/health")
|
||||||
async def health():
|
async def health():
|
||||||
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
||||||
@ -137,37 +125,23 @@ async def get_rates():
|
|||||||
"timestamp": data["time"],
|
"timestamp": data["time"],
|
||||||
}
|
}
|
||||||
set_cache("rates", result)
|
set_cache("rates", result)
|
||||||
# 异步落库(不阻塞响应)
|
asyncio.create_task(save_snapshot(result))
|
||||||
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/snapshots")
|
@app.get("/api/snapshots")
|
||||||
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
||||||
"""查询本地落库的实时快照数据"""
|
|
||||||
since = int(time.time()) - hours * 3600
|
since = int(time.time()) - hours * 3600
|
||||||
conn = sqlite3.connect(DB_PATH)
|
rows = await async_fetch(
|
||||||
conn.row_factory = sqlite3.Row
|
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots "
|
||||||
rows = conn.execute(
|
"WHERE ts >= $1 ORDER BY ts ASC LIMIT $2",
|
||||||
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
|
since, limit
|
||||||
(since, limit)
|
)
|
||||||
).fetchall()
|
return {"count": len(rows), "hours": hours, "data": rows}
|
||||||
conn.close()
|
|
||||||
return {
|
|
||||||
"count": len(rows),
|
|
||||||
"hours": hours,
|
|
||||||
"data": [dict(r) for r in rows]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/kline")
|
@app.get("/api/kline")
|
||||||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
||||||
"""
|
|
||||||
从 rate_snapshots 聚合K线数据
|
|
||||||
symbol: BTC | ETH
|
|
||||||
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
|
|
||||||
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
|
|
||||||
"""
|
|
||||||
interval_secs = {
|
interval_secs = {
|
||||||
"1m": 60, "5m": 300, "30m": 1800,
|
"1m": 60, "5m": 300, "30m": 1800,
|
||||||
"1h": 3600, "4h": 14400, "8h": 28800,
|
"1h": 3600, "4h": 14400, "8h": 28800,
|
||||||
@ -176,22 +150,20 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
|
|||||||
bar_secs = interval_secs.get(interval, 300)
|
bar_secs = interval_secs.get(interval, 300)
|
||||||
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
|
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
|
||||||
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
|
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
|
||||||
|
|
||||||
# 查询足够多的原始数据(limit根K * bar_secs最多需要的时间范围)
|
|
||||||
since = int(time.time()) - bar_secs * limit
|
since = int(time.time()) - bar_secs * limit
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
rows = conn.execute(
|
rows = await async_fetch(
|
||||||
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
|
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots "
|
||||||
(since,)
|
f"WHERE ts >= $1 ORDER BY ts ASC",
|
||||||
).fetchall()
|
since
|
||||||
conn.close()
|
)
|
||||||
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return {"symbol": symbol, "interval": interval, "data": []}
|
return {"symbol": symbol, "interval": interval, "data": []}
|
||||||
|
|
||||||
# 按bar_secs分组聚合OHLC
|
|
||||||
bars: dict = {}
|
bars: dict = {}
|
||||||
for ts, rate, price in rows:
|
for r in rows:
|
||||||
|
ts, rate, price = r["ts"], r["rate"], r["price"]
|
||||||
bar_ts = (ts // bar_secs) * bar_secs
|
bar_ts = (ts // bar_secs) * bar_secs
|
||||||
if bar_ts not in bars:
|
if bar_ts not in bars:
|
||||||
bars[bar_ts] = {
|
bars[bar_ts] = {
|
||||||
@ -209,20 +181,16 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
|
|||||||
b["price_close"] = price
|
b["price_close"] = price
|
||||||
|
|
||||||
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
|
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
|
||||||
# 转换为万分之(费率 × 10000)
|
|
||||||
for b in data:
|
for b in data:
|
||||||
for k in ("open", "high", "low", "close"):
|
for k in ("open", "high", "low", "close"):
|
||||||
b[k] = round(b[k] * 10000, 4)
|
b[k] = round(b[k] * 10000, 4)
|
||||||
|
|
||||||
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
|
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/stats/ytd")
|
@app.get("/api/stats/ytd")
|
||||||
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||||||
"""今年以来(YTD)资金费率年化统计"""
|
|
||||||
cached = get_cache("stats_ytd", 3600)
|
cached = get_cache("stats_ytd", 3600)
|
||||||
if cached: return cached
|
if cached: return cached
|
||||||
# 今年1月1日 00:00 UTC
|
|
||||||
import datetime
|
import datetime
|
||||||
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
|
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
|
||||||
end_time = int(time.time() * 1000)
|
end_time = int(time.time() * 1000)
|
||||||
@ -252,16 +220,12 @@ async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
|||||||
|
|
||||||
@app.get("/api/signals/history")
|
@app.get("/api/signals/history")
|
||||||
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
||||||
"""查询信号推送历史"""
|
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(DB_PATH)
|
rows = await async_fetch(
|
||||||
conn.row_factory = sqlite3.Row
|
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT $1",
|
||||||
rows = conn.execute(
|
limit
|
||||||
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?",
|
)
|
||||||
(limit,)
|
return {"items": rows}
|
||||||
).fetchall()
|
|
||||||
conn.close()
|
|
||||||
return {"items": [dict(r) for r in rows]}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"items": [], "error": str(e)}
|
return {"items": [], "error": str(e)}
|
||||||
|
|
||||||
@ -334,20 +298,11 @@ async def get_stats(user: dict = Depends(get_current_user)):
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
# ─── aggTrades 查询接口 ──────────────────────────────────────────
|
# ─── aggTrades 查询接口(PG版)───────────────────────────────────
|
||||||
|
|
||||||
@app.get("/api/trades/meta")
|
@app.get("/api/trades/meta")
|
||||||
async def get_trades_meta(user: dict = Depends(get_current_user)):
|
async def get_trades_meta(user: dict = Depends(get_current_user)):
|
||||||
"""aggTrades采集状态:各symbol最新agg_id和时间"""
|
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta")
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
try:
|
|
||||||
rows = conn.execute(
|
|
||||||
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
|
|
||||||
).fetchall()
|
|
||||||
except Exception:
|
|
||||||
rows = []
|
|
||||||
conn.close()
|
|
||||||
result = {}
|
result = {}
|
||||||
for r in rows:
|
for r in rows:
|
||||||
sym = r["symbol"].replace("USDT", "")
|
sym = r["symbol"].replace("USDT", "")
|
||||||
@ -367,46 +322,23 @@ async def get_trades_summary(
|
|||||||
interval: str = "1m",
|
interval: str = "1m",
|
||||||
user: dict = Depends(get_current_user),
|
user: dict = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""分钟级聚合:买卖delta/成交速率/vwap"""
|
|
||||||
if end_ms == 0:
|
if end_ms == 0:
|
||||||
end_ms = int(time.time() * 1000)
|
end_ms = int(time.time() * 1000)
|
||||||
if start_ms == 0:
|
if start_ms == 0:
|
||||||
start_ms = end_ms - 3600 * 1000 # 默认1小时
|
start_ms = end_ms - 3600 * 1000
|
||||||
|
|
||||||
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
|
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
|
||||||
sym_full = symbol.upper() + "USDT"
|
sym_full = symbol.upper() + "USDT"
|
||||||
|
|
||||||
# 确定需要查哪些月表
|
# PG分区表自动裁剪,直接查主表
|
||||||
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
|
rows = await async_fetch(
|
||||||
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
|
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||||
months = set()
|
"WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3 ORDER BY time_ms ASC",
|
||||||
cur = start_dt.replace(day=1)
|
sym_full, start_ms, end_ms
|
||||||
while cur <= end_dt:
|
)
|
||||||
months.add(cur.strftime("%Y%m"))
|
|
||||||
if cur.month == 12:
|
|
||||||
cur = cur.replace(year=cur.year + 1, month=1)
|
|
||||||
else:
|
|
||||||
cur = cur.replace(month=cur.month + 1)
|
|
||||||
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
all_rows = []
|
|
||||||
for month in sorted(months):
|
|
||||||
tname = f"agg_trades_{month}"
|
|
||||||
try:
|
|
||||||
rows = conn.execute(
|
|
||||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
|
||||||
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
|
|
||||||
(sym_full, start_ms, end_ms)
|
|
||||||
).fetchall()
|
|
||||||
all_rows.extend(rows)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
# 按interval聚合
|
|
||||||
bars: dict = {}
|
bars: dict = {}
|
||||||
for row in all_rows:
|
for row in rows:
|
||||||
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
|
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
|
||||||
if bar_ms not in bars:
|
if bar_ms not in bars:
|
||||||
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
|
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
|
||||||
@ -414,7 +346,7 @@ async def get_trades_summary(
|
|||||||
b = bars[bar_ms]
|
b = bars[bar_ms]
|
||||||
qty = float(row["qty"])
|
qty = float(row["qty"])
|
||||||
price = float(row["price"])
|
price = float(row["price"])
|
||||||
if row["is_buyer_maker"] == 0: # 主动买
|
if row["is_buyer_maker"] == 0:
|
||||||
b["buy_vol"] += qty
|
b["buy_vol"] += qty
|
||||||
else:
|
else:
|
||||||
b["sell_vol"] += qty
|
b["sell_vol"] += qty
|
||||||
@ -446,47 +378,28 @@ async def get_trades_latest(
|
|||||||
limit: int = 30,
|
limit: int = 30,
|
||||||
user: dict = Depends(get_current_user),
|
user: dict = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""查最新N条原始成交记录(从本地DB,实时刷新用)"""
|
|
||||||
sym_full = symbol.upper() + "USDT"
|
sym_full = symbol.upper() + "USDT"
|
||||||
now_month = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m")
|
rows = await async_fetch(
|
||||||
tname = f"agg_trades_{now_month}"
|
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||||
conn = sqlite3.connect(DB_PATH)
|
"WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2",
|
||||||
conn.row_factory = sqlite3.Row
|
sym_full, limit
|
||||||
try:
|
)
|
||||||
rows = conn.execute(
|
return {"symbol": symbol, "count": len(rows), "data": rows}
|
||||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
|
||||||
f"WHERE symbol = ? ORDER BY agg_id DESC LIMIT ?",
|
|
||||||
(sym_full, limit)
|
@app.get("/api/collector/health")
|
||||||
).fetchall()
|
|
||||||
except Exception:
|
|
||||||
rows = []
|
|
||||||
conn.close()
|
|
||||||
return {
|
|
||||||
"symbol": symbol,
|
|
||||||
"count": len(rows),
|
|
||||||
"data": [dict(r) for r in rows],
|
|
||||||
}
|
|
||||||
async def collector_health(user: dict = Depends(get_current_user)):
|
async def collector_health(user: dict = Depends(get_current_user)):
|
||||||
"""采集器健康状态"""
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
now_ms = int(time.time() * 1000)
|
now_ms = int(time.time() * 1000)
|
||||||
|
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta")
|
||||||
status = {}
|
status = {}
|
||||||
try:
|
for r in rows:
|
||||||
rows = conn.execute(
|
sym = r["symbol"].replace("USDT", "")
|
||||||
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
|
lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000
|
||||||
).fetchall()
|
status[sym] = {
|
||||||
for r in rows:
|
"last_agg_id": r["last_agg_id"],
|
||||||
sym = r["symbol"].replace("USDT", "")
|
"lag_seconds": round(lag_s, 1),
|
||||||
lag_s = (now_ms - r["last_time_ms"]) / 1000
|
"healthy": lag_s < 30,
|
||||||
status[sym] = {
|
}
|
||||||
"last_agg_id": r["last_agg_id"],
|
|
||||||
"lag_seconds": round(lag_s, 1),
|
|
||||||
"healthy": lag_s < 30, # 30秒内有数据算健康
|
|
||||||
}
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
conn.close()
|
|
||||||
return {"collector": status, "timestamp": now_ms}
|
return {"collector": status, "timestamp": now_ms}
|
||||||
|
|
||||||
|
|
||||||
@ -498,50 +411,29 @@ async def get_signal_indicators(
|
|||||||
minutes: int = 60,
|
minutes: int = 60,
|
||||||
user: dict = Depends(get_current_user),
|
user: dict = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""获取signal_indicators_1m数据(前端图表用)"""
|
|
||||||
sym_full = symbol.upper() + "USDT"
|
sym_full = symbol.upper() + "USDT"
|
||||||
now_ms = int(time.time() * 1000)
|
now_ms = int(time.time() * 1000)
|
||||||
start_ms = now_ms - minutes * 60 * 1000
|
start_ms = now_ms - minutes * 60 * 1000
|
||||||
|
rows = await async_fetch(
|
||||||
conn = sqlite3.connect(DB_PATH)
|
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
|
||||||
conn.row_factory = sqlite3.Row
|
"FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC",
|
||||||
try:
|
sym_full, start_ms
|
||||||
rows = conn.execute(
|
)
|
||||||
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
|
return {"symbol": symbol, "count": len(rows), "data": rows}
|
||||||
"FROM signal_indicators_1m WHERE symbol = ? AND ts >= ? ORDER BY ts ASC",
|
|
||||||
(sym_full, start_ms)
|
|
||||||
).fetchall()
|
|
||||||
except Exception:
|
|
||||||
rows = []
|
|
||||||
conn.close()
|
|
||||||
return {
|
|
||||||
"symbol": symbol,
|
|
||||||
"count": len(rows),
|
|
||||||
"data": [dict(r) for r in rows],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/signals/latest")
|
@app.get("/api/signals/latest")
|
||||||
async def get_signal_latest(
|
async def get_signal_latest(user: dict = Depends(get_current_user)):
|
||||||
user: dict = Depends(get_current_user),
|
|
||||||
):
|
|
||||||
"""获取最新一条各symbol的指标快照"""
|
|
||||||
conn = sqlite3.connect(DB_PATH)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
result = {}
|
result = {}
|
||||||
for sym in ["BTCUSDT", "ETHUSDT"]:
|
for sym in ["BTCUSDT", "ETHUSDT"]:
|
||||||
try:
|
row = await async_fetchrow(
|
||||||
row = conn.execute(
|
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
|
||||||
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
|
"vwap_30m, price, p95_qty, p99_qty, score, signal "
|
||||||
"vwap_30m, price, p95_qty, p99_qty, score, signal "
|
"FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1",
|
||||||
"FROM signal_indicators WHERE symbol = ? ORDER BY ts DESC LIMIT 1",
|
sym
|
||||||
(sym,)
|
)
|
||||||
).fetchone()
|
if row:
|
||||||
if row:
|
result[sym.replace("USDT", "")] = row
|
||||||
result[sym.replace("USDT", "")] = dict(row)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
conn.close()
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -551,20 +443,13 @@ async def get_signal_trades(
|
|||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
user: dict = Depends(get_current_user),
|
user: dict = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""获取信号交易记录"""
|
if status == "all":
|
||||||
conn = sqlite3.connect(DB_PATH)
|
rows = await async_fetch(
|
||||||
conn.row_factory = sqlite3.Row
|
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit
|
||||||
try:
|
)
|
||||||
if status == "all":
|
else:
|
||||||
rows = conn.execute(
|
rows = await async_fetch(
|
||||||
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT ?", (limit,)
|
"SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2",
|
||||||
).fetchall()
|
status, limit
|
||||||
else:
|
)
|
||||||
rows = conn.execute(
|
return {"count": len(rows), "data": rows}
|
||||||
"SELECT * FROM signal_trades WHERE status = ? ORDER BY ts_open DESC LIMIT ?",
|
|
||||||
(status, limit)
|
|
||||||
).fetchall()
|
|
||||||
except Exception:
|
|
||||||
rows = []
|
|
||||||
conn.close()
|
|
||||||
return {"count": len(rows), "data": [dict(r) for r in rows]}
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
signal_engine.py — V5 短线交易信号引擎
|
signal_engine.py — V5 短线交易信号引擎(PostgreSQL版)
|
||||||
|
|
||||||
架构:
|
架构:
|
||||||
- 独立PM2进程,每5秒循环
|
- 独立PM2进程,每5秒循环
|
||||||
@ -17,14 +17,13 @@ signal_engine.py — V5 短线交易信号引擎
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sqlite3
|
|
||||||
import time
|
import time
|
||||||
import math
|
|
||||||
import statistics
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from db import get_sync_conn, init_schema
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||||
@ -35,110 +34,33 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("signal-engine")
|
logger = logging.getLogger("signal-engine")
|
||||||
|
|
||||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
|
||||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||||
LOOP_INTERVAL = 5 # 秒
|
LOOP_INTERVAL = 5 # 秒
|
||||||
|
|
||||||
# 窗口大小(毫秒)
|
# 窗口大小(毫秒)
|
||||||
WINDOW_FAST = 30 * 60 * 1000 # 30分钟
|
WINDOW_FAST = 30 * 60 * 1000 # 30分钟
|
||||||
WINDOW_MID = 4 * 3600 * 1000 # 4小时
|
WINDOW_MID = 4 * 3600 * 1000 # 4小时
|
||||||
WINDOW_DAY = 24 * 3600 * 1000 # 24小时(用于P95/P99计算)
|
WINDOW_DAY = 24 * 3600 * 1000 # 24小时
|
||||||
WINDOW_VWAP = 30 * 60 * 1000 # 30分钟
|
WINDOW_VWAP = 30 * 60 * 1000 # 30分钟
|
||||||
|
|
||||||
# ATR参数
|
# ATR参数
|
||||||
ATR_PERIOD_MS = 5 * 60 * 1000 # 5分钟K线
|
ATR_PERIOD_MS = 5 * 60 * 1000
|
||||||
ATR_LENGTH = 14 # 14根
|
ATR_LENGTH = 14
|
||||||
|
|
||||||
# 信号冷却
|
# 信号冷却
|
||||||
COOLDOWN_MS = 10 * 60 * 1000 # 10分钟
|
COOLDOWN_MS = 10 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
# ─── DB helpers ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def get_conn() -> sqlite3.Connection:
|
|
||||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
|
||||||
conn.row_factory = sqlite3.Row
|
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
|
||||||
conn.execute("PRAGMA synchronous=NORMAL")
|
|
||||||
return conn
|
|
||||||
|
|
||||||
|
|
||||||
def init_tables(conn: sqlite3.Connection):
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS signal_indicators (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
ts INTEGER NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
cvd_fast REAL,
|
|
||||||
cvd_mid REAL,
|
|
||||||
cvd_day REAL,
|
|
||||||
cvd_fast_slope REAL,
|
|
||||||
atr_5m REAL,
|
|
||||||
atr_percentile REAL,
|
|
||||||
vwap_30m REAL,
|
|
||||||
price REAL,
|
|
||||||
p95_qty REAL,
|
|
||||||
p99_qty REAL,
|
|
||||||
buy_vol_1m REAL,
|
|
||||||
sell_vol_1m REAL,
|
|
||||||
score INTEGER,
|
|
||||||
signal TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts)")
|
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts)")
|
|
||||||
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
ts INTEGER NOT NULL,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
cvd_fast REAL,
|
|
||||||
cvd_mid REAL,
|
|
||||||
cvd_day REAL,
|
|
||||||
atr_5m REAL,
|
|
||||||
vwap_30m REAL,
|
|
||||||
price REAL,
|
|
||||||
score INTEGER,
|
|
||||||
signal TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts)")
|
|
||||||
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS signal_trades (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
ts_open INTEGER NOT NULL,
|
|
||||||
ts_close INTEGER,
|
|
||||||
symbol TEXT NOT NULL,
|
|
||||||
direction TEXT NOT NULL,
|
|
||||||
entry_price REAL,
|
|
||||||
exit_price REAL,
|
|
||||||
qty REAL,
|
|
||||||
score INTEGER,
|
|
||||||
pnl REAL,
|
|
||||||
sl_price REAL,
|
|
||||||
tp1_price REAL,
|
|
||||||
tp2_price REAL,
|
|
||||||
status TEXT DEFAULT 'open'
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# ─── 滚动窗口 ───────────────────────────────────────────────────
|
# ─── 滚动窗口 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
class TradeWindow:
|
class TradeWindow:
|
||||||
"""滚动时间窗口,维护买卖量和价格数据"""
|
|
||||||
|
|
||||||
def __init__(self, window_ms: int):
|
def __init__(self, window_ms: int):
|
||||||
self.window_ms = window_ms
|
self.window_ms = window_ms
|
||||||
self.trades: deque = deque() # (time_ms, qty, price, is_buyer_maker)
|
self.trades: deque = deque()
|
||||||
self.buy_vol = 0.0
|
self.buy_vol = 0.0
|
||||||
self.sell_vol = 0.0
|
self.sell_vol = 0.0
|
||||||
self.pq_sum = 0.0 # price * qty 累加(VWAP用)
|
self.pq_sum = 0.0
|
||||||
self.q_sum = 0.0 # qty累加
|
self.q_sum = 0.0
|
||||||
self.quantities: list = [] # 用于P95/P99(仅24h窗口用)
|
|
||||||
|
|
||||||
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
|
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
|
||||||
self.trades.append((time_ms, qty, price, is_buyer_maker))
|
self.trades.append((time_ms, qty, price, is_buyer_maker))
|
||||||
@ -154,8 +76,7 @@ class TradeWindow:
|
|||||||
cutoff = now_ms - self.window_ms
|
cutoff = now_ms - self.window_ms
|
||||||
while self.trades and self.trades[0][0] < cutoff:
|
while self.trades and self.trades[0][0] < cutoff:
|
||||||
t_ms, qty, price, ibm = self.trades.popleft()
|
t_ms, qty, price, ibm = self.trades.popleft()
|
||||||
pq = price * qty
|
self.pq_sum -= price * qty
|
||||||
self.pq_sum -= pq
|
|
||||||
self.q_sum -= qty
|
self.q_sum -= qty
|
||||||
if ibm == 0:
|
if ibm == 0:
|
||||||
self.buy_vol -= qty
|
self.buy_vol -= qty
|
||||||
@ -170,29 +91,21 @@ class TradeWindow:
|
|||||||
def vwap(self) -> float:
|
def vwap(self) -> float:
|
||||||
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
|
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
|
||||||
|
|
||||||
@property
|
|
||||||
def total_vol(self) -> float:
|
|
||||||
return self.buy_vol + self.sell_vol
|
|
||||||
|
|
||||||
|
|
||||||
class ATRCalculator:
|
class ATRCalculator:
|
||||||
"""5分钟K线ATR计算"""
|
|
||||||
|
|
||||||
def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH):
|
def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH):
|
||||||
self.period_ms = period_ms
|
self.period_ms = period_ms
|
||||||
self.length = length
|
self.length = length
|
||||||
self.candles: deque = deque(maxlen=length + 1)
|
self.candles: deque = deque(maxlen=length + 1)
|
||||||
self.current_candle: Optional[dict] = None
|
self.current_candle: Optional[dict] = None
|
||||||
self.atr_history: deque = deque(maxlen=288) # 24h of 5m candles for percentile
|
self.atr_history: deque = deque(maxlen=288)
|
||||||
|
|
||||||
def update(self, time_ms: int, price: float):
|
def update(self, time_ms: int, price: float):
|
||||||
bar_ms = (time_ms // self.period_ms) * self.period_ms
|
bar_ms = (time_ms // self.period_ms) * self.period_ms
|
||||||
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
|
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
|
||||||
if self.current_candle is not None:
|
if self.current_candle is not None:
|
||||||
self.candles.append(self.current_candle)
|
self.candles.append(self.current_candle)
|
||||||
self.current_candle = {
|
self.current_candle = {"bar": bar_ms, "open": price, "high": price, "low": price, "close": price}
|
||||||
"bar": bar_ms, "open": price, "high": price, "low": price, "close": price
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
c = self.current_candle
|
c = self.current_candle
|
||||||
c["high"] = max(c["high"], price)
|
c["high"] = max(c["high"], price)
|
||||||
@ -212,7 +125,6 @@ class ATRCalculator:
|
|||||||
trs.append(tr)
|
trs.append(tr)
|
||||||
if not trs:
|
if not trs:
|
||||||
return 0.0
|
return 0.0
|
||||||
# EMA-style ATR
|
|
||||||
atr_val = trs[0]
|
atr_val = trs[0]
|
||||||
for tr in trs[1:]:
|
for tr in trs[1:]:
|
||||||
atr_val = (atr_val * (self.length - 1) + tr) / self.length
|
atr_val = (atr_val * (self.length - 1) + tr) / self.length
|
||||||
@ -220,7 +132,6 @@ class ATRCalculator:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def atr_percentile(self) -> float:
|
def atr_percentile(self) -> float:
|
||||||
"""当前ATR在最近24h中的分位数"""
|
|
||||||
current = self.atr
|
current = self.atr
|
||||||
if current == 0:
|
if current == 0:
|
||||||
return 50.0
|
return 50.0
|
||||||
@ -233,8 +144,6 @@ class ATRCalculator:
|
|||||||
|
|
||||||
|
|
||||||
class SymbolState:
|
class SymbolState:
|
||||||
"""单个交易对的完整状态"""
|
|
||||||
|
|
||||||
def __init__(self, symbol: str):
|
def __init__(self, symbol: str):
|
||||||
self.symbol = symbol
|
self.symbol = symbol
|
||||||
self.win_fast = TradeWindow(WINDOW_FAST)
|
self.win_fast = TradeWindow(WINDOW_FAST)
|
||||||
@ -244,12 +153,10 @@ class SymbolState:
|
|||||||
self.atr_calc = ATRCalculator()
|
self.atr_calc = ATRCalculator()
|
||||||
self.last_processed_id = 0
|
self.last_processed_id = 0
|
||||||
self.warmup = True
|
self.warmup = True
|
||||||
self.warmup_until = 0
|
|
||||||
self.prev_cvd_fast = 0.0
|
self.prev_cvd_fast = 0.0
|
||||||
self.last_signal_ts = 0
|
self.last_signal_ts = 0
|
||||||
self.last_signal_dir = ""
|
self.last_signal_dir = ""
|
||||||
# P99大单追踪(最近15分钟)
|
self.recent_large_trades: deque = deque()
|
||||||
self.recent_large_trades: deque = deque() # (time_ms, qty, is_buyer_maker)
|
|
||||||
|
|
||||||
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||||||
now_ms = time_ms
|
now_ms = time_ms
|
||||||
@ -258,47 +165,34 @@ class SymbolState:
|
|||||||
self.win_day.add(time_ms, qty, price, is_buyer_maker)
|
self.win_day.add(time_ms, qty, price, is_buyer_maker)
|
||||||
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
|
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
|
||||||
self.atr_calc.update(time_ms, price)
|
self.atr_calc.update(time_ms, price)
|
||||||
|
|
||||||
self.win_fast.trim(now_ms)
|
self.win_fast.trim(now_ms)
|
||||||
self.win_mid.trim(now_ms)
|
self.win_mid.trim(now_ms)
|
||||||
self.win_day.trim(now_ms)
|
self.win_day.trim(now_ms)
|
||||||
self.win_vwap.trim(now_ms)
|
self.win_vwap.trim(now_ms)
|
||||||
|
|
||||||
self.last_processed_id = agg_id
|
self.last_processed_id = agg_id
|
||||||
|
|
||||||
def compute_p95_p99(self) -> tuple[float, float]:
|
def compute_p95_p99(self) -> tuple:
|
||||||
"""从24h窗口计算大单阈值"""
|
|
||||||
if len(self.win_day.trades) < 100:
|
if len(self.win_day.trades) < 100:
|
||||||
return 5.0, 10.0 # 默认兜底
|
return 5.0, 10.0
|
||||||
|
qtys = sorted([t[1] for t in self.win_day.trades])
|
||||||
qtys = [t[1] for t in self.win_day.trades]
|
|
||||||
qtys.sort()
|
|
||||||
n = len(qtys)
|
n = len(qtys)
|
||||||
p95 = qtys[int(n * 0.95)]
|
p95 = qtys[int(n * 0.95)]
|
||||||
p99 = qtys[int(n * 0.99)]
|
p99 = qtys[int(n * 0.99)]
|
||||||
|
|
||||||
# BTC兜底5,ETH兜底50
|
|
||||||
if "BTC" in self.symbol:
|
if "BTC" in self.symbol:
|
||||||
p95 = max(p95, 5.0)
|
p95 = max(p95, 5.0); p99 = max(p99, 10.0)
|
||||||
p99 = max(p99, 10.0)
|
|
||||||
else:
|
else:
|
||||||
p95 = max(p95, 50.0)
|
p95 = max(p95, 50.0); p99 = max(p99, 100.0)
|
||||||
p99 = max(p99, 100.0)
|
|
||||||
return p95, p99
|
return p95, p99
|
||||||
|
|
||||||
def update_large_trades(self, now_ms: int, p99: float):
|
def update_large_trades(self, now_ms: int, p99: float):
|
||||||
"""更新最近15分钟的P99大单记录"""
|
|
||||||
cutoff = now_ms - 15 * 60 * 1000
|
cutoff = now_ms - 15 * 60 * 1000
|
||||||
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
|
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
|
||||||
self.recent_large_trades.popleft()
|
self.recent_large_trades.popleft()
|
||||||
|
|
||||||
# 从最近处理的trades中找大单
|
|
||||||
for t in self.win_fast.trades:
|
for t in self.win_fast.trades:
|
||||||
if t[1] >= p99 and t[0] > cutoff:
|
if t[1] >= p99 and t[0] > cutoff:
|
||||||
self.recent_large_trades.append((t[0], t[1], t[3]))
|
self.recent_large_trades.append((t[0], t[1], t[3]))
|
||||||
|
|
||||||
def evaluate_signal(self, now_ms: int) -> dict:
|
def evaluate_signal(self, now_ms: int) -> dict:
|
||||||
"""评估信号:核心3条件 + 加分3条件"""
|
|
||||||
cvd_fast = self.win_fast.cvd
|
cvd_fast = self.win_fast.cvd
|
||||||
cvd_mid = self.win_mid.cvd
|
cvd_mid = self.win_mid.cvd
|
||||||
vwap = self.win_vwap.vwap
|
vwap = self.win_vwap.vwap
|
||||||
@ -306,198 +200,127 @@ class SymbolState:
|
|||||||
atr_pct = self.atr_calc.atr_percentile
|
atr_pct = self.atr_calc.atr_percentile
|
||||||
p95, p99 = self.compute_p95_p99()
|
p95, p99 = self.compute_p95_p99()
|
||||||
self.update_large_trades(now_ms, p99)
|
self.update_large_trades(now_ms, p99)
|
||||||
|
|
||||||
# 当前价格(用VWAP近似)
|
|
||||||
price = vwap if vwap > 0 else 0
|
price = vwap if vwap > 0 else 0
|
||||||
|
|
||||||
# CVD_fast斜率
|
|
||||||
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
|
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
|
||||||
self.prev_cvd_fast = cvd_fast
|
self.prev_cvd_fast = cvd_fast
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"cvd_fast": cvd_fast,
|
"cvd_fast": cvd_fast, "cvd_mid": cvd_mid, "cvd_day": self.win_day.cvd,
|
||||||
"cvd_mid": cvd_mid,
|
|
||||||
"cvd_day": self.win_day.cvd,
|
|
||||||
"cvd_fast_slope": cvd_fast_slope,
|
"cvd_fast_slope": cvd_fast_slope,
|
||||||
"atr": atr,
|
"atr": atr, "atr_pct": atr_pct, "vwap": vwap, "price": price,
|
||||||
"atr_pct": atr_pct,
|
"p95": p95, "p99": p99, "signal": None, "direction": None, "score": 0,
|
||||||
"vwap": vwap,
|
|
||||||
"price": price,
|
|
||||||
"p95": p95,
|
|
||||||
"p99": p99,
|
|
||||||
"signal": None,
|
|
||||||
"direction": None,
|
|
||||||
"score": 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.warmup or price == 0 or atr == 0:
|
if self.warmup or price == 0 or atr == 0:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 冷却期检查
|
|
||||||
if now_ms - self.last_signal_ts < COOLDOWN_MS:
|
if now_ms - self.last_signal_ts < COOLDOWN_MS:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# === 核心条件 ===
|
long_core = cvd_fast > 0 and cvd_fast_slope > 0 and cvd_mid > 0 and price > vwap
|
||||||
# 做多
|
short_core = cvd_fast < 0 and cvd_fast_slope < 0 and cvd_mid < 0 and price < vwap
|
||||||
long_core = (
|
|
||||||
cvd_fast > 0 and cvd_fast_slope > 0 and # CVD_fast正且上升
|
|
||||||
cvd_mid > 0 and # CVD_mid正
|
|
||||||
price > vwap # 价格在VWAP上方
|
|
||||||
)
|
|
||||||
# 做空
|
|
||||||
short_core = (
|
|
||||||
cvd_fast < 0 and cvd_fast_slope < 0 and
|
|
||||||
cvd_mid < 0 and
|
|
||||||
price < vwap
|
|
||||||
)
|
|
||||||
|
|
||||||
if not long_core and not short_core:
|
if not long_core and not short_core:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
direction = "LONG" if long_core else "SHORT"
|
direction = "LONG" if long_core else "SHORT"
|
||||||
|
|
||||||
# === 加分条件 ===
|
|
||||||
score = 0
|
score = 0
|
||||||
|
|
||||||
# 1. ATR压缩→扩张 (+25)
|
|
||||||
if atr_pct > 60:
|
if atr_pct > 60:
|
||||||
score += 25
|
score += 25
|
||||||
|
has_adverse = any(
|
||||||
# 2. 无反向P99大单 (+20)
|
(direction == "LONG" and lt[2] == 1) or (direction == "SHORT" and lt[2] == 0)
|
||||||
has_adverse_large = False
|
for lt in self.recent_large_trades
|
||||||
for lt in self.recent_large_trades:
|
)
|
||||||
if direction == "LONG" and lt[2] == 1: # 大卖单对多头不利
|
if not has_adverse:
|
||||||
has_adverse_large = True
|
|
||||||
elif direction == "SHORT" and lt[2] == 0: # 大买单对空头不利
|
|
||||||
has_adverse_large = True
|
|
||||||
if not has_adverse_large:
|
|
||||||
score += 20
|
score += 20
|
||||||
|
|
||||||
# 3. 资金费率配合 (+15) — 从rate_snapshots读取
|
|
||||||
# 暂时跳过,后续接入
|
|
||||||
# TODO: 接入资金费率条件
|
|
||||||
score += 0 # placeholder
|
|
||||||
|
|
||||||
result["signal"] = direction
|
result["signal"] = direction
|
||||||
result["direction"] = direction
|
result["direction"] = direction
|
||||||
result["score"] = score
|
result["score"] = score
|
||||||
|
|
||||||
# 更新冷却
|
|
||||||
self.last_signal_ts = now_ms
|
self.last_signal_ts = now_ms
|
||||||
self.last_signal_dir = direction
|
self.last_signal_dir = direction
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
# ─── 主循环 ──────────────────────────────────────────────────────
|
# ─── PG DB操作 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_month_tables(conn: sqlite3.Connection) -> list[str]:
|
def load_historical(state: SymbolState, window_ms: int):
|
||||||
rows = conn.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
|
|
||||||
).fetchall()
|
|
||||||
return [r["name"] for r in rows]
|
|
||||||
|
|
||||||
|
|
||||||
def load_historical(conn: sqlite3.Connection, state: SymbolState, window_ms: int):
|
|
||||||
"""冷启动:回灌历史数据到内存窗口"""
|
|
||||||
now_ms = int(time.time() * 1000)
|
now_ms = int(time.time() * 1000)
|
||||||
start_ms = now_ms - window_ms
|
start_ms = now_ms - window_ms
|
||||||
tables = get_month_tables(conn)
|
|
||||||
|
|
||||||
count = 0
|
count = 0
|
||||||
for tname in tables:
|
with get_sync_conn() as conn:
|
||||||
try:
|
with conn.cursor() as cur:
|
||||||
rows = conn.execute(
|
cur.execute(
|
||||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||||
f"WHERE symbol = ? AND time_ms >= ? ORDER BY agg_id ASC",
|
"WHERE symbol = %s AND time_ms >= %s ORDER BY agg_id ASC",
|
||||||
(state.symbol, start_ms)
|
(state.symbol, start_ms)
|
||||||
).fetchall()
|
)
|
||||||
for r in rows:
|
while True:
|
||||||
state.process_trade(r["agg_id"], r["time_ms"], r["price"], r["qty"], r["is_buyer_maker"])
|
rows = cur.fetchmany(5000)
|
||||||
count += 1
|
if not rows:
|
||||||
except Exception as e:
|
break
|
||||||
logger.warning(f"Error loading {tname}: {e}")
|
for r in rows:
|
||||||
|
state.process_trade(r[0], r[3], r[1], r[2], r[4])
|
||||||
|
count += 1
|
||||||
logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)")
|
logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)")
|
||||||
state.warmup = False
|
state.warmup = False
|
||||||
|
|
||||||
|
|
||||||
def fetch_new_trades(conn: sqlite3.Connection, symbol: str, last_id: int) -> list:
|
def fetch_new_trades(symbol: str, last_id: int) -> list:
|
||||||
"""增量读取新aggTrades"""
|
with get_sync_conn() as conn:
|
||||||
tables = get_month_tables(conn)
|
with conn.cursor() as cur:
|
||||||
results = []
|
cur.execute(
|
||||||
for tname in tables[-2:]: # 只查最近两个月表
|
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||||
try:
|
"WHERE symbol = %s AND agg_id > %s ORDER BY agg_id ASC LIMIT 10000",
|
||||||
rows = conn.execute(
|
|
||||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
|
||||||
f"WHERE symbol = ? AND agg_id > ? ORDER BY agg_id ASC LIMIT 10000",
|
|
||||||
(symbol, last_id)
|
(symbol, last_id)
|
||||||
).fetchall()
|
)
|
||||||
results.extend(rows)
|
return [{"agg_id": r[0], "price": r[1], "qty": r[2], "time_ms": r[3], "is_buyer_maker": r[4]}
|
||||||
except Exception:
|
for r in cur.fetchall()]
|
||||||
pass
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def save_indicator(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
|
def save_indicator(ts: int, symbol: str, result: dict):
|
||||||
conn.execute("""
|
with get_sync_conn() as conn:
|
||||||
INSERT INTO signal_indicators
|
with conn.cursor() as cur:
|
||||||
(ts, symbol, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile,
|
cur.execute(
|
||||||
vwap_30m, price, p95_qty, p99_qty, score, signal)
|
"INSERT INTO signal_indicators "
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
"(ts,symbol,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,vwap_30m,price,p95_qty,p99_qty,score,signal) "
|
||||||
""", (
|
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||||
ts, symbol,
|
(ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
|
||||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
|
result["atr"], result["atr_pct"], result["vwap"], result["price"],
|
||||||
result["atr"], result["atr_pct"],
|
result["p95"], result["p99"], result["score"], result.get("signal"))
|
||||||
result["vwap"], result["price"],
|
)
|
||||||
result["p95"], result["p99"],
|
conn.commit()
|
||||||
result["score"], result.get("signal")
|
|
||||||
))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
def save_indicator_1m(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
|
def save_indicator_1m(ts: int, symbol: str, result: dict):
|
||||||
"""每分钟聚合保存"""
|
|
||||||
bar_ts = (ts // 60000) * 60000
|
bar_ts = (ts // 60000) * 60000
|
||||||
existing = conn.execute(
|
with get_sync_conn() as conn:
|
||||||
"SELECT id FROM signal_indicators_1m WHERE ts = ? AND symbol = ?", (bar_ts, symbol)
|
with conn.cursor() as cur:
|
||||||
).fetchone()
|
cur.execute("SELECT id FROM signal_indicators_1m WHERE ts=%s AND symbol=%s", (bar_ts, symbol))
|
||||||
if existing:
|
if cur.fetchone():
|
||||||
conn.execute("""
|
cur.execute(
|
||||||
UPDATE signal_indicators_1m SET
|
"UPDATE signal_indicators_1m SET cvd_fast=%s,cvd_mid=%s,cvd_day=%s,atr_5m=%s,vwap_30m=%s,price=%s,score=%s,signal=%s WHERE ts=%s AND symbol=%s",
|
||||||
cvd_fast=?, cvd_mid=?, cvd_day=?, atr_5m=?, vwap_30m=?, price=?, score=?, signal=?
|
(result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"], result["vwap"],
|
||||||
WHERE ts=? AND symbol=?
|
result["price"], result["score"], result.get("signal"), bar_ts, symbol)
|
||||||
""", (
|
)
|
||||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
|
else:
|
||||||
result["atr"], result["vwap"], result["price"],
|
cur.execute(
|
||||||
result["score"], result.get("signal"),
|
"INSERT INTO signal_indicators_1m (ts,symbol,cvd_fast,cvd_mid,cvd_day,atr_5m,vwap_30m,price,score,signal) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||||
bar_ts, symbol
|
(bar_ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"],
|
||||||
))
|
result["vwap"], result["price"], result["score"], result.get("signal"))
|
||||||
else:
|
)
|
||||||
conn.execute("""
|
conn.commit()
|
||||||
INSERT INTO signal_indicators_1m
|
|
||||||
(ts, symbol, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
""", (
|
|
||||||
bar_ts, symbol,
|
|
||||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
|
|
||||||
result["atr"], result["vwap"], result["price"],
|
|
||||||
result["score"], result.get("signal")
|
|
||||||
))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
|
|
||||||
|
# ─── 主循环 ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
conn = get_conn()
|
init_schema()
|
||||||
init_tables(conn)
|
|
||||||
|
|
||||||
states = {sym: SymbolState(sym) for sym in SYMBOLS}
|
states = {sym: SymbolState(sym) for sym in SYMBOLS}
|
||||||
|
|
||||||
# 冷启动:回灌4h数据
|
|
||||||
for sym, state in states.items():
|
for sym, state in states.items():
|
||||||
load_historical(conn, state, WINDOW_MID)
|
load_historical(state, WINDOW_MID)
|
||||||
|
|
||||||
logger.info("=== Signal Engine 启动完成 ===")
|
logger.info("=== Signal Engine (PG) 启动完成 ===")
|
||||||
|
|
||||||
last_1m_save = {}
|
last_1m_save = {}
|
||||||
cycle = 0
|
cycle = 0
|
||||||
@ -505,44 +328,27 @@ def main():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
now_ms = int(time.time() * 1000)
|
now_ms = int(time.time() * 1000)
|
||||||
|
|
||||||
for sym, state in states.items():
|
for sym, state in states.items():
|
||||||
# 增量读取新数据
|
new_trades = fetch_new_trades(sym, state.last_processed_id)
|
||||||
new_trades = fetch_new_trades(conn, sym, state.last_processed_id)
|
|
||||||
for t in new_trades:
|
for t in new_trades:
|
||||||
state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"])
|
state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"])
|
||||||
|
|
||||||
# 评估信号
|
|
||||||
result = state.evaluate_signal(now_ms)
|
result = state.evaluate_signal(now_ms)
|
||||||
|
save_indicator(now_ms, sym, result)
|
||||||
|
|
||||||
# 保存5秒指标
|
|
||||||
save_indicator(conn, now_ms, sym, result)
|
|
||||||
|
|
||||||
# 每分钟保存聚合
|
|
||||||
bar_1m = (now_ms // 60000) * 60000
|
bar_1m = (now_ms // 60000) * 60000
|
||||||
if last_1m_save.get(sym) != bar_1m:
|
if last_1m_save.get(sym) != bar_1m:
|
||||||
save_indicator_1m(conn, now_ms, sym, result)
|
save_indicator_1m(now_ms, sym, result)
|
||||||
last_1m_save[sym] = bar_1m
|
last_1m_save[sym] = bar_1m
|
||||||
|
|
||||||
# 有信号则记录
|
|
||||||
if result.get("signal"):
|
if result.get("signal"):
|
||||||
logger.info(
|
logger.info(f"[{sym}] 🚨 信号: {result['signal']} score={result['score']} price={result['price']:.1f}")
|
||||||
f"[{sym}] 🚨 信号: {result['signal']} "
|
|
||||||
f"score={result['score']} price={result['price']:.1f} "
|
|
||||||
f"CVD_fast={result['cvd_fast']:.1f} CVD_mid={result['cvd_mid']:.1f}"
|
|
||||||
)
|
|
||||||
# TODO: Discord推送
|
|
||||||
# TODO: 写入signal_trades
|
|
||||||
|
|
||||||
cycle += 1
|
cycle += 1
|
||||||
if cycle % 60 == 0: # 每5分钟打一次状态
|
if cycle % 60 == 0:
|
||||||
for sym, state in states.items():
|
for sym, state in states.items():
|
||||||
r = state.evaluate_signal(now_ms)
|
r = state.evaluate_signal(now_ms)
|
||||||
logger.info(
|
logger.info(f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f}")
|
||||||
f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} "
|
|
||||||
f"ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f} "
|
|
||||||
f"P95={r['p95']:.4f} P99={r['p99']:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"循环异常: {e}", exc_info=True)
|
logger.error(f"循环异常: {e}", exc_info=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user