arbitrage-engine/backend/db.py
dev-worker 31e6e19ea6 fix: V3全面审阅修复 — 12项问题
P0-1: 风控Fail-Closed(状态文件缺失/过期/异常→拒绝开仓)
P0-2: 1R基准跨模块统一(position_sync+risk_guard从live_config动态读)
P0-3: close_all紧急全平校验返回值+二次验仓
P0-4: Coinbase Premium单位修复(premium_pct/100→比例值)
P1-3: 正向funding计入净PnL(不再只扣负值)
P1-4: 数据新鲜度检查落地(查signal_indicators最新ts)
P1-6: live表DDL补全到SCHEMA_SQL(live_config/live_events/live_trades)
P2-1: _get_risk_usd()加60秒缓存
P2-3: 模拟盘前端*200→从config动态算paper1R
P2-4: XRP/SOL跳过Coinbase Premium采集(无数据源)
P3-2: SQL参数化(fetch_pending_signals用ANY替代f-string)
额外: pnl_r公式修正(gross-fee+funding,funding正负都正确计入)
2026-03-02 17:28:23 +00:00

415 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
db.py — PostgreSQL 数据库连接层
同步连接池psycopg2供脚本类使用
异步连接池asyncpg供FastAPI使用
"""
import os
import asyncio
import asyncpg
import psycopg2
import psycopg2.pool
from contextlib import contextmanager
# PG连接参数本地
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_PORT = int(os.getenv("PG_PORT", 5432))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
# Cloud SQL连接参数双写目标
CLOUD_PG_HOST = os.getenv("CLOUD_PG_HOST", "10.106.0.3")
CLOUD_PG_PORT = int(os.getenv("CLOUD_PG_PORT", 5432))
CLOUD_PG_DB = os.getenv("CLOUD_PG_DB", "arb_engine")
CLOUD_PG_USER = os.getenv("CLOUD_PG_USER", "arb")
CLOUD_PG_PASS = os.getenv("CLOUD_PG_PASS", "arb_engine_2026")
CLOUD_PG_ENABLED = os.getenv("CLOUD_PG_ENABLED", "true").lower() == "true"
# ─── 同步连接池psycopg2─────────────────────────────────────
_sync_pool = None
def get_sync_pool() -> psycopg2.pool.ThreadedConnectionPool:
global _sync_pool
if _sync_pool is None:
_sync_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=1, maxconn=5,
host=PG_HOST, port=PG_PORT,
dbname=PG_DB, user=PG_USER, password=PG_PASS,
)
return _sync_pool
@contextmanager
def get_sync_conn():
"""获取同步PG连接自动归还到池"""
pool = get_sync_pool()
conn = pool.getconn()
try:
yield conn
finally:
pool.putconn(conn)
def sync_execute(sql: str, params=None, fetch=False):
"""简便执行SQL"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params)
if fetch:
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
conn.commit()
def sync_executemany(sql: str, params_list: list):
"""批量执行"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.executemany(sql, params_list)
conn.commit()
# ─── Cloud SQL 同步连接池(双写用)───────────────────────────────
_cloud_sync_pool = None
def get_cloud_sync_pool():
global _cloud_sync_pool
if not CLOUD_PG_ENABLED:
return None
if _cloud_sync_pool is None:
try:
_cloud_sync_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=1, maxconn=5,
host=CLOUD_PG_HOST, port=CLOUD_PG_PORT,
dbname=CLOUD_PG_DB, user=CLOUD_PG_USER, password=CLOUD_PG_PASS,
)
except Exception as e:
import logging
logging.getLogger("db").error(f"Cloud SQL pool init failed: {e}")
return None
return _cloud_sync_pool
@contextmanager
def get_cloud_sync_conn():
"""获取Cloud SQL同步连接失败返回None不影响主流程"""
pool = get_cloud_sync_pool()
if pool is None:
yield None
return
conn = None
try:
conn = pool.getconn()
yield conn
except Exception as e:
import logging
logging.getLogger("db").error(f"Cloud SQL conn error: {e}")
yield None
finally:
if conn and pool:
try:
pool.putconn(conn)
except Exception:
pass
# ─── 异步连接池asyncpg─────────────────────────────────────
_async_pool: asyncpg.Pool | None = None
async def get_async_pool() -> asyncpg.Pool:
global _async_pool
if _async_pool is None:
_async_pool = await asyncpg.create_pool(
dsn=PG_DSN, min_size=2, max_size=10,
)
return _async_pool
async def async_fetch(sql: str, *args) -> list[dict]:
"""异步查询返回dict列表"""
pool = await get_async_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(sql, *args)
return [dict(r) for r in rows]
async def async_fetchrow(sql: str, *args) -> dict | None:
"""异步查询单行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(sql, *args)
return dict(row) if row else None
async def async_execute(sql: str, *args):
"""异步执行"""
pool = await get_async_pool()
async with pool.acquire() as conn:
await conn.execute(sql, *args)
async def close_async_pool():
global _async_pool
if _async_pool:
await _async_pool.close()
_async_pool = None
# ─── 建表PG Schema──────────────────────────────────────────
SCHEMA_SQL = """
-- 费率快照
CREATE TABLE IF NOT EXISTS rate_snapshots (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
btc_rate DOUBLE PRECISION NOT NULL,
eth_rate DOUBLE PRECISION NOT NULL,
btc_price DOUBLE PRECISION NOT NULL,
eth_price DOUBLE PRECISION NOT NULL,
btc_index_price DOUBLE PRECISION,
eth_index_price DOUBLE PRECISION
);
CREATE INDEX IF NOT EXISTS idx_rate_snapshots_ts ON rate_snapshots(ts);
-- aggTrades元数据
CREATE TABLE IF NOT EXISTS agg_trades_meta (
symbol TEXT PRIMARY KEY,
last_agg_id BIGINT NOT NULL,
last_time_ms BIGINT,
earliest_agg_id BIGINT,
earliest_time_ms BIGINT,
updated_at TEXT
);
-- aggTrades主表分区父表
CREATE TABLE IF NOT EXISTS agg_trades (
agg_id BIGINT NOT NULL,
symbol TEXT NOT NULL,
price DOUBLE PRECISION NOT NULL,
qty DOUBLE PRECISION NOT NULL,
time_ms BIGINT NOT NULL,
is_buyer_maker SMALLINT NOT NULL,
PRIMARY KEY (time_ms, symbol, agg_id)
) PARTITION BY RANGE (time_ms);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_time ON agg_trades(symbol, time_ms DESC);
CREATE INDEX IF NOT EXISTS idx_agg_trades_sym_agg ON agg_trades(symbol, agg_id);
-- Signal Engine表
CREATE TABLE IF NOT EXISTS signal_indicators (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
cvd_fast_slope DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
atr_percentile DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
p95_qty DOUBLE PRECISION,
p99_qty DOUBLE PRECISION,
buy_vol_1m DOUBLE PRECISION,
sell_vol_1m DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
atr_5m DOUBLE PRECISION,
vwap_30m DOUBLE PRECISION,
price DOUBLE PRECISION,
score INTEGER,
signal TEXT
);
CREATE INDEX IF NOT EXISTS idx_si1m_sym_ts ON signal_indicators_1m(symbol, ts);
CREATE TABLE IF NOT EXISTS signal_trades (
id BIGSERIAL PRIMARY KEY,
ts_open BIGINT NOT NULL,
ts_close BIGINT,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
entry_price DOUBLE PRECISION,
exit_price DOUBLE PRECISION,
qty DOUBLE PRECISION,
score INTEGER,
pnl DOUBLE PRECISION,
sl_price DOUBLE PRECISION,
tp1_price DOUBLE PRECISION,
tp2_price DOUBLE PRECISION,
status TEXT DEFAULT 'open'
);
-- 信号日志(旧表兼容)
CREATE TABLE IF NOT EXISTS signal_logs (
id BIGSERIAL PRIMARY KEY,
symbol TEXT,
rate DOUBLE PRECISION,
annualized DOUBLE PRECISION,
sent_at TEXT,
message TEXT
);
-- 用户表auth
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
role TEXT DEFAULT 'user',
created_at TIMESTAMP DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS invite_codes (
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
used_by BIGINT REFERENCES users(id),
created_at TIMESTAMP DEFAULT NOW()
);
-- 模拟盘交易表
CREATE TABLE IF NOT EXISTS paper_trades (
id BIGSERIAL PRIMARY KEY,
symbol TEXT NOT NULL,
direction TEXT NOT NULL,
score INT NOT NULL,
tier TEXT NOT NULL,
entry_price DOUBLE PRECISION NOT NULL,
entry_ts BIGINT NOT NULL,
exit_price DOUBLE PRECISION,
exit_ts BIGINT,
tp1_price DOUBLE PRECISION NOT NULL,
tp2_price DOUBLE PRECISION NOT NULL,
sl_price DOUBLE PRECISION NOT NULL,
tp1_hit BOOLEAN DEFAULT FALSE,
status TEXT DEFAULT 'active',
pnl_r DOUBLE PRECISION DEFAULT 0,
atr_at_entry DOUBLE PRECISION DEFAULT 0,
score_factors JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
-- Live trading tables
CREATE TABLE IF NOT EXISTS live_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
label TEXT,
updated_at TIMESTAMP DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS live_events (
id BIGSERIAL PRIMARY KEY,
ts BIGINT DEFAULT (EXTRACT(EPOCH FROM NOW()) * 1000)::BIGINT,
level TEXT,
category TEXT,
symbol TEXT,
message TEXT,
detail JSONB
);
CREATE TABLE IF NOT EXISTS live_trades (
id BIGSERIAL PRIMARY KEY,
symbol TEXT NOT NULL,
strategy TEXT NOT NULL,
direction TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
entry_price DOUBLE PRECISION,
exit_price DOUBLE PRECISION,
entry_ts BIGINT,
exit_ts BIGINT,
sl_price DOUBLE PRECISION,
tp1_price DOUBLE PRECISION,
tp2_price DOUBLE PRECISION,
tp1_hit BOOLEAN DEFAULT FALSE,
score DOUBLE PRECISION,
tier TEXT,
pnl_r DOUBLE PRECISION,
fee_usdt DOUBLE PRECISION DEFAULT 0,
funding_fee_usdt DOUBLE PRECISION DEFAULT 0,
risk_distance DOUBLE PRECISION,
atr_at_entry DOUBLE PRECISION,
score_factors JSONB,
signal_id BIGINT,
binance_order_id TEXT,
fill_price DOUBLE PRECISION,
slippage_bps DOUBLE PRECISION,
protection_gap_ms BIGINT,
signal_to_order_ms BIGINT,
order_to_fill_ms BIGINT,
qty DOUBLE PRECISION,
created_at TIMESTAMP DEFAULT NOW()
);
"""
def ensure_partitions():
"""创建当月和下月的分区表"""
import datetime
now = datetime.datetime.utcnow()
months = []
for delta in range(0, 3): # 当月+下2个月
m = now.month + delta
y = now.year + (m - 1) // 12
m = ((m - 1) % 12) + 1
months.append(f"{y}{m:02d}")
with get_sync_conn() as conn:
with conn.cursor() as cur:
for m in set(months):
year = int(m[:4])
month = int(m[4:])
# 计算分区范围UTC毫秒时间戳
start = datetime.datetime(year, month, 1, tzinfo=datetime.timezone.utc)
if month == 12:
end = datetime.datetime(year + 1, 1, 1, tzinfo=datetime.timezone.utc)
else:
end = datetime.datetime(year, month + 1, 1, tzinfo=datetime.timezone.utc)
start_ms = int(start.timestamp() * 1000)
end_ms = int(end.timestamp() * 1000)
part_name = f"agg_trades_{m}"
try:
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {part_name}
PARTITION OF agg_trades
FOR VALUES FROM ({start_ms}) TO ({end_ms})
""")
except Exception:
pass # 分区已存在
conn.commit()
def init_schema():
"""初始化全部PG表结构"""
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in SCHEMA_SQL.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
conn.rollback()
# 忽略已存在错误
continue
cur.execute(
"ALTER TABLE paper_trades "
"ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'"
)
conn.commit()
ensure_partitions()