diff --git a/backend/agg_trades_collector.py b/backend/agg_trades_collector.py index cfe3dfe..5977281 100644 --- a/backend/agg_trades_collector.py +++ b/backend/agg_trades_collector.py @@ -1,27 +1,29 @@ """ -agg_trades_collector.py — aggTrades全量采集守护进程 +agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版) 架构: - WebSocket主链路:实时推送,延迟<100ms - REST补洞:断线重连后从last_agg_id追平 - 每分钟巡检:校验agg_id连续性,发现断档自动补洞 - - 批量写入:攒200条或1秒flush一次,减少WAL压力 - - 按月分表:agg_trades_YYYYMM,单表千万行内查询快 - - 健康接口:GET /collector/health 可监控 + - 批量写入:攒200条或1秒flush一次 + - PG分区表:按月自动分区,MVCC并发无锁冲突 """ import asyncio import json import logging import os -import sqlite3 import time from datetime import datetime, timezone from typing import Optional import httpx +import psycopg2 +import psycopg2.extras import websockets +from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS + logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", @@ -32,137 +34,85 @@ logging.basicConfig( ) logger = logging.getLogger("collector") -DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/3.0"} -# 批量写入缓冲 -_buffer: dict[str, list] = {s: [] for s in SYMBOLS} BATCH_SIZE = 200 -BATCH_TIMEOUT = 1.0 # seconds +BATCH_TIMEOUT = 1.0 # ─── DB helpers ────────────────────────────────────────────────── -def get_conn() -> sqlite3.Connection: - conn = sqlite3.connect(DB_PATH, timeout=30) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn - - -def table_name(ts_ms: int) -> str: - """按月分表:agg_trades_202602""" - dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) - return f"agg_trades_{dt.strftime('%Y%m')}" - - -def ensure_table(conn: sqlite3.Connection, tname: str): - conn.execute(f""" - CREATE TABLE IF NOT EXISTS {tname} ( - agg_id INTEGER PRIMARY KEY, - symbol TEXT NOT NULL, - price REAL NOT NULL, - qty REAL NOT NULL, - first_trade_id INTEGER, - last_trade_id INTEGER, - time_ms INTEGER NOT NULL, - is_buyer_maker INTEGER NOT NULL - ) - """) - conn.execute(f""" - CREATE INDEX IF NOT EXISTS idx_{tname}_sym_time - ON {tname}(symbol, time_ms) - """) - - -def ensure_meta_table(conn: sqlite3.Connection): - conn.execute(""" - CREATE TABLE IF NOT EXISTS agg_trades_meta ( - symbol TEXT PRIMARY KEY, - last_agg_id INTEGER NOT NULL, - last_time_ms INTEGER NOT NULL, - updated_at TEXT DEFAULT (datetime('now')) - ) - """) - - def get_last_agg_id(symbol: str) -> Optional[int]: try: - conn = get_conn() - ensure_meta_table(conn) - row = conn.execute( - "SELECT last_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,) - ).fetchone() - conn.close() - return row["last_agg_id"] if row else None + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT last_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,)) + row = cur.fetchone() + return row[0] if row else None except Exception as e: logger.error(f"get_last_agg_id error: {e}") return None -def update_meta(conn: sqlite3.Connection, symbol: str, last_agg_id: int, last_time_ms: int): - conn.execute(""" - INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at) - VALUES (?, ?, ?, datetime('now')) - ON CONFLICT(symbol) DO UPDATE SET - last_agg_id = excluded.last_agg_id, - last_time_ms = excluded.last_time_ms, - updated_at = excluded.updated_at - """, (symbol, last_agg_id, last_time_ms)) +def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int): + with conn.cursor() as cur: + cur.execute(""" + INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, updated_at) + VALUES (%s, %s, %s, NOW()) + ON CONFLICT(symbol) DO UPDATE SET + last_agg_id = EXCLUDED.last_agg_id, + last_time_ms = EXCLUDED.last_time_ms, + updated_at = NOW() + """, (symbol, last_agg_id, last_time_ms)) def flush_buffer(symbol: str, trades: list) -> int: - """写入一批trades,返回实际写入条数(去重后)""" + """写入一批trades到PG,返回实际写入条数""" if not trades: return 0 try: - conn = get_conn() - ensure_meta_table(conn) - # 按月分组 - by_month: dict[str, list] = {} - for t in trades: - tname = table_name(t["T"]) - if tname not in by_month: - by_month[tname] = [] - by_month[tname].append(t) + # 确保分区存在 + ensure_partitions() - inserted = 0 - last_agg_id = 0 - last_time_ms = 0 + with get_sync_conn() as conn: + with conn.cursor() as cur: + # 批量插入(ON CONFLICT忽略重复) + values = [] + last_agg_id = 0 + last_time_ms = 0 - for tname, batch in by_month.items(): - ensure_table(conn, tname) - for t in batch: - cur = conn.execute( - f"""INSERT OR IGNORE INTO {tname} - (agg_id, symbol, price, qty, first_trade_id, last_trade_id, time_ms, is_buyer_maker) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - ( - t["a"], # agg_id - symbol, - float(t["p"]), - float(t["q"]), - t.get("f"), # first_trade_id - t.get("l"), # last_trade_id - t["T"], # time_ms - 1 if t["m"] else 0, # is_buyer_maker - ) + for t in trades: + agg_id = t["a"] + time_ms = t["T"] + values.append(( + agg_id, symbol, + float(t["p"]), float(t["q"]), + time_ms, + 1 if t["m"] else 0, + )) + if agg_id > last_agg_id: + last_agg_id = agg_id + last_time_ms = time_ms + + # 批量INSERT + psycopg2.extras.execute_values( + cur, + """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) + VALUES %s + ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""", + values, + template="(%s, %s, %s, %s, %s, %s)", + page_size=1000, ) - if cur.rowcount > 0: - inserted += 1 - if t["a"] > last_agg_id: - last_agg_id = t["a"] - last_time_ms = t["T"] + inserted = cur.rowcount - if last_agg_id > 0: - update_meta(conn, symbol, last_agg_id, last_time_ms) + if last_agg_id > 0: + update_meta(conn, symbol, last_agg_id, last_time_ms) - conn.commit() - conn.close() - return inserted + conn.commit() + return inserted except Exception as e: logger.error(f"flush_buffer [{symbol}] error: {e}") return 0 @@ -171,7 +121,6 @@ def flush_buffer(symbol: str, trades: list) -> int: # ─── REST补洞 ──────────────────────────────────────────────────── async def rest_catchup(symbol: str, from_id: int) -> int: - """从from_id开始REST拉取,追平到最新,返回补洞条数""" total = 0 current_id = from_id logger.info(f"[{symbol}] REST catchup from agg_id={from_id}") @@ -195,10 +144,9 @@ async def rest_catchup(symbol: str, from_id: int) -> int: if last <= current_id: break current_id = last + 1 - # 如果拉到的比最新少1000条,说明追平了 if len(data) < 1000: break - await asyncio.sleep(0.1) # rate limit友好 + await asyncio.sleep(0.1) except Exception as e: logger.error(f"[{symbol}] REST catchup error: {e}") break @@ -210,7 +158,6 @@ async def rest_catchup(symbol: str, from_id: int) -> int: # ─── WebSocket采集 ─────────────────────────────────────────────── async def ws_collect(symbol: str): - """单Symbol的WS采集循环,自动断线重连+REST补洞""" stream = symbol.lower() + "@aggTrade" url = f"wss://fstream.binance.com/ws/{stream}" buffer: list = [] @@ -219,14 +166,13 @@ async def ws_collect(symbol: str): while True: try: - # 断线重连前先REST补洞 last_id = get_last_agg_id(symbol) if last_id is not None: await rest_catchup(symbol, last_id + 1) logger.info(f"[{symbol}] Connecting WS: {url}") async with websockets.connect(url, ping_interval=20, ping_timeout=10) as ws: - reconnect_delay = 1.0 # 连上了就重置 + reconnect_delay = 1.0 logger.info(f"[{symbol}] WS connected") async for raw in ws: @@ -236,7 +182,6 @@ async def ws_collect(symbol: str): buffer.append(msg) - # 批量flush:满200条或超1秒 now = time.time() if len(buffer) >= BATCH_SIZE or (now - last_flush) >= BATCH_TIMEOUT: count = flush_buffer(symbol, buffer) @@ -250,110 +195,90 @@ async def ws_collect(symbol: str): except Exception as e: logger.error(f"[{symbol}] WS error: {e}, reconnecting in {reconnect_delay}s") finally: - # flush剩余buffer if buffer: flush_buffer(symbol, buffer) buffer.clear() await asyncio.sleep(reconnect_delay) - reconnect_delay = min(reconnect_delay * 2, 30) # exponential backoff, max 30s + reconnect_delay = min(reconnect_delay * 2, 30) # ─── 连续性巡检 ────────────────────────────────────────────────── async def continuity_check(): - """每60秒巡检一次:检查各symbol最近的agg_id是否有断档""" while True: await asyncio.sleep(60) try: - conn = get_conn() - ensure_meta_table(conn) - for symbol in SYMBOLS: - row = conn.execute( - "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", - (symbol,) - ).fetchone() - if not row: - continue - # 检查最近1000条是否连续 - now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") - tname = f"agg_trades_{now_month}" - try: - rows = conn.execute( - f"SELECT agg_id FROM {tname} WHERE symbol = ? ORDER BY agg_id DESC LIMIT 100", - (symbol,) - ).fetchall() - if len(rows) < 2: - continue - ids = [r["agg_id"] for r in rows] - # 检查是否连续(降序) - gaps = [] - for i in range(len(ids) - 1): - diff = ids[i] - ids[i + 1] - if diff > 1: - gaps.append((ids[i + 1], ids[i], diff - 1)) - if gaps: - logger.warning(f"[{symbol}] Found {len(gaps)} gaps in recent data: {gaps[:3]}") - # 触发补洞 - min_gap_id = min(g[0] for g in gaps) - asyncio.create_task(rest_catchup(symbol, min_gap_id)) - else: - logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row['last_agg_id']}") - except Exception: - pass # 表可能还不存在 - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + for symbol in SYMBOLS: + cur.execute( + "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s", + (symbol,) + ) + row = cur.fetchone() + if not row: + continue + # 检查最近100条是否连续 + cur.execute( + "SELECT agg_id FROM agg_trades WHERE symbol = %s ORDER BY agg_id DESC LIMIT 100", + (symbol,) + ) + rows = cur.fetchall() + if len(rows) < 2: + continue + ids = [r[0] for r in rows] + gaps = [] + for i in range(len(ids) - 1): + diff = ids[i] - ids[i + 1] + if diff > 1: + gaps.append((ids[i + 1], ids[i], diff - 1)) + if gaps: + logger.warning(f"[{symbol}] Found {len(gaps)} gaps: {gaps[:3]}") + min_gap_id = min(g[0] for g in gaps) + asyncio.create_task(rest_catchup(symbol, min_gap_id)) + else: + logger.debug(f"[{symbol}] Continuity OK, last_agg_id={row[0]}") except Exception as e: logger.error(f"Continuity check error: {e}") -# ─── 每日完整性报告 ────────────────────────────────────────────── +# ─── 每小时报告 ────────────────────────────────────────────────── async def daily_report(): - """每小时生成一次完整性摘要日志""" while True: await asyncio.sleep(3600) try: - conn = get_conn() - ensure_meta_table(conn) - report_lines = ["=== AggTrades Integrity Report ==="] - for symbol in SYMBOLS: - row = conn.execute( - "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = ?", - (symbol,) - ).fetchone() - if not row: - report_lines.append(f" {symbol}: No data yet") - continue - last_dt = datetime.fromtimestamp(row["last_time_ms"] / 1000, tz=timezone.utc).isoformat() - # 统计本月总量 - now_month = datetime.now(tz=timezone.utc).strftime("%Y%m") - tname = f"agg_trades_{now_month}" - try: - count = conn.execute( - f"SELECT COUNT(*) as c FROM {tname} WHERE symbol = ?", (symbol,) - ).fetchone()["c"] - except Exception: - count = 0 - report_lines.append( - f" {symbol}: last_agg_id={row['last_agg_id']}, last_time={last_dt}, month_count={count:,}" - ) - conn.close() - logger.info("\n".join(report_lines)) + with get_sync_conn() as conn: + with conn.cursor() as cur: + report = ["=== AggTrades Integrity Report ==="] + for symbol in SYMBOLS: + cur.execute( + "SELECT last_agg_id, last_time_ms FROM agg_trades_meta WHERE symbol = %s", + (symbol,) + ) + row = cur.fetchone() + if not row: + report.append(f" {symbol}: No data yet") + continue + last_dt = datetime.fromtimestamp(row[1] / 1000, tz=timezone.utc).isoformat() + cur.execute( + "SELECT COUNT(*) FROM agg_trades WHERE symbol = %s", (symbol,) + ) + count = cur.fetchone()[0] + report.append(f" {symbol}: last_agg_id={row[0]}, last_time={last_dt}, total={count:,}") + logger.info("\n".join(report)) except Exception as e: - logger.error(f"Daily report error: {e}") + logger.error(f"Report error: {e}") # ─── 入口 ──────────────────────────────────────────────────────── async def main(): - logger.info("AggTrades Collector starting...") - # 确保基础表存在 - conn = get_conn() - ensure_meta_table(conn) - conn.commit() - conn.close() + logger.info("AggTrades Collector (PG) starting...") + # 确保分区存在 + ensure_partitions() - # 并行启动所有symbol的WS + 巡检 + 报告 tasks = [ ws_collect(sym) for sym in SYMBOLS ] + [ diff --git a/backend/backfill_agg_trades.py b/backend/backfill_agg_trades.py index 1381e5c..ab55a9b 100644 --- a/backend/backfill_agg_trades.py +++ b/backend/backfill_agg_trades.py @@ -1,27 +1,18 @@ """ -backfill_agg_trades.py — 历史aggTrades回补脚本 - -功能: - - 从当前DB最早agg_id向历史方向回补 - - Binance REST API分页拉取,每次1000条 - - INSERT OR IGNORE写入按月分表 - - 断点续传:记录进度到agg_trades_meta - - 速率控制:sleep 200ms/请求,429自动退避 - - BTC+ETH并行回补 - -用法: - python3 backfill_agg_trades.py [--days 30] [--symbol BTCUSDT] +backfill_agg_trades.py — 历史aggTrades回补脚本(PostgreSQL版) """ import argparse import logging import os -import sqlite3 import sys import time from datetime import datetime, timezone import requests +import psycopg2.extras + +from db import get_sync_conn, init_schema, ensure_partitions logging.basicConfig( level=logging.INFO, @@ -33,119 +24,56 @@ logging.basicConfig( ) logger = logging.getLogger("backfill") -DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" HEADERS = {"User-Agent": "Mozilla/5.0 ArbitrageEngine/backfill"} BATCH_SIZE = 1000 -SLEEP_MS = 2000 # 低优先级:2秒/请求,让出锁给实时服务 +SLEEP_MS = 2000 # ─── DB helpers ────────────────────────────────────────────────── -def get_conn() -> sqlite3.Connection: - conn = sqlite3.connect(DB_PATH, timeout=30) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn +def get_earliest_agg_id(symbol: str) -> int | None: + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = %s", (symbol,)) + row = cur.fetchone() + if row and row[0]: + return row[0] + # fallback: scan agg_trades + cur.execute("SELECT MIN(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,)) + row = cur.fetchone() + return row[0] if row else None -def table_name(ts_ms: int) -> str: - dt = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc) - return f"agg_trades_{dt.strftime('%Y%m')}" - - -def ensure_table(conn: sqlite3.Connection, tname: str): - conn.execute(f""" - CREATE TABLE IF NOT EXISTS {tname} ( - agg_id INTEGER PRIMARY KEY, - symbol TEXT NOT NULL, - price REAL NOT NULL, - qty REAL NOT NULL, - time_ms INTEGER NOT NULL, - is_buyer_maker INTEGER NOT NULL - ) - """) - conn.execute(f"CREATE INDEX IF NOT EXISTS idx_{tname}_symbol_time ON {tname}(symbol, time_ms)") - - -def ensure_meta(conn: sqlite3.Connection): - conn.execute(""" - CREATE TABLE IF NOT EXISTS agg_trades_meta ( - symbol TEXT PRIMARY KEY, - last_agg_id INTEGER, - last_time_ms INTEGER, - earliest_agg_id INTEGER, - earliest_time_ms INTEGER - ) - """) - # 添加earliest字段(如果旧表没有) - try: - conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_agg_id INTEGER") - conn.execute("ALTER TABLE agg_trades_meta ADD COLUMN earliest_time_ms INTEGER") - except Exception: - pass # 已存在 - - -def get_earliest_agg_id(conn: sqlite3.Connection, symbol: str) -> int | None: - """查找DB中该symbol的最小agg_id""" - # 先查meta表 - row = conn.execute( - "SELECT earliest_agg_id FROM agg_trades_meta WHERE symbol = ?", (symbol,) - ).fetchone() - if row and row["earliest_agg_id"]: - return row["earliest_agg_id"] - - # meta表没有,扫所有月表 - tables = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'" - ).fetchall() - min_id = None - for t in tables: - r = conn.execute( - f"SELECT MIN(agg_id) as mid FROM {t['name']} WHERE symbol = ?", (symbol,) - ).fetchone() - if r and r["mid"] is not None: - if min_id is None or r["mid"] < min_id: - min_id = r["mid"] - return min_id - - -def update_earliest_meta(conn: sqlite3.Connection, symbol: str, agg_id: int, time_ms: int): - # 先检查是否已有该symbol的记录 - row = conn.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = ?", (symbol,)).fetchone() - if row: - conn.execute(""" - UPDATE agg_trades_meta SET - earliest_agg_id = MIN(?, COALESCE(earliest_agg_id, ?)), - earliest_time_ms = MIN(?, COALESCE(earliest_time_ms, ?)) - WHERE symbol = ? - """, (agg_id, agg_id, time_ms, time_ms, symbol)) - else: - conn.execute(""" - INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms) - VALUES (?, ?, ?, ?, ?) - """, (symbol, agg_id, time_ms, agg_id, time_ms)) - conn.commit() +def update_earliest_meta(symbol: str, agg_id: int, time_ms: int): + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT symbol FROM agg_trades_meta WHERE symbol = %s", (symbol,)) + if cur.fetchone(): + cur.execute(""" + UPDATE agg_trades_meta SET + earliest_agg_id = LEAST(%s, COALESCE(earliest_agg_id, %s)), + earliest_time_ms = LEAST(%s, COALESCE(earliest_time_ms, %s)) + WHERE symbol = %s + """, (agg_id, agg_id, time_ms, time_ms, symbol)) + else: + cur.execute(""" + INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms) + VALUES (%s, %s, %s, %s, %s) + """, (symbol, agg_id, time_ms, agg_id, time_ms)) + conn.commit() # ─── REST API ──────────────────────────────────────────────────── -def fetch_agg_trades(symbol: str, from_id: int | None = None, - start_time: int | None = None, limit: int = 1000) -> list[dict]: - """从Binance拉取aggTrades,支持fromId和startTime""" +def fetch_agg_trades(symbol: str, from_id: int | None = None, limit: int = 1000) -> list: params = {"symbol": symbol, "limit": limit} if from_id is not None: params["fromId"] = from_id - elif start_time is not None: - params["startTime"] = start_time for attempt in range(5): try: - r = requests.get( - f"{BINANCE_FAPI}/aggTrades", - params=params, headers=HEADERS, timeout=15 - ) + r = requests.get(f"{BINANCE_FAPI}/aggTrades", params=params, headers=HEADERS, timeout=15) if r.status_code == 429: wait = min(60 * (2 ** attempt), 300) logger.warning(f"Rate limited (429), waiting {wait}s (attempt {attempt+1})") @@ -162,67 +90,47 @@ def fetch_agg_trades(symbol: str, from_id: int | None = None, return [] -def write_batch(conn: sqlite3.Connection, symbol: str, trades: list[dict]) -> int: - """批量写入,返回实际插入条数""" - inserted = 0 - tables_seen = set() - for t in trades: - tname = table_name(t["T"]) - if tname not in tables_seen: - ensure_table(conn, tname) - tables_seen.add(tname) - try: - conn.execute( - f"INSERT OR IGNORE INTO {tname} (agg_id, symbol, price, qty, time_ms, is_buyer_maker) " - f"VALUES (?, ?, ?, ?, ?, ?)", - (t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0) +def write_batch(symbol: str, trades: list) -> int: + if not trades: + return 0 + ensure_partitions() + values = [(t["a"], symbol, float(t["p"]), float(t["q"]), t["T"], 1 if t["m"] else 0) for t in trades] + with get_sync_conn() as conn: + with conn.cursor() as cur: + psycopg2.extras.execute_values( + cur, + "INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) " + "VALUES %s ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING", + values, + template="(%s, %s, %s, %s, %s, %s)", + page_size=1000, ) - inserted += 1 - except sqlite3.IntegrityError: - pass - conn.commit() + inserted = cur.rowcount + conn.commit() return inserted # ─── 主逻辑 ────────────────────────────────────────────────────── def backfill_symbol(symbol: str, target_days: int | None = None): - """回补单个symbol的历史数据""" - conn = get_conn() - ensure_meta(conn) - - earliest = get_earliest_agg_id(conn, symbol) + earliest = get_earliest_agg_id(symbol) if earliest is None: - logger.info(f"[{symbol}] DB无数据,从最新开始拉最近的数据确定起点") + logger.info(f"[{symbol}] DB无数据,拉最新确定起点") trades = fetch_agg_trades(symbol, limit=1) if not trades: logger.error(f"[{symbol}] 无法获取起始数据") return earliest = trades[0]["a"] - logger.info(f"[{symbol}] 当前最新agg_id: {earliest}") - # 计算目标:往前补到多少天前 - if target_days: - target_ts = int((time.time() - target_days * 86400) * 1000) - else: - target_ts = 0 # 拉到最早 + target_ts = int((time.time() - target_days * 86400) * 1000) if target_days else 0 logger.info(f"[{symbol}] 开始回补,当前最早agg_id={earliest}") - if target_days: - target_dt = datetime.fromtimestamp(target_ts / 1000, tz=timezone.utc) - logger.info(f"[{symbol}] 目标:回补到 {target_dt.strftime('%Y-%m-%d')} ({target_days}天前)") - total_inserted = 0 total_requests = 0 - rate_limit_hits = 0 current_from_id = earliest while True: - # 从current_from_id向前拉:先拉current_from_id - BATCH_SIZE*2的位置 - # Binance fromId是从该id开始正向拉,所以我们用endTime方式 - # 更可靠:用fromId = current_from_id - BATCH_SIZE,然后正向拉 fetch_from = max(0, current_from_id - BATCH_SIZE) - trades = fetch_agg_trades(symbol, from_id=fetch_from, limit=BATCH_SIZE) total_requests += 1 @@ -230,104 +138,56 @@ def backfill_symbol(symbol: str, target_days: int | None = None): logger.info(f"[{symbol}] 无更多数据,回补结束") break - # 过滤掉已有的(>= earliest的数据实时采集器已覆盖) new_trades = [t for t in trades if t["a"] < current_from_id] - if not new_trades: - # 没有更早的数据了 logger.info(f"[{symbol}] 已到达最早数据,回补结束") break - inserted = write_batch(conn, symbol, new_trades) + inserted = write_batch(symbol, new_trades) total_inserted += inserted oldest = min(new_trades, key=lambda x: x["a"]) oldest_time = datetime.fromtimestamp(oldest["T"] / 1000, tz=timezone.utc) current_from_id = oldest["a"] + update_earliest_meta(symbol, oldest["a"], oldest["T"]) - # 更新meta - update_earliest_meta(conn, symbol, oldest["a"], oldest["T"]) - - # 进度日志(每50批打一次) if total_requests % 50 == 0: logger.info( f"[{symbol}] 进度: {total_inserted:,} 条已插入, " - f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, " - f"agg_id={current_from_id:,}, " - f"请求数={total_requests}, 429次数={rate_limit_hits}" + f"当前位置: {oldest_time.strftime('%Y-%m-%d %H:%M')}, agg_id={current_from_id:,}, 请求数={total_requests}" ) - # 检查是否到达目标时间 if target_ts and oldest["T"] <= target_ts: logger.info(f"[{symbol}] 已达到目标时间,回补结束") break time.sleep(SLEEP_MS / 1000) - conn.close() - - logger.info( - f"[{symbol}] 回补完成: " - f"总插入={total_inserted:,}, 总请求={total_requests}, " - f"429次数={rate_limit_hits}" - ) - return total_inserted, total_requests, rate_limit_hits + logger.info(f"[{symbol}] 回补完成: 总插入={total_inserted:,}, 总请求={total_requests}") def check_continuity(symbol: str): - """检查agg_id连续性""" - conn = get_conn() - tables = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name" - ).fetchall() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT COUNT(*), MIN(agg_id), MAX(agg_id) FROM agg_trades WHERE symbol = %s", (symbol,)) + row = cur.fetchone() + total, min_id, max_id = row + if not total: + logger.info(f"[{symbol}] 无数据") + return + span = max_id - min_id + 1 + coverage = total / span * 100 if span > 0 else 0 + logger.info(f"[{symbol}] 总条数={total:,}, ID范围={min_id:,}~{max_id:,}, 覆盖率={coverage:.2f}%") - all_ids = [] - for t in tables: - rows = conn.execute( - f"SELECT agg_id FROM {t['name']} WHERE symbol = ? ORDER BY agg_id", - (symbol,) - ).fetchall() - all_ids.extend(r["agg_id"] for r in rows) - - conn.close() - - if not all_ids: - logger.info(f"[{symbol}] 无数据") - return - - all_ids.sort() - gaps = 0 - gap_ranges = [] - for i in range(1, len(all_ids)): - diff = all_ids[i] - all_ids[i-1] - if diff > 1: - gaps += 1 - if len(gap_ranges) < 10: - gap_ranges.append((all_ids[i-1], all_ids[i], diff - 1)) - - total = len(all_ids) - span = all_ids[-1] - all_ids[0] + 1 - coverage = total / span * 100 if span > 0 else 0 - - logger.info( - f"[{symbol}] 连续性检查: " - f"总条数={total:,}, ID范围={all_ids[0]:,}~{all_ids[-1]:,}, " - f"理论条数={span:,}, 覆盖率={coverage:.2f}%, 缺口数={gaps}" - ) - if gap_ranges: - for start, end, missing in gap_ranges[:5]: - logger.info(f" 缺口: {start} → {end} (缺{missing}条)") - - -# ─── CLI ───────────────────────────────────────────────────────── def main(): - parser = argparse.ArgumentParser(description="Backfill aggTrades from Binance") - parser.add_argument("--days", type=int, default=None, help="回补天数(默认全量)") - parser.add_argument("--symbol", type=str, default=None, help="指定symbol(默认BTC+ETH)") - parser.add_argument("--check", action="store_true", help="仅检查连续性") + parser = argparse.ArgumentParser() + parser.add_argument("--days", type=int, default=None) + parser.add_argument("--symbol", type=str, default=None) + parser.add_argument("--check", action="store_true") args = parser.parse_args() + init_schema() symbols = [args.symbol] if args.symbol else ["BTCUSDT", "ETHUSDT"] if args.check: @@ -336,18 +196,12 @@ def main(): return logger.info(f"=== 开始回补 === symbols={symbols}, days={args.days or '全量'}") - for sym in symbols: - logger.info(f"--- 回补 {sym} ---") - result = backfill_symbol(sym, target_days=args.days) - if result: - inserted, reqs, limits = result - logger.info(f"[{sym}] 结果: 插入{inserted:,}条, {reqs}次请求, {limits}次429") + backfill_symbol(sym, target_days=args.days) - logger.info("=== 回补完成,开始连续性检查 ===") + logger.info("=== 连续性检查 ===") for sym in symbols: check_continuity(sym) - logger.info("=== 全部完成 ===") diff --git a/backend/db.py b/backend/db.py new file mode 100644 index 0000000..a488f02 --- /dev/null +++ b/backend/db.py @@ -0,0 +1,282 @@ +""" +db.py — PostgreSQL 数据库连接层 +同步连接池(psycopg2)供脚本类使用 +异步连接池(asyncpg)供FastAPI使用 +""" + +import os +import asyncio +import asyncpg +import psycopg2 +import psycopg2.pool +from contextlib import contextmanager + +# PG连接参数 +PG_HOST = os.getenv("PG_HOST", "127.0.0.1") +PG_PORT = int(os.getenv("PG_PORT", 5432)) +PG_DB = os.getenv("PG_DB", "arb_engine") +PG_USER = os.getenv("PG_USER", "arb") +PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") + +PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}" + +# ─── 同步连接池(psycopg2)───────────────────────────────────── + +_sync_pool = None + +def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool: + global _sync_pool + if _sync_pool is None: + _sync_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=1, maxconn=5, + host=PG_HOST, port=PG_PORT, + dbname=PG_DB, user=PG_USER, password=PG_PASS, + ) + return _sync_pool + + +@contextmanager +def get_sync_conn(): + """获取同步PG连接(自动归还到池)""" + pool = get_sync_pool() + conn = pool.getconn() + try: + yield conn + finally: + pool.putconn(conn) + + +def sync_execute(sql: str, params=None, fetch=False): + """简便执行SQL""" + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, params) + if fetch: + cols = [desc[0] for desc in cur.description] + return [dict(zip(cols, row)) for row in cur.fetchall()] + conn.commit() + + +def sync_executemany(sql: str, params_list: list): + """批量执行""" + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.executemany(sql, params_list) + conn.commit() + + +# ─── 异步连接池(asyncpg)───────────────────────────────────── + +_async_pool: asyncpg.Pool | None = None + +async def get_async_pool() -> asyncpg.Pool: + global _async_pool + if _async_pool is None: + _async_pool = await asyncpg.create_pool( + dsn=PG_DSN, min_size=2, max_size=10, + ) + return _async_pool + + +async def async_fetch(sql: str, *args) -> list[dict]: + """异步查询,返回dict列表""" + pool = await get_async_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(sql, *args) + return [dict(r) for r in rows] + + +async def async_fetchrow(sql: str, *args) -> dict | None: + """异步查询单行""" + pool = await get_async_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow(sql, *args) + return dict(row) if row else None + + +async def async_execute(sql: str, *args): + """异步执行""" + pool = await get_async_pool() + async with pool.acquire() as conn: + await conn.execute(sql, *args) + + +async def close_async_pool(): + global _async_pool + if _async_pool: + await _async_pool.close() + _async_pool = None + + +# ─── 建表(PG Schema)────────────────────────────────────────── + +SCHEMA_SQL = """ +-- 费率快照 +CREATE TABLE IF NOT EXISTS rate_snapshots ( + id BIGSERIAL PRIMARY KEY, + ts BIGINT NOT NULL, + btc_rate DOUBLE PRECISION NOT NULL, + eth_rate DOUBLE PRECISION NOT NULL, + btc_price DOUBLE PRECISION NOT NULL, + eth_price DOUBLE PRECISION NOT NULL, + btc_index_price DOUBLE PRECISION, + eth_index_price DOUBLE PRECISION +); +CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts); + +-- aggTrades元数据 +CREATE TABLE IF NOT EXISTS agg_trades_meta ( + symbol TEXT PRIMARY KEY, + last_agg_id BIGINT NOT NULL, + last_time_ms BIGINT, + earliest_agg_id BIGINT, + earliest_time_ms BIGINT, + updated_at TEXT +); + +-- aggTrades主表(分区父表) +CREATE TABLE IF NOT EXISTS agg_trades ( + agg_id BIGINT NOT NULL, + symbol TEXT NOT NULL, + price DOUBLE PRECISION NOT NULL, + qty DOUBLE PRECISION NOT NULL, + time_ms BIGINT NOT NULL, + is_buyer_maker SMALLINT NOT NULL, + PRIMARY KEY (time_ms, symbol, agg_id) +) PARTITION BY RANGE (time_ms); + +CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC); +CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id); + +-- Signal Engine表 +CREATE TABLE IF NOT EXISTS signal_indicators ( + id BIGSERIAL PRIMARY KEY, + ts BIGINT NOT NULL, + symbol TEXT NOT NULL, + cvd_fast DOUBLE PRECISION, + cvd_mid DOUBLE PRECISION, + cvd_day DOUBLE PRECISION, + cvd_fast_slope DOUBLE PRECISION, + atr_5m DOUBLE PRECISION, + atr_percentile DOUBLE PRECISION, + vwap_30m DOUBLE PRECISION, + price DOUBLE PRECISION, + p95_qty DOUBLE PRECISION, + p99_qty DOUBLE PRECISION, + buy_vol_1m DOUBLE PRECISION, + sell_vol_1m DOUBLE PRECISION, + score INTEGER, + signal TEXT +); +CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts); +CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts); + +CREATE TABLE IF NOT EXISTS signal_indicators_1m ( + id BIGSERIAL PRIMARY KEY, + ts BIGINT NOT NULL, + symbol TEXT NOT NULL, + cvd_fast DOUBLE PRECISION, + cvd_mid DOUBLE PRECISION, + cvd_day DOUBLE PRECISION, + atr_5m DOUBLE PRECISION, + vwap_30m DOUBLE PRECISION, + price DOUBLE PRECISION, + score INTEGER, + signal TEXT +); +CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts); + +CREATE TABLE IF NOT EXISTS signal_trades ( + id BIGSERIAL PRIMARY KEY, + ts_open BIGINT NOT NULL, + ts_close BIGINT, + symbol TEXT NOT NULL, + direction TEXT NOT NULL, + entry_price DOUBLE PRECISION, + exit_price DOUBLE PRECISION, + qty DOUBLE PRECISION, + score INTEGER, + pnl DOUBLE PRECISION, + sl_price DOUBLE PRECISION, + tp1_price DOUBLE PRECISION, + tp2_price DOUBLE PRECISION, + status TEXT DEFAULT 'open' +); + +-- 信号日志(旧表兼容) +CREATE TABLE IF NOT EXISTS signal_logs ( + id BIGSERIAL PRIMARY KEY, + symbol TEXT, + rate DOUBLE PRECISION, + annualized DOUBLE PRECISION, + sent_at TEXT, + message TEXT +); + +-- 用户表(auth) +CREATE TABLE IF NOT EXISTS users ( + id BIGSERIAL PRIMARY KEY, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + role TEXT DEFAULT 'user', + created_at TIMESTAMP DEFAULT NOW() +); + +CREATE TABLE IF NOT EXISTS invite_codes ( + id BIGSERIAL PRIMARY KEY, + code TEXT UNIQUE NOT NULL, + used_by BIGINT REFERENCES users(id), + created_at TIMESTAMP DEFAULT NOW() +); +""" + + +def ensure_partitions(): + """创建当月和下月的分区表""" + import datetime + now = datetime.datetime.utcnow() + months = [] + for delta in range(0, 3): # 当月+下2个月 + d = now + datetime.timedelta(days=delta * 30) + months.append(d.strftime("%Y%m")) + + with get_sync_conn() as conn: + with conn.cursor() as cur: + for m in set(months): + year = int(m[:4]) + month = int(m[4:]) + # 计算分区范围(毫秒时间戳) + start = datetime.datetime(year, month, 1) + if month == 12: + end = datetime.datetime(year + 1, 1, 1) + else: + end = datetime.datetime(year, month + 1, 1) + start_ms = int(start.timestamp() * 1000) + end_ms = int(end.timestamp() * 1000) + part_name = f"agg_trades_{m}" + try: + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {part_name} + PARTITION OF agg_trades + FOR VALUES FROM ({start_ms}) TO ({end_ms}) + """) + except Exception: + pass # 分区已存在 + conn.commit() + + +def init_schema(): + """初始化全部PG表结构""" + with get_sync_conn() as conn: + with conn.cursor() as cur: + for stmt in SCHEMA_SQL.split(";"): + stmt = stmt.strip() + if stmt: + try: + cur.execute(stmt) + except Exception as e: + conn.rollback() + # 忽略已存在错误 + continue + conn.commit() + ensure_partitions() diff --git a/backend/main.py b/backend/main.py index 27505d1..47312e1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -2,9 +2,13 @@ from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware import httpx from datetime import datetime, timedelta -import asyncio, time, sqlite3, os +import asyncio, time, os from auth import router as auth_router, get_current_user, ensure_tables as ensure_auth_tables +from db import ( + init_schema, ensure_partitions, get_async_pool, async_fetch, async_fetchrow, async_execute, + close_async_pool, +) import datetime as _dt app = FastAPI(title="Arbitrage Engine API") @@ -21,7 +25,6 @@ app.include_router(auth_router) BINANCE_FAPI = "https://fapi.binance.com/fapi/v1" SYMBOLS = ["BTCUSDT", "ETHUSDT"] HEADERS = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} -DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") # 简单内存缓存(history/stats 60秒,rates 3秒) _cache: dict = {} @@ -36,50 +39,27 @@ def set_cache(key: str, data): _cache[key] = {"ts": time.time(), "data": data} -def init_db(): - conn = sqlite3.connect(DB_PATH) - conn.execute(""" - CREATE TABLE IF NOT EXISTS rate_snapshots ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts INTEGER NOT NULL, - btc_rate REAL NOT NULL, - eth_rate REAL NOT NULL, - btc_price REAL NOT NULL, - eth_price REAL NOT NULL, - btc_index_price REAL, - eth_index_price REAL - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts)") - conn.commit() - conn.close() - - -def save_snapshot(rates: dict): +async def save_snapshot(rates: dict): try: - conn = sqlite3.connect(DB_PATH) btc = rates.get("BTC", {}) eth = rates.get("ETH", {}) - conn.execute( - "INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES (?,?,?,?,?,?,?)", - ( - int(time.time()), - float(btc.get("lastFundingRate", 0)), - float(eth.get("lastFundingRate", 0)), - float(btc.get("markPrice", 0)), - float(eth.get("markPrice", 0)), - float(btc.get("indexPrice", 0)), - float(eth.get("indexPrice", 0)), - ) + await async_execute( + "INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) " + "VALUES ($1,$2,$3,$4,$5,$6,$7)", + int(time.time()), + float(btc.get("lastFundingRate", 0)), + float(eth.get("lastFundingRate", 0)), + float(btc.get("markPrice", 0)), + float(eth.get("markPrice", 0)), + float(btc.get("indexPrice", 0)), + float(eth.get("indexPrice", 0)), ) - conn.commit() - conn.close() except Exception as e: pass # 落库失败不影响API响应 async def background_snapshot_loop(): - """后台每2秒自动拉取费率+价格并落库,不依赖前端调用""" + """后台每2秒自动拉取费率+价格并落库""" while True: try: async with httpx.AsyncClient(timeout=5, headers=HEADERS) as client: @@ -97,7 +77,7 @@ async def background_snapshot_loop(): "indexPrice": float(data["indexPrice"]), } if result: - save_snapshot(result) + await save_snapshot(result) except Exception: pass await asyncio.sleep(2) @@ -105,11 +85,19 @@ async def background_snapshot_loop(): @app.on_event("startup") async def startup(): - init_db() + # 初始化PG schema + init_schema() ensure_auth_tables() + # 初始化asyncpg池 + await get_async_pool() asyncio.create_task(background_snapshot_loop()) +@app.on_event("shutdown") +async def shutdown(): + await close_async_pool() + + @app.get("/api/health") async def health(): return {"status": "ok", "timestamp": datetime.utcnow().isoformat()} @@ -137,37 +125,23 @@ async def get_rates(): "timestamp": data["time"], } set_cache("rates", result) - # 异步落库(不阻塞响应) - asyncio.create_task(asyncio.to_thread(save_snapshot, result)) + asyncio.create_task(save_snapshot(result)) return result @app.get("/api/snapshots") async def get_snapshots(hours: int = 24, limit: int = 5000, user: dict = Depends(get_current_user)): - """查询本地落库的实时快照数据""" since = int(time.time()) - hours * 3600 - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC LIMIT ?", - (since, limit) - ).fetchall() - conn.close() - return { - "count": len(rows), - "hours": hours, - "data": [dict(r) for r in rows] - } + rows = await async_fetch( + "SELECT ts, btc_rate, eth_rate, btc_price, eth_price FROM rate_snapshots " + "WHERE ts >= $1 ORDER BY ts ASC LIMIT $2", + since, limit + ) + return {"count": len(rows), "hours": hours, "data": rows} @app.get("/api/kline") async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, user: dict = Depends(get_current_user)): - """ - 从 rate_snapshots 聚合K线数据 - symbol: BTC | ETH - interval: 1m | 5m | 30m | 1h | 4h | 8h | 1d | 1w | 1M - 返回: [{time, open, high, low, close, price_open, price_high, price_low, price_close}] - """ interval_secs = { "1m": 60, "5m": 300, "30m": 1800, "1h": 3600, "4h": 14400, "8h": 28800, @@ -176,22 +150,20 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, bar_secs = interval_secs.get(interval, 300) rate_col = "btc_rate" if symbol.upper() == "BTC" else "eth_rate" price_col = "btc_price" if symbol.upper() == "BTC" else "eth_price" - - # 查询足够多的原始数据(limit根K * bar_secs最多需要的时间范围) since = int(time.time()) - bar_secs * limit - conn = sqlite3.connect(DB_PATH) - rows = conn.execute( - f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots WHERE ts >= ? ORDER BY ts ASC", - (since,) - ).fetchall() - conn.close() + + rows = await async_fetch( + f"SELECT ts, {rate_col} as rate, {price_col} as price FROM rate_snapshots " + f"WHERE ts >= $1 ORDER BY ts ASC", + since + ) if not rows: return {"symbol": symbol, "interval": interval, "data": []} - # 按bar_secs分组聚合OHLC bars: dict = {} - for ts, rate, price in rows: + for r in rows: + ts, rate, price = r["ts"], r["rate"], r["price"] bar_ts = (ts // bar_secs) * bar_secs if bar_ts not in bars: bars[bar_ts] = { @@ -209,20 +181,16 @@ async def get_kline(symbol: str = "BTC", interval: str = "5m", limit: int = 500, b["price_close"] = price data = sorted(bars.values(), key=lambda x: x["time"])[-limit:] - # 转换为万分之(费率 × 10000) for b in data: for k in ("open", "high", "low", "close"): b[k] = round(b[k] * 10000, 4) - return {"symbol": symbol, "interval": interval, "count": len(data), "data": data} @app.get("/api/stats/ytd") async def get_stats_ytd(user: dict = Depends(get_current_user)): - """今年以来(YTD)资金费率年化统计""" cached = get_cache("stats_ytd", 3600) if cached: return cached - # 今年1月1日 00:00 UTC import datetime year_start = int(datetime.datetime(datetime.datetime.utcnow().year, 1, 1).timestamp() * 1000) end_time = int(time.time() * 1000) @@ -252,16 +220,12 @@ async def get_stats_ytd(user: dict = Depends(get_current_user)): @app.get("/api/signals/history") async def get_signals_history(limit: int = 100, user: dict = Depends(get_current_user)): - """查询信号推送历史""" try: - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT ?", - (limit,) - ).fetchall() - conn.close() - return {"items": [dict(r) for r in rows]} + rows = await async_fetch( + "SELECT id, symbol, rate, annualized, sent_at, message FROM signal_logs ORDER BY sent_at DESC LIMIT $1", + limit + ) + return {"items": rows} except Exception as e: return {"items": [], "error": str(e)} @@ -334,20 +298,11 @@ async def get_stats(user: dict = Depends(get_current_user)): return stats -# ─── aggTrades 查询接口 ────────────────────────────────────────── +# ─── aggTrades 查询接口(PG版)─────────────────────────────────── @app.get("/api/trades/meta") async def get_trades_meta(user: dict = Depends(get_current_user)): - """aggTrades采集状态:各symbol最新agg_id和时间""" - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - try: - rows = conn.execute( - "SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta" - ).fetchall() - except Exception: - rows = [] - conn.close() + rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms, updated_at FROM agg_trades_meta") result = {} for r in rows: sym = r["symbol"].replace("USDT", "") @@ -367,46 +322,23 @@ async def get_trades_summary( interval: str = "1m", user: dict = Depends(get_current_user), ): - """分钟级聚合:买卖delta/成交速率/vwap""" if end_ms == 0: end_ms = int(time.time() * 1000) if start_ms == 0: - start_ms = end_ms - 3600 * 1000 # 默认1小时 + start_ms = end_ms - 3600 * 1000 interval_ms = {"1m": 60000, "5m": 300000, "15m": 900000, "1h": 3600000}.get(interval, 60000) sym_full = symbol.upper() + "USDT" - # 确定需要查哪些月表 - start_dt = _dt.datetime.fromtimestamp(start_ms / 1000, tz=_dt.timezone.utc) - end_dt = _dt.datetime.fromtimestamp(end_ms / 1000, tz=_dt.timezone.utc) - months = set() - cur = start_dt.replace(day=1) - while cur <= end_dt: - months.add(cur.strftime("%Y%m")) - if cur.month == 12: - cur = cur.replace(year=cur.year + 1, month=1) - else: - cur = cur.replace(month=cur.month + 1) + # PG分区表自动裁剪,直接查主表 + rows = await async_fetch( + "SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades " + "WHERE symbol = $1 AND time_ms >= $2 AND time_ms < $3 ORDER BY time_ms ASC", + sym_full, start_ms, end_ms + ) - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - all_rows = [] - for month in sorted(months): - tname = f"agg_trades_{month}" - try: - rows = conn.execute( - f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " - f"WHERE symbol = ? AND time_ms >= ? AND time_ms < ? ORDER BY time_ms ASC", - (sym_full, start_ms, end_ms) - ).fetchall() - all_rows.extend(rows) - except Exception: - pass - conn.close() - - # 按interval聚合 bars: dict = {} - for row in all_rows: + for row in rows: bar_ms = (row["time_ms"] // interval_ms) * interval_ms if bar_ms not in bars: bars[bar_ms] = {"time_ms": bar_ms, "buy_vol": 0.0, "sell_vol": 0.0, @@ -414,7 +346,7 @@ async def get_trades_summary( b = bars[bar_ms] qty = float(row["qty"]) price = float(row["price"]) - if row["is_buyer_maker"] == 0: # 主动买 + if row["is_buyer_maker"] == 0: b["buy_vol"] += qty else: b["sell_vol"] += qty @@ -446,47 +378,28 @@ async def get_trades_latest( limit: int = 30, user: dict = Depends(get_current_user), ): - """查最新N条原始成交记录(从本地DB,实时刷新用)""" sym_full = symbol.upper() + "USDT" - now_month = _dt.datetime.now(_dt.timezone.utc).strftime("%Y%m") - tname = f"agg_trades_{now_month}" - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - try: - rows = conn.execute( - f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " - f"WHERE symbol = ? ORDER BY agg_id DESC LIMIT ?", - (sym_full, limit) - ).fetchall() - except Exception: - rows = [] - conn.close() - return { - "symbol": symbol, - "count": len(rows), - "data": [dict(r) for r in rows], - } + rows = await async_fetch( + "SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades " + "WHERE symbol = $1 ORDER BY time_ms DESC, agg_id DESC LIMIT $2", + sym_full, limit + ) + return {"symbol": symbol, "count": len(rows), "data": rows} + + +@app.get("/api/collector/health") async def collector_health(user: dict = Depends(get_current_user)): - """采集器健康状态""" - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row now_ms = int(time.time() * 1000) + rows = await async_fetch("SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta") status = {} - try: - rows = conn.execute( - "SELECT symbol, last_agg_id, last_time_ms FROM agg_trades_meta" - ).fetchall() - for r in rows: - sym = r["symbol"].replace("USDT", "") - lag_s = (now_ms - r["last_time_ms"]) / 1000 - status[sym] = { - "last_agg_id": r["last_agg_id"], - "lag_seconds": round(lag_s, 1), - "healthy": lag_s < 30, # 30秒内有数据算健康 - } - except Exception: - pass - conn.close() + for r in rows: + sym = r["symbol"].replace("USDT", "") + lag_s = (now_ms - (r["last_time_ms"] or 0)) / 1000 + status[sym] = { + "last_agg_id": r["last_agg_id"], + "lag_seconds": round(lag_s, 1), + "healthy": lag_s < 30, + } return {"collector": status, "timestamp": now_ms} @@ -498,50 +411,29 @@ async def get_signal_indicators( minutes: int = 60, user: dict = Depends(get_current_user), ): - """获取signal_indicators_1m数据(前端图表用)""" sym_full = symbol.upper() + "USDT" now_ms = int(time.time() * 1000) start_ms = now_ms - minutes * 60 * 1000 - - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - try: - rows = conn.execute( - "SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal " - "FROM signal_indicators_1m WHERE symbol = ? AND ts >= ? ORDER BY ts ASC", - (sym_full, start_ms) - ).fetchall() - except Exception: - rows = [] - conn.close() - return { - "symbol": symbol, - "count": len(rows), - "data": [dict(r) for r in rows], - } + rows = await async_fetch( + "SELECT ts, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal " + "FROM signal_indicators_1m WHERE symbol = $1 AND ts >= $2 ORDER BY ts ASC", + sym_full, start_ms + ) + return {"symbol": symbol, "count": len(rows), "data": rows} @app.get("/api/signals/latest") -async def get_signal_latest( - user: dict = Depends(get_current_user), -): - """获取最新一条各symbol的指标快照""" - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row +async def get_signal_latest(user: dict = Depends(get_current_user)): result = {} for sym in ["BTCUSDT", "ETHUSDT"]: - try: - row = conn.execute( - "SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, " - "vwap_30m, price, p95_qty, p99_qty, score, signal " - "FROM signal_indicators WHERE symbol = ? ORDER BY ts DESC LIMIT 1", - (sym,) - ).fetchone() - if row: - result[sym.replace("USDT", "")] = dict(row) - except Exception: - pass - conn.close() + row = await async_fetchrow( + "SELECT ts, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, " + "vwap_30m, price, p95_qty, p99_qty, score, signal " + "FROM signal_indicators WHERE symbol = $1 ORDER BY ts DESC LIMIT 1", + sym + ) + if row: + result[sym.replace("USDT", "")] = row return result @@ -551,20 +443,13 @@ async def get_signal_trades( limit: int = 50, user: dict = Depends(get_current_user), ): - """获取信号交易记录""" - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - try: - if status == "all": - rows = conn.execute( - "SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT ?", (limit,) - ).fetchall() - else: - rows = conn.execute( - "SELECT * FROM signal_trades WHERE status = ? ORDER BY ts_open DESC LIMIT ?", - (status, limit) - ).fetchall() - except Exception: - rows = [] - conn.close() - return {"count": len(rows), "data": [dict(r) for r in rows]} + if status == "all": + rows = await async_fetch( + "SELECT * FROM signal_trades ORDER BY ts_open DESC LIMIT $1", limit + ) + else: + rows = await async_fetch( + "SELECT * FROM signal_trades WHERE status = $1 ORDER BY ts_open DESC LIMIT $2", + status, limit + ) + return {"count": len(rows), "data": rows} diff --git a/backend/signal_engine.py b/backend/signal_engine.py index c66b833..26d58f5 100644 --- a/backend/signal_engine.py +++ b/backend/signal_engine.py @@ -1,5 +1,5 @@ """ -signal_engine.py — V5 短线交易信号引擎 +signal_engine.py — V5 短线交易信号引擎(PostgreSQL版) 架构: - 独立PM2进程,每5秒循环 @@ -17,14 +17,13 @@ signal_engine.py — V5 短线交易信号引擎 import logging import os -import sqlite3 import time -import math -import statistics from collections import deque from datetime import datetime, timezone from typing import Optional +from db import get_sync_conn, init_schema + logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", @@ -35,110 +34,33 @@ logging.basicConfig( ) logger = logging.getLogger("signal-engine") -DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") SYMBOLS = ["BTCUSDT", "ETHUSDT"] LOOP_INTERVAL = 5 # 秒 # 窗口大小(毫秒) WINDOW_FAST = 30 * 60 * 1000 # 30分钟 WINDOW_MID = 4 * 3600 * 1000 # 4小时 -WINDOW_DAY = 24 * 3600 * 1000 # 24小时(用于P95/P99计算) +WINDOW_DAY = 24 * 3600 * 1000 # 24小时 WINDOW_VWAP = 30 * 60 * 1000 # 30分钟 # ATR参数 -ATR_PERIOD_MS = 5 * 60 * 1000 # 5分钟K线 -ATR_LENGTH = 14 # 14根 +ATR_PERIOD_MS = 5 * 60 * 1000 +ATR_LENGTH = 14 # 信号冷却 -COOLDOWN_MS = 10 * 60 * 1000 # 10分钟 - - -# ─── DB helpers ────────────────────────────────────────────────── - -def get_conn() -> sqlite3.Connection: - conn = sqlite3.connect(DB_PATH, timeout=30) - conn.row_factory = sqlite3.Row - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA synchronous=NORMAL") - return conn - - -def init_tables(conn: sqlite3.Connection): - conn.execute(""" - CREATE TABLE IF NOT EXISTS signal_indicators ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts INTEGER NOT NULL, - symbol TEXT NOT NULL, - cvd_fast REAL, - cvd_mid REAL, - cvd_day REAL, - cvd_fast_slope REAL, - atr_5m REAL, - atr_percentile REAL, - vwap_30m REAL, - price REAL, - p95_qty REAL, - p99_qty REAL, - buy_vol_1m REAL, - sell_vol_1m REAL, - score INTEGER, - signal TEXT - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts)") - conn.execute("CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts)") - - conn.execute(""" - CREATE TABLE IF NOT EXISTS signal_indicators_1m ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts INTEGER NOT NULL, - symbol TEXT NOT NULL, - cvd_fast REAL, - cvd_mid REAL, - cvd_day REAL, - atr_5m REAL, - vwap_30m REAL, - price REAL, - score INTEGER, - signal TEXT - ) - """) - conn.execute("CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts)") - - conn.execute(""" - CREATE TABLE IF NOT EXISTS signal_trades ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - ts_open INTEGER NOT NULL, - ts_close INTEGER, - symbol TEXT NOT NULL, - direction TEXT NOT NULL, - entry_price REAL, - exit_price REAL, - qty REAL, - score INTEGER, - pnl REAL, - sl_price REAL, - tp1_price REAL, - tp2_price REAL, - status TEXT DEFAULT 'open' - ) - """) - conn.commit() +COOLDOWN_MS = 10 * 60 * 1000 # ─── 滚动窗口 ─────────────────────────────────────────────────── class TradeWindow: - """滚动时间窗口,维护买卖量和价格数据""" - def __init__(self, window_ms: int): self.window_ms = window_ms - self.trades: deque = deque() # (time_ms, qty, price, is_buyer_maker) + self.trades: deque = deque() self.buy_vol = 0.0 self.sell_vol = 0.0 - self.pq_sum = 0.0 # price * qty 累加(VWAP用) - self.q_sum = 0.0 # qty累加 - self.quantities: list = [] # 用于P95/P99(仅24h窗口用) + self.pq_sum = 0.0 + self.q_sum = 0.0 def add(self, time_ms: int, qty: float, price: float, is_buyer_maker: int): self.trades.append((time_ms, qty, price, is_buyer_maker)) @@ -154,8 +76,7 @@ class TradeWindow: cutoff = now_ms - self.window_ms while self.trades and self.trades[0][0] < cutoff: t_ms, qty, price, ibm = self.trades.popleft() - pq = price * qty - self.pq_sum -= pq + self.pq_sum -= price * qty self.q_sum -= qty if ibm == 0: self.buy_vol -= qty @@ -170,29 +91,21 @@ class TradeWindow: def vwap(self) -> float: return self.pq_sum / self.q_sum if self.q_sum > 0 else 0.0 - @property - def total_vol(self) -> float: - return self.buy_vol + self.sell_vol - class ATRCalculator: - """5分钟K线ATR计算""" - def __init__(self, period_ms: int = ATR_PERIOD_MS, length: int = ATR_LENGTH): self.period_ms = period_ms self.length = length self.candles: deque = deque(maxlen=length + 1) self.current_candle: Optional[dict] = None - self.atr_history: deque = deque(maxlen=288) # 24h of 5m candles for percentile + self.atr_history: deque = deque(maxlen=288) def update(self, time_ms: int, price: float): bar_ms = (time_ms // self.period_ms) * self.period_ms if self.current_candle is None or self.current_candle["bar"] != bar_ms: if self.current_candle is not None: self.candles.append(self.current_candle) - self.current_candle = { - "bar": bar_ms, "open": price, "high": price, "low": price, "close": price - } + self.current_candle = {"bar": bar_ms, "open": price, "high": price, "low": price, "close": price} else: c = self.current_candle c["high"] = max(c["high"], price) @@ -212,7 +125,6 @@ class ATRCalculator: trs.append(tr) if not trs: return 0.0 - # EMA-style ATR atr_val = trs[0] for tr in trs[1:]: atr_val = (atr_val * (self.length - 1) + tr) / self.length @@ -220,7 +132,6 @@ class ATRCalculator: @property def atr_percentile(self) -> float: - """当前ATR在最近24h中的分位数""" current = self.atr if current == 0: return 50.0 @@ -233,8 +144,6 @@ class ATRCalculator: class SymbolState: - """单个交易对的完整状态""" - def __init__(self, symbol: str): self.symbol = symbol self.win_fast = TradeWindow(WINDOW_FAST) @@ -244,12 +153,10 @@ class SymbolState: self.atr_calc = ATRCalculator() self.last_processed_id = 0 self.warmup = True - self.warmup_until = 0 self.prev_cvd_fast = 0.0 self.last_signal_ts = 0 self.last_signal_dir = "" - # P99大单追踪(最近15分钟) - self.recent_large_trades: deque = deque() # (time_ms, qty, is_buyer_maker) + self.recent_large_trades: deque = deque() def process_trade(self, agg_id: int, time_ms: int, price: float, qty: float, is_buyer_maker: int): now_ms = time_ms @@ -258,47 +165,34 @@ class SymbolState: self.win_day.add(time_ms, qty, price, is_buyer_maker) self.win_vwap.add(time_ms, qty, price, is_buyer_maker) self.atr_calc.update(time_ms, price) - self.win_fast.trim(now_ms) self.win_mid.trim(now_ms) self.win_day.trim(now_ms) self.win_vwap.trim(now_ms) - self.last_processed_id = agg_id - def compute_p95_p99(self) -> tuple[float, float]: - """从24h窗口计算大单阈值""" + def compute_p95_p99(self) -> tuple: if len(self.win_day.trades) < 100: - return 5.0, 10.0 # 默认兜底 - - qtys = [t[1] for t in self.win_day.trades] - qtys.sort() + return 5.0, 10.0 + qtys = sorted([t[1] for t in self.win_day.trades]) n = len(qtys) p95 = qtys[int(n * 0.95)] p99 = qtys[int(n * 0.99)] - - # BTC兜底5,ETH兜底50 if "BTC" in self.symbol: - p95 = max(p95, 5.0) - p99 = max(p99, 10.0) + p95 = max(p95, 5.0); p99 = max(p99, 10.0) else: - p95 = max(p95, 50.0) - p99 = max(p99, 100.0) + p95 = max(p95, 50.0); p99 = max(p99, 100.0) return p95, p99 def update_large_trades(self, now_ms: int, p99: float): - """更新最近15分钟的P99大单记录""" cutoff = now_ms - 15 * 60 * 1000 while self.recent_large_trades and self.recent_large_trades[0][0] < cutoff: self.recent_large_trades.popleft() - - # 从最近处理的trades中找大单 for t in self.win_fast.trades: if t[1] >= p99 and t[0] > cutoff: self.recent_large_trades.append((t[0], t[1], t[3])) def evaluate_signal(self, now_ms: int) -> dict: - """评估信号:核心3条件 + 加分3条件""" cvd_fast = self.win_fast.cvd cvd_mid = self.win_mid.cvd vwap = self.win_vwap.vwap @@ -306,198 +200,127 @@ class SymbolState: atr_pct = self.atr_calc.atr_percentile p95, p99 = self.compute_p95_p99() self.update_large_trades(now_ms, p99) - - # 当前价格(用VWAP近似) price = vwap if vwap > 0 else 0 - - # CVD_fast斜率 cvd_fast_slope = cvd_fast - self.prev_cvd_fast self.prev_cvd_fast = cvd_fast result = { - "cvd_fast": cvd_fast, - "cvd_mid": cvd_mid, - "cvd_day": self.win_day.cvd, + "cvd_fast": cvd_fast, "cvd_mid": cvd_mid, "cvd_day": self.win_day.cvd, "cvd_fast_slope": cvd_fast_slope, - "atr": atr, - "atr_pct": atr_pct, - "vwap": vwap, - "price": price, - "p95": p95, - "p99": p99, - "signal": None, - "direction": None, - "score": 0, + "atr": atr, "atr_pct": atr_pct, "vwap": vwap, "price": price, + "p95": p95, "p99": p99, "signal": None, "direction": None, "score": 0, } if self.warmup or price == 0 or atr == 0: return result - - # 冷却期检查 if now_ms - self.last_signal_ts < COOLDOWN_MS: return result - # === 核心条件 === - # 做多 - long_core = ( - cvd_fast > 0 and cvd_fast_slope > 0 and # CVD_fast正且上升 - cvd_mid > 0 and # CVD_mid正 - price > vwap # 价格在VWAP上方 - ) - # 做空 - short_core = ( - cvd_fast < 0 and cvd_fast_slope < 0 and - cvd_mid < 0 and - price < vwap - ) + long_core = cvd_fast > 0 and cvd_fast_slope > 0 and cvd_mid > 0 and price > vwap + short_core = cvd_fast < 0 and cvd_fast_slope < 0 and cvd_mid < 0 and price < vwap if not long_core and not short_core: return result direction = "LONG" if long_core else "SHORT" - - # === 加分条件 === score = 0 - - # 1. ATR压缩→扩张 (+25) if atr_pct > 60: score += 25 - - # 2. 无反向P99大单 (+20) - has_adverse_large = False - for lt in self.recent_large_trades: - if direction == "LONG" and lt[2] == 1: # 大卖单对多头不利 - has_adverse_large = True - elif direction == "SHORT" and lt[2] == 0: # 大买单对空头不利 - has_adverse_large = True - if not has_adverse_large: + has_adverse = any( + (direction == "LONG" and lt[2] == 1) or (direction == "SHORT" and lt[2] == 0) + for lt in self.recent_large_trades + ) + if not has_adverse: score += 20 - # 3. 资金费率配合 (+15) — 从rate_snapshots读取 - # 暂时跳过,后续接入 - # TODO: 接入资金费率条件 - score += 0 # placeholder - result["signal"] = direction result["direction"] = direction result["score"] = score - - # 更新冷却 self.last_signal_ts = now_ms self.last_signal_dir = direction - return result -# ─── 主循环 ────────────────────────────────────────────────────── +# ─── PG DB操作 ─────────────────────────────────────────────────── -def get_month_tables(conn: sqlite3.Connection) -> list[str]: - rows = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name" - ).fetchall() - return [r["name"] for r in rows] - - -def load_historical(conn: sqlite3.Connection, state: SymbolState, window_ms: int): - """冷启动:回灌历史数据到内存窗口""" +def load_historical(state: SymbolState, window_ms: int): now_ms = int(time.time() * 1000) start_ms = now_ms - window_ms - tables = get_month_tables(conn) - count = 0 - for tname in tables: - try: - rows = conn.execute( - f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " - f"WHERE symbol = ? AND time_ms >= ? ORDER BY agg_id ASC", + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades " + "WHERE symbol = %s AND time_ms >= %s ORDER BY agg_id ASC", (state.symbol, start_ms) - ).fetchall() - for r in rows: - state.process_trade(r["agg_id"], r["time_ms"], r["price"], r["qty"], r["is_buyer_maker"]) - count += 1 - except Exception as e: - logger.warning(f"Error loading {tname}: {e}") - + ) + while True: + rows = cur.fetchmany(5000) + if not rows: + break + for r in rows: + state.process_trade(r[0], r[3], r[1], r[2], r[4]) + count += 1 logger.info(f"[{state.symbol}] 冷启动完成: 加载{count:,}条历史数据 (窗口={window_ms//3600000}h)") state.warmup = False -def fetch_new_trades(conn: sqlite3.Connection, symbol: str, last_id: int) -> list: - """增量读取新aggTrades""" - tables = get_month_tables(conn) - results = [] - for tname in tables[-2:]: # 只查最近两个月表 - try: - rows = conn.execute( - f"SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM {tname} " - f"WHERE symbol = ? AND agg_id > ? ORDER BY agg_id ASC LIMIT 10000", +def fetch_new_trades(symbol: str, last_id: int) -> list: + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT agg_id, price, qty, time_ms, is_buyer_maker FROM agg_trades " + "WHERE symbol = %s AND agg_id > %s ORDER BY agg_id ASC LIMIT 10000", (symbol, last_id) - ).fetchall() - results.extend(rows) - except Exception: - pass - return results + ) + return [{"agg_id": r[0], "price": r[1], "qty": r[2], "time_ms": r[3], "is_buyer_maker": r[4]} + for r in cur.fetchall()] -def save_indicator(conn: sqlite3.Connection, ts: int, symbol: str, result: dict): - conn.execute(""" - INSERT INTO signal_indicators - (ts, symbol, cvd_fast, cvd_mid, cvd_day, cvd_fast_slope, atr_5m, atr_percentile, - vwap_30m, price, p95_qty, p99_qty, score, signal) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - ts, symbol, - result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"], - result["atr"], result["atr_pct"], - result["vwap"], result["price"], - result["p95"], result["p99"], - result["score"], result.get("signal") - )) - conn.commit() +def save_indicator(ts: int, symbol: str, result: dict): + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute( + "INSERT INTO signal_indicators " + "(ts,symbol,cvd_fast,cvd_mid,cvd_day,cvd_fast_slope,atr_5m,atr_percentile,vwap_30m,price,p95_qty,p99_qty,score,signal) " + "VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + (ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["cvd_fast_slope"], + result["atr"], result["atr_pct"], result["vwap"], result["price"], + result["p95"], result["p99"], result["score"], result.get("signal")) + ) + conn.commit() -def save_indicator_1m(conn: sqlite3.Connection, ts: int, symbol: str, result: dict): - """每分钟聚合保存""" +def save_indicator_1m(ts: int, symbol: str, result: dict): bar_ts = (ts // 60000) * 60000 - existing = conn.execute( - "SELECT id FROM signal_indicators_1m WHERE ts = ? AND symbol = ?", (bar_ts, symbol) - ).fetchone() - if existing: - conn.execute(""" - UPDATE signal_indicators_1m SET - cvd_fast=?, cvd_mid=?, cvd_day=?, atr_5m=?, vwap_30m=?, price=?, score=?, signal=? - WHERE ts=? AND symbol=? - """, ( - result["cvd_fast"], result["cvd_mid"], result["cvd_day"], - result["atr"], result["vwap"], result["price"], - result["score"], result.get("signal"), - bar_ts, symbol - )) - else: - conn.execute(""" - INSERT INTO signal_indicators_1m - (ts, symbol, cvd_fast, cvd_mid, cvd_day, atr_5m, vwap_30m, price, score, signal) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - bar_ts, symbol, - result["cvd_fast"], result["cvd_mid"], result["cvd_day"], - result["atr"], result["vwap"], result["price"], - result["score"], result.get("signal") - )) - conn.commit() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT id FROM signal_indicators_1m WHERE ts=%s AND symbol=%s", (bar_ts, symbol)) + if cur.fetchone(): + cur.execute( + "UPDATE signal_indicators_1m SET cvd_fast=%s,cvd_mid=%s,cvd_day=%s,atr_5m=%s,vwap_30m=%s,price=%s,score=%s,signal=%s WHERE ts=%s AND symbol=%s", + (result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"], result["vwap"], + result["price"], result["score"], result.get("signal"), bar_ts, symbol) + ) + else: + cur.execute( + "INSERT INTO signal_indicators_1m (ts,symbol,cvd_fast,cvd_mid,cvd_day,atr_5m,vwap_30m,price,score,signal) VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", + (bar_ts, symbol, result["cvd_fast"], result["cvd_mid"], result["cvd_day"], result["atr"], + result["vwap"], result["price"], result["score"], result.get("signal")) + ) + conn.commit() +# ─── 主循环 ────────────────────────────────────────────────────── + def main(): - conn = get_conn() - init_tables(conn) - + init_schema() states = {sym: SymbolState(sym) for sym in SYMBOLS} - # 冷启动:回灌4h数据 for sym, state in states.items(): - load_historical(conn, state, WINDOW_MID) + load_historical(state, WINDOW_MID) - logger.info("=== Signal Engine 启动完成 ===") + logger.info("=== Signal Engine (PG) 启动完成 ===") last_1m_save = {} cycle = 0 @@ -505,44 +328,27 @@ def main(): while True: try: now_ms = int(time.time() * 1000) - for sym, state in states.items(): - # 增量读取新数据 - new_trades = fetch_new_trades(conn, sym, state.last_processed_id) + new_trades = fetch_new_trades(sym, state.last_processed_id) for t in new_trades: state.process_trade(t["agg_id"], t["time_ms"], t["price"], t["qty"], t["is_buyer_maker"]) - # 评估信号 result = state.evaluate_signal(now_ms) + save_indicator(now_ms, sym, result) - # 保存5秒指标 - save_indicator(conn, now_ms, sym, result) - - # 每分钟保存聚合 bar_1m = (now_ms // 60000) * 60000 if last_1m_save.get(sym) != bar_1m: - save_indicator_1m(conn, now_ms, sym, result) + save_indicator_1m(now_ms, sym, result) last_1m_save[sym] = bar_1m - # 有信号则记录 if result.get("signal"): - logger.info( - f"[{sym}] 🚨 信号: {result['signal']} " - f"score={result['score']} price={result['price']:.1f} " - f"CVD_fast={result['cvd_fast']:.1f} CVD_mid={result['cvd_mid']:.1f}" - ) - # TODO: Discord推送 - # TODO: 写入signal_trades + logger.info(f"[{sym}] 🚨 信号: {result['signal']} score={result['score']} price={result['price']:.1f}") cycle += 1 - if cycle % 60 == 0: # 每5分钟打一次状态 + if cycle % 60 == 0: for sym, state in states.items(): r = state.evaluate_signal(now_ms) - logger.info( - f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} " - f"ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f} " - f"P95={r['p95']:.4f} P99={r['p99']:.4f}" - ) + logger.info(f"[{sym}] 状态: CVD_fast={r['cvd_fast']:.1f} CVD_mid={r['cvd_mid']:.1f} ATR={r['atr']:.2f}({r['atr_pct']:.0f}%) VWAP={r['vwap']:.1f}") except Exception as e: logger.error(f"循环异常: {e}", exc_info=True)