304 lines
8.9 KiB
Python
304 lines
8.9 KiB
Python
"""
|
||
db.py — PostgreSQL 数据库连接层
|
||
同步连接池(psycopg2)供脚本类使用
|
||
异步连接池(asyncpg)供FastAPI使用
|
||
"""
|
||
|
||
import os
|
||
import asyncio
|
||
import asyncpg
|
||
import psycopg2
|
||
import psycopg2.pool
|
||
from contextlib import contextmanager
|
||
|
||
# PG连接参数
|
||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
||
PG_PORT = int(os.getenv("PG_PORT", 5432))
|
||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||
PG_USER = os.getenv("PG_USER", "arb")
|
||
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
||
|
||
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
||
|
||
# ─── 同步连接池(psycopg2)─────────────────────────────────────
|
||
|
||
_sync_pool = None
|
||
|
||
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
|
||
global _sync_pool
|
||
if _sync_pool is None:
|
||
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
|
||
minconn=1, maxconn=5,
|
||
host=PG_HOST, port=PG_PORT,
|
||
dbname=PG_DB, user=PG_USER, password=PG_PASS,
|
||
)
|
||
return _sync_pool
|
||
|
||
|
||
@contextmanager
|
||
def get_sync_conn():
|
||
"""获取同步PG连接(自动归还到池)"""
|
||
pool = get_sync_pool()
|
||
conn = pool.getconn()
|
||
try:
|
||
yield conn
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
def sync_execute(sql: str, params=None, fetch=False):
|
||
"""简便执行SQL"""
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.execute(sql, params)
|
||
if fetch:
|
||
cols = [desc[0] for desc in cur.description]
|
||
return [dict(zip(cols, row)) for row in cur.fetchall()]
|
||
conn.commit()
|
||
|
||
|
||
def sync_executemany(sql: str, params_list: list):
|
||
"""批量执行"""
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
cur.executemany(sql, params_list)
|
||
conn.commit()
|
||
|
||
|
||
# ─── 异步连接池(asyncpg)─────────────────────────────────────
|
||
|
||
_async_pool: asyncpg.Pool | None = None
|
||
|
||
async def get_async_pool() -> asyncpg.Pool:
|
||
global _async_pool
|
||
if _async_pool is None:
|
||
_async_pool = await asyncpg.create_pool(
|
||
dsn=PG_DSN, min_size=2, max_size=10,
|
||
)
|
||
return _async_pool
|
||
|
||
|
||
async def async_fetch(sql: str, *args) -> list[dict]:
|
||
"""异步查询,返回dict列表"""
|
||
pool = await get_async_pool()
|
||
async with pool.acquire() as conn:
|
||
rows = await conn.fetch(sql, *args)
|
||
return [dict(r) for r in rows]
|
||
|
||
|
||
async def async_fetchrow(sql: str, *args) -> dict | None:
|
||
"""异步查询单行"""
|
||
pool = await get_async_pool()
|
||
async with pool.acquire() as conn:
|
||
row = await conn.fetchrow(sql, *args)
|
||
return dict(row) if row else None
|
||
|
||
|
||
async def async_execute(sql: str, *args):
|
||
"""异步执行"""
|
||
pool = await get_async_pool()
|
||
async with pool.acquire() as conn:
|
||
await conn.execute(sql, *args)
|
||
|
||
|
||
async def close_async_pool():
|
||
global _async_pool
|
||
if _async_pool:
|
||
await _async_pool.close()
|
||
_async_pool = None
|
||
|
||
|
||
# ─── 建表(PG Schema)──────────────────────────────────────────
|
||
|
||
SCHEMA_SQL = """
|
||
-- 费率快照
|
||
CREATE TABLE IF NOT EXISTS rate_snapshots (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
ts BIGINT NOT NULL,
|
||
btc_rate DOUBLE PRECISION NOT NULL,
|
||
eth_rate DOUBLE PRECISION NOT NULL,
|
||
btc_price DOUBLE PRECISION NOT NULL,
|
||
eth_price DOUBLE PRECISION NOT NULL,
|
||
btc_index_price DOUBLE PRECISION,
|
||
eth_index_price DOUBLE PRECISION
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
|
||
|
||
-- aggTrades元数据
|
||
CREATE TABLE IF NOT EXISTS agg_trades_meta (
|
||
symbol TEXT PRIMARY KEY,
|
||
last_agg_id BIGINT NOT NULL,
|
||
last_time_ms BIGINT,
|
||
earliest_agg_id BIGINT,
|
||
earliest_time_ms BIGINT,
|
||
updated_at TEXT
|
||
);
|
||
|
||
-- aggTrades主表(分区父表)
|
||
CREATE TABLE IF NOT EXISTS agg_trades (
|
||
agg_id BIGINT NOT NULL,
|
||
symbol TEXT NOT NULL,
|
||
price DOUBLE PRECISION NOT NULL,
|
||
qty DOUBLE PRECISION NOT NULL,
|
||
time_ms BIGINT NOT NULL,
|
||
is_buyer_maker SMALLINT NOT NULL,
|
||
PRIMARY KEY (time_ms, symbol, agg_id)
|
||
) PARTITION BY RANGE (time_ms);
|
||
|
||
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
|
||
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
|
||
|
||
-- Signal Engine表
|
||
CREATE TABLE IF NOT EXISTS signal_indicators (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
ts BIGINT NOT NULL,
|
||
symbol TEXT NOT NULL,
|
||
cvd_fast DOUBLE PRECISION,
|
||
cvd_mid DOUBLE PRECISION,
|
||
cvd_day DOUBLE PRECISION,
|
||
cvd_fast_slope DOUBLE PRECISION,
|
||
atr_5m DOUBLE PRECISION,
|
||
atr_percentile DOUBLE PRECISION,
|
||
vwap_30m DOUBLE PRECISION,
|
||
price DOUBLE PRECISION,
|
||
p95_qty DOUBLE PRECISION,
|
||
p99_qty DOUBLE PRECISION,
|
||
buy_vol_1m DOUBLE PRECISION,
|
||
sell_vol_1m DOUBLE PRECISION,
|
||
score INTEGER,
|
||
signal TEXT
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
|
||
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
|
||
|
||
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
ts BIGINT NOT NULL,
|
||
symbol TEXT NOT NULL,
|
||
cvd_fast DOUBLE PRECISION,
|
||
cvd_mid DOUBLE PRECISION,
|
||
cvd_day DOUBLE PRECISION,
|
||
atr_5m DOUBLE PRECISION,
|
||
vwap_30m DOUBLE PRECISION,
|
||
price DOUBLE PRECISION,
|
||
score INTEGER,
|
||
signal TEXT
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
|
||
|
||
CREATE TABLE IF NOT EXISTS signal_trades (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
ts_open BIGINT NOT NULL,
|
||
ts_close BIGINT,
|
||
symbol TEXT NOT NULL,
|
||
direction TEXT NOT NULL,
|
||
entry_price DOUBLE PRECISION,
|
||
exit_price DOUBLE PRECISION,
|
||
qty DOUBLE PRECISION,
|
||
score INTEGER,
|
||
pnl DOUBLE PRECISION,
|
||
sl_price DOUBLE PRECISION,
|
||
tp1_price DOUBLE PRECISION,
|
||
tp2_price DOUBLE PRECISION,
|
||
status TEXT DEFAULT 'open'
|
||
);
|
||
|
||
-- 信号日志(旧表兼容)
|
||
CREATE TABLE IF NOT EXISTS signal_logs (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
symbol TEXT,
|
||
rate DOUBLE PRECISION,
|
||
annualized DOUBLE PRECISION,
|
||
sent_at TEXT,
|
||
message TEXT
|
||
);
|
||
|
||
-- 用户表(auth)
|
||
CREATE TABLE IF NOT EXISTS users (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
email TEXT UNIQUE NOT NULL,
|
||
password_hash TEXT NOT NULL,
|
||
role TEXT DEFAULT 'user',
|
||
created_at TIMESTAMP DEFAULT NOW()
|
||
);
|
||
|
||
CREATE TABLE IF NOT EXISTS invite_codes (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
code TEXT UNIQUE NOT NULL,
|
||
used_by BIGINT REFERENCES users(id),
|
||
created_at TIMESTAMP DEFAULT NOW()
|
||
);
|
||
|
||
-- 模拟盘交易表
|
||
CREATE TABLE IF NOT EXISTS paper_trades (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
symbol TEXT NOT NULL,
|
||
direction TEXT NOT NULL,
|
||
score INT NOT NULL,
|
||
tier TEXT NOT NULL,
|
||
entry_price DOUBLE PRECISION NOT NULL,
|
||
entry_ts BIGINT NOT NULL,
|
||
exit_price DOUBLE PRECISION,
|
||
exit_ts BIGINT,
|
||
tp1_price DOUBLE PRECISION NOT NULL,
|
||
tp2_price DOUBLE PRECISION NOT NULL,
|
||
sl_price DOUBLE PRECISION NOT NULL,
|
||
tp1_hit BOOLEAN DEFAULT FALSE,
|
||
status TEXT DEFAULT 'active',
|
||
pnl_r DOUBLE PRECISION DEFAULT 0,
|
||
atr_at_entry DOUBLE PRECISION DEFAULT 0,
|
||
created_at TIMESTAMP DEFAULT NOW()
|
||
);
|
||
"""
|
||
|
||
|
||
def ensure_partitions():
|
||
"""创建当月和下月的分区表"""
|
||
import datetime
|
||
now = datetime.datetime.utcnow()
|
||
months = []
|
||
for delta in range(0, 3): # 当月+下2个月
|
||
d = now + datetime.timedelta(days=delta * 30)
|
||
months.append(d.strftime("%Y%m"))
|
||
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
for m in set(months):
|
||
year = int(m[:4])
|
||
month = int(m[4:])
|
||
# 计算分区范围(毫秒时间戳)
|
||
start = datetime.datetime(year, month, 1)
|
||
if month == 12:
|
||
end = datetime.datetime(year + 1, 1, 1)
|
||
else:
|
||
end = datetime.datetime(year, month + 1, 1)
|
||
start_ms = int(start.timestamp() * 1000)
|
||
end_ms = int(end.timestamp() * 1000)
|
||
part_name = f"agg_trades_{m}"
|
||
try:
|
||
cur.execute(f"""
|
||
CREATE TABLE IF NOT EXISTS {part_name}
|
||
PARTITION OF agg_trades
|
||
FOR VALUES FROM ({start_ms}) TO ({end_ms})
|
||
""")
|
||
except Exception:
|
||
pass # 分区已存在
|
||
conn.commit()
|
||
|
||
|
||
def init_schema():
|
||
"""初始化全部PG表结构"""
|
||
with get_sync_conn() as conn:
|
||
with conn.cursor() as cur:
|
||
for stmt in SCHEMA_SQL.split(";"):
|
||
stmt = stmt.strip()
|
||
if stmt:
|
||
try:
|
||
cur.execute(stmt)
|
||
except Exception as e:
|
||
conn.rollback()
|
||
# 忽略已存在错误
|
||
continue
|
||
conn.commit()
|
||
ensure_partitions()
|