""" db.py — PostgreSQL 数据库连接层 同步连接池(psycopg2)供脚本类使用 异步连接池(asyncpg)供FastAPI使用 """ import os import asyncio import asyncpg import psycopg2 import psycopg2.pool from contextlib import contextmanager # PG连接参数 PG_HOST = os.getenv("PG_HOST", "127.0.0.1") PG_PORT = int(os.getenv("PG_PORT", 5432)) PG_DB = os.getenv("PG_DB", "arb_engine") PG_USER = os.getenv("PG_USER", "arb") PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}" # ─── 同步连接池(psycopg2)───────────────────────────────────── _sync_pool = None def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool: global _sync_pool if _sync_pool is None: _sync_pool = psycopg2.pool.ThreadedConnectionPool( minconn=1, maxconn=5, host=PG_HOST, port=PG_PORT, dbname=PG_DB, user=PG_USER, password=PG_PASS, ) return _sync_pool @contextmanager def get_sync_conn(): """获取同步PG连接(自动归还到池)""" pool = get_sync_pool() conn = pool.getconn() try: yield conn finally: pool.putconn(conn) def sync_execute(sql: str, params=None, fetch=False): """简便执行SQL""" with get_sync_conn() as conn: with conn.cursor() as cur: cur.execute(sql, params) if fetch: cols = [desc[0] for desc in cur.description] return [dict(zip(cols, row)) for row in cur.fetchall()] conn.commit() def sync_executemany(sql: str, params_list: list): """批量执行""" with get_sync_conn() as conn: with conn.cursor() as cur: cur.executemany(sql, params_list) conn.commit() # ─── 异步连接池(asyncpg)───────────────────────────────────── _async_pool: asyncpg.Pool | None = None async def get_async_pool() -> asyncpg.Pool: global _async_pool if _async_pool is None: _async_pool = await asyncpg.create_pool( dsn=PG_DSN, min_size=2, max_size=10, ) return _async_pool async def async_fetch(sql: str, *args) -> list[dict]: """异步查询,返回dict列表""" pool = await get_async_pool() async with pool.acquire() as conn: rows = await conn.fetch(sql, *args) return [dict(r) for r in rows] async def async_fetchrow(sql: str, *args) -> dict | None: """异步查询单行""" pool = await get_async_pool() async with pool.acquire() as conn: row = await conn.fetchrow(sql, *args) return dict(row) if row else None async def async_execute(sql: str, *args): """异步执行""" pool = await get_async_pool() async with pool.acquire() as conn: await conn.execute(sql, *args) async def close_async_pool(): global _async_pool if _async_pool: await _async_pool.close() _async_pool = None # ─── 建表(PG Schema)────────────────────────────────────────── SCHEMA_SQL = """ -- 费率快照 CREATE TABLE IF NOT EXISTS rate_snapshots ( id BIGSERIAL PRIMARY KEY, ts BIGINT NOT NULL, btc_rate DOUBLE PRECISION NOT NULL, eth_rate DOUBLE PRECISION NOT NULL, btc_price DOUBLE PRECISION NOT NULL, eth_price DOUBLE PRECISION NOT NULL, btc_index_price DOUBLE PRECISION, eth_index_price DOUBLE PRECISION ); CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts); -- aggTrades元数据 CREATE TABLE IF NOT EXISTS agg_trades_meta ( symbol TEXT PRIMARY KEY, last_agg_id BIGINT NOT NULL, last_time_ms BIGINT, earliest_agg_id BIGINT, earliest_time_ms BIGINT, updated_at TEXT ); -- aggTrades主表(分区父表) CREATE TABLE IF NOT EXISTS agg_trades ( agg_id BIGINT NOT NULL, symbol TEXT NOT NULL, price DOUBLE PRECISION NOT NULL, qty DOUBLE PRECISION NOT NULL, time_ms BIGINT NOT NULL, is_buyer_maker SMALLINT NOT NULL, PRIMARY KEY (time_ms, symbol, agg_id) ) PARTITION BY RANGE (time_ms); CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC); CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id); -- Signal Engine表 CREATE TABLE IF NOT EXISTS signal_indicators ( id BIGSERIAL PRIMARY KEY, ts BIGINT NOT NULL, symbol TEXT NOT NULL, cvd_fast DOUBLE PRECISION, cvd_mid DOUBLE PRECISION, cvd_day DOUBLE PRECISION, cvd_fast_slope DOUBLE PRECISION, atr_5m DOUBLE PRECISION, atr_percentile DOUBLE PRECISION, vwap_30m DOUBLE PRECISION, price DOUBLE PRECISION, p95_qty DOUBLE PRECISION, p99_qty DOUBLE PRECISION, buy_vol_1m DOUBLE PRECISION, sell_vol_1m DOUBLE PRECISION, score INTEGER, signal TEXT ); CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts); CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts); CREATE TABLE IF NOT EXISTS signal_indicators_1m ( id BIGSERIAL PRIMARY KEY, ts BIGINT NOT NULL, symbol TEXT NOT NULL, cvd_fast DOUBLE PRECISION, cvd_mid DOUBLE PRECISION, cvd_day DOUBLE PRECISION, atr_5m DOUBLE PRECISION, vwap_30m DOUBLE PRECISION, price DOUBLE PRECISION, score INTEGER, signal TEXT ); CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts); CREATE TABLE IF NOT EXISTS signal_trades ( id BIGSERIAL PRIMARY KEY, ts_open BIGINT NOT NULL, ts_close BIGINT, symbol TEXT NOT NULL, direction TEXT NOT NULL, entry_price DOUBLE PRECISION, exit_price DOUBLE PRECISION, qty DOUBLE PRECISION, score INTEGER, pnl DOUBLE PRECISION, sl_price DOUBLE PRECISION, tp1_price DOUBLE PRECISION, tp2_price DOUBLE PRECISION, status TEXT DEFAULT 'open' ); -- 信号日志(旧表兼容) CREATE TABLE IF NOT EXISTS signal_logs ( id BIGSERIAL PRIMARY KEY, symbol TEXT, rate DOUBLE PRECISION, annualized DOUBLE PRECISION, sent_at TEXT, message TEXT ); -- 用户表(auth) CREATE TABLE IF NOT EXISTS users ( id BIGSERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL, password_hash TEXT NOT NULL, role TEXT DEFAULT 'user', created_at TIMESTAMP DEFAULT NOW() ); CREATE TABLE IF NOT EXISTS invite_codes ( id BIGSERIAL PRIMARY KEY, code TEXT UNIQUE NOT NULL, used_by BIGINT REFERENCES users(id), created_at TIMESTAMP DEFAULT NOW() ); """ def ensure_partitions(): """创建当月和下月的分区表""" import datetime now = datetime.datetime.utcnow() months = [] for delta in range(0, 3): # 当月+下2个月 d = now + datetime.timedelta(days=delta * 30) months.append(d.strftime("%Y%m")) with get_sync_conn() as conn: with conn.cursor() as cur: for m in set(months): year = int(m[:4]) month = int(m[4:]) # 计算分区范围(毫秒时间戳) start = datetime.datetime(year, month, 1) if month == 12: end = datetime.datetime(year + 1, 1, 1) else: end = datetime.datetime(year, month + 1, 1) start_ms = int(start.timestamp() * 1000) end_ms = int(end.timestamp() * 1000) part_name = f"agg_trades_{m}" try: cur.execute(f""" CREATE TABLE IF NOT EXISTS {part_name} PARTITION OF agg_trades FOR VALUES FROM ({start_ms}) TO ({end_ms}) """) except Exception: pass # 分区已存在 conn.commit() def init_schema(): """初始化全部PG表结构""" with get_sync_conn() as conn: with conn.cursor() as cur: for stmt in SCHEMA_SQL.split(";"): stmt = stmt.strip() if stmt: try: cur.execute(stmt) except Exception as e: conn.rollback() # 忽略已存在错误 continue conn.commit() ensure_partitions()