refactor: SQLite→PostgreSQL migration - db.py连接层 + main/collector/signal-engine/backfill全部改PG

Phase 1: 核心数据表(agg_trades/rate_snapshots/signal*)迁PG
auth.py暂保留SQLite(低频,不影响性能)
- db.py: psycopg2同步池 + asyncpg异步池 + PG schema + 分区管理
- main.py: 全部改asyncpg查询
- collector: psycopg2 + execute_values批量写入
- signal-engine: psycopg2同步读写
- backfill: psycopg2 + ON CONFLICT DO NOTHING
This commit is contained in:
root 2026-02-27 16:15:16 +00:00
parent 23c7597a40
commit 4168c1dd88
5 changed files with 668 additions and 916 deletions

View File

@ -1,27 +1,29 @@
"""
agg_trades_collector.py aggTrades全量采集守护进程
agg_trades_collector.py aggTrades全量采集守护进程PostgreSQL版
架构
- WebSocket主链路实时推送延迟<100ms
- REST补洞断线重连后从last_agg_id追平
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次减少WAL压力
- 按月分表agg_trades_YYYYMM单表千万行内查询快
- 健康接口GET /collector/health 可监控
- 批量写入攒200条或1秒flush一次
- PG分区表按月自动分区MVCC并发无锁冲突
"""
import asyncio
import json
import logging
import os
import sqlite3
import time
from datetime import datetime, timezone
from typing import Optional
import httpx
import psycopg2
import psycopg2.extras
import websockets
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
@ -32,136 +34,84 @@ logging.basicConfig(
)
logger = logging.getLogger("collector")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"}
# 批量写入缓冲
_buffer: dict[str, list] = {s: [] for s in SYMBOLS}
BATCH_SIZE = 200
BATCH_TIMEOUT = 1.0 # seconds
BATCH_TIMEOUT = 1.0
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def table_name(ts_ms: int) -> str:
"""按月分表agg_trades_202602"""
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
return f"agg_trades_{dt.strftime('%Y%m')}"
def ensure_table(conn: sqlite3.Connection, tname: str):
conn.execute(f"""
CREATE TABLE IF NOT EXISTS {tname} (
agg_id INTEGER PRIMARY KEY,
symbol TEXT NOT NULL,
price REAL NOT NULL,
qty REAL NOT NULL,
first_trade_id INTEGER,
last_trade_id INTEGER,
time_ms INTEGER NOT NULL,
is_buyer_maker INTEGER NOT NULL
)
""")
conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time
ON {tname}(symbol, time_ms)
""")
def ensure_meta_table(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id INTEGER NOT NULL,
last_time_ms INTEGER NOT NULL,
updated_at TEXT DEFAULT (datetime('now'))
)
""")
def get_last_agg_id(symbol: str) -> Optional[int]:
try:
conn = get_conn()
ensure_meta_table(conn)
row = conn.execute(
"SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
).fetchone()
conn.close()
return row["last_agg_id"] if row else None
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
row = cur.fetchone()
return row[0] if row else None
except Exception as e:
logger.error(f"get_last_agg_id error: {e}")
return None
def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int):
conn.execute("""
def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
with conn.cursor() as cur:
cur.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at)
VALUES (?, ?, ?, datetime('now'))
VALUES (%s, %s, %s, NOW())
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = excluded.last_agg_id,
last_time_ms = excluded.last_time_ms,
updated_at = excluded.updated_at
last_agg_id = EXCLUDED.last_agg_id,
last_time_ms = EXCLUDED.last_time_ms,
updated_at = NOW()
""", (symbol, last_agg_id, last_time_ms))
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades,返回实际写入条数(去重后)"""
"""写入一批trades到PG返回实际写入条数"""
if not trades:
return 0
try:
conn = get_conn()
ensure_meta_table(conn)
# 按月分组
by_month: dict[str, list] = {}
for t in trades:
tname = table_name(t["T"])
if tname not in by_month:
by_month[tname] = []
by_month[tname].append(t)
# 确保分区存在
ensure_partitions()
inserted = 0
with get_sync_conn() as conn:
with conn.cursor() as cur:
# 批量插入ON CONFLICT忽略重复
values = []
last_agg_id = 0
last_time_ms = 0
for tname, batch in by_month.items():
ensure_table(conn, tname)
for t in batch:
cur = conn.execute(
f"""INSERT OR IGNORE INTO {tname}
(agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
t["a"], # agg_id
symbol,
float(t["p"]),
float(t["q"]),
t.get("f"), # first_trade_id
t.get("l"), # last_trade_id
t["T"], # time_ms
1 if t["m"] else 0, # is_buyer_maker
for t in trades:
agg_id = t["a"]
time_ms = t["T"]
values.append((
agg_id, symbol,
float(t["p"]), float(t["q"]),
time_ms,
1 if t["m"] else 0,
))
if agg_id > last_agg_id:
last_agg_id = agg_id
last_time_ms = time_ms
# 批量INSERT
psycopg2.extras.execute_values(
cur,
"""INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker)
VALUES %s
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""",
values,
template="(%s, %s, %s, %s, %s, %s)",
page_size=1000,
)
)
if cur.rowcount > 0:
inserted += 1
if t["a"] > last_agg_id:
last_agg_id = t["a"]
last_time_ms = t["T"]
inserted = cur.rowcount
if last_agg_id > 0:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
conn.close()
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")
@ -171,7 +121,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
# ─── REST补洞 ────────────────────────────────────────────────────
async def rest_catchup(symbol: str, from_id: int) -> int:
"""从from_id开始REST拉取追平到最新返回补洞条数"""
total = 0
current_id = from_id
logger.info(f"[{symbol}] REST catchup from agg_id={from_id}")
@ -195,10 +144,9 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
if last <= current_id:
break
current_id = last + 1
# 如果拉到的比最新少1000条说明追平了
if len(data) < 1000:
break
await asyncio.sleep(0.1) # rate limit友好
await asyncio.sleep(0.1)
except Exception as e:
logger.error(f"[{symbol}] REST catchup error: {e}")
break
@ -210,7 +158,6 @@ async def rest_catchup(symbol: str, from_id: int) -> int:
# ─── WebSocket采集 ───────────────────────────────────────────────
async def ws_collect(symbol: str):
"""单Symbol的WS采集循环自动断线重连+REST补洞"""
stream = symbol.lower() + "@aggTrade"
url = f"wss://fstream.binance.com/ws/{stream}"
buffer: list = []
@ -219,14 +166,13 @@ async def ws_collect(symbol: str):
while True:
try:
# 断线重连前先REST补洞
last_id = get_last_agg_id(symbol)
if last_id is not None:
await rest_catchup(symbol, last_id + 1)
logger.info(f"[{symbol}] Connecting WS: {url}")
async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws:
reconnect_delay = 1.0 # 连上了就重置
reconnect_delay = 1.0
logger.info(f"[{symbol}] WS connected")
async for raw in ws:
@ -236,7 +182,6 @@ async def ws_collect(symbol: str):
buffer.append(msg)
# 批量flush满200条或超1秒
now = time.time()
if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT:
count = flush_buffer(symbol, buffer)
@ -250,110 +195,90 @@ async def ws_collect(symbol: str):
except Exception as e:
logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s")
finally:
# flush剩余buffer
if buffer:
flush_buffer(symbol, buffer)
buffer.clear()
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s
reconnect_delay = min(reconnect_delay * 2, 30)
# ─── 连续性巡检 ──────────────────────────────────────────────────
async def continuity_check():
"""每60秒巡检一次检查各symbol最近的agg_id是否有断档"""
while True:
await asyncio.sleep(60)
try:
conn = get_conn()
ensure_meta_table(conn)
with get_sync_conn() as conn:
with conn.cursor() as cur:
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
).fetchone()
)
row = cur.fetchone()
if not row:
continue
# 检查最近1000条是否连续
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
rows = conn.execute(
f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100",
# 检查最近100条是否连续
cur.execute(
"SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100",
(symbol,)
).fetchall()
)
rows = cur.fetchall()
if len(rows) < 2:
continue
ids = [r["agg_id"] for r in rows]
# 检查是否连续(降序)
ids = [r[0] for r in rows]
gaps = []
for i in range(len(ids) - 1):
diff = ids[i] - ids[i + 1]
if diff > 1:
gaps.append((ids[i + 1], ids[i], diff - 1))
if gaps:
logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}")
# 触发补洞
logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}")
min_gap_id = min(g[0] for g in gaps)
asyncio.create_task(rest_catchup(symbol, min_gap_id))
else:
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}")
except Exception:
pass # 表可能还不存在
conn.close()
logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}")
except Exception as e:
logger.error(f"Continuity check error: {e}")
# ─── 每日完整性报告 ──────────────────────────────────────────────
# ─── 每小时报告 ──────────────────────────────────────────────────
async def daily_report():
"""每小时生成一次完整性摘要日志"""
while True:
await asyncio.sleep(3600)
try:
conn = get_conn()
ensure_meta_table(conn)
report_lines = ["=== AggTrades Integrity Report ==="]
with get_sync_conn() as conn:
with conn.cursor() as cur:
report = ["=== AggTrades Integrity Report ==="]
for symbol in SYMBOLS:
row = conn.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?",
cur.execute(
"SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s",
(symbol,)
).fetchone()
if not row:
report_lines.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat()
# 统计本月总量
now_month = datetime.now(tz=timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
try:
count = conn.execute(
f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,)
).fetchone()["c"]
except Exception:
count = 0
report_lines.append(
f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}"
)
conn.close()
logger.info("\n".join(report_lines))
row = cur.fetchone()
if not row:
report.append(f" {symbol}: No data yet")
continue
last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat()
cur.execute(
"SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,)
)
count = cur.fetchone()[0]
report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}")
logger.info("\n".join(report))
except Exception as e:
logger.error(f"Daily report error: {e}")
logger.error(f"Report error: {e}")
# ─── 入口 ────────────────────────────────────────────────────────
async def main():
logger.info("AggTrades Collector starting...")
# 确保基础表存在
conn = get_conn()
ensure_meta_table(conn)
conn.commit()
conn.close()
logger.info("AggTrades Collector (PG) starting...")
# 确保分区存在
ensure_partitions()
# 并行启动所有symbol的WS + 巡检 + 报告
tasks = [
ws_collect(sym) for sym in SYMBOLS
] + [

View File

@ -1,27 +1,18 @@
"""
backfill_agg_trades.py 历史aggTrades回补脚本
功能
- 从当前DB最早agg_id向历史方向回补
- Binance REST API分页拉取每次1000条
- INSERT OR IGNORE写入按月分表
- 断点续传记录进度到agg_trades_meta
- 速率控制sleep 200ms/请求429自动退避
- BTC+ETH并行回补
用法
python3 backfill_agg_trades.py [--days 30] [--symbol BTCUSDT]
backfill_agg_trades.py 历史aggTrades回补脚本PostgreSQL版
"""
import argparse
import logging
import os
import sqlite3
import sys
import time
from datetime import datetime, timezone
import requests
import psycopg2.extras
from db import get_sync_conn, init_schema, ensure_partitions
logging.basicConfig(
level=logging.INFO,
@ -33,119 +24,56 @@ logging.basicConfig(
)
logger = logging.getLogger("backfill")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/backfill"}
BATCH_SIZE = 1000
SLEEP_MS = 2000 # 低优先级2秒/请求,让出锁给实时服务
SLEEP_MS = 2000
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def get_earliest_agg_id(symbol: str) -> int | None:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,))
row = cur.fetchone()
if row and row[0]:
return row[0]
# fallback: scan agg_trades
cur.execute("SELECT MIN(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
row = cur.fetchone()
return row[0] if row else None
def table_name(ts_ms: int) -> str:
dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
return f"agg_trades_{dt.strftime('%Y%m')}"
def ensure_table(conn: sqlite3.Connection, tname: str):
conn.execute(f"""
CREATE TABLE IF NOT EXISTS {tname} (
agg_id INTEGER PRIMARY KEY,
symbol TEXT NOT NULL,
price REAL NOT NULL,
qty REAL NOT NULL,
time_ms INTEGER NOT NULL,
is_buyer_maker INTEGER NOT NULL
)
""")
conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{tname}_symbol_time ON {tname}(symbol, time_ms)")
def ensure_meta(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id INTEGER,
last_time_ms INTEGER,
earliest_agg_id INTEGER,
earliest_time_ms INTEGER
)
""")
# 添加earliest字段如果旧表没有
try:
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_agg_id INTEGER")
conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_time_ms INTEGER")
except Exception:
pass # 已存在
def get_earliest_agg_id(conn: sqlite3.Connection, symbol: str) -> int | None:
"""查找DB中该symbol的最小agg_id"""
# 先查meta表
row = conn.execute(
"SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,)
).fetchone()
if row and row["earliest_agg_id"]:
return row["earliest_agg_id"]
# meta表没有扫所有月表
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'"
).fetchall()
min_id = None
for t in tables:
r = conn.execute(
f"SELECT MIN(agg_id) as mid FROM {t['name']} WHERE symbol = ?", (symbol,)
).fetchone()
if r and r["mid"] is not None:
if min_id is None or r["mid"] < min_id:
min_id = r["mid"]
return min_id
def update_earliest_meta(conn: sqlite3.Connection, symbol: str, agg_id: int, time_ms: int):
# 先检查是否已有该symbol的记录
row = conn.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = ?", (symbol,)).fetchone()
if row:
conn.execute("""
def update_earliest_meta(symbol: str, agg_id: int, time_ms: int):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = %s", (symbol,))
if cur.fetchone():
cur.execute("""
UPDATE agg_trades_meta SET
earliest_agg_id = MIN(?, COALESCE(earliest_agg_id, ?)),
earliest_time_ms = MIN(?, COALESCE(earliest_time_ms, ?))
WHERE symbol = ?
earliest_agg_id = LEAST(%s, COALESCE(earliest_agg_id, %s)),
earliest_time_ms = LEAST(%s, COALESCE(earliest_time_ms, %s))
WHERE symbol = %s
""", (agg_id, agg_id, time_ms, time_ms, symbol))
else:
conn.execute("""
cur.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms)
VALUES (?, ?, ?, ?, ?)
VALUES (%s, %s, %s, %s, %s)
""", (symbol, agg_id, time_ms, agg_id, time_ms))
conn.commit()
# ─── REST API ────────────────────────────────────────────────────
def fetch_agg_trades(symbol: str, from_id: int | None = None,
start_time: int | None = None, limit: int = 1000) -> list[dict]:
"""从Binance拉取aggTrades支持fromId和startTime"""
def fetch_agg_trades(symbol: str, from_id: int | None = None, limit: int = 1000) -> list:
params = {"symbol": symbol, "limit": limit}
if from_id is not None:
params["fromId"] = from_id
elif start_time is not None:
params["startTime"] = start_time
for attempt in range(5):
try:
r = requests.get(
f"{BINANCE_FAPI}/aggTrades",
params=params, headers=HEADERS, timeout=15
)
r = requests.get(f"{BINANCE_FAPI}/aggTrades", params=params, headers=HEADERS, timeout=15)
if r.status_code == 429:
wait = min(60 * (2 ** attempt), 300)
logger.warning(f"Rate limited (429), waiting {wait}s (attempt {attempt+1})")
@ -162,24 +90,22 @@ def fetch_agg_trades(symbol: str, from_id: int | None = None,
return []
def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> int:
"""批量写入,返回实际插入条数"""
inserted = 0
tables_seen = set()
for t in trades:
tname = table_name(t["T"])
if tname not in tables_seen:
ensure_table(conn, tname)
tables_seen.add(tname)
try:
conn.execute(
f"INSERT OR IGNORE INTO {tname} (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
f"VALUES (?, ?, ?, ?, ?, ?)",
(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0)
def write_batch(symbol: str, trades: list) -> int:
if not trades:
return 0
ensure_partitions()
values = [(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0) for t in trades]
with get_sync_conn() as conn:
with conn.cursor() as cur:
psycopg2.extras.execute_values(
cur,
"INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) "
"VALUES %s ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING",
values,
template="(%s, %s, %s, %s, %s, %s)",
page_size=1000,
)
inserted += 1
except sqlite3.IntegrityError:
pass
inserted = cur.rowcount
conn.commit()
return inserted
@ -187,42 +113,24 @@ def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> in
# ─── 主逻辑 ──────────────────────────────────────────────────────
def backfill_symbol(symbol: str, target_days: int | None = None):
"""回补单个symbol的历史数据"""
conn = get_conn()
ensure_meta(conn)
earliest = get_earliest_agg_id(conn, symbol)
earliest = get_earliest_agg_id(symbol)
if earliest is None:
logger.info(f"[{symbol}] DB无数据从最新开始拉最近的数据确定起点")
logger.info(f"[{symbol}] DB无数据拉最新确定起点")
trades = fetch_agg_trades(symbol, limit=1)
if not trades:
logger.error(f"[{symbol}] 无法获取起始数据")
return
earliest = trades[0]["a"]
logger.info(f"[{symbol}] 当前最新agg_id: {earliest}")
# 计算目标:往前补到多少天前
if target_days:
target_ts = int((time.time() - target_days * 86400) * 1000)
else:
target_ts = 0 # 拉到最早
target_ts = int((time.time() - target_days * 86400) * 1000) if target_days else 0
logger.info(f"[{symbol}] 开始回补当前最早agg_id={earliest}")
if target_days:
target_dt = datetime.fromtimestamp(target_ts / 1000, tz=timezone.utc)
logger.info(f"[{symbol}] 目标:回补到 {target_dt.strftime('%Y-%m-%d')} ({target_days}天前)")
total_inserted = 0
total_requests = 0
rate_limit_hits = 0
current_from_id = earliest
while True:
# 从current_from_id向前拉先拉current_from_id - BATCH_SIZE*2的位置
# Binance fromId是从该id开始正向拉所以我们用endTime方式
# 更可靠用fromId = current_from_id - BATCH_SIZE然后正向拉
fetch_from = max(0, current_from_id - BATCH_SIZE)
trades = fetch_agg_trades(symbol, from_id=fetch_from, limit=BATCH_SIZE)
total_requests += 1
@ -230,104 +138,56 @@ def backfill_symbol(symbol: str, target_days: int | None = None):
logger.info(f"[{symbol}] 无更多数据,回补结束")
break
# 过滤掉已有的(>= earliest的数据实时采集器已覆盖
new_trades = [t for t in trades if t["a"] < current_from_id]
if not new_trades:
# 没有更早的数据了
logger.info(f"[{symbol}] 已到达最早数据,回补结束")
break
inserted = write_batch(conn, symbol, new_trades)
inserted = write_batch(symbol, new_trades)
total_inserted += inserted
oldest = min(new_trades, key=lambda x: x["a"])
oldest_time = datetime.fromtimestamp(oldest["T"] / 1000, tz=timezone.utc)
current_from_id = oldest["a"]
update_earliest_meta(symbol, oldest["a"], oldest["T"])
# 更新meta
update_earliest_meta(conn, symbol, oldest["a"], oldest["T"])
# 进度日志每50批打一次
if total_requests % 50 == 0:
logger.info(
f"[{symbol}] 进度: {total_inserted:,} 条已插入, "
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, "
f"agg_id={current_from_id:,}, "
f"请求数={total_requests}, 429次数={rate_limit_hits}"
f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, agg_id={current_from_id:,}, 请求数={total_requests}"
)
# 检查是否到达目标时间
if target_ts and oldest["T"] <= target_ts:
logger.info(f"[{symbol}] 已达到目标时间,回补结束")
break
time.sleep(SLEEP_MS / 1000)
conn.close()
logger.info(
f"[{symbol}] 回补完成: "
f"总插入={total_inserted:,}, 总请求={total_requests}, "
f"429次数={rate_limit_hits}"
)
return total_inserted, total_requests, rate_limit_hits
logger.info(f"[{symbol}] 回补完成: 总插入={total_inserted:,}, 总请求={total_requests}")
def check_continuity(symbol: str):
"""检查agg_id连续性"""
conn = get_conn()
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
).fetchall()
all_ids = []
for t in tables:
rows = conn.execute(
f"SELECT agg_id FROM {t['name']} WHERE symbol = ? ORDER BY agg_id",
(symbol,)
).fetchall()
all_ids.extend(r["agg_id"] for r in rows)
conn.close()
if not all_ids:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT COUNT(*), MIN(agg_id), MAX(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,))
row = cur.fetchone()
total, min_id, max_id = row
if not total:
logger.info(f"[{symbol}] 无数据")
return
all_ids.sort()
gaps = 0
gap_ranges = []
for i in range(1, len(all_ids)):
diff = all_ids[i] - all_ids[i-1]
if diff > 1:
gaps += 1
if len(gap_ranges) < 10:
gap_ranges.append((all_ids[i-1], all_ids[i], diff - 1))
total = len(all_ids)
span = all_ids[-1] - all_ids[0] + 1
span = max_id - min_id + 1
coverage = total / span * 100 if span > 0 else 0
logger.info(f"[{symbol}] 总条数={total:,}, ID范围={min_id:,}~{max_id:,}, 覆盖率={coverage:.2f}%")
logger.info(
f"[{symbol}] 连续性检查: "
f"总条数={total:,}, ID范围={all_ids[0]:,}~{all_ids[-1]:,}, "
f"理论条数={span:,}, 覆盖率={coverage:.2f}%, 缺口数={gaps}"
)
if gap_ranges:
for start, end, missing in gap_ranges[:5]:
logger.info(f" 缺口: {start}{end} (缺{missing}条)")
# ─── CLI ─────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Backfill aggTrades from Binance")
parser.add_argument("--days", type=int, default=None, help="回补天数(默认全量)")
parser.add_argument("--symbol", type=str, default=None, help="指定symbol默认BTC+ETH")
parser.add_argument("--check", action="store_true", help="仅检查连续性")
parser = argparse.ArgumentParser()
parser.add_argument("--days", type=int, default=None)
parser.add_argument("--symbol", type=str, default=None)
parser.add_argument("--check", action="store_true")
args = parser.parse_args()
init_schema()
symbols = [args.symbol] if args.symbol else ["BTCUSDT", "ETHUSDT"]
if args.check:
@ -336,18 +196,12 @@ def main():
return
logger.info(f"=== 开始回补 === symbols={symbols}, days={args.days or '全量'}")
for sym in symbols:
logger.info(f"--- 回补 {sym} ---")
result = backfill_symbol(sym, target_days=args.days)
if result:
inserted, reqs, limits = result
logger.info(f"[{sym}] 结果: 插入{inserted:,}条, {reqs}次请求, {limits}次429")
backfill_symbol(sym, target_days=args.days)
logger.info("=== 回补完成,开始连续性检查 ===")
logger.info("=== 连续性检查 ===")
for sym in symbols:
check_continuity(sym)
logger.info("=== 全部完成 ===")

282
backend/db.py Normal file
View File

@ -0,0 +1,282 @@
"""
db.py PostgreSQL 数据库连接层
同步连接池psycopg2供脚本类使用
异步连接池asyncpg供FastAPI使用
"""
import os
import asyncio
import asyncpg
import psycopg2
import psycopg2.pool
from contextlib import contextmanager
# PG连接参数
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_PORT = int(os.getenv("PG_PORT", 5432))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
# ─── 同步连接池psycopg2─────────────────────────────────────
_sync_pool = None
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
global _sync_pool
if _sync_pool is None:
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=1, maxconn=5,
host=PG_HOST, port=PG_PORT,
dbname=PG_DB, user=PG_USER, password=PG_PASS,
)
return _sync_pool
@contextmanager
def get_sync_conn():
"""获取同步PG连接自动归还到池"""
pool = get_sync_pool()
conn = pool.getconn()
try:
yield conn
finally:
pool.putconn(conn)
def sync_execute(sql: str, params=None, fetch=False):
"""简便执行SQL"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params)
if fetch:
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
conn.commit()
def sync_executemany(sql: str, params_list: list):
"""批量执行"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.executemany(sql, params_list)
conn.commit()
# ─── 异步连接池asyncpg─────────────────────────────────────
_async_pool: asyncpg.Pool | None = None
async def get_async_pool() -> asyncpg.Pool:
global _async_pool
if _async_pool is None:
_async_pool = await asyncpg.create_pool(
dsn=PG_DSN, min_size=2, max_size=10,
)
return _async_pool
async def async_fetch(sql: str, *args) -> list[dict]:
"""异步查询返回dict列表"""
pool = await get_async_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(sql, *args)
return [dict(r) for r in rows]
async def async_fetchrow(sql: str, *args) -> dict | None:
"""异步查询单行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(sql, *args)
return dict(row) if row else None
async def async_execute(sql: str, *args):
"""异步执行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
await conn.execute(sql, *args)
async def close_async_pool():
global _async_pool
if _async_pool:
await _async_pool.close()
_async_pool = None
# ─── 建表PG Schema──────────────────────────────────────────
SCHEMA_SQL = """
-- 费率快照
CREATE TABLE IF NOT EXISTS rate_snapshots (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
btc_rate DOUBLE PRECISION NOT NULL,
eth_rate DOUBLE PRECISION NOT NULL,
btc_price DOUBLE PRECISION NOT NULL,
eth_price DOUBLE PRECISION NOT NULL,
btc_index_price DOUBLE PRECISION,
eth_index_price DOUBLE PRECISION
);
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
-- aggTrades元数据
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id BIGINT NOT NULL,
last_time_ms BIGINT,
earliest_agg_id BIGINT,
earliest_time_ms BIGINT,
updated_at TEXT
);
-- aggTrades主表分区父表
CREATE TABLE IF NOT EXISTS agg_trades (
agg_id BIGINT NOT NULL,
symbol TEXT NOT NULL,
price DOUBLE PRECISION NOT NULL,
qty DOUBLE PRECISION NOT NULL,
time_ms BIGINT NOT NULL,
is_buyer_maker SMALLINT NOT NULL,
PRIMARY KEY (time_ms, symbol, agg_id)
) PARTITION BY RANGE (time_ms);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
-- Signal Engine表
CREATE TABLE IF NOT EXISTS signal_indicators (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
cvd_fast_slope DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
atr_percentile DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
p95_qty DOUBLE PRECISION,
p99_qty DOUBLE PRECISION,
buy_vol_1m DOUBLE PRECISION,
sell_vol_1m DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_trades (
id BIGSERIAL PRIMARY KEY,
ts_open BIGINT NOT NULL,
ts_close BIGINT,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
entry_price DOUBLE PRECISION,
exit_price DOUBLE PRECISION,
qty DOUBLE PRECISION,
score INTEGER,
pnl DOUBLE PRECISION,
sl_price DOUBLE PRECISION,
tp1_price DOUBLE PRECISION,
tp2_price DOUBLE PRECISION,
status TEXT DEFAULT 'open'
);
-- 信号日志旧表兼容
CREATE TABLE IF NOT EXISTS signal_logs (
id BIGSERIAL PRIMARY KEY,
symbol TEXT,
rate DOUBLE PRECISION,
annualized DOUBLE PRECISION,
sent_at TEXT,
message TEXT
);
-- 用户表auth
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
role TEXT DEFAULT 'user',
created_at TIMESTAMP DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS invite_codes (
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
used_by BIGINT REFERENCES users(id),
created_at TIMESTAMP DEFAULT NOW()
);
"""
def ensure_partitions():
"""创建当月和下月的分区表"""
import datetime
now = datetime.datetime.utcnow()
months = []
for delta in range(0, 3): # 当月+下2个月
d = now + datetime.timedelta(days=delta * 30)
months.append(d.strftime("%Y%m"))
with get_sync_conn() as conn:
with conn.cursor() as cur:
for m in set(months):
year = int(m[:4])
month = int(m[4:])
# 计算分区范围(毫秒时间戳)
start = datetime.datetime(year, month, 1)
if month == 12:
end = datetime.datetime(year + 1, 1, 1)
else:
end = datetime.datetime(year, month + 1, 1)
start_ms = int(start.timestamp() * 1000)
end_ms = int(end.timestamp() * 1000)
part_name = f"agg_trades_{m}"
try:
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {part_name}
PARTITION OF agg_trades
FOR VALUES FROM ({start_ms}) TO ({end_ms})
""")
except Exception:
pass # 分区已存在
conn.commit()
def init_schema():
"""初始化全部PG表结构"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in SCHEMA_SQL.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
conn.rollback()
# 忽略已存在错误
continue
conn.commit()
ensure_partitions()

View File

@ -2,9 +2,13 @@ from fastapi import FastAPI, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
import httpx
from datetime import datetime, timedelta
import asyncio, time, sqlite3, os
import asyncio, time, os
from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables
from db import (
init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute,
close_async_pool,
)
import datetime as _dt
app = FastAPI(title="Arbitrage Engine API")
@ -21,7 +25,6 @@ app.include_router(auth_router)
BINANCE_FAPI = "https://fapi.binance.com/fapi/v1"
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
# 简单内存缓存history/stats 60秒rates 3秒
_cache: dict = {}
@ -36,33 +39,13 @@ def set_cache(key: str, data):
_cache[key] = {"ts": time.time(), "data": data}
def init_db():
conn = sqlite3.connect(DB_PATH)
conn.execute("""
CREATE TABLE IF NOT EXISTS rate_snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
btc_rate REAL NOT NULL,
eth_rate REAL NOT NULL,
btc_price REAL NOT NULL,
eth_price REAL NOT NULL,
btc_index_price REAL,
eth_index_price REAL
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)")
conn.commit()
conn.close()
def save_snapshot(rates: dict):
async def save_snapshot(rates: dict):
try:
conn = sqlite3.connect(DB_PATH)
btc = rates.get("BTC", {})
eth = rates.get("ETH", {})
conn.execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)",
(
await async_execute(
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) "
"VALUES ($1,$2,$3,$4,$5,$6,$7)",
int(time.time()),
float(btc.get("lastFundingRate", 0)),
float(eth.get("lastFundingRate", 0)),
@ -71,15 +54,12 @@ def save_snapshot(rates: dict):
float(btc.get("indexPrice", 0)),
float(eth.get("indexPrice", 0)),
)
)
conn.commit()
conn.close()
except Exception as e:
pass # 落库失败不影响API响应
async def background_snapshot_loop():
"""后台每2秒自动拉取费率+价格并落库,不依赖前端调用"""
"""后台每2秒自动拉取费率+价格并落库"""
while True:
try:
async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client:
@ -97,7 +77,7 @@ async def background_snapshot_loop():
"indexPrice": float(data["indexPrice"]),
}
if result:
save_snapshot(result)
await save_snapshot(result)
except Exception:
pass
await asyncio.sleep(2)
@ -105,11 +85,19 @@ async def background_snapshot_loop():
@app.on_event("startup")
async def startup():
init_db()
# 初始化PG schema
init_schema()
ensure_auth_tables()
# 初始化asyncpg池
await get_async_pool()
asyncio.create_task(background_snapshot_loop())
@app.on_event("shutdown")
async def shutdown():
await close_async_pool()
@app.get("/api/health")
async def health():
return {"status": "ok", "timestamp": datetime.utcnow().isoformat()}
@ -137,37 +125,23 @@ async def get_rates():
"timestamp": data["time"],
}
set_cache("rates", result)
# 异步落库(不阻塞响应)
asyncio.create_task(asyncio.to_thread(save_snapshot, result))
asyncio.create_task(save_snapshot(result))
return result
@app.get("/api/snapshots")
async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)):
"""查询本地落库的实时快照数据"""
since = int(time.time()) - hours * 3600
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?",
(since, limit)
).fetchall()
conn.close()
return {
"count": len(rows),
"hours": hours,
"data": [dict(r) for r in rows]
}
rows = await async_fetch(
"SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots "
"WHERE ts >= $1 ORDER BY ts ASC LIMIT $2",
since, limit
)
return {"count": len(rows), "hours": hours, "data": rows}
@app.get("/api/kline")
async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)):
"""
rate_snapshots 聚合K线数据
symbol: BTC | ETH
interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M
返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}]
"""
interval_secs = {
"1m": 60, "5m": 300, "30m": 1800,
"1h": 3600, "4h": 14400, "8h": 28800,
@ -176,22 +150,20 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
bar_secs = interval_secs.get(interval, 300)
rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate"
price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price"
# 查询足够多的原始数据limit根K * bar_secs最多需要的时间范围
since = int(time.time()) - bar_secs * limit
conn = sqlite3.connect(DB_PATH)
rows = conn.execute(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC",
(since,)
).fetchall()
conn.close()
rows = await async_fetch(
f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots "
f"WHERE ts >= $1 ORDER BY ts ASC",
since
)
if not rows:
return {"symbol": symbol, "interval": interval, "data": []}
# 按bar_secs分组聚合OHLC
bars: dict = {}
for ts, rate, price in rows:
for r in rows:
ts, rate, price = r["ts"], r["rate"], r["price"]
bar_ts = (ts // bar_secs) * bar_secs
if bar_ts not in bars:
bars[bar_ts] = {
@ -209,20 +181,16 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500,
b["price_close"] = price
data = sorted(bars.values(), key=lambda x: x["time"])[-limit:]
# 转换为万分之(费率 × 10000
for b in data:
for k in ("open", "high", "low", "close"):
b[k] = round(b[k] * 10000, 4)
return {"symbol": symbol, "interval": interval, "count": len(data), "data": data}
@app.get("/api/stats/ytd")
async def get_stats_ytd(user: dict = Depends(get_current_user)):
"""今年以来(YTD)资金费率年化统计"""
cached = get_cache("stats_ytd", 3600)
if cached: return cached
# 今年1月1日 00:00 UTC
import datetime
year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000)
end_time = int(time.time() * 1000)
@ -252,16 +220,12 @@ async def get_stats_ytd(user: dict = Depends(get_current_user)):
@app.get("/api/signals/history")
async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)):
"""查询信号推送历史"""
try:
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?",
(limit,)
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
rows = await async_fetch(
"SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT $1",
limit
)
return {"items": rows}
except Exception as e:
return {"items": [], "error": str(e)}
@ -334,20 +298,11 @@ async def get_stats(user: dict = Depends(get_current_user)):
return stats
# ─── aggTrades 查询接口 ──────────────────────────────────────────
# ─── aggTrades 查询接口PG版───────────────────────────────────
@app.get("/api/trades/meta")
async def get_trades_meta(user: dict = Depends(get_current_user)):
"""aggTrades采集状态各symbol最新agg_id和时间"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta"
).fetchall()
except Exception:
rows = []
conn.close()
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta")
result = {}
for r in rows:
sym = r["symbol"].replace("USDT", "")
@ -367,46 +322,23 @@ async def get_trades_summary(
interval: str = "1m",
user: dict = Depends(get_current_user),
):
"""分钟级聚合买卖delta/成交速率/vwap"""
if end_ms == 0:
end_ms = int(time.time() * 1000)
if start_ms == 0:
start_ms = end_ms - 3600 * 1000 # 默认1小时
start_ms = end_ms - 3600 * 1000
interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000)
sym_full = symbol.upper() + "USDT"
# 确定需要查哪些月表
start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc)
end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc)
months = set()
cur = start_dt.replace(day=1)
while cur <= end_dt:
months.add(cur.strftime("%Y%m"))
if cur.month == 12:
cur = cur.replace(year=cur.year + 1, month=1)
else:
cur = cur.replace(month=cur.month + 1)
# PG分区表自动裁剪直接查主表
rows = await async_fetch(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3 ORDER BY time_ms ASC",
sym_full, start_ms, end_ms
)
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
all_rows = []
for month in sorted(months):
tname = f"agg_trades_{month}"
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC",
(sym_full, start_ms, end_ms)
).fetchall()
all_rows.extend(rows)
except Exception:
pass
conn.close()
# 按interval聚合
bars: dict = {}
for row in all_rows:
for row in rows:
bar_ms = (row["time_ms"] // interval_ms) * interval_ms
if bar_ms not in bars:
bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0,
@ -414,7 +346,7 @@ async def get_trades_summary(
b = bars[bar_ms]
qty = float(row["qty"])
price = float(row["price"])
if row["is_buyer_maker"] == 0: # 主动买
if row["is_buyer_maker"] == 0:
b["buy_vol"] += qty
else:
b["sell_vol"] += qty
@ -446,47 +378,28 @@ async def get_trades_latest(
limit: int = 30,
user: dict = Depends(get_current_user),
):
"""查最新N条原始成交记录从本地DB实时刷新用"""
sym_full = symbol.upper() + "USDT"
now_month = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m")
tname = f"agg_trades_{now_month}"
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? ORDER BY agg_id DESC LIMIT ?",
(sym_full, limit)
).fetchall()
except Exception:
rows = []
conn.close()
return {
"symbol": symbol,
"count": len(rows),
"data": [dict(r) for r in rows],
}
rows = await async_fetch(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2",
sym_full, limit
)
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/collector/health")
async def collector_health(user: dict = Depends(get_current_user)):
"""采集器健康状态"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
now_ms = int(time.time() * 1000)
rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta")
status = {}
try:
rows = conn.execute(
"SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta"
).fetchall()
for r in rows:
sym = r["symbol"].replace("USDT", "")
lag_s = (now_ms - r["last_time_ms"]) / 1000
lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000
status[sym] = {
"last_agg_id": r["last_agg_id"],
"lag_seconds": round(lag_s, 1),
"healthy": lag_s < 30, # 30秒内有数据算健康
"healthy": lag_s < 30,
}
except Exception:
pass
conn.close()
return {"collector": status, "timestamp": now_ms}
@ -498,50 +411,29 @@ async def get_signal_indicators(
minutes: int = 60,
user: dict = Depends(get_current_user),
):
"""获取signal_indicators_1m数据前端图表用"""
sym_full = symbol.upper() + "USDT"
now_ms = int(time.time() * 1000)
start_ms = now_ms - minutes * 60 * 1000
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
rows = await async_fetch(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal "
"FROM signal_indicators_1m WHERE symbol = ? AND ts >= ? ORDER BY ts ASC",
(sym_full, start_ms)
).fetchall()
except Exception:
rows = []
conn.close()
return {
"symbol": symbol,
"count": len(rows),
"data": [dict(r) for r in rows],
}
"FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC",
sym_full, start_ms
)
return {"symbol": symbol, "count": len(rows), "data": rows}
@app.get("/api/signals/latest")
async def get_signal_latest(
user: dict = Depends(get_current_user),
):
"""获取最新一条各symbol的指标快照"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
async def get_signal_latest(user: dict = Depends(get_current_user)):
result = {}
for sym in ["BTCUSDT", "ETHUSDT"]:
try:
row = conn.execute(
row = await async_fetchrow(
"SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, "
"vwap_30m, price, p95_qty, p99_qty, score, signal "
"FROM signal_indicators WHERE symbol = ? ORDER BY ts DESC LIMIT 1",
(sym,)
).fetchone()
"FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1",
sym
)
if row:
result[sym.replace("USDT", "")] = dict(row)
except Exception:
pass
conn.close()
result[sym.replace("USDT", "")] = row
return result
@ -551,20 +443,13 @@ async def get_signal_trades(
limit: int = 50,
user: dict = Depends(get_current_user),
):
"""获取信号交易记录"""
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
try:
if status == "all":
rows = conn.execute(
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT ?", (limit,)
).fetchall()
rows = await async_fetch(
"SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit
)
else:
rows = conn.execute(
"SELECT * FROM signal_trades WHERE status = ? ORDER BY ts_open DESC LIMIT ?",
(status, limit)
).fetchall()
except Exception:
rows = []
conn.close()
return {"count": len(rows), "data": [dict(r) for r in rows]}
rows = await async_fetch(
"SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2",
status, limit
)
return {"count": len(rows), "data": rows}

View File

@ -1,5 +1,5 @@
"""
signal_engine.py V5 短线交易信号引擎
signal_engine.py V5 短线交易信号引擎PostgreSQL版
架构
- 独立PM2进程每5秒循环
@ -17,14 +17,13 @@ signal_engine.py — V5 短线交易信号引擎
import logging
import os
import sqlite3
import time
import math
import statistics
from collections import deque
from datetime import datetime, timezone
from typing import Optional
from db import get_sync_conn, init_schema
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
@ -35,110 +34,33 @@ logging.basicConfig(
)
logger = logging.getLogger("signal-engine")
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
SYMBOLS = ["BTCUSDT", "ETHUSDT"]
LOOP_INTERVAL = 5 # 秒
# 窗口大小(毫秒)
WINDOW_FAST = 30 * 60 * 1000 # 30分钟
WINDOW_MID = 4 * 3600 * 1000 # 4小时
WINDOW_DAY = 24 * 3600 * 1000 # 24小时用于P95/P99计算
WINDOW_DAY = 24 * 3600 * 1000 # 24小时
WINDOW_VWAP = 30 * 60 * 1000 # 30分钟
# ATR参数
ATR_PERIOD_MS = 5 * 60 * 1000 # 5分钟K线
ATR_LENGTH = 14 # 14根
ATR_PERIOD_MS = 5 * 60 * 1000
ATR_LENGTH = 14
# 信号冷却
COOLDOWN_MS = 10 * 60 * 1000 # 10分钟
# ─── DB helpers ──────────────────────────────────────────────────
def get_conn() -> sqlite3.Connection:
conn = sqlite3.connect(DB_PATH, timeout=30)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA synchronous=NORMAL")
return conn
def init_tables(conn: sqlite3.Connection):
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_indicators (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
cvd_fast REAL,
cvd_mid REAL,
cvd_day REAL,
cvd_fast_slope REAL,
atr_5m REAL,
atr_percentile REAL,
vwap_30m REAL,
price REAL,
p95_qty REAL,
p99_qty REAL,
buy_vol_1m REAL,
sell_vol_1m REAL,
score INTEGER,
signal TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts)")
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts INTEGER NOT NULL,
symbol TEXT NOT NULL,
cvd_fast REAL,
cvd_mid REAL,
cvd_day REAL,
atr_5m REAL,
vwap_30m REAL,
price REAL,
score INTEGER,
signal TEXT
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts)")
conn.execute("""
CREATE TABLE IF NOT EXISTS signal_trades (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ts_open INTEGER NOT NULL,
ts_close INTEGER,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
entry_price REAL,
exit_price REAL,
qty REAL,
score INTEGER,
pnl REAL,
sl_price REAL,
tp1_price REAL,
tp2_price REAL,
status TEXT DEFAULT 'open'
)
""")
conn.commit()
COOLDOWN_MS = 10 * 60 * 1000
# ─── 滚动窗口 ───────────────────────────────────────────────────
class TradeWindow:
"""滚动时间窗口,维护买卖量和价格数据"""
def __init__(self, window_ms: int):
self.window_ms = window_ms
self.trades: deque = deque() # (time_ms, qty, price, is_buyer_maker)
self.trades: deque = deque()
self.buy_vol = 0.0
self.sell_vol = 0.0
self.pq_sum = 0.0 # price * qty 累加VWAP用
self.q_sum = 0.0 # qty累加
self.quantities: list = [] # 用于P95/P99仅24h窗口用
self.pq_sum = 0.0
self.q_sum = 0.0
def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int):
self.trades.append((time_ms, qty, price, is_buyer_maker))
@ -154,8 +76,7 @@ class TradeWindow:
cutoff = now_ms - self.window_ms
while self.trades and self.trades[0][0] < cutoff:
t_ms, qty, price, ibm = self.trades.popleft()
pq = price * qty
self.pq_sum -= pq
self.pq_sum -= price * qty
self.q_sum -= qty
if ibm == 0:
self.buy_vol -= qty
@ -170,29 +91,21 @@ class TradeWindow:
def vwap(self) -> float:
return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0
@property
def total_vol(self) -> float:
return self.buy_vol + self.sell_vol
class ATRCalculator:
"""5分钟K线ATR计算"""
def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH):
self.period_ms = period_ms
self.length = length
self.candles: deque = deque(maxlen=length + 1)
self.current_candle: Optional[dict] = None
self.atr_history: deque = deque(maxlen=288) # 24h of 5m candles for percentile
self.atr_history: deque = deque(maxlen=288)
def update(self, time_ms: int, price: float):
bar_ms = (time_ms // self.period_ms) * self.period_ms
if self.current_candle is None or self.current_candle["bar"] != bar_ms:
if self.current_candle is not None:
self.candles.append(self.current_candle)
self.current_candle = {
"bar": bar_ms, "open": price, "high": price, "low": price, "close": price
}
self.current_candle = {"bar": bar_ms, "open": price, "high": price, "low": price, "close": price}
else:
c = self.current_candle
c["high"] = max(c["high"], price)
@ -212,7 +125,6 @@ class ATRCalculator:
trs.append(tr)
if not trs:
return 0.0
# EMA-style ATR
atr_val = trs[0]
for tr in trs[1:]:
atr_val = (atr_val * (self.length - 1) + tr) / self.length
@ -220,7 +132,6 @@ class ATRCalculator:
@property
def atr_percentile(self) -> float:
"""当前ATR在最近24h中的分位数"""
current = self.atr
if current == 0:
return 50.0
@ -233,8 +144,6 @@ class ATRCalculator:
class SymbolState:
"""单个交易对的完整状态"""
def __init__(self, symbol: str):
self.symbol = symbol
self.win_fast = TradeWindow(WINDOW_FAST)
@ -244,12 +153,10 @@ class SymbolState:
self.atr_calc = ATRCalculator()
self.last_processed_id = 0
self.warmup = True
self.warmup_until = 0
self.prev_cvd_fast = 0.0
self.last_signal_ts = 0
self.last_signal_dir = ""
# P99大单追踪最近15分钟
self.recent_large_trades: deque = deque() # (time_ms, qty, is_buyer_maker)
self.recent_large_trades: deque = deque()
def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int):
now_ms = time_ms
@ -258,47 +165,34 @@ class SymbolState:
self.win_day.add(time_ms, qty, price, is_buyer_maker)
self.win_vwap.add(time_ms, qty, price, is_buyer_maker)
self.atr_calc.update(time_ms, price)
self.win_fast.trim(now_ms)
self.win_mid.trim(now_ms)
self.win_day.trim(now_ms)
self.win_vwap.trim(now_ms)
self.last_processed_id = agg_id
def compute_p95_p99(self) -> tuple[float, float]:
"""从24h窗口计算大单阈值"""
def compute_p95_p99(self) -> tuple:
if len(self.win_day.trades) < 100:
return 5.0, 10.0 # 默认兜底
qtys = [t[1] for t in self.win_day.trades]
qtys.sort()
return 5.0, 10.0
qtys = sorted([t[1] for t in self.win_day.trades])
n = len(qtys)
p95 = qtys[int(n * 0.95)]
p99 = qtys[int(n * 0.99)]
# BTC兜底5ETH兜底50
if "BTC" in self.symbol:
p95 = max(p95, 5.0)
p99 = max(p99, 10.0)
p95 = max(p95, 5.0); p99 = max(p99, 10.0)
else:
p95 = max(p95, 50.0)
p99 = max(p99, 100.0)
p95 = max(p95, 50.0); p99 = max(p99, 100.0)
return p95, p99
def update_large_trades(self, now_ms: int, p99: float):
"""更新最近15分钟的P99大单记录"""
cutoff = now_ms - 15 * 60 * 1000
while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff:
self.recent_large_trades.popleft()
# 从最近处理的trades中找大单
for t in self.win_fast.trades:
if t[1] >= p99 and t[0] > cutoff:
self.recent_large_trades.append((t[0], t[1], t[3]))
def evaluate_signal(self, now_ms: int) -> dict:
"""评估信号核心3条件 + 加分3条件"""
cvd_fast = self.win_fast.cvd
cvd_mid = self.win_mid.cvd
vwap = self.win_vwap.vwap
@ -306,198 +200,127 @@ class SymbolState:
atr_pct = self.atr_calc.atr_percentile
p95, p99 = self.compute_p95_p99()
self.update_large_trades(now_ms, p99)
# 当前价格用VWAP近似
price = vwap if vwap > 0 else 0
# CVD_fast斜率
cvd_fast_slope = cvd_fast - self.prev_cvd_fast
self.prev_cvd_fast = cvd_fast
result = {
"cvd_fast": cvd_fast,
"cvd_mid": cvd_mid,
"cvd_day": self.win_day.cvd,
"cvd_fast": cvd_fast, "cvd_mid": cvd_mid, "cvd_day": self.win_day.cvd,
"cvd_fast_slope": cvd_fast_slope,
"atr": atr,
"atr_pct": atr_pct,
"vwap": vwap,
"price": price,
"p95": p95,
"p99": p99,
"signal": None,
"direction": None,
"score": 0,
"atr": atr, "atr_pct": atr_pct, "vwap": vwap, "price": price,
"p95": p95, "p99": p99, "signal": None, "direction": None, "score": 0,
}
if self.warmup or price == 0 or atr == 0:
return result
# 冷却期检查
if now_ms - self.last_signal_ts < COOLDOWN_MS:
return result
# === 核心条件 ===
# 做多
long_core = (
cvd_fast > 0 and cvd_fast_slope > 0 and # CVD_fast正且上升
cvd_mid > 0 and # CVD_mid正
price > vwap # 价格在VWAP上方
)
# 做空
short_core = (
cvd_fast < 0 and cvd_fast_slope < 0 and
cvd_mid < 0 and
price < vwap
)
long_core = cvd_fast > 0 and cvd_fast_slope > 0 and cvd_mid > 0 and price > vwap
short_core = cvd_fast < 0 and cvd_fast_slope < 0 and cvd_mid < 0 and price < vwap
if not long_core and not short_core:
return result
direction = "LONG" if long_core else "SHORT"
# === 加分条件 ===
score = 0
# 1. ATR压缩→扩张 (+25)
if atr_pct > 60:
score += 25
# 2. 无反向P99大单 (+20)
has_adverse_large = False
for lt in self.recent_large_trades:
if direction == "LONG" and lt[2] == 1: # 大卖单对多头不利
has_adverse_large = True
elif direction == "SHORT" and lt[2] == 0: # 大买单对空头不利
has_adverse_large = True
if not has_adverse_large:
has_adverse = any(
(direction == "LONG" and lt[2] == 1) or (direction == "SHORT" and lt[2] == 0)
for lt in self.recent_large_trades
)
if not has_adverse:
score += 20
# 3. 资金费率配合 (+15) — 从rate_snapshots读取
# 暂时跳过,后续接入
# TODO: 接入资金费率条件
score += 0 # placeholder
result["signal"] = direction
result["direction"] = direction
result["score"] = score
# 更新冷却
self.last_signal_ts = now_ms
self.last_signal_dir = direction
return result
# ─── 主循环 ──────────────────────────────────────────────────────
# ─── PG DB操作 ───────────────────────────────────────────────────
def get_month_tables(conn: sqlite3.Connection) -> list[str]:
rows = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
).fetchall()
return [r["name"] for r in rows]
def load_historical(conn: sqlite3.Connection, state: SymbolState, window_ms: int):
"""冷启动:回灌历史数据到内存窗口"""
def load_historical(state: SymbolState, window_ms: int):
now_ms = int(time.time() * 1000)
start_ms = now_ms - window_ms
tables = get_month_tables(conn)
count = 0
for tname in tables:
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND time_ms >= ? ORDER BY agg_id ASC",
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND time_ms >= %s ORDER BY agg_id ASC",
(state.symbol, start_ms)
).fetchall()
)
while True:
rows = cur.fetchmany(5000)
if not rows:
break
for r in rows:
state.process_trade(r["agg_id"], r["time_ms"], r["price"], r["qty"], r["is_buyer_maker"])
state.process_trade(r[0], r[3], r[1], r[2], r[4])
count += 1
except Exception as e:
logger.warning(f"Error loading {tname}: {e}")
logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)")
state.warmup = False
def fetch_new_trades(conn: sqlite3.Connection, symbol: str, last_id: int) -> list:
"""增量读取新aggTrades"""
tables = get_month_tables(conn)
results = []
for tname in tables[-2:]: # 只查最近两个月表
try:
rows = conn.execute(
f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"WHERE symbol = ? AND agg_id > ? ORDER BY agg_id ASC LIMIT 10000",
def fetch_new_trades(symbol: str, last_id: int) -> list:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades "
"WHERE symbol = %s AND agg_id > %s ORDER BY agg_id ASC LIMIT 10000",
(symbol, last_id)
).fetchall()
results.extend(rows)
except Exception:
pass
return results
)
return [{"agg_id": r[0], "price": r[1], "qty": r[2], "time_ms": r[3], "is_buyer_maker": r[4]}
for r in cur.fetchall()]
def save_indicator(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
conn.execute("""
INSERT INTO signal_indicators
(ts, symbol, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile,
vwap_30m, price, p95_qty, p99_qty, score, signal)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
ts, symbol,
result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
result["atr"], result["atr_pct"],
result["vwap"], result["price"],
result["p95"], result["p99"],
result["score"], result.get("signal")
))
def save_indicator(ts: int, symbol: str, result: dict):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO signal_indicators "
"(ts,symbol,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,vwap_30m,price,p95_qty,p99_qty,score,signal) "
"VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
(ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"],
result["atr"], result["atr_pct"], result["vwap"], result["price"],
result["p95"], result["p99"], result["score"], result.get("signal"))
)
conn.commit()
def save_indicator_1m(conn: sqlite3.Connection, ts: int, symbol: str, result: dict):
"""每分钟聚合保存"""
def save_indicator_1m(ts: int, symbol: str, result: dict):
bar_ts = (ts // 60000) * 60000
existing = conn.execute(
"SELECT id FROM signal_indicators_1m WHERE ts = ? AND symbol = ?", (bar_ts, symbol)
).fetchone()
if existing:
conn.execute("""
UPDATE signal_indicators_1m SET
cvd_fast=?, cvd_mid=?, cvd_day=?, atr_5m=?, vwap_30m=?, price=?, score=?, signal=?
WHERE ts=? AND symbol=?
""", (
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
result["atr"], result["vwap"], result["price"],
result["score"], result.get("signal"),
bar_ts, symbol
))
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT id FROM signal_indicators_1m WHERE ts=%s AND symbol=%s", (bar_ts, symbol))
if cur.fetchone():
cur.execute(
"UPDATE signal_indicators_1m SET cvd_fast=%s,cvd_mid=%s,cvd_day=%s,atr_5m=%s,vwap_30m=%s,price=%s,score=%s,signal=%s WHERE ts=%s AND symbol=%s",
(result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"], result["vwap"],
result["price"], result["score"], result.get("signal"), bar_ts, symbol)
)
else:
conn.execute("""
INSERT INTO signal_indicators_1m
(ts, symbol, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
bar_ts, symbol,
result["cvd_fast"], result["cvd_mid"], result["cvd_day"],
result["atr"], result["vwap"], result["price"],
result["score"], result.get("signal")
))
cur.execute(
"INSERT INTO signal_indicators_1m (ts,symbol,cvd_fast,cvd_mid,cvd_day,atr_5m,vwap_30m,price,score,signal) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)",
(bar_ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"],
result["vwap"], result["price"], result["score"], result.get("signal"))
)
conn.commit()
# ─── 主循环 ──────────────────────────────────────────────────────
def main():
conn = get_conn()
init_tables(conn)
init_schema()
states = {sym: SymbolState(sym) for sym in SYMBOLS}
# 冷启动回灌4h数据
for sym, state in states.items():
load_historical(conn, state, WINDOW_MID)
load_historical(state, WINDOW_MID)
logger.info("=== Signal Engine 启动完成 ===")
logger.info("=== Signal Engine (PG) 启动完成 ===")
last_1m_save = {}
cycle = 0
@ -505,44 +328,27 @@ def main():
while True:
try:
now_ms = int(time.time() * 1000)
for sym, state in states.items():
# 增量读取新数据
new_trades = fetch_new_trades(conn, sym, state.last_processed_id)
new_trades = fetch_new_trades(sym, state.last_processed_id)
for t in new_trades:
state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"])
# 评估信号
result = state.evaluate_signal(now_ms)
save_indicator(now_ms, sym, result)
# 保存5秒指标
save_indicator(conn, now_ms, sym, result)
# 每分钟保存聚合
bar_1m = (now_ms // 60000) * 60000
if last_1m_save.get(sym) != bar_1m:
save_indicator_1m(conn, now_ms, sym, result)
save_indicator_1m(now_ms, sym, result)
last_1m_save[sym] = bar_1m
# 有信号则记录
if result.get("signal"):
logger.info(
f"[{sym}] 🚨 信号: {result['signal']} "
f"score={result['score']} price={result['price']:.1f} "
f"CVD_fast={result['cvd_fast']:.1f} CVD_mid={result['cvd_mid']:.1f}"
)
# TODO: Discord推送
# TODO: 写入signal_trades
logger.info(f"[{sym}] 🚨 信号: {result['signal']} score={result['score']} price={result['price']:.1f}")
cycle += 1
if cycle % 60 == 0: # 每5分钟打一次状态
if cycle % 60 == 0:
for sym, state in states.items():
r = state.evaluate_signal(now_ms)
logger.info(
f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} "
f"ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f} "
f"P95={r['p95']:.4f} P99={r['p99']:.4f}"
)
logger.info(f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f}")
except Exception as e:
logger.error(f"循环异常: {e}", exc_info=True)