arbitrage-engine/backend/db.py

304 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
db.py — PostgreSQL 数据库连接层
同步连接池psycopg2供脚本类使用
异步连接池asyncpg供FastAPI使用
"""
import os
import asyncio
import asyncpg
import psycopg2
import psycopg2.pool
from contextlib import contextmanager
# PG连接参数
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_PORT = int(os.getenv("PG_PORT", 5432))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
# ─── 同步连接池psycopg2─────────────────────────────────────
_sync_pool = None
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
global _sync_pool
if _sync_pool is None:
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=1, maxconn=5,
host=PG_HOST, port=PG_PORT,
dbname=PG_DB, user=PG_USER, password=PG_PASS,
)
return _sync_pool
@contextmanager
def get_sync_conn():
"""获取同步PG连接自动归还到池"""
pool = get_sync_pool()
conn = pool.getconn()
try:
yield conn
finally:
pool.putconn(conn)
def sync_execute(sql: str, params=None, fetch=False):
"""简便执行SQL"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params)
if fetch:
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
conn.commit()
def sync_executemany(sql: str, params_list: list):
"""批量执行"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.executemany(sql, params_list)
conn.commit()
# ─── 异步连接池asyncpg─────────────────────────────────────
_async_pool: asyncpg.Pool | None = None
async def get_async_pool() -> asyncpg.Pool:
global _async_pool
if _async_pool is None:
_async_pool = await asyncpg.create_pool(
dsn=PG_DSN, min_size=2, max_size=10,
)
return _async_pool
async def async_fetch(sql: str, *args) -> list[dict]:
"""异步查询返回dict列表"""
pool = await get_async_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(sql, *args)
return [dict(r) for r in rows]
async def async_fetchrow(sql: str, *args) -> dict | None:
"""异步查询单行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(sql, *args)
return dict(row) if row else None
async def async_execute(sql: str, *args):
"""异步执行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
await conn.execute(sql, *args)
async def close_async_pool():
global _async_pool
if _async_pool:
await _async_pool.close()
_async_pool = None
# ─── 建表PG Schema──────────────────────────────────────────
SCHEMA_SQL = """
-- 费率快照
CREATE TABLE IF NOT EXISTS rate_snapshots (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
btc_rate DOUBLE PRECISION NOT NULL,
eth_rate DOUBLE PRECISION NOT NULL,
btc_price DOUBLE PRECISION NOT NULL,
eth_price DOUBLE PRECISION NOT NULL,
btc_index_price DOUBLE PRECISION,
eth_index_price DOUBLE PRECISION
);
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
-- aggTrades元数据
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id BIGINT NOT NULL,
last_time_ms BIGINT,
earliest_agg_id BIGINT,
earliest_time_ms BIGINT,
updated_at TEXT
);
-- aggTrades主表分区父表
CREATE TABLE IF NOT EXISTS agg_trades (
agg_id BIGINT NOT NULL,
symbol TEXT NOT NULL,
price DOUBLE PRECISION NOT NULL,
qty DOUBLE PRECISION NOT NULL,
time_ms BIGINT NOT NULL,
is_buyer_maker SMALLINT NOT NULL,
PRIMARY KEY (time_ms, symbol, agg_id)
) PARTITION BY RANGE (time_ms);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
-- Signal Engine表
CREATE TABLE IF NOT EXISTS signal_indicators (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
cvd_fast_slope DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
atr_percentile DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
p95_qty DOUBLE PRECISION,
p99_qty DOUBLE PRECISION,
buy_vol_1m DOUBLE PRECISION,
sell_vol_1m DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_trades (
id BIGSERIAL PRIMARY KEY,
ts_open BIGINT NOT NULL,
ts_close BIGINT,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
entry_price DOUBLE PRECISION,
exit_price DOUBLE PRECISION,
qty DOUBLE PRECISION,
score INTEGER,
pnl DOUBLE PRECISION,
sl_price DOUBLE PRECISION,
tp1_price DOUBLE PRECISION,
tp2_price DOUBLE PRECISION,
status TEXT DEFAULT 'open'
);
-- 信号日志(旧表兼容)
CREATE TABLE IF NOT EXISTS signal_logs (
id BIGSERIAL PRIMARY KEY,
symbol TEXT,
rate DOUBLE PRECISION,
annualized DOUBLE PRECISION,
sent_at TEXT,
message TEXT
);
-- 用户表auth
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
role TEXT DEFAULT 'user',
created_at TIMESTAMP DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS invite_codes (
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
used_by BIGINT REFERENCES users(id),
created_at TIMESTAMP DEFAULT NOW()
);
-- 模拟盘交易表
CREATE TABLE IF NOT EXISTS paper_trades (
id BIGSERIAL PRIMARY KEY,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
score INT NOT NULL,
tier TEXT NOT NULL,
entry_price DOUBLE PRECISION NOT NULL,
entry_ts BIGINT NOT NULL,
exit_price DOUBLE PRECISION,
exit_ts BIGINT,
tp1_price DOUBLE PRECISION NOT NULL,
tp2_price DOUBLE PRECISION NOT NULL,
sl_price DOUBLE PRECISION NOT NULL,
tp1_hit BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'active',
pnl_r DOUBLE PRECISION DEFAULT 0,
atr_at_entry DOUBLE PRECISION DEFAULT 0,
created_at TIMESTAMP DEFAULT NOW()
);
"""
def ensure_partitions():
"""创建当月和下月的分区表"""
import datetime
now = datetime.datetime.utcnow()
months = []
for delta in range(0, 3): # 当月+下2个月
d = now + datetime.timedelta(days=delta * 30)
months.append(d.strftime("%Y%m"))
with get_sync_conn() as conn:
with conn.cursor() as cur:
for m in set(months):
year = int(m[:4])
month = int(m[4:])
# 计算分区范围(毫秒时间戳)
start = datetime.datetime(year, month, 1)
if month == 12:
end = datetime.datetime(year + 1, 1, 1)
else:
end = datetime.datetime(year, month + 1, 1)
start_ms = int(start.timestamp() * 1000)
end_ms = int(end.timestamp() * 1000)
part_name = f"agg_trades_{m}"
try:
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {part_name}
PARTITION OF agg_trades
FOR VALUES FROM ({start_ms}) TO ({end_ms})
""")
except Exception:
pass # 分区已存在
conn.commit()
def init_schema():
"""初始化全部PG表结构"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in SCHEMA_SQL.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
conn.rollback()
# 忽略已存在错误
continue
conn.commit()
ensure_partitions()