refactor: unify all DB connections to Cloud SQL, remove dual-write and SQLite code

This commit is contained in:
root 2026-03-03 12:30:01 +00:00
parent c0c37c4c7e
commit 91ed44ad9f
8 changed files with 62 additions and 87 deletions

View File

@ -7,6 +7,7 @@ agg_trades_collector.py — aggTrades全量采集守护进程PostgreSQL版
- 每分钟巡检校验agg_id连续性发现断档自动补洞
- 批量写入攒200条或1秒flush一次
- PG分区表按月自动分区MVCC并发无锁冲突
- 统一写入 Cloud SQL双写机制已移除
"""
import asyncio
@ -22,7 +23,7 @@ import psycopg2
import psycopg2.extras
import websockets
from db import get_sync_conn, get_sync_pool, get_cloud_sync_conn, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS, CLOUD_PG_ENABLED
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST
logging.basicConfig(
level=logging.INFO,
@ -69,11 +70,10 @@ def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
def flush_buffer(symbol: str, trades: list) -> int:
"""写入一批trades到PG本地+Cloud SQL双写),返回实际写入条数"""
"""写入一批trades到Cloud SQL返回实际写入条数"""
if not trades:
return 0
try:
# 确保分区存在
ensure_partitions()
values = []
@ -98,7 +98,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING"""
insert_template = "(%s, %s, %s, %s, %s, %s)"
# 写本地PG
inserted = 0
with get_sync_conn() as conn:
with conn.cursor() as cur:
@ -111,22 +110,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
update_meta(conn, symbol, last_agg_id, last_time_ms)
conn.commit()
# 双写Cloud SQL失败不影响主流程
if CLOUD_PG_ENABLED:
try:
with get_cloud_sync_conn() as cloud_conn:
if cloud_conn:
with cloud_conn.cursor() as cur:
psycopg2.extras.execute_values(
cur, insert_sql, values,
template=insert_template, page_size=1000,
)
if last_agg_id > 0:
update_meta(cloud_conn, symbol, last_agg_id, last_time_ms)
cloud_conn.commit()
except Exception as e:
logger.warning(f"[{symbol}] Cloud SQL write failed (non-fatal): {e}")
return inserted
except Exception as e:
logger.error(f"flush_buffer [{symbol}] error: {e}")

View File

@ -34,7 +34,7 @@ logging.basicConfig(
)
logger = logging.getLogger("backtest")
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
PG_PORT = int(os.getenv("PG_PORT", "5432"))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")

View File

@ -1,5 +1,6 @@
"""
db.py PostgreSQL 数据库连接层
统一连接到 Cloud SQLPG_HOST 默认 10.106.0.3
同步连接池psycopg2供脚本类使用
异步连接池asyncpg供FastAPI使用
"""
@ -11,8 +12,8 @@ import psycopg2
import psycopg2.pool
from contextlib import contextmanager
# PG连接参数本地
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
# PG连接参数统一连接 Cloud SQL
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
PG_PORT = int(os.getenv("PG_PORT", 5432))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")
@ -20,14 +21,6 @@ PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
# Cloud SQL连接参数双写目标
CLOUD_PG_HOST = os.getenv("CLOUD_PG_HOST", "10.106.0.3")
CLOUD_PG_PORT = int(os.getenv("CLOUD_PG_PORT", 5432))
CLOUD_PG_DB = os.getenv("CLOUD_PG_DB", "arb_engine")
CLOUD_PG_USER = os.getenv("CLOUD_PG_USER", "arb")
CLOUD_PG_PASS = os.getenv("CLOUD_PG_PASS", "arb_engine_2026")
CLOUD_PG_ENABLED = os.getenv("CLOUD_PG_ENABLED", "true").lower() == "true"
# ─── 同步连接池psycopg2─────────────────────────────────────
_sync_pool = None
@ -73,51 +66,6 @@ def sync_executemany(sql: str, params_list: list):
conn.commit()
# ─── Cloud SQL 同步连接池(双写用)───────────────────────────────
_cloud_sync_pool = None
def get_cloud_sync_pool():
global _cloud_sync_pool
if not CLOUD_PG_ENABLED:
return None
if _cloud_sync_pool is None:
try:
_cloud_sync_pool = psycopg2.pool.ThreadedConnectionPool(
minconn=1, maxconn=5,
host=CLOUD_PG_HOST, port=CLOUD_PG_PORT,
dbname=CLOUD_PG_DB, user=CLOUD_PG_USER, password=CLOUD_PG_PASS,
)
except Exception as e:
import logging
logging.getLogger("db").error(f"Cloud SQL pool init failed: {e}")
return None
return _cloud_sync_pool
@contextmanager
def get_cloud_sync_conn():
"""获取Cloud SQL同步连接失败返回None不影响主流程"""
pool = get_cloud_sync_pool()
if pool is None:
yield None
return
conn = None
try:
conn = pool.getconn()
yield conn
except Exception as e:
import logging
logging.getLogger("db").error(f"Cloud SQL conn error: {e}")
yield None
finally:
if conn and pool:
try:
pool.putconn(conn)
except Exception:
pass
# ─── 异步连接池asyncpg─────────────────────────────────────
_async_pool: asyncpg.Pool | None = None
@ -206,6 +154,7 @@ CREATE TABLE IF NOT EXISTS signal_indicators (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
strategy TEXT,
cvd_fast DOUBLE PRECISION,
cvd_mid DOUBLE PRECISION,
cvd_day DOUBLE PRECISION,
@ -219,10 +168,12 @@ CREATE TABLE IF NOT EXISTS signal_indicators (
buy_vol_1m DOUBLE PRECISION,
sell_vol_1m DOUBLE PRECISION,
score INTEGER,
signal TEXT
signal TEXT,
factors JSONB
);
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
CREATE INDEX IF NOT EXISTS idx_si_strategy ON signal_indicators(strategy, ts);
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
id BIGSERIAL PRIMARY KEY,
@ -256,7 +207,7 @@ CREATE TABLE IF NOT EXISTS signal_trades (
status TEXT DEFAULT 'open'
);
-- 信号日志旧表兼容
-- 信号日志旧表兼容保留不再写入新数据
CREATE TABLE IF NOT EXISTS signal_logs (
id BIGSERIAL PRIMARY KEY,
symbol TEXT,
@ -266,7 +217,32 @@ CREATE TABLE IF NOT EXISTS signal_logs (
message TEXT
);
-- 用户表auth
-- 市场指标 market_data_collector 写入
CREATE TABLE IF NOT EXISTS market_indicators (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
indicator_type TEXT NOT NULL,
value DOUBLE PRECISION,
extra JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_mi_sym_type_ts ON market_indicators(symbol, indicator_type, ts DESC);
-- 清算数据 liquidation_collector 写入
CREATE TABLE IF NOT EXISTS liquidations (
id BIGSERIAL PRIMARY KEY,
ts BIGINT NOT NULL,
symbol TEXT NOT NULL,
side TEXT NOT NULL,
qty DOUBLE PRECISION,
price DOUBLE PRECISION,
usd_value DOUBLE PRECISION,
created_at TIMESTAMP DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_liq_sym_ts ON liquidations(symbol, ts DESC);
-- 用户表 auth.py 负责完整定义此处仅兜底
CREATE TABLE IF NOT EXISTS users (
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
@ -282,6 +258,13 @@ CREATE TABLE IF NOT EXISTS invite_codes (
created_at TIMESTAMP DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code TEXT NOT NULL,
used_by BIGINT REFERENCES users(id),
used_at TIMESTAMP DEFAULT NOW()
);
-- 模拟盘交易表
CREATE TABLE IF NOT EXISTS paper_trades (
id BIGSERIAL PRIMARY KEY,
@ -300,6 +283,7 @@ CREATE TABLE IF NOT EXISTS paper_trades (
status TEXT DEFAULT 'active',
pnl_r DOUBLE PRECISION DEFAULT 0,
atr_at_entry DOUBLE PRECISION DEFAULT 0,
risk_distance DOUBLE PRECISION,
score_factors JSONB,
created_at TIMESTAMP DEFAULT NOW()
);
@ -373,7 +357,6 @@ def ensure_partitions():
for m in set(months):
year = int(m[:4])
month = int(m[4:])
# 计算分区范围UTC毫秒时间戳
start = datetime.datetime(year, month, 1, tzinfo=datetime.timezone.utc)
if month == 12:
end = datetime.datetime(year + 1, 1, 1, tzinfo=datetime.timezone.utc)
@ -402,13 +385,22 @@ def init_schema():
if stmt:
try:
cur.execute(stmt)
except Exception as e:
except Exception:
conn.rollback()
# 忽略已存在错误
continue
cur.execute(
"ALTER TABLE paper_trades "
"ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'"
)
# 补全字段(向前兼容旧部署)
migrations = [
"ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'",
"ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS risk_distance DOUBLE PRECISION",
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS strategy TEXT",
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS factors JSONB",
"ALTER TABLE users ADD COLUMN IF NOT EXISTS discord_id TEXT",
"ALTER TABLE users ADD COLUMN IF NOT EXISTS banned BOOLEAN DEFAULT FALSE",
]
for m in migrations:
try:
cur.execute(m)
except Exception:
conn.rollback()
conn.commit()
ensure_partitions()

View File

@ -12,7 +12,7 @@ from psycopg2.extras import Json
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
INTERVAL_SECONDS = 300
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
PG_PORT = int(os.getenv("PG_PORT", "5432"))
PG_DB = os.getenv("PG_DB", "arb_engine")
PG_USER = os.getenv("PG_USER", "arb")