refactor: SQLite→PostgreSQL migration - db.py连接层 + main/collector/signal-engine/backfill全部改PG
Phase 1: 核心数据表(agg_trades/rate_snapshots/signal*)迁PG auth.py暂保留SQLite(低频,不影响性能) - db.py: psycopg2同步池 + asyncpg异步池 + PG schema + 分区管理 - main.py: 全部改asyncpg查询 - collector: psycopg2 + execute_values批量写入 - signal-engine: psycopg2同步读写 - backfill: psycopg2 + ON CONFLICT DO NOTHING
This commit is contained in:
parent
23c7597a40
commit
4168c1dd88
@ -1,27 +1,29 @@
|
||||
"""
|
||||
agg_trades_collector.py — aggTrades全量采集守护进程
|
||||
agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版)
|
||||
|
||||
架构:
|
||||
- WebSocket主链路:实时推送,延迟<100ms
|
||||
- REST补洞:断线重连后从last_agg_id追平
|
||||
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
||||
- 批量写入:攒200条或1秒flush一次,减少WAL压力
|
||||
- 按月分表:agg_trades_YYYYMM,单表千万行内查询快
|
||||
- 健康接口:GET /collector/health 可监控
|
||||
- 批量写入:攒200条或1秒flush一次
|
||||
- PG分区表:按月自动分区,MVCC并发无锁冲突
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import websockets
|
||||
|
||||
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
@ -32,136 +34,84 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("collector")
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
|
||||
|
||||
# 批量写入缓冲
|
||||
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
|
||||
BATCH_SIZE = 200
|
||||
BATCH_TIMEOUT = 1.0 # seconds
|
||||
BATCH_TIMEOUT = 1.0
|
||||
|
||||
|
||||
# ─── DB helpers ──────────────────────────────────────────────────
|
||||
|
||||
def get_conn() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
def table_name(ts_ms: int) -> str:
|
||||
"""按月分表:agg_trades_202602"""
|
||||
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
return f"agg_trades_{dt.strftime('%Y%m')}"
|
||||
|
||||
|
||||
def ensure_table(conn: sqlite3.Connection, tname: str):
|
||||
conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {tname} (
|
||||
agg_id INTEGER PRIMARY KEY,
|
||||
symbol TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
qty REAL NOT NULL,
|
||||
first_trade_id INTEGER,
|
||||
last_trade_id INTEGER,
|
||||
time_ms INTEGER NOT NULL,
|
||||
is_buyer_maker INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute(f"""
|
||||
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
|
||||
ON {tname}(symbol, time_ms)
|
||||
""")
|
||||
|
||||
|
||||
def ensure_meta_table(conn: sqlite3.Connection):
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||||
symbol TEXT PRIMARY KEY,
|
||||
last_agg_id INTEGER NOT NULL,
|
||||
last_time_ms INTEGER NOT NULL,
|
||||
updated_at TEXT DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def get_last_agg_id(symbol: str) -> Optional[int]:
|
||||
try:
|
||||
conn = get_conn()
|
||||
ensure_meta_table(conn)
|
||||
row = conn.execute(
|
||||
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
|
||||
).fetchone()
|
||||
conn.close()
|
||||
return row["last_agg_id"] if row else None
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||
row = cur.fetchone()
|
||||
return row[0] if row else None
|
||||
except Exception as e:
|
||||
logger.error(f"get_last_agg_id error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
|
||||
conn.execute("""
|
||||
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
|
||||
VALUES (?, ?, ?, datetime('now'))
|
||||
VALUES (%s, %s, %s, NOW())
|
||||
ON CONFLICT(symbol) DO UPDATE SET
|
||||
last_agg_id = excluded.last_agg_id,
|
||||
last_time_ms = excluded.last_time_ms,
|
||||
updated_at = excluded.updated_at
|
||||
last_agg_id = EXCLUDED.last_agg_id,
|
||||
last_time_ms = EXCLUDED.last_time_ms,
|
||||
updated_at = NOW()
|
||||
""", (symbol, last_agg_id, last_time_ms))
|
||||
|
||||
|
||||
def flush_buffer(symbol: str, trades: list) -> int:
|
||||
"""写入一批trades,返回实际写入条数(去重后)"""
|
||||
"""写入一批trades到PG,返回实际写入条数"""
|
||||
if not trades:
|
||||
return 0
|
||||
try:
|
||||
conn = get_conn()
|
||||
ensure_meta_table(conn)
|
||||
# 按月分组
|
||||
by_month: dict[str, list] = {}
|
||||
for t in trades:
|
||||
tname = table_name(t["T"])
|
||||
if tname not in by_month:
|
||||
by_month[tname] = []
|
||||
by_month[tname].append(t)
|
||||
# 确保分区存在
|
||||
ensure_partitions()
|
||||
|
||||
inserted = 0
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 批量插入(ON CONFLICT忽略重复)
|
||||
values = []
|
||||
last_agg_id = 0
|
||||
last_time_ms = 0
|
||||
|
||||
for tname, batch in by_month.items():
|
||||
ensure_table(conn, tname)
|
||||
for t in batch:
|
||||
cur = conn.execute(
|
||||
f"""INSERT OR IGNORE INTO {tname}
|
||||
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
t["a"], # agg_id
|
||||
symbol,
|
||||
float(t["p"]),
|
||||
float(t["q"]),
|
||||
t.get("f"), # first_trade_id
|
||||
t.get("l"), # last_trade_id
|
||||
t["T"], # time_ms
|
||||
1 if t["m"] else 0, # is_buyer_maker
|
||||
for t in trades:
|
||||
agg_id = t["a"]
|
||||
time_ms = t["T"]
|
||||
values.append((
|
||||
agg_id, symbol,
|
||||
float(t["p"]), float(t["q"]),
|
||||
time_ms,
|
||||
1 if t["m"] else 0,
|
||||
))
|
||||
if agg_id > last_agg_id:
|
||||
last_agg_id = agg_id
|
||||
last_time_ms = time_ms
|
||||
|
||||
# 批量INSERT
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"""INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
|
||||
VALUES %s
|
||||
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""",
|
||||
values,
|
||||
template="(%s, %s, %s, %s, %s, %s)",
|
||||
page_size=1000,
|
||||
)
|
||||
)
|
||||
if cur.rowcount > 0:
|
||||
inserted += 1
|
||||
if t["a"] > last_agg_id:
|
||||
last_agg_id = t["a"]
|
||||
last_time_ms = t["T"]
|
||||
inserted = cur.rowcount
|
||||
|
||||
if last_agg_id > 0:
|
||||
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return inserted
|
||||
except Exception as e:
|
||||
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
||||
@ -171,7 +121,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
|
||||
# ─── REST补洞 ────────────────────────────────────────────────────
|
||||
|
||||
async def rest_catchup(symbol: str, from_id: int) -> int:
|
||||
"""从from_id开始REST拉取,追平到最新,返回补洞条数"""
|
||||
total = 0
|
||||
current_id = from_id
|
||||
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
|
||||
@ -195,10 +144,9 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
|
||||
if last <= current_id:
|
||||
break
|
||||
current_id = last + 1
|
||||
# 如果拉到的比最新少1000条,说明追平了
|
||||
if len(data) < 1000:
|
||||
break
|
||||
await asyncio.sleep(0.1) # rate limit友好
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.error(f"[{symbol}] REST catchup error: {e}")
|
||||
break
|
||||
@ -210,7 +158,6 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
|
||||
# ─── WebSocket采集 ───────────────────────────────────────────────
|
||||
|
||||
async def ws_collect(symbol: str):
|
||||
"""单Symbol的WS采集循环,自动断线重连+REST补洞"""
|
||||
stream = symbol.lower() + "@aggTrade"
|
||||
url = f"wss://fstream.binance.com/ws/{stream}"
|
||||
buffer: list = []
|
||||
@ -219,14 +166,13 @@ async def ws_collect(symbol: str):
|
||||
|
||||
while True:
|
||||
try:
|
||||
# 断线重连前先REST补洞
|
||||
last_id = get_last_agg_id(symbol)
|
||||
if last_id is not None:
|
||||
await rest_catchup(symbol, last_id + 1)
|
||||
|
||||
logger.info(f"[{symbol}] Connecting WS: {url}")
|
||||
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
|
||||
reconnect_delay = 1.0 # 连上了就重置
|
||||
reconnect_delay = 1.0
|
||||
logger.info(f"[{symbol}] WS connected")
|
||||
|
||||
async for raw in ws:
|
||||
@ -236,7 +182,6 @@ async def ws_collect(symbol: str):
|
||||
|
||||
buffer.append(msg)
|
||||
|
||||
# 批量flush:满200条或超1秒
|
||||
now = time.time()
|
||||
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
|
||||
count = flush_buffer(symbol, buffer)
|
||||
@ -250,110 +195,90 @@ async def ws_collect(symbol: str):
|
||||
except Exception as e:
|
||||
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
|
||||
finally:
|
||||
# flush剩余buffer
|
||||
if buffer:
|
||||
flush_buffer(symbol, buffer)
|
||||
buffer.clear()
|
||||
|
||||
await asyncio.sleep(reconnect_delay)
|
||||
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
|
||||
reconnect_delay = min(reconnect_delay * 2, 30)
|
||||
|
||||
|
||||
# ─── 连续性巡检 ──────────────────────────────────────────────────
|
||||
|
||||
async def continuity_check():
|
||||
"""每60秒巡检一次:检查各symbol最近的agg_id是否有断档"""
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
try:
|
||||
conn = get_conn()
|
||||
ensure_meta_table(conn)
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
for symbol in SYMBOLS:
|
||||
row = conn.execute(
|
||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
||||
cur.execute(
|
||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||||
(symbol,)
|
||||
).fetchone()
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
continue
|
||||
# 检查最近1000条是否连续
|
||||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
||||
tname = f"agg_trades_{now_month}"
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
|
||||
# 检查最近100条是否连续
|
||||
cur.execute(
|
||||
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
|
||||
(symbol,)
|
||||
).fetchall()
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
if len(rows) < 2:
|
||||
continue
|
||||
ids = [r["agg_id"] for r in rows]
|
||||
# 检查是否连续(降序)
|
||||
ids = [r[0] for r in rows]
|
||||
gaps = []
|
||||
for i in range(len(ids) - 1):
|
||||
diff = ids[i] - ids[i + 1]
|
||||
if diff > 1:
|
||||
gaps.append((ids[i + 1], ids[i], diff - 1))
|
||||
if gaps:
|
||||
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
|
||||
# 触发补洞
|
||||
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
|
||||
min_gap_id = min(g[0] for g in gaps)
|
||||
asyncio.create_task(rest_catchup(symbol, min_gap_id))
|
||||
else:
|
||||
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
|
||||
except Exception:
|
||||
pass # 表可能还不存在
|
||||
conn.close()
|
||||
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Continuity check error: {e}")
|
||||
|
||||
|
||||
# ─── 每日完整性报告 ──────────────────────────────────────────────
|
||||
# ─── 每小时报告 ──────────────────────────────────────────────────
|
||||
|
||||
async def daily_report():
|
||||
"""每小时生成一次完整性摘要日志"""
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
try:
|
||||
conn = get_conn()
|
||||
ensure_meta_table(conn)
|
||||
report_lines = ["=== AggTrades Integrity Report ==="]
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
report = ["=== AggTrades Integrity Report ==="]
|
||||
for symbol in SYMBOLS:
|
||||
row = conn.execute(
|
||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
|
||||
cur.execute(
|
||||
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
|
||||
(symbol,)
|
||||
).fetchone()
|
||||
if not row:
|
||||
report_lines.append(f" {symbol}: No data yet")
|
||||
continue
|
||||
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
|
||||
# 统计本月总量
|
||||
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
|
||||
tname = f"agg_trades_{now_month}"
|
||||
try:
|
||||
count = conn.execute(
|
||||
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
|
||||
).fetchone()["c"]
|
||||
except Exception:
|
||||
count = 0
|
||||
report_lines.append(
|
||||
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
|
||||
)
|
||||
conn.close()
|
||||
logger.info("\n".join(report_lines))
|
||||
row = cur.fetchone()
|
||||
if not row:
|
||||
report.append(f" {symbol}: No data yet")
|
||||
continue
|
||||
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
|
||||
cur.execute(
|
||||
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
|
||||
)
|
||||
count = cur.fetchone()[0]
|
||||
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
|
||||
logger.info("\n".join(report))
|
||||
except Exception as e:
|
||||
logger.error(f"Daily report error: {e}")
|
||||
logger.error(f"Report error: {e}")
|
||||
|
||||
|
||||
# ─── 入口 ────────────────────────────────────────────────────────
|
||||
|
||||
async def main():
|
||||
logger.info("AggTrades Collector starting...")
|
||||
# 确保基础表存在
|
||||
conn = get_conn()
|
||||
ensure_meta_table(conn)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
logger.info("AggTrades Collector (PG) starting...")
|
||||
# 确保分区存在
|
||||
ensure_partitions()
|
||||
|
||||
# 并行启动所有symbol的WS + 巡检 + 报告
|
||||
tasks = [
|
||||
ws_collect(sym) for sym in SYMBOLS
|
||||
] + [
|
||||
|
||||
@ -1,27 +1,18 @@
|
||||
"""
|
||||
backfill_agg_trades.py — 历史aggTrades回补脚本
|
||||
|
||||
功能:
|
||||
- 从当前DB最早agg_id向历史方向回补
|
||||
- Binance REST API分页拉取,每次1000条
|
||||
- INSERT OR IGNORE写入按月分表
|
||||
- 断点续传:记录进度到agg_trades_meta
|
||||
- 速率控制:sleep 200ms/请求,429自动退避
|
||||
- BTC+ETH并行回补
|
||||
|
||||
用法:
|
||||
python3 backfill_agg_trades.py [--days 30] [--symbol BTCUSDT]
|
||||
backfill_agg_trades.py — 历史aggTrades回补脚本(PostgreSQL版)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import requests
|
||||
import psycopg2.extras
|
||||
|
||||
from db import get_sync_conn, init_schema, ensure_partitions
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@ -33,119 +24,56 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("backfill")
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/backfill"}
|
||||
BATCH_SIZE = 1000
|
||||
SLEEP_MS = 2000 # 低优先级:2秒/请求,让出锁给实时服务
|
||||
SLEEP_MS = 2000
|
||||
|
||||
|
||||
# ─── DB helpers ──────────────────────────────────────────────────
|
||||
|
||||
def get_conn() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
def get_earliest_agg_id(symbol: str) -> int | None:
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||
row = cur.fetchone()
|
||||
if row and row[0]:
|
||||
return row[0]
|
||||
# fallback: scan agg_trades
|
||||
cur.execute("SELECT MIN(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
|
||||
row = cur.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
|
||||
def table_name(ts_ms: int) -> str:
|
||||
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
return f"agg_trades_{dt.strftime('%Y%m')}"
|
||||
|
||||
|
||||
def ensure_table(conn: sqlite3.Connection, tname: str):
|
||||
conn.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {tname} (
|
||||
agg_id INTEGER PRIMARY KEY,
|
||||
symbol TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
qty REAL NOT NULL,
|
||||
time_ms INTEGER NOT NULL,
|
||||
is_buyer_maker INTEGER NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{tname}_symbol_time ON {tname}(symbol, time_ms)")
|
||||
|
||||
|
||||
def ensure_meta(conn: sqlite3.Connection):
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||||
symbol TEXT PRIMARY KEY,
|
||||
last_agg_id INTEGER,
|
||||
last_time_ms INTEGER,
|
||||
earliest_agg_id INTEGER,
|
||||
earliest_time_ms INTEGER
|
||||
)
|
||||
""")
|
||||
# 添加earliest字段(如果旧表没有)
|
||||
try:
|
||||
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_agg_id INTEGER")
|
||||
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_time_ms INTEGER")
|
||||
except Exception:
|
||||
pass # 已存在
|
||||
|
||||
|
||||
def get_earliest_agg_id(conn: sqlite3.Connection, symbol: str) -> int | None:
|
||||
"""查找DB中该symbol的最小agg_id"""
|
||||
# 先查meta表
|
||||
row = conn.execute(
|
||||
"SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
|
||||
).fetchone()
|
||||
if row and row["earliest_agg_id"]:
|
||||
return row["earliest_agg_id"]
|
||||
|
||||
# meta表没有,扫所有月表
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'"
|
||||
).fetchall()
|
||||
min_id = None
|
||||
for t in tables:
|
||||
r = conn.execute(
|
||||
f"SELECT MIN(agg_id) as mid FROM {t['name']} WHERE symbol = ?", (symbol,)
|
||||
).fetchone()
|
||||
if r and r["mid"] is not None:
|
||||
if min_id is None or r["mid"] < min_id:
|
||||
min_id = r["mid"]
|
||||
return min_id
|
||||
|
||||
|
||||
def update_earliest_meta(conn: sqlite3.Connection, symbol: str, agg_id: int, time_ms: int):
|
||||
# 先检查是否已有该symbol的记录
|
||||
row = conn.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = ?", (symbol,)).fetchone()
|
||||
if row:
|
||||
conn.execute("""
|
||||
def update_earliest_meta(symbol: str, agg_id: int, time_ms: int):
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = %s", (symbol,))
|
||||
if cur.fetchone():
|
||||
cur.execute("""
|
||||
UPDATE agg_trades_meta SET
|
||||
earliest_agg_id = MIN(?, COALESCE(earliest_agg_id, ?)),
|
||||
earliest_time_ms = MIN(?, COALESCE(earliest_time_ms, ?))
|
||||
WHERE symbol = ?
|
||||
earliest_agg_id = LEAST(%s, COALESCE(earliest_agg_id, %s)),
|
||||
earliest_time_ms = LEAST(%s, COALESCE(earliest_time_ms, %s))
|
||||
WHERE symbol = %s
|
||||
""", (agg_id, agg_id, time_ms, time_ms, symbol))
|
||||
else:
|
||||
conn.execute("""
|
||||
cur.execute("""
|
||||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""", (symbol, agg_id, time_ms, agg_id, time_ms))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ─── REST API ────────────────────────────────────────────────────
|
||||
|
||||
def fetch_agg_trades(symbol: str, from_id: int | None = None,
|
||||
start_time: int | None = None, limit: int = 1000) -> list[dict]:
|
||||
"""从Binance拉取aggTrades,支持fromId和startTime"""
|
||||
def fetch_agg_trades(symbol: str, from_id: int | None = None, limit: int = 1000) -> list:
|
||||
params = {"symbol": symbol, "limit": limit}
|
||||
if from_id is not None:
|
||||
params["fromId"] = from_id
|
||||
elif start_time is not None:
|
||||
params["startTime"] = start_time
|
||||
|
||||
for attempt in range(5):
|
||||
try:
|
||||
r = requests.get(
|
||||
f"{BINANCE_FAPI}/aggTrades",
|
||||
params=params, headers=HEADERS, timeout=15
|
||||
)
|
||||
r = requests.get(f"{BINANCE_FAPI}/aggTrades", params=params, headers=HEADERS, timeout=15)
|
||||
if r.status_code == 429:
|
||||
wait = min(60 * (2 ** attempt), 300)
|
||||
logger.warning(f"Rate limited (429), waiting {wait}s (attempt {attempt+1})")
|
||||
@ -162,24 +90,22 @@ def fetch_agg_trades(symbol: str, from_id: int | None = None,
|
||||
return []
|
||||
|
||||
|
||||
def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> int:
|
||||
"""批量写入,返回实际插入条数"""
|
||||
inserted = 0
|
||||
tables_seen = set()
|
||||
for t in trades:
|
||||
tname = table_name(t["T"])
|
||||
if tname not in tables_seen:
|
||||
ensure_table(conn, tname)
|
||||
tables_seen.add(tname)
|
||||
try:
|
||||
conn.execute(
|
||||
f"INSERT OR IGNORE INTO {tname} (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
|
||||
f"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0)
|
||||
def write_batch(symbol: str, trades: list) -> int:
|
||||
if not trades:
|
||||
return 0
|
||||
ensure_partitions()
|
||||
values = [(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0) for t in trades]
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.execute_values(
|
||||
cur,
|
||||
"INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
|
||||
"VALUES %s ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING",
|
||||
values,
|
||||
template="(%s, %s, %s, %s, %s, %s)",
|
||||
page_size=1000,
|
||||
)
|
||||
inserted += 1
|
||||
except sqlite3.IntegrityError:
|
||||
pass
|
||||
inserted = cur.rowcount
|
||||
conn.commit()
|
||||
return inserted
|
||||
|
||||
@ -187,42 +113,24 @@ def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> in
|
||||
# ─── 主逻辑 ──────────────────────────────────────────────────────
|
||||
|
||||
def backfill_symbol(symbol: str, target_days: int | None = None):
|
||||
"""回补单个symbol的历史数据"""
|
||||
conn = get_conn()
|
||||
ensure_meta(conn)
|
||||
|
||||
earliest = get_earliest_agg_id(conn, symbol)
|
||||
earliest = get_earliest_agg_id(symbol)
|
||||
if earliest is None:
|
||||
logger.info(f"[{symbol}] DB无数据,从最新开始拉最近的数据确定起点")
|
||||
logger.info(f"[{symbol}] DB无数据,拉最新确定起点")
|
||||
trades = fetch_agg_trades(symbol, limit=1)
|
||||
if not trades:
|
||||
logger.error(f"[{symbol}] 无法获取起始数据")
|
||||
return
|
||||
earliest = trades[0]["a"]
|
||||
logger.info(f"[{symbol}] 当前最新agg_id: {earliest}")
|
||||
|
||||
# 计算目标:往前补到多少天前
|
||||
if target_days:
|
||||
target_ts = int((time.time() - target_days * 86400) * 1000)
|
||||
else:
|
||||
target_ts = 0 # 拉到最早
|
||||
target_ts = int((time.time() - target_days * 86400) * 1000) if target_days else 0
|
||||
|
||||
logger.info(f"[{symbol}] 开始回补,当前最早agg_id={earliest}")
|
||||
if target_days:
|
||||
target_dt = datetime.fromtimestamp(target_ts / 1000, tz=timezone.utc)
|
||||
logger.info(f"[{symbol}] 目标:回补到 {target_dt.strftime('%Y-%m-%d')} ({target_days}天前)")
|
||||
|
||||
total_inserted = 0
|
||||
total_requests = 0
|
||||
rate_limit_hits = 0
|
||||
current_from_id = earliest
|
||||
|
||||
while True:
|
||||
# 从current_from_id向前拉:先拉current_from_id - BATCH_SIZE*2的位置
|
||||
# Binance fromId是从该id开始正向拉,所以我们用endTime方式
|
||||
# 更可靠:用fromId = current_from_id - BATCH_SIZE,然后正向拉
|
||||
fetch_from = max(0, current_from_id - BATCH_SIZE)
|
||||
|
||||
trades = fetch_agg_trades(symbol, from_id=fetch_from, limit=BATCH_SIZE)
|
||||
total_requests += 1
|
||||
|
||||
@ -230,104 +138,56 @@ def backfill_symbol(symbol: str, target_days: int | None = None):
|
||||
logger.info(f"[{symbol}] 无更多数据,回补结束")
|
||||
break
|
||||
|
||||
# 过滤掉已有的(>= earliest的数据实时采集器已覆盖)
|
||||
new_trades = [t for t in trades if t["a"] < current_from_id]
|
||||
|
||||
if not new_trades:
|
||||
# 没有更早的数据了
|
||||
logger.info(f"[{symbol}] 已到达最早数据,回补结束")
|
||||
break
|
||||
|
||||
inserted = write_batch(conn, symbol, new_trades)
|
||||
inserted = write_batch(symbol, new_trades)
|
||||
total_inserted += inserted
|
||||
|
||||
oldest = min(new_trades, key=lambda x: x["a"])
|
||||
oldest_time = datetime.fromtimestamp(oldest["T"] / 1000, tz=timezone.utc)
|
||||
current_from_id = oldest["a"]
|
||||
update_earliest_meta(symbol, oldest["a"], oldest["T"])
|
||||
|
||||
# 更新meta
|
||||
update_earliest_meta(conn, symbol, oldest["a"], oldest["T"])
|
||||
|
||||
# 进度日志(每50批打一次)
|
||||
if total_requests % 50 == 0:
|
||||
logger.info(
|
||||
f"[{symbol}] 进度: {total_inserted:,} 条已插入, "
|
||||
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, "
|
||||
f"agg_id={current_from_id:,}, "
|
||||
f"请求数={total_requests}, 429次数={rate_limit_hits}"
|
||||
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, agg_id={current_from_id:,}, 请求数={total_requests}"
|
||||
)
|
||||
|
||||
# 检查是否到达目标时间
|
||||
if target_ts and oldest["T"] <= target_ts:
|
||||
logger.info(f"[{symbol}] 已达到目标时间,回补结束")
|
||||
break
|
||||
|
||||
time.sleep(SLEEP_MS / 1000)
|
||||
|
||||
conn.close()
|
||||
|
||||
logger.info(
|
||||
f"[{symbol}] 回补完成: "
|
||||
f"总插入={total_inserted:,}, 总请求={total_requests}, "
|
||||
f"429次数={rate_limit_hits}"
|
||||
)
|
||||
return total_inserted, total_requests, rate_limit_hits
|
||||
logger.info(f"[{symbol}] 回补完成: 总插入={total_inserted:,}, 总请求={total_requests}")
|
||||
|
||||
|
||||
def check_continuity(symbol: str):
|
||||
"""检查agg_id连续性"""
|
||||
conn = get_conn()
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
|
||||
).fetchall()
|
||||
|
||||
all_ids = []
|
||||
for t in tables:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id FROM {t['name']} WHERE symbol = ? ORDER BY agg_id",
|
||||
(symbol,)
|
||||
).fetchall()
|
||||
all_ids.extend(r["agg_id"] for r in rows)
|
||||
|
||||
conn.close()
|
||||
|
||||
if not all_ids:
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT COUNT(*), MIN(agg_id), MAX(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
|
||||
row = cur.fetchone()
|
||||
total, min_id, max_id = row
|
||||
if not total:
|
||||
logger.info(f"[{symbol}] 无数据")
|
||||
return
|
||||
|
||||
all_ids.sort()
|
||||
gaps = 0
|
||||
gap_ranges = []
|
||||
for i in range(1, len(all_ids)):
|
||||
diff = all_ids[i] - all_ids[i-1]
|
||||
if diff > 1:
|
||||
gaps += 1
|
||||
if len(gap_ranges) < 10:
|
||||
gap_ranges.append((all_ids[i-1], all_ids[i], diff - 1))
|
||||
|
||||
total = len(all_ids)
|
||||
span = all_ids[-1] - all_ids[0] + 1
|
||||
span = max_id - min_id + 1
|
||||
coverage = total / span * 100 if span > 0 else 0
|
||||
logger.info(f"[{symbol}] 总条数={total:,}, ID范围={min_id:,}~{max_id:,}, 覆盖率={coverage:.2f}%")
|
||||
|
||||
logger.info(
|
||||
f"[{symbol}] 连续性检查: "
|
||||
f"总条数={total:,}, ID范围={all_ids[0]:,}~{all_ids[-1]:,}, "
|
||||
f"理论条数={span:,}, 覆盖率={coverage:.2f}%, 缺口数={gaps}"
|
||||
)
|
||||
if gap_ranges:
|
||||
for start, end, missing in gap_ranges[:5]:
|
||||
logger.info(f" 缺口: {start} → {end} (缺{missing}条)")
|
||||
|
||||
|
||||
# ─── CLI ─────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Backfill aggTrades from Binance")
|
||||
parser.add_argument("--days", type=int, default=None, help="回补天数(默认全量)")
|
||||
parser.add_argument("--symbol", type=str, default=None, help="指定symbol(默认BTC+ETH)")
|
||||
parser.add_argument("--check", action="store_true", help="仅检查连续性")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--days", type=int, default=None)
|
||||
parser.add_argument("--symbol", type=str, default=None)
|
||||
parser.add_argument("--check", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
init_schema()
|
||||
symbols = [args.symbol] if args.symbol else ["BTCUSDT", "ETHUSDT"]
|
||||
|
||||
if args.check:
|
||||
@ -336,18 +196,12 @@ def main():
|
||||
return
|
||||
|
||||
logger.info(f"=== 开始回补 === symbols={symbols}, days={args.days or '全量'}")
|
||||
|
||||
for sym in symbols:
|
||||
logger.info(f"--- 回补 {sym} ---")
|
||||
result = backfill_symbol(sym, target_days=args.days)
|
||||
if result:
|
||||
inserted, reqs, limits = result
|
||||
logger.info(f"[{sym}] 结果: 插入{inserted:,}条, {reqs}次请求, {limits}次429")
|
||||
backfill_symbol(sym, target_days=args.days)
|
||||
|
||||
logger.info("=== 回补完成,开始连续性检查 ===")
|
||||
logger.info("=== 连续性检查 ===")
|
||||
for sym in symbols:
|
||||
check_continuity(sym)
|
||||
|
||||
logger.info("=== 全部完成 ===")
|
||||
|
||||
|
||||
|
||||
282
backend/db.py
Normal file
282
backend/db.py
Normal file
@ -0,0 +1,282 @@
|
||||
"""
|
||||
db.py — PostgreSQL 数据库连接层
|
||||
同步连接池(psycopg2)供脚本类使用
|
||||
异步连接池(asyncpg)供FastAPI使用
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import asyncpg
|
||||
import psycopg2
|
||||
import psycopg2.pool
|
||||
from contextlib import contextmanager
|
||||
|
||||
# PG连接参数
|
||||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
||||
PG_PORT = int(os.getenv("PG_PORT", 5432))
|
||||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||
PG_USER = os.getenv("PG_USER", "arb")
|
||||
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
||||
|
||||
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
||||
|
||||
# ─── 同步连接池(psycopg2)─────────────────────────────────────
|
||||
|
||||
_sync_pool = None
|
||||
|
||||
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
|
||||
global _sync_pool
|
||||
if _sync_pool is None:
|
||||
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
|
||||
minconn=1, maxconn=5,
|
||||
host=PG_HOST, port=PG_PORT,
|
||||
dbname=PG_DB, user=PG_USER, password=PG_PASS,
|
||||
)
|
||||
return _sync_pool
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_sync_conn():
|
||||
"""获取同步PG连接(自动归还到池)"""
|
||||
pool = get_sync_pool()
|
||||
conn = pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
pool.putconn(conn)
|
||||
|
||||
|
||||
def sync_execute(sql: str, params=None, fetch=False):
|
||||
"""简便执行SQL"""
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(sql, params)
|
||||
if fetch:
|
||||
cols = [desc[0] for desc in cur.description]
|
||||
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||||
conn.commit()
|
||||
|
||||
|
||||
def sync_executemany(sql: str, params_list: list):
|
||||
"""批量执行"""
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.executemany(sql, params_list)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ─── 异步连接池(asyncpg)─────────────────────────────────────
|
||||
|
||||
_async_pool: asyncpg.Pool | None = None
|
||||
|
||||
async def get_async_pool() -> asyncpg.Pool:
|
||||
global _async_pool
|
||||
if _async_pool is None:
|
||||
_async_pool = await asyncpg.create_pool(
|
||||
dsn=PG_DSN, min_size=2, max_size=10,
|
||||
)
|
||||
return _async_pool
|
||||
|
||||
|
||||
async def async_fetch(sql: str, *args) -> list[dict]:
|
||||
"""异步查询,返回dict列表"""
|
||||
pool = await get_async_pool()
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(sql, *args)
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
async def async_fetchrow(sql: str, *args) -> dict | None:
|
||||
"""异步查询单行"""
|
||||
pool = await get_async_pool()
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(sql, *args)
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
async def async_execute(sql: str, *args):
|
||||
"""异步执行"""
|
||||
pool = await get_async_pool()
|
||||
async with pool.acquire() as conn:
|
||||
await conn.execute(sql, *args)
|
||||
|
||||
|
||||
async def close_async_pool():
|
||||
global _async_pool
|
||||
if _async_pool:
|
||||
await _async_pool.close()
|
||||
_async_pool = None
|
||||
|
||||
|
||||
# ─── 建表(PG Schema)──────────────────────────────────────────
|
||||
|
||||
SCHEMA_SQL = """
|
||||
-- 费率快照
|
||||
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
ts BIGINT NOT NULL,
|
||||
btc_rate DOUBLE PRECISION NOT NULL,
|
||||
eth_rate DOUBLE PRECISION NOT NULL,
|
||||
btc_price DOUBLE PRECISION NOT NULL,
|
||||
eth_price DOUBLE PRECISION NOT NULL,
|
||||
btc_index_price DOUBLE PRECISION,
|
||||
eth_index_price DOUBLE PRECISION
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
|
||||
|
||||
-- aggTrades元数据
|
||||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||||
symbol TEXT PRIMARY KEY,
|
||||
last_agg_id BIGINT NOT NULL,
|
||||
last_time_ms BIGINT,
|
||||
earliest_agg_id BIGINT,
|
||||
earliest_time_ms BIGINT,
|
||||
updated_at TEXT
|
||||
);
|
||||
|
||||
-- aggTrades主表(分区父表)
|
||||
CREATE TABLE IF NOT EXISTS agg_trades (
|
||||
agg_id BIGINT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
price DOUBLE PRECISION NOT NULL,
|
||||
qty DOUBLE PRECISION NOT NULL,
|
||||
time_ms BIGINT NOT NULL,
|
||||
is_buyer_maker SMALLINT NOT NULL,
|
||||
PRIMARY KEY (time_ms, symbol, agg_id)
|
||||
) PARTITION BY RANGE (time_ms);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
|
||||
|
||||
-- Signal Engine表
|
||||
CREATE TABLE IF NOT EXISTS signal_indicators (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
ts BIGINT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
cvd_fast DOUBLE PRECISION,
|
||||
cvd_mid DOUBLE PRECISION,
|
||||
cvd_day DOUBLE PRECISION,
|
||||
cvd_fast_slope DOUBLE PRECISION,
|
||||
atr_5m DOUBLE PRECISION,
|
||||
atr_percentile DOUBLE PRECISION,
|
||||
vwap_30m DOUBLE PRECISION,
|
||||
price DOUBLE PRECISION,
|
||||
p95_qty DOUBLE PRECISION,
|
||||
p99_qty DOUBLE PRECISION,
|
||||
buy_vol_1m DOUBLE PRECISION,
|
||||
sell_vol_1m DOUBLE PRECISION,
|
||||
score INTEGER,
|
||||
signal TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
|
||||
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
ts BIGINT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
cvd_fast DOUBLE PRECISION,
|
||||
cvd_mid DOUBLE PRECISION,
|
||||
cvd_day DOUBLE PRECISION,
|
||||
atr_5m DOUBLE PRECISION,
|
||||
vwap_30m DOUBLE PRECISION,
|
||||
price DOUBLE PRECISION,
|
||||
score INTEGER,
|
||||
signal TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signal_trades (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
ts_open BIGINT NOT NULL,
|
||||
ts_close BIGINT,
|
||||
symbol TEXT NOT NULL,
|
||||
direction TEXT NOT NULL,
|
||||
entry_price DOUBLE PRECISION,
|
||||
exit_price DOUBLE PRECISION,
|
||||
qty DOUBLE PRECISION,
|
||||
score INTEGER,
|
||||
pnl DOUBLE PRECISION,
|
||||
sl_price DOUBLE PRECISION,
|
||||
tp1_price DOUBLE PRECISION,
|
||||
tp2_price DOUBLE PRECISION,
|
||||
status TEXT DEFAULT 'open'
|
||||
);
|
||||
|
||||
-- 信号日志(旧表兼容)
|
||||
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
symbol TEXT,
|
||||
rate DOUBLE PRECISION,
|
||||
annualized DOUBLE PRECISION,
|
||||
sent_at TEXT,
|
||||
message TEXT
|
||||
);
|
||||
|
||||
-- 用户表(auth)
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT DEFAULT 'user',
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS invite_codes (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
code TEXT UNIQUE NOT NULL,
|
||||
used_by BIGINT REFERENCES users(id),
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
def ensure_partitions():
|
||||
"""创建当月和下月的分区表"""
|
||||
import datetime
|
||||
now = datetime.datetime.utcnow()
|
||||
months = []
|
||||
for delta in range(0, 3): # 当月+下2个月
|
||||
d = now + datetime.timedelta(days=delta * 30)
|
||||
months.append(d.strftime("%Y%m"))
|
||||
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
for m in set(months):
|
||||
year = int(m[:4])
|
||||
month = int(m[4:])
|
||||
# 计算分区范围(毫秒时间戳)
|
||||
start = datetime.datetime(year, month, 1)
|
||||
if month == 12:
|
||||
end = datetime.datetime(year + 1, 1, 1)
|
||||
else:
|
||||
end = datetime.datetime(year, month + 1, 1)
|
||||
start_ms = int(start.timestamp() * 1000)
|
||||
end_ms = int(end.timestamp() * 1000)
|
||||
part_name = f"agg_trades_{m}"
|
||||
try:
|
||||
cur.execute(f"""
|
||||
CREATE TABLE IF NOT EXISTS {part_name}
|
||||
PARTITION OF agg_trades
|
||||
FOR VALUES FROM ({start_ms}) TO ({end_ms})
|
||||
""")
|
||||
except Exception:
|
||||
pass # 分区已存在
|
||||
conn.commit()
|
||||
|
||||
|
||||
def init_schema():
|
||||
"""初始化全部PG表结构"""
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
for stmt in SCHEMA_SQL.split(";"):
|
||||
stmt = stmt.strip()
|
||||
if stmt:
|
||||
try:
|
||||
cur.execute(stmt)
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
# 忽略已存在错误
|
||||
continue
|
||||
conn.commit()
|
||||
ensure_partitions()
|
||||
279
backend/main.py
279
backend/main.py
@ -2,9 +2,13 @@ from fastapi import FastAPI, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import httpx
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio, time, sqlite3, os
|
||||
import asyncio, time, os
|
||||
|
||||
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
|
||||
from db import (
|
||||
init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute,
|
||||
close_async_pool,
|
||||
)
|
||||
import datetime as _dt
|
||||
|
||||
app = FastAPI(title="Arbitrage Engine API")
|
||||
@ -21,7 +25,6 @@ app.include_router(auth_router)
|
||||
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
|
||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||
|
||||
# 简单内存缓存(history/stats 60秒,rates 3秒)
|
||||
_cache: dict = {}
|
||||
@ -36,33 +39,13 @@ def set_cache(key: str, data):
|
||||
_cache[key] = {"ts": time.time(), "data": data}
|
||||
|
||||
|
||||
def init_db():
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts INTEGER NOT NULL,
|
||||
btc_rate REAL NOT NULL,
|
||||
eth_rate REAL NOT NULL,
|
||||
btc_price REAL NOT NULL,
|
||||
eth_price REAL NOT NULL,
|
||||
btc_index_price REAL,
|
||||
eth_index_price REAL
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def save_snapshot(rates: dict):
|
||||
async def save_snapshot(rates: dict):
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
btc = rates.get("BTC", {})
|
||||
eth = rates.get("ETH", {})
|
||||
conn.execute(
|
||||
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
|
||||
(
|
||||
await async_execute(
|
||||
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) "
|
||||
"VALUES ($1,$2,$3,$4,$5,$6,$7)",
|
||||
int(time.time()),
|
||||
float(btc.get("lastFundingRate", 0)),
|
||||
float(eth.get("lastFundingRate", 0)),
|
||||
@ -71,15 +54,12 @@ def save_snapshot(rates: dict):
|
||||
float(btc.get("indexPrice", 0)),
|
||||
float(eth.get("indexPrice", 0)),
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
pass # 落库失败不影响API响应
|
||||
|
||||
|
||||
async def background_snapshot_loop():
|
||||
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
|
||||
"""后台每2秒自动拉取费率+价格并落库"""
|
||||
while True:
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
|
||||
@ -97,7 +77,7 @@ async def background_snapshot_loop():
|
||||
"indexPrice": float(data["indexPrice"]),
|
||||
}
|
||||
if result:
|
||||
save_snapshot(result)
|
||||
await save_snapshot(result)
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(2)
|
||||
@ -105,11 +85,19 @@ async def background_snapshot_loop():
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
init_db()
|
||||
# 初始化PG schema
|
||||
init_schema()
|
||||
ensure_auth_tables()
|
||||
# 初始化asyncpg池
|
||||
await get_async_pool()
|
||||
asyncio.create_task(background_snapshot_loop())
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
await close_async_pool()
|
||||
|
||||
|
||||
@app.get("/api/health")
|
||||
async def health():
|
||||
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
|
||||
@ -137,37 +125,23 @@ async def get_rates():
|
||||
"timestamp": data["time"],
|
||||
}
|
||||
set_cache("rates", result)
|
||||
# 异步落库(不阻塞响应)
|
||||
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
|
||||
asyncio.create_task(save_snapshot(result))
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/api/snapshots")
|
||||
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
|
||||
"""查询本地落库的实时快照数据"""
|
||||
since = int(time.time()) - hours * 3600
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
|
||||
(since, limit)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return {
|
||||
"count": len(rows),
|
||||
"hours": hours,
|
||||
"data": [dict(r) for r in rows]
|
||||
}
|
||||
rows = await async_fetch(
|
||||
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots "
|
||||
"WHERE ts >= $1 ORDER BY ts ASC LIMIT $2",
|
||||
since, limit
|
||||
)
|
||||
return {"count": len(rows), "hours": hours, "data": rows}
|
||||
|
||||
|
||||
@app.get("/api/kline")
|
||||
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
从 rate_snapshots 聚合K线数据
|
||||
symbol: BTC | ETH
|
||||
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
|
||||
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
|
||||
"""
|
||||
interval_secs = {
|
||||
"1m": 60, "5m": 300, "30m": 1800,
|
||||
"1h": 3600, "4h": 14400, "8h": 28800,
|
||||
@ -176,22 +150,20 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
|
||||
bar_secs = interval_secs.get(interval, 300)
|
||||
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
|
||||
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
|
||||
|
||||
# 查询足够多的原始数据(limit根K * bar_secs最多需要的时间范围)
|
||||
since = int(time.time()) - bar_secs * limit
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
rows = conn.execute(
|
||||
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
|
||||
(since,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
rows = await async_fetch(
|
||||
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots "
|
||||
f"WHERE ts >= $1 ORDER BY ts ASC",
|
||||
since
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return {"symbol": symbol, "interval": interval, "data": []}
|
||||
|
||||
# 按bar_secs分组聚合OHLC
|
||||
bars: dict = {}
|
||||
for ts, rate, price in rows:
|
||||
for r in rows:
|
||||
ts, rate, price = r["ts"], r["rate"], r["price"]
|
||||
bar_ts = (ts // bar_secs) * bar_secs
|
||||
if bar_ts not in bars:
|
||||
bars[bar_ts] = {
|
||||
@ -209,20 +181,16 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
|
||||
b["price_close"] = price
|
||||
|
||||
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
|
||||
# 转换为万分之(费率 × 10000)
|
||||
for b in data:
|
||||
for k in ("open", "high", "low", "close"):
|
||||
b[k] = round(b[k] * 10000, 4)
|
||||
|
||||
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
|
||||
|
||||
|
||||
@app.get("/api/stats/ytd")
|
||||
async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||||
"""今年以来(YTD)资金费率年化统计"""
|
||||
cached = get_cache("stats_ytd", 3600)
|
||||
if cached: return cached
|
||||
# 今年1月1日 00:00 UTC
|
||||
import datetime
|
||||
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
|
||||
end_time = int(time.time() * 1000)
|
||||
@ -252,16 +220,12 @@ async def get_stats_ytd(user: dict = Depends(get_current_user)):
|
||||
|
||||
@app.get("/api/signals/history")
|
||||
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
|
||||
"""查询信号推送历史"""
|
||||
try:
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?",
|
||||
(limit,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
return {"items": [dict(r) for r in rows]}
|
||||
rows = await async_fetch(
|
||||
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT $1",
|
||||
limit
|
||||
)
|
||||
return {"items": rows}
|
||||
except Exception as e:
|
||||
return {"items": [], "error": str(e)}
|
||||
|
||||
@ -334,20 +298,11 @@ async def get_stats(user: dict = Depends(get_current_user)):
|
||||
return stats
|
||||
|
||||
|
||||
# ─── aggTrades 查询接口 ──────────────────────────────────────────
|
||||
# ─── aggTrades 查询接口(PG版)───────────────────────────────────
|
||||
|
||||
@app.get("/api/trades/meta")
|
||||
async def get_trades_meta(user: dict = Depends(get_current_user)):
|
||||
"""aggTrades采集状态:各symbol最新agg_id和时间"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
|
||||
).fetchall()
|
||||
except Exception:
|
||||
rows = []
|
||||
conn.close()
|
||||
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta")
|
||||
result = {}
|
||||
for r in rows:
|
||||
sym = r["symbol"].replace("USDT", "")
|
||||
@ -367,46 +322,23 @@ async def get_trades_summary(
|
||||
interval: str = "1m",
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""分钟级聚合:买卖delta/成交速率/vwap"""
|
||||
if end_ms == 0:
|
||||
end_ms = int(time.time() * 1000)
|
||||
if start_ms == 0:
|
||||
start_ms = end_ms - 3600 * 1000 # 默认1小时
|
||||
start_ms = end_ms - 3600 * 1000
|
||||
|
||||
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
|
||||
sym_full = symbol.upper() + "USDT"
|
||||
|
||||
# 确定需要查哪些月表
|
||||
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
|
||||
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
|
||||
months = set()
|
||||
cur = start_dt.replace(day=1)
|
||||
while cur <= end_dt:
|
||||
months.add(cur.strftime("%Y%m"))
|
||||
if cur.month == 12:
|
||||
cur = cur.replace(year=cur.year + 1, month=1)
|
||||
else:
|
||||
cur = cur.replace(month=cur.month + 1)
|
||||
# PG分区表自动裁剪,直接查主表
|
||||
rows = await async_fetch(
|
||||
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3 ORDER BY time_ms ASC",
|
||||
sym_full, start_ms, end_ms
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
all_rows = []
|
||||
for month in sorted(months):
|
||||
tname = f"agg_trades_{month}"
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||||
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
|
||||
(sym_full, start_ms, end_ms)
|
||||
).fetchall()
|
||||
all_rows.extend(rows)
|
||||
except Exception:
|
||||
pass
|
||||
conn.close()
|
||||
|
||||
# 按interval聚合
|
||||
bars: dict = {}
|
||||
for row in all_rows:
|
||||
for row in rows:
|
||||
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
|
||||
if bar_ms not in bars:
|
||||
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
|
||||
@ -414,7 +346,7 @@ async def get_trades_summary(
|
||||
b = bars[bar_ms]
|
||||
qty = float(row["qty"])
|
||||
price = float(row["price"])
|
||||
if row["is_buyer_maker"] == 0: # 主动买
|
||||
if row["is_buyer_maker"] == 0:
|
||||
b["buy_vol"] += qty
|
||||
else:
|
||||
b["sell_vol"] += qty
|
||||
@ -446,47 +378,28 @@ async def get_trades_latest(
|
||||
limit: int = 30,
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""查最新N条原始成交记录(从本地DB,实时刷新用)"""
|
||||
sym_full = symbol.upper() + "USDT"
|
||||
now_month = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m")
|
||||
tname = f"agg_trades_{now_month}"
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||||
f"WHERE symbol = ? ORDER BY agg_id DESC LIMIT ?",
|
||||
(sym_full, limit)
|
||||
).fetchall()
|
||||
except Exception:
|
||||
rows = []
|
||||
conn.close()
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"count": len(rows),
|
||||
"data": [dict(r) for r in rows],
|
||||
}
|
||||
rows = await async_fetch(
|
||||
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2",
|
||||
sym_full, limit
|
||||
)
|
||||
return {"symbol": symbol, "count": len(rows), "data": rows}
|
||||
|
||||
|
||||
@app.get("/api/collector/health")
|
||||
async def collector_health(user: dict = Depends(get_current_user)):
|
||||
"""采集器健康状态"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
now_ms = int(time.time() * 1000)
|
||||
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta")
|
||||
status = {}
|
||||
try:
|
||||
rows = conn.execute(
|
||||
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
|
||||
).fetchall()
|
||||
for r in rows:
|
||||
sym = r["symbol"].replace("USDT", "")
|
||||
lag_s = (now_ms - r["last_time_ms"]) / 1000
|
||||
lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000
|
||||
status[sym] = {
|
||||
"last_agg_id": r["last_agg_id"],
|
||||
"lag_seconds": round(lag_s, 1),
|
||||
"healthy": lag_s < 30, # 30秒内有数据算健康
|
||||
"healthy": lag_s < 30,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
conn.close()
|
||||
return {"collector": status, "timestamp": now_ms}
|
||||
|
||||
|
||||
@ -498,50 +411,29 @@ async def get_signal_indicators(
|
||||
minutes: int = 60,
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""获取signal_indicators_1m数据(前端图表用)"""
|
||||
sym_full = symbol.upper() + "USDT"
|
||||
now_ms = int(time.time() * 1000)
|
||||
start_ms = now_ms - minutes * 60 * 1000
|
||||
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
rows = conn.execute(
|
||||
rows = await async_fetch(
|
||||
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
|
||||
"FROM signal_indicators_1m WHERE symbol = ? AND ts >= ? ORDER BY ts ASC",
|
||||
(sym_full, start_ms)
|
||||
).fetchall()
|
||||
except Exception:
|
||||
rows = []
|
||||
conn.close()
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"count": len(rows),
|
||||
"data": [dict(r) for r in rows],
|
||||
}
|
||||
"FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC",
|
||||
sym_full, start_ms
|
||||
)
|
||||
return {"symbol": symbol, "count": len(rows), "data": rows}
|
||||
|
||||
|
||||
@app.get("/api/signals/latest")
|
||||
async def get_signal_latest(
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""获取最新一条各symbol的指标快照"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
async def get_signal_latest(user: dict = Depends(get_current_user)):
|
||||
result = {}
|
||||
for sym in ["BTCUSDT", "ETHUSDT"]:
|
||||
try:
|
||||
row = conn.execute(
|
||||
row = await async_fetchrow(
|
||||
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
|
||||
"vwap_30m, price, p95_qty, p99_qty, score, signal "
|
||||
"FROM signal_indicators WHERE symbol = ? ORDER BY ts DESC LIMIT 1",
|
||||
(sym,)
|
||||
).fetchone()
|
||||
"FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1",
|
||||
sym
|
||||
)
|
||||
if row:
|
||||
result[sym.replace("USDT", "")] = dict(row)
|
||||
except Exception:
|
||||
pass
|
||||
conn.close()
|
||||
result[sym.replace("USDT", "")] = row
|
||||
return result
|
||||
|
||||
|
||||
@ -551,20 +443,13 @@ async def get_signal_trades(
|
||||
limit: int = 50,
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""获取信号交易记录"""
|
||||
conn = sqlite3.connect(DB_PATH)
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
if status == "all":
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT ?", (limit,)
|
||||
).fetchall()
|
||||
rows = await async_fetch(
|
||||
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit
|
||||
)
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM signal_trades WHERE status = ? ORDER BY ts_open DESC LIMIT ?",
|
||||
(status, limit)
|
||||
).fetchall()
|
||||
except Exception:
|
||||
rows = []
|
||||
conn.close()
|
||||
return {"count": len(rows), "data": [dict(r) for r in rows]}
|
||||
rows = await async_fetch(
|
||||
"SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2",
|
||||
status, limit
|
||||
)
|
||||
return {"count": len(rows), "data": rows}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
signal_engine.py — V5 短线交易信号引擎
|
||||
signal_engine.py — V5 短线交易信号引擎(PostgreSQL版)
|
||||
|
||||
架构:
|
||||
- 独立PM2进程,每5秒循环
|
||||
@ -17,14 +17,13 @@ signal_engine.py — V5 短线交易信号引擎
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import math
|
||||
import statistics
|
||||
from collections import deque
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from db import get_sync_conn, init_schema
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
@ -35,110 +34,33 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("signal-engine")
|
||||
|
||||
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||||
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
|
||||
LOOP_INTERVAL = 5 # 秒
|
||||
|
||||
# 窗口大小(毫秒)
|
||||
WINDOW_FAST = 30 * 60 * 1000 # 30分钟
|
||||
WINDOW_MID = 4 * 3600 * 1000 # 4小时
|
||||
WINDOW_DAY = 24 * 3600 * 1000 # 24小时(用于P95/P99计算)
|
||||
WINDOW_DAY = 24 * 3600 * 1000 # 24小时
|
||||
WINDOW_VWAP = 30 * 60 * 1000 # 30分钟
|
||||
|
||||
# ATR参数
|
||||
ATR_PERIOD_MS = 5 * 60 * 1000 # 5分钟K线
|
||||
ATR_LENGTH = 14 # 14根
|
||||
ATR_PERIOD_MS = 5 * 60 * 1000
|
||||
ATR_LENGTH = 14
|
||||
|
||||
# 信号冷却
|
||||
COOLDOWN_MS = 10 * 60 * 1000 # 10分钟
|
||||
|
||||
|
||||
# ─── DB helpers ──────────────────────────────────────────────────
|
||||
|
||||
def get_conn() -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(DB_PATH, timeout=30)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
return conn
|
||||
|
||||
|
||||
def init_tables(conn: sqlite3.Connection):
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signal_indicators (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
cvd_fast REAL,
|
||||
cvd_mid REAL,
|
||||
cvd_day REAL,
|
||||
cvd_fast_slope REAL,
|
||||
atr_5m REAL,
|
||||
atr_percentile REAL,
|
||||
vwap_30m REAL,
|
||||
price REAL,
|
||||
p95_qty REAL,
|
||||
p99_qty REAL,
|
||||
buy_vol_1m REAL,
|
||||
sell_vol_1m REAL,
|
||||
score INTEGER,
|
||||
signal TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts)")
|
||||
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts INTEGER NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
cvd_fast REAL,
|
||||
cvd_mid REAL,
|
||||
cvd_day REAL,
|
||||
atr_5m REAL,
|
||||
vwap_30m REAL,
|
||||
price REAL,
|
||||
score INTEGER,
|
||||
signal TEXT
|
||||
)
|
||||
""")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts)")
|
||||
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS signal_trades (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
ts_open INTEGER NOT NULL,
|
||||
ts_close INTEGER,
|
||||
symbol TEXT NOT NULL,
|
||||
direction TEXT NOT NULL,
|
||||
entry_price REAL,
|
||||
exit_price REAL,
|
||||
qty REAL,
|
||||
score INTEGER,
|
||||
pnl REAL,
|
||||
sl_price REAL,
|
||||
tp1_price REAL,
|
||||
tp2_price REAL,
|
||||
status TEXT DEFAULT 'open'
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
COOLDOWN_MS = 10 * 60 * 1000
|
||||
|
||||
|
||||
# ─── 滚动窗口 ───────────────────────────────────────────────────
|
||||
|
||||
class TradeWindow:
|
||||
"""滚动时间窗口,维护买卖量和价格数据"""
|
||||
|
||||
def __init__(self, window_ms: int):
|
||||
self.window_ms = window_ms
|
||||
self.trades: deque = deque() # (time_ms, qty, price, is_buyer_maker)
|
||||
self.trades: deque = deque()
|
||||
self.buy_vol = 0.0
|
||||
self.sell_vol = 0.0
|
||||
self.pq_sum = 0.0 # price * qty 累加(VWAP用)
|
||||
self.q_sum = 0.0 # qty累加
|
||||
self.quantities: list = [] # 用于P95/P99(仅24h窗口用)
|
||||
self.pq_sum = 0.0
|
||||
self.q_sum = 0.0
|
||||
|
||||
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
|
||||
self.trades.append((time_ms, qty, price, is_buyer_maker))
|
||||
@ -154,8 +76,7 @@ class TradeWindow:
|
||||
cutoff = now_ms - self.window_ms
|
||||
while self.trades and self.trades[0][0] < cutoff:
|
||||
t_ms, qty, price, ibm = self.trades.popleft()
|
||||
pq = price * qty
|
||||
self.pq_sum -= pq
|
||||
self.pq_sum -= price * qty
|
||||
self.q_sum -= qty
|
||||
if ibm == 0:
|
||||
self.buy_vol -= qty
|
||||
@ -170,29 +91,21 @@ class TradeWindow:
|
||||
def vwap(self) -> float:
|
||||
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
|
||||
|
||||
@property
|
||||
def total_vol(self) -> float:
|
||||
return self.buy_vol + self.sell_vol
|
||||
|
||||
|
||||
class ATRCalculator:
|
||||
"""5分钟K线ATR计算"""
|
||||
|
||||
def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH):
|
||||
self.period_ms = period_ms
|
||||
self.length = length
|
||||
self.candles: deque = deque(maxlen=length + 1)
|
||||
self.current_candle: Optional[dict] = None
|
||||
self.atr_history: deque = deque(maxlen=288) # 24h of 5m candles for percentile
|
||||
self.atr_history: deque = deque(maxlen=288)
|
||||
|
||||
def update(self, time_ms: int, price: float):
|
||||
bar_ms = (time_ms // self.period_ms) * self.period_ms
|
||||
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
|
||||
if self.current_candle is not None:
|
||||
self.candles.append(self.current_candle)
|
||||
self.current_candle = {
|
||||
"bar": bar_ms, "open": price, "high": price, "low": price, "close": price
|
||||
}
|
||||
self.current_candle = {"bar": bar_ms, "open": price, "high": price, "low": price, "close": price}
|
||||
else:
|
||||
c = self.current_candle
|
||||
c["high"] = max(c["high"], price)
|
||||
@ -212,7 +125,6 @@ class ATRCalculator:
|
||||
trs.append(tr)
|
||||
if not trs:
|
||||
return 0.0
|
||||
# EMA-style ATR
|
||||
atr_val = trs[0]
|
||||
for tr in trs[1:]:
|
||||
atr_val = (atr_val * (self.length - 1) + tr) / self.length
|
||||
@ -220,7 +132,6 @@ class ATRCalculator:
|
||||
|
||||
@property
|
||||
def atr_percentile(self) -> float:
|
||||
"""当前ATR在最近24h中的分位数"""
|
||||
current = self.atr
|
||||
if current == 0:
|
||||
return 50.0
|
||||
@ -233,8 +144,6 @@ class ATRCalculator:
|
||||
|
||||
|
||||
class SymbolState:
|
||||
"""单个交易对的完整状态"""
|
||||
|
||||
def __init__(self, symbol: str):
|
||||
self.symbol = symbol
|
||||
self.win_fast = TradeWindow(WINDOW_FAST)
|
||||
@ -244,12 +153,10 @@ class SymbolState:
|
||||
self.atr_calc = ATRCalculator()
|
||||
self.last_processed_id = 0
|
||||
self.warmup = True
|
||||
self.warmup_until = 0
|
||||
self.prev_cvd_fast = 0.0
|
||||
self.last_signal_ts = 0
|
||||
self.last_signal_dir = ""
|
||||
# P99大单追踪(最近15分钟)
|
||||
self.recent_large_trades: deque = deque() # (time_ms, qty, is_buyer_maker)
|
||||
self.recent_large_trades: deque = deque()
|
||||
|
||||
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
|
||||
now_ms = time_ms
|
||||
@ -258,47 +165,34 @@ class SymbolState:
|
||||
self.win_day.add(time_ms, qty, price, is_buyer_maker)
|
||||
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
|
||||
self.atr_calc.update(time_ms, price)
|
||||
|
||||
self.win_fast.trim(now_ms)
|
||||
self.win_mid.trim(now_ms)
|
||||
self.win_day.trim(now_ms)
|
||||
self.win_vwap.trim(now_ms)
|
||||
|
||||
self.last_processed_id = agg_id
|
||||
|
||||
def compute_p95_p99(self) -> tuple[float, float]:
|
||||
"""从24h窗口计算大单阈值"""
|
||||
def compute_p95_p99(self) -> tuple:
|
||||
if len(self.win_day.trades) < 100:
|
||||
return 5.0, 10.0 # 默认兜底
|
||||
|
||||
qtys = [t[1] for t in self.win_day.trades]
|
||||
qtys.sort()
|
||||
return 5.0, 10.0
|
||||
qtys = sorted([t[1] for t in self.win_day.trades])
|
||||
n = len(qtys)
|
||||
p95 = qtys[int(n * 0.95)]
|
||||
p99 = qtys[int(n * 0.99)]
|
||||
|
||||
# BTC兜底5,ETH兜底50
|
||||
if "BTC" in self.symbol:
|
||||
p95 = max(p95, 5.0)
|
||||
p99 = max(p99, 10.0)
|
||||
p95 = max(p95, 5.0); p99 = max(p99, 10.0)
|
||||
else:
|
||||
p95 = max(p95, 50.0)
|
||||
p99 = max(p99, 100.0)
|
||||
p95 = max(p95, 50.0); p99 = max(p99, 100.0)
|
||||
return p95, p99
|
||||
|
||||
def update_large_trades(self, now_ms: int, p99: float):
|
||||
"""更新最近15分钟的P99大单记录"""
|
||||
cutoff = now_ms - 15 * 60 * 1000
|
||||
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
|
||||
self.recent_large_trades.popleft()
|
||||
|
||||
# 从最近处理的trades中找大单
|
||||
for t in self.win_fast.trades:
|
||||
if t[1] >= p99 and t[0] > cutoff:
|
||||
self.recent_large_trades.append((t[0], t[1], t[3]))
|
||||
|
||||
def evaluate_signal(self, now_ms: int) -> dict:
|
||||
"""评估信号:核心3条件 + 加分3条件"""
|
||||
cvd_fast = self.win_fast.cvd
|
||||
cvd_mid = self.win_mid.cvd
|
||||
vwap = self.win_vwap.vwap
|
||||
@ -306,198 +200,127 @@ class SymbolState:
|
||||
atr_pct = self.atr_calc.atr_percentile
|
||||
p95, p99 = self.compute_p95_p99()
|
||||
self.update_large_trades(now_ms, p99)
|
||||
|
||||
# 当前价格(用VWAP近似)
|
||||
price = vwap if vwap > 0 else 0
|
||||
|
||||
# CVD_fast斜率
|
||||
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
|
||||
self.prev_cvd_fast = cvd_fast
|
||||
|
||||
result = {
|
||||
"cvd_fast": cvd_fast,
|
||||
"cvd_mid": cvd_mid,
|
||||
"cvd_day": self.win_day.cvd,
|
||||
"cvd_fast": cvd_fast, "cvd_mid": cvd_mid, "cvd_day": self.win_day.cvd,
|
||||
"cvd_fast_slope": cvd_fast_slope,
|
||||
"atr": atr,
|
||||
"atr_pct": atr_pct,
|
||||
"vwap": vwap,
|
||||
"price": price,
|
||||
"p95": p95,
|
||||
"p99": p99,
|
||||
"signal": None,
|
||||
"direction": None,
|
||||
"score": 0,
|
||||
"atr": atr, "atr_pct": atr_pct, "vwap": vwap, "price": price,
|
||||
"p95": p95, "p99": p99, "signal": None, "direction": None, "score": 0,
|
||||
}
|
||||
|
||||
if self.warmup or price == 0 or atr == 0:
|
||||
return result
|
||||
|
||||
# 冷却期检查
|
||||
if now_ms - self.last_signal_ts < COOLDOWN_MS:
|
||||
return result
|
||||
|
||||
# === 核心条件 ===
|
||||
# 做多
|
||||
long_core = (
|
||||
cvd_fast > 0 and cvd_fast_slope > 0 and # CVD_fast正且上升
|
||||
cvd_mid > 0 and # CVD_mid正
|
||||
price > vwap # 价格在VWAP上方
|
||||
)
|
||||
# 做空
|
||||
short_core = (
|
||||
cvd_fast < 0 and cvd_fast_slope < 0 and
|
||||
cvd_mid < 0 and
|
||||
price < vwap
|
||||
)
|
||||
long_core = cvd_fast > 0 and cvd_fast_slope > 0 and cvd_mid > 0 and price > vwap
|
||||
short_core = cvd_fast < 0 and cvd_fast_slope < 0 and cvd_mid < 0 and price < vwap
|
||||
|
||||
if not long_core and not short_core:
|
||||
return result
|
||||
|
||||
direction = "LONG" if long_core else "SHORT"
|
||||
|
||||
# === 加分条件 ===
|
||||
score = 0
|
||||
|
||||
# 1. ATR压缩→扩张 (+25)
|
||||
if atr_pct > 60:
|
||||
score += 25
|
||||
|
||||
# 2. 无反向P99大单 (+20)
|
||||
has_adverse_large = False
|
||||
for lt in self.recent_large_trades:
|
||||
if direction == "LONG" and lt[2] == 1: # 大卖单对多头不利
|
||||
has_adverse_large = True
|
||||
elif direction == "SHORT" and lt[2] == 0: # 大买单对空头不利
|
||||
has_adverse_large = True
|
||||
if not has_adverse_large:
|
||||
has_adverse = any(
|
||||
(direction == "LONG" and lt[2] == 1) or (direction == "SHORT" and lt[2] == 0)
|
||||
for lt in self.recent_large_trades
|
||||
)
|
||||
if not has_adverse:
|
||||
score += 20
|
||||
|
||||
# 3. 资金费率配合 (+15) — 从rate_snapshots读取
|
||||
# 暂时跳过,后续接入
|
||||
# TODO: 接入资金费率条件
|
||||
score += 0 # placeholder
|
||||
|
||||
result["signal"] = direction
|
||||
result["direction"] = direction
|
||||
result["score"] = score
|
||||
|
||||
# 更新冷却
|
||||
self.last_signal_ts = now_ms
|
||||
self.last_signal_dir = direction
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ─── 主循环 ──────────────────────────────────────────────────────
|
||||
# ─── PG DB操作 ───────────────────────────────────────────────────
|
||||
|
||||
def get_month_tables(conn: sqlite3.Connection) -> list[str]:
|
||||
rows = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
|
||||
).fetchall()
|
||||
return [r["name"] for r in rows]
|
||||
|
||||
|
||||
def load_historical(conn: sqlite3.Connection, state: SymbolState, window_ms: int):
|
||||
"""冷启动:回灌历史数据到内存窗口"""
|
||||
def load_historical(state: SymbolState, window_ms: int):
|
||||
now_ms = int(time.time() * 1000)
|
||||
start_ms = now_ms - window_ms
|
||||
tables = get_month_tables(conn)
|
||||
|
||||
count = 0
|
||||
for tname in tables:
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||||
f"WHERE symbol = ? AND time_ms >= ? ORDER BY agg_id ASC",
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND time_ms >= %s ORDER BY agg_id ASC",
|
||||
(state.symbol, start_ms)
|
||||
).fetchall()
|
||||
)
|
||||
while True:
|
||||
rows = cur.fetchmany(5000)
|
||||
if not rows:
|
||||
break
|
||||
for r in rows:
|
||||
state.process_trade(r["agg_id"], r["time_ms"], r["price"], r["qty"], r["is_buyer_maker"])
|
||||
state.process_trade(r[0], r[3], r[1], r[2], r[4])
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading {tname}: {e}")
|
||||
|
||||
logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)")
|
||||
state.warmup = False
|
||||
|
||||
|
||||
def fetch_new_trades(conn: sqlite3.Connection, symbol: str, last_id: int) -> list:
|
||||
"""增量读取新aggTrades"""
|
||||
tables = get_month_tables(conn)
|
||||
results = []
|
||||
for tname in tables[-2:]: # 只查最近两个月表
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||||
f"WHERE symbol = ? AND agg_id > ? ORDER BY agg_id ASC LIMIT 10000",
|
||||
def fetch_new_trades(symbol: str, last_id: int) -> list:
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
|
||||
"WHERE symbol = %s AND agg_id > %s ORDER BY agg_id ASC LIMIT 10000",
|
||||
(symbol, last_id)
|
||||
).fetchall()
|
||||
results.extend(rows)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
)
|
||||
return [{"agg_id": r[0], "price": r[1], "qty": r[2], "time_ms": r[3], "is_buyer_maker": r[4]}
|
||||
for r in cur.fetchall()]
|
||||
|
||||
|
||||
def save_indicator(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
|
||||
conn.execute("""
|
||||
INSERT INTO signal_indicators
|
||||
(ts, symbol, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile,
|
||||
vwap_30m, price, p95_qty, p99_qty, score, signal)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
ts, symbol,
|
||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
|
||||
result["atr"], result["atr_pct"],
|
||||
result["vwap"], result["price"],
|
||||
result["p95"], result["p99"],
|
||||
result["score"], result.get("signal")
|
||||
))
|
||||
def save_indicator(ts: int, symbol: str, result: dict):
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO signal_indicators "
|
||||
"(ts,symbol,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,vwap_30m,price,p95_qty,p99_qty,score,signal) "
|
||||
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||
(ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
|
||||
result["atr"], result["atr_pct"], result["vwap"], result["price"],
|
||||
result["p95"], result["p99"], result["score"], result.get("signal"))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def save_indicator_1m(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
|
||||
"""每分钟聚合保存"""
|
||||
def save_indicator_1m(ts: int, symbol: str, result: dict):
|
||||
bar_ts = (ts // 60000) * 60000
|
||||
existing = conn.execute(
|
||||
"SELECT id FROM signal_indicators_1m WHERE ts = ? AND symbol = ?", (bar_ts, symbol)
|
||||
).fetchone()
|
||||
if existing:
|
||||
conn.execute("""
|
||||
UPDATE signal_indicators_1m SET
|
||||
cvd_fast=?, cvd_mid=?, cvd_day=?, atr_5m=?, vwap_30m=?, price=?, score=?, signal=?
|
||||
WHERE ts=? AND symbol=?
|
||||
""", (
|
||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
|
||||
result["atr"], result["vwap"], result["price"],
|
||||
result["score"], result.get("signal"),
|
||||
bar_ts, symbol
|
||||
))
|
||||
with get_sync_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT id FROM signal_indicators_1m WHERE ts=%s AND symbol=%s", (bar_ts, symbol))
|
||||
if cur.fetchone():
|
||||
cur.execute(
|
||||
"UPDATE signal_indicators_1m SET cvd_fast=%s,cvd_mid=%s,cvd_day=%s,atr_5m=%s,vwap_30m=%s,price=%s,score=%s,signal=%s WHERE ts=%s AND symbol=%s",
|
||||
(result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"], result["vwap"],
|
||||
result["price"], result["score"], result.get("signal"), bar_ts, symbol)
|
||||
)
|
||||
else:
|
||||
conn.execute("""
|
||||
INSERT INTO signal_indicators_1m
|
||||
(ts, symbol, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
bar_ts, symbol,
|
||||
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
|
||||
result["atr"], result["vwap"], result["price"],
|
||||
result["score"], result.get("signal")
|
||||
))
|
||||
cur.execute(
|
||||
"INSERT INTO signal_indicators_1m (ts,symbol,cvd_fast,cvd_mid,cvd_day,atr_5m,vwap_30m,price,score,signal) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
|
||||
(bar_ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"],
|
||||
result["vwap"], result["price"], result["score"], result.get("signal"))
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ─── 主循环 ──────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
conn = get_conn()
|
||||
init_tables(conn)
|
||||
|
||||
init_schema()
|
||||
states = {sym: SymbolState(sym) for sym in SYMBOLS}
|
||||
|
||||
# 冷启动:回灌4h数据
|
||||
for sym, state in states.items():
|
||||
load_historical(conn, state, WINDOW_MID)
|
||||
load_historical(state, WINDOW_MID)
|
||||
|
||||
logger.info("=== Signal Engine 启动完成 ===")
|
||||
logger.info("=== Signal Engine (PG) 启动完成 ===")
|
||||
|
||||
last_1m_save = {}
|
||||
cycle = 0
|
||||
@ -505,44 +328,27 @@ def main():
|
||||
while True:
|
||||
try:
|
||||
now_ms = int(time.time() * 1000)
|
||||
|
||||
for sym, state in states.items():
|
||||
# 增量读取新数据
|
||||
new_trades = fetch_new_trades(conn, sym, state.last_processed_id)
|
||||
new_trades = fetch_new_trades(sym, state.last_processed_id)
|
||||
for t in new_trades:
|
||||
state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"])
|
||||
|
||||
# 评估信号
|
||||
result = state.evaluate_signal(now_ms)
|
||||
save_indicator(now_ms, sym, result)
|
||||
|
||||
# 保存5秒指标
|
||||
save_indicator(conn, now_ms, sym, result)
|
||||
|
||||
# 每分钟保存聚合
|
||||
bar_1m = (now_ms // 60000) * 60000
|
||||
if last_1m_save.get(sym) != bar_1m:
|
||||
save_indicator_1m(conn, now_ms, sym, result)
|
||||
save_indicator_1m(now_ms, sym, result)
|
||||
last_1m_save[sym] = bar_1m
|
||||
|
||||
# 有信号则记录
|
||||
if result.get("signal"):
|
||||
logger.info(
|
||||
f"[{sym}] 🚨 信号: {result['signal']} "
|
||||
f"score={result['score']} price={result['price']:.1f} "
|
||||
f"CVD_fast={result['cvd_fast']:.1f} CVD_mid={result['cvd_mid']:.1f}"
|
||||
)
|
||||
# TODO: Discord推送
|
||||
# TODO: 写入signal_trades
|
||||
logger.info(f"[{sym}] 🚨 信号: {result['signal']} score={result['score']} price={result['price']:.1f}")
|
||||
|
||||
cycle += 1
|
||||
if cycle % 60 == 0: # 每5分钟打一次状态
|
||||
if cycle % 60 == 0:
|
||||
for sym, state in states.items():
|
||||
r = state.evaluate_signal(now_ms)
|
||||
logger.info(
|
||||
f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} "
|
||||
f"ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f} "
|
||||
f"P95={r['p95']:.4f} P99={r['p99']:.4f}"
|
||||
)
|
||||
logger.info(f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"循环异常: {e}", exc_info=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user