refactor: unify all DB connections to Cloud SQL, remove dual-write and SQLite code
This commit is contained in:
parent
c0c37c4c7e
commit
91ed44ad9f
@ -7,6 +7,7 @@ agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版)
|
|||||||
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
- 每分钟巡检:校验agg_id连续性,发现断档自动补洞
|
||||||
- 批量写入:攒200条或1秒flush一次
|
- 批量写入:攒200条或1秒flush一次
|
||||||
- PG分区表:按月自动分区,MVCC并发无锁冲突
|
- PG分区表:按月自动分区,MVCC并发无锁冲突
|
||||||
|
- 统一写入 Cloud SQL(双写机制已移除)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -22,7 +23,7 @@ import psycopg2
|
|||||||
import psycopg2.extras
|
import psycopg2.extras
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
from db import get_sync_conn, get_sync_pool, get_cloud_sync_conn, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS, CLOUD_PG_ENABLED
|
from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@ -69,11 +70,10 @@ def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int):
|
|||||||
|
|
||||||
|
|
||||||
def flush_buffer(symbol: str, trades: list) -> int:
|
def flush_buffer(symbol: str, trades: list) -> int:
|
||||||
"""写入一批trades到PG(本地+Cloud SQL双写),返回实际写入条数"""
|
"""写入一批trades到Cloud SQL,返回实际写入条数"""
|
||||||
if not trades:
|
if not trades:
|
||||||
return 0
|
return 0
|
||||||
try:
|
try:
|
||||||
# 确保分区存在
|
|
||||||
ensure_partitions()
|
ensure_partitions()
|
||||||
|
|
||||||
values = []
|
values = []
|
||||||
@ -98,7 +98,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
|
|||||||
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING"""
|
ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING"""
|
||||||
insert_template = "(%s, %s, %s, %s, %s, %s)"
|
insert_template = "(%s, %s, %s, %s, %s, %s)"
|
||||||
|
|
||||||
# 写本地PG
|
|
||||||
inserted = 0
|
inserted = 0
|
||||||
with get_sync_conn() as conn:
|
with get_sync_conn() as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
@ -111,22 +110,6 @@ def flush_buffer(symbol: str, trades: list) -> int:
|
|||||||
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
update_meta(conn, symbol, last_agg_id, last_time_ms)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
# 双写Cloud SQL(失败不影响主流程)
|
|
||||||
if CLOUD_PG_ENABLED:
|
|
||||||
try:
|
|
||||||
with get_cloud_sync_conn() as cloud_conn:
|
|
||||||
if cloud_conn:
|
|
||||||
with cloud_conn.cursor() as cur:
|
|
||||||
psycopg2.extras.execute_values(
|
|
||||||
cur, insert_sql, values,
|
|
||||||
template=insert_template, page_size=1000,
|
|
||||||
)
|
|
||||||
if last_agg_id > 0:
|
|
||||||
update_meta(cloud_conn, symbol, last_agg_id, last_time_ms)
|
|
||||||
cloud_conn.commit()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[{symbol}] Cloud SQL write failed (non-fatal): {e}")
|
|
||||||
|
|
||||||
return inserted
|
return inserted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
logger.error(f"flush_buffer [{symbol}] error: {e}")
|
||||||
|
|||||||
@ -34,7 +34,7 @@ logging.basicConfig(
|
|||||||
)
|
)
|
||||||
logger = logging.getLogger("backtest")
|
logger = logging.getLogger("backtest")
|
||||||
|
|
||||||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
|
||||||
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
||||||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||||
PG_USER = os.getenv("PG_USER", "arb")
|
PG_USER = os.getenv("PG_USER", "arb")
|
||||||
|
|||||||
122
backend/db.py
122
backend/db.py
@ -1,5 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
db.py — PostgreSQL 数据库连接层
|
db.py — PostgreSQL 数据库连接层
|
||||||
|
统一连接到 Cloud SQL(PG_HOST 默认 10.106.0.3)
|
||||||
同步连接池(psycopg2)供脚本类使用
|
同步连接池(psycopg2)供脚本类使用
|
||||||
异步连接池(asyncpg)供FastAPI使用
|
异步连接池(asyncpg)供FastAPI使用
|
||||||
"""
|
"""
|
||||||
@ -11,8 +12,8 @@ import psycopg2
|
|||||||
import psycopg2.pool
|
import psycopg2.pool
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
# PG连接参数(本地)
|
# PG连接参数(统一连接 Cloud SQL)
|
||||||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
|
||||||
PG_PORT = int(os.getenv("PG_PORT", 5432))
|
PG_PORT = int(os.getenv("PG_PORT", 5432))
|
||||||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||||
PG_USER = os.getenv("PG_USER", "arb")
|
PG_USER = os.getenv("PG_USER", "arb")
|
||||||
@ -20,14 +21,6 @@ PG_PASS = os.getenv("PG_PASS", "arb_engine_2026")
|
|||||||
|
|
||||||
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}"
|
||||||
|
|
||||||
# Cloud SQL连接参数(双写目标)
|
|
||||||
CLOUD_PG_HOST = os.getenv("CLOUD_PG_HOST", "10.106.0.3")
|
|
||||||
CLOUD_PG_PORT = int(os.getenv("CLOUD_PG_PORT", 5432))
|
|
||||||
CLOUD_PG_DB = os.getenv("CLOUD_PG_DB", "arb_engine")
|
|
||||||
CLOUD_PG_USER = os.getenv("CLOUD_PG_USER", "arb")
|
|
||||||
CLOUD_PG_PASS = os.getenv("CLOUD_PG_PASS", "arb_engine_2026")
|
|
||||||
CLOUD_PG_ENABLED = os.getenv("CLOUD_PG_ENABLED", "true").lower() == "true"
|
|
||||||
|
|
||||||
# ─── 同步连接池(psycopg2)─────────────────────────────────────
|
# ─── 同步连接池(psycopg2)─────────────────────────────────────
|
||||||
|
|
||||||
_sync_pool = None
|
_sync_pool = None
|
||||||
@ -73,51 +66,6 @@ def sync_executemany(sql: str, params_list: list):
|
|||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
# ─── Cloud SQL 同步连接池(双写用)───────────────────────────────
|
|
||||||
|
|
||||||
_cloud_sync_pool = None
|
|
||||||
|
|
||||||
def get_cloud_sync_pool():
|
|
||||||
global _cloud_sync_pool
|
|
||||||
if not CLOUD_PG_ENABLED:
|
|
||||||
return None
|
|
||||||
if _cloud_sync_pool is None:
|
|
||||||
try:
|
|
||||||
_cloud_sync_pool = psycopg2.pool.ThreadedConnectionPool(
|
|
||||||
minconn=1, maxconn=5,
|
|
||||||
host=CLOUD_PG_HOST, port=CLOUD_PG_PORT,
|
|
||||||
dbname=CLOUD_PG_DB, user=CLOUD_PG_USER, password=CLOUD_PG_PASS,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
import logging
|
|
||||||
logging.getLogger("db").error(f"Cloud SQL pool init failed: {e}")
|
|
||||||
return None
|
|
||||||
return _cloud_sync_pool
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def get_cloud_sync_conn():
|
|
||||||
"""获取Cloud SQL同步连接(失败返回None,不影响主流程)"""
|
|
||||||
pool = get_cloud_sync_pool()
|
|
||||||
if pool is None:
|
|
||||||
yield None
|
|
||||||
return
|
|
||||||
conn = None
|
|
||||||
try:
|
|
||||||
conn = pool.getconn()
|
|
||||||
yield conn
|
|
||||||
except Exception as e:
|
|
||||||
import logging
|
|
||||||
logging.getLogger("db").error(f"Cloud SQL conn error: {e}")
|
|
||||||
yield None
|
|
||||||
finally:
|
|
||||||
if conn and pool:
|
|
||||||
try:
|
|
||||||
pool.putconn(conn)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ─── 异步连接池(asyncpg)─────────────────────────────────────
|
# ─── 异步连接池(asyncpg)─────────────────────────────────────
|
||||||
|
|
||||||
_async_pool: asyncpg.Pool | None = None
|
_async_pool: asyncpg.Pool | None = None
|
||||||
@ -206,6 +154,7 @@ CREATE TABLE IF NOT EXISTS signal_indicators (
|
|||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
ts BIGINT NOT NULL,
|
ts BIGINT NOT NULL,
|
||||||
symbol TEXT NOT NULL,
|
symbol TEXT NOT NULL,
|
||||||
|
strategy TEXT,
|
||||||
cvd_fast DOUBLE PRECISION,
|
cvd_fast DOUBLE PRECISION,
|
||||||
cvd_mid DOUBLE PRECISION,
|
cvd_mid DOUBLE PRECISION,
|
||||||
cvd_day DOUBLE PRECISION,
|
cvd_day DOUBLE PRECISION,
|
||||||
@ -219,10 +168,12 @@ CREATE TABLE IF NOT EXISTS signal_indicators (
|
|||||||
buy_vol_1m DOUBLE PRECISION,
|
buy_vol_1m DOUBLE PRECISION,
|
||||||
sell_vol_1m DOUBLE PRECISION,
|
sell_vol_1m DOUBLE PRECISION,
|
||||||
score INTEGER,
|
score INTEGER,
|
||||||
signal TEXT
|
signal TEXT,
|
||||||
|
factors JSONB
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
|
CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts);
|
||||||
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
|
CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_si_strategy ON signal_indicators(strategy, ts);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
CREATE TABLE IF NOT EXISTS signal_indicators_1m (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
@ -256,7 +207,7 @@ CREATE TABLE IF NOT EXISTS signal_trades (
|
|||||||
status TEXT DEFAULT 'open'
|
status TEXT DEFAULT 'open'
|
||||||
);
|
);
|
||||||
|
|
||||||
-- 信号日志(旧表兼容)
|
-- 信号日志(旧表兼容保留,不再写入新数据)
|
||||||
CREATE TABLE IF NOT EXISTS signal_logs (
|
CREATE TABLE IF NOT EXISTS signal_logs (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
symbol TEXT,
|
symbol TEXT,
|
||||||
@ -266,7 +217,32 @@ CREATE TABLE IF NOT EXISTS signal_logs (
|
|||||||
message TEXT
|
message TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
-- 用户表(auth)
|
-- 市场指标(由 market_data_collector 写入)
|
||||||
|
CREATE TABLE IF NOT EXISTS market_indicators (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
indicator_type TEXT NOT NULL,
|
||||||
|
value DOUBLE PRECISION,
|
||||||
|
extra JSONB,
|
||||||
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_mi_sym_type_ts ON market_indicators(symbol, indicator_type, ts DESC);
|
||||||
|
|
||||||
|
-- 清算数据(由 liquidation_collector 写入)
|
||||||
|
CREATE TABLE IF NOT EXISTS liquidations (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
ts BIGINT NOT NULL,
|
||||||
|
symbol TEXT NOT NULL,
|
||||||
|
side TEXT NOT NULL,
|
||||||
|
qty DOUBLE PRECISION,
|
||||||
|
price DOUBLE PRECISION,
|
||||||
|
usd_value DOUBLE PRECISION,
|
||||||
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_liq_sym_ts ON liquidations(symbol, ts DESC);
|
||||||
|
|
||||||
|
-- 用户表(由 auth.py 负责完整定义,此处仅兜底)
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
email TEXT UNIQUE NOT NULL,
|
email TEXT UNIQUE NOT NULL,
|
||||||
@ -282,6 +258,13 @@ CREATE TABLE IF NOT EXISTS invite_codes (
|
|||||||
created_at TIMESTAMP DEFAULT NOW()
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS invite_usage (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
code TEXT NOT NULL,
|
||||||
|
used_by BIGINT REFERENCES users(id),
|
||||||
|
used_at TIMESTAMP DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
-- 模拟盘交易表
|
-- 模拟盘交易表
|
||||||
CREATE TABLE IF NOT EXISTS paper_trades (
|
CREATE TABLE IF NOT EXISTS paper_trades (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
@ -300,6 +283,7 @@ CREATE TABLE IF NOT EXISTS paper_trades (
|
|||||||
status TEXT DEFAULT 'active',
|
status TEXT DEFAULT 'active',
|
||||||
pnl_r DOUBLE PRECISION DEFAULT 0,
|
pnl_r DOUBLE PRECISION DEFAULT 0,
|
||||||
atr_at_entry DOUBLE PRECISION DEFAULT 0,
|
atr_at_entry DOUBLE PRECISION DEFAULT 0,
|
||||||
|
risk_distance DOUBLE PRECISION,
|
||||||
score_factors JSONB,
|
score_factors JSONB,
|
||||||
created_at TIMESTAMP DEFAULT NOW()
|
created_at TIMESTAMP DEFAULT NOW()
|
||||||
);
|
);
|
||||||
@ -373,7 +357,6 @@ def ensure_partitions():
|
|||||||
for m in set(months):
|
for m in set(months):
|
||||||
year = int(m[:4])
|
year = int(m[:4])
|
||||||
month = int(m[4:])
|
month = int(m[4:])
|
||||||
# 计算分区范围(UTC毫秒时间戳)
|
|
||||||
start = datetime.datetime(year, month, 1, tzinfo=datetime.timezone.utc)
|
start = datetime.datetime(year, month, 1, tzinfo=datetime.timezone.utc)
|
||||||
if month == 12:
|
if month == 12:
|
||||||
end = datetime.datetime(year + 1, 1, 1, tzinfo=datetime.timezone.utc)
|
end = datetime.datetime(year + 1, 1, 1, tzinfo=datetime.timezone.utc)
|
||||||
@ -402,13 +385,22 @@ def init_schema():
|
|||||||
if stmt:
|
if stmt:
|
||||||
try:
|
try:
|
||||||
cur.execute(stmt)
|
cur.execute(stmt)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
conn.rollback()
|
conn.rollback()
|
||||||
# 忽略已存在错误
|
|
||||||
continue
|
continue
|
||||||
cur.execute(
|
# 补全字段(向前兼容旧部署)
|
||||||
"ALTER TABLE paper_trades "
|
migrations = [
|
||||||
"ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'"
|
"ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'",
|
||||||
)
|
"ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS risk_distance DOUBLE PRECISION",
|
||||||
|
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS strategy TEXT",
|
||||||
|
"ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS factors JSONB",
|
||||||
|
"ALTER TABLE users ADD COLUMN IF NOT EXISTS discord_id TEXT",
|
||||||
|
"ALTER TABLE users ADD COLUMN IF NOT EXISTS banned BOOLEAN DEFAULT FALSE",
|
||||||
|
]
|
||||||
|
for m in migrations:
|
||||||
|
try:
|
||||||
|
cur.execute(m)
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
conn.commit()
|
conn.commit()
|
||||||
ensure_partitions()
|
ensure_partitions()
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from psycopg2.extras import Json
|
|||||||
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
|
SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"]
|
||||||
INTERVAL_SECONDS = 300
|
INTERVAL_SECONDS = 300
|
||||||
|
|
||||||
PG_HOST = os.getenv("PG_HOST", "127.0.0.1")
|
PG_HOST = os.getenv("PG_HOST", "10.106.0.3")
|
||||||
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
PG_PORT = int(os.getenv("PG_PORT", "5432"))
|
||||||
PG_DB = os.getenv("PG_DB", "arb_engine")
|
PG_DB = os.getenv("PG_DB", "arb_engine")
|
||||||
PG_USER = os.getenv("PG_USER", "arb")
|
PG_USER = os.getenv("PG_USER", "arb")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user