From 91ed44ad9f714aab61dc65f007690ce5b05141d4 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 3 Mar 2026 12:30:01 +0000 Subject: [PATCH] refactor: unify all DB connections to Cloud SQL, remove dual-write and SQLite code --- .../migrate_auth_sqlite_to_pg.py.archived | 0 .../migrate_sqlite_to_pg.py.archived | 0 .../signal_pusher.py.archived | 0 .../subscriptions.py.archived | 0 backend/agg_trades_collector.py | 23 +--- backend/backtest.py | 2 +- backend/db.py | 122 ++++++++---------- backend/market_data_collector.py | 2 +- 8 files changed, 62 insertions(+), 87 deletions(-) rename backend/migrate_auth_sqlite_to_pg.py => archive/migrate_auth_sqlite_to_pg.py.archived (100%) rename backend/migrate_sqlite_to_pg.py => archive/migrate_sqlite_to_pg.py.archived (100%) rename backend/signal_pusher.py => archive/signal_pusher.py.archived (100%) rename backend/subscriptions.py => archive/subscriptions.py.archived (100%) diff --git a/backend/migrate_auth_sqlite_to_pg.py b/archive/migrate_auth_sqlite_to_pg.py.archived similarity index 100% rename from backend/migrate_auth_sqlite_to_pg.py rename to archive/migrate_auth_sqlite_to_pg.py.archived diff --git a/backend/migrate_sqlite_to_pg.py b/archive/migrate_sqlite_to_pg.py.archived similarity index 100% rename from backend/migrate_sqlite_to_pg.py rename to archive/migrate_sqlite_to_pg.py.archived diff --git a/backend/signal_pusher.py b/archive/signal_pusher.py.archived similarity index 100% rename from backend/signal_pusher.py rename to archive/signal_pusher.py.archived diff --git a/backend/subscriptions.py b/archive/subscriptions.py.archived similarity index 100% rename from backend/subscriptions.py rename to archive/subscriptions.py.archived diff --git a/backend/agg_trades_collector.py b/backend/agg_trades_collector.py index 2ee83c3..570bba0 100644 --- a/backend/agg_trades_collector.py +++ b/backend/agg_trades_collector.py @@ -7,6 +7,7 @@ agg_trades_collector.py — aggTrades全量采集守护进程(PostgreSQL版) - 每分钟巡检:校验agg_id连续性,发现断档自动补洞 - 批量写入:攒200条或1秒flush一次 - PG分区表:按月自动分区,MVCC并发无锁冲突 + - 统一写入 Cloud SQL(双写机制已移除) """ import asyncio @@ -22,7 +23,7 @@ import psycopg2 import psycopg2.extras import websockets -from db import get_sync_conn, get_sync_pool, get_cloud_sync_conn, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS, CLOUD_PG_ENABLED +from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST logging.basicConfig( level=logging.INFO, @@ -69,11 +70,10 @@ def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int): def flush_buffer(symbol: str, trades: list) -> int: - """写入一批trades到PG(本地+Cloud SQL双写),返回实际写入条数""" + """写入一批trades到Cloud SQL,返回实际写入条数""" if not trades: return 0 try: - # 确保分区存在 ensure_partitions() values = [] @@ -98,7 +98,6 @@ def flush_buffer(symbol: str, trades: list) -> int: ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""" insert_template = "(%s, %s, %s, %s, %s, %s)" - # 写本地PG inserted = 0 with get_sync_conn() as conn: with conn.cursor() as cur: @@ -111,22 +110,6 @@ def flush_buffer(symbol: str, trades: list) -> int: update_meta(conn, symbol, last_agg_id, last_time_ms) conn.commit() - # 双写Cloud SQL(失败不影响主流程) - if CLOUD_PG_ENABLED: - try: - with get_cloud_sync_conn() as cloud_conn: - if cloud_conn: - with cloud_conn.cursor() as cur: - psycopg2.extras.execute_values( - cur, insert_sql, values, - template=insert_template, page_size=1000, - ) - if last_agg_id > 0: - update_meta(cloud_conn, symbol, last_agg_id, last_time_ms) - cloud_conn.commit() - except Exception as e: - logger.warning(f"[{symbol}] Cloud SQL write failed (non-fatal): {e}") - return inserted except Exception as e: logger.error(f"flush_buffer [{symbol}] error: {e}") diff --git a/backend/backtest.py b/backend/backtest.py index 7968509..b19063d 100644 --- a/backend/backtest.py +++ b/backend/backtest.py @@ -34,7 +34,7 @@ logging.basicConfig( ) logger = logging.getLogger("backtest") -PG_HOST = os.getenv("PG_HOST", "127.0.0.1") +PG_HOST = os.getenv("PG_HOST", "10.106.0.3") PG_PORT = int(os.getenv("PG_PORT", "5432")) PG_DB = os.getenv("PG_DB", "arb_engine") PG_USER = os.getenv("PG_USER", "arb") diff --git a/backend/db.py b/backend/db.py index 844d25b..90805d6 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,5 +1,6 @@ """ db.py — PostgreSQL 数据库连接层 +统一连接到 Cloud SQL(PG_HOST 默认 10.106.0.3) 同步连接池(psycopg2)供脚本类使用 异步连接池(asyncpg)供FastAPI使用 """ @@ -11,8 +12,8 @@ import psycopg2 import psycopg2.pool from contextlib import contextmanager -# PG连接参数(本地) -PG_HOST = os.getenv("PG_HOST", "127.0.0.1") +# PG连接参数(统一连接 Cloud SQL) +PG_HOST = os.getenv("PG_HOST", "10.106.0.3") PG_PORT = int(os.getenv("PG_PORT", 5432)) PG_DB = os.getenv("PG_DB", "arb_engine") PG_USER = os.getenv("PG_USER", "arb") @@ -20,14 +21,6 @@ PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}" -# Cloud SQL连接参数(双写目标) -CLOUD_PG_HOST = os.getenv("CLOUD_PG_HOST", "10.106.0.3") -CLOUD_PG_PORT = int(os.getenv("CLOUD_PG_PORT", 5432)) -CLOUD_PG_DB = os.getenv("CLOUD_PG_DB", "arb_engine") -CLOUD_PG_USER = os.getenv("CLOUD_PG_USER", "arb") -CLOUD_PG_PASS = os.getenv("CLOUD_PG_PASS", "arb_engine_2026") -CLOUD_PG_ENABLED = os.getenv("CLOUD_PG_ENABLED", "true").lower() == "true" - # ─── 同步连接池(psycopg2)───────────────────────────────────── _sync_pool = None @@ -73,51 +66,6 @@ def sync_executemany(sql: str, params_list: list): conn.commit() -# ─── Cloud SQL 同步连接池(双写用)─────────────────────────────── - -_cloud_sync_pool = None - -def get_cloud_sync_pool(): - global _cloud_sync_pool - if not CLOUD_PG_ENABLED: - return None - if _cloud_sync_pool is None: - try: - _cloud_sync_pool = psycopg2.pool.ThreadedConnectionPool( - minconn=1, maxconn=5, - host=CLOUD_PG_HOST, port=CLOUD_PG_PORT, - dbname=CLOUD_PG_DB, user=CLOUD_PG_USER, password=CLOUD_PG_PASS, - ) - except Exception as e: - import logging - logging.getLogger("db").error(f"Cloud SQL pool init failed: {e}") - return None - return _cloud_sync_pool - - -@contextmanager -def get_cloud_sync_conn(): - """获取Cloud SQL同步连接(失败返回None,不影响主流程)""" - pool = get_cloud_sync_pool() - if pool is None: - yield None - return - conn = None - try: - conn = pool.getconn() - yield conn - except Exception as e: - import logging - logging.getLogger("db").error(f"Cloud SQL conn error: {e}") - yield None - finally: - if conn and pool: - try: - pool.putconn(conn) - except Exception: - pass - - # ─── 异步连接池(asyncpg)───────────────────────────────────── _async_pool: asyncpg.Pool | None = None @@ -206,6 +154,7 @@ CREATE TABLE IF NOT EXISTS signal_indicators ( id BIGSERIAL PRIMARY KEY, ts BIGINT NOT NULL, symbol TEXT NOT NULL, + strategy TEXT, cvd_fast DOUBLE PRECISION, cvd_mid DOUBLE PRECISION, cvd_day DOUBLE PRECISION, @@ -219,10 +168,12 @@ CREATE TABLE IF NOT EXISTS signal_indicators ( buy_vol_1m DOUBLE PRECISION, sell_vol_1m DOUBLE PRECISION, score INTEGER, - signal TEXT + signal TEXT, + factors JSONB ); CREATE INDEX IF NOT EXISTS idx_si_ts ON signal_indicators(ts); CREATE INDEX IF NOT EXISTS idx_si_sym_ts ON signal_indicators(symbol, ts); +CREATE INDEX IF NOT EXISTS idx_si_strategy ON signal_indicators(strategy, ts); CREATE TABLE IF NOT EXISTS signal_indicators_1m ( id BIGSERIAL PRIMARY KEY, @@ -256,7 +207,7 @@ CREATE TABLE IF NOT EXISTS signal_trades ( status TEXT DEFAULT 'open' ); --- 信号日志(旧表兼容) +-- 信号日志(旧表兼容保留,不再写入新数据) CREATE TABLE IF NOT EXISTS signal_logs ( id BIGSERIAL PRIMARY KEY, symbol TEXT, @@ -266,7 +217,32 @@ CREATE TABLE IF NOT EXISTS signal_logs ( message TEXT ); --- 用户表(auth) +-- 市场指标(由 market_data_collector 写入) +CREATE TABLE IF NOT EXISTS market_indicators ( + id BIGSERIAL PRIMARY KEY, + ts BIGINT NOT NULL, + symbol TEXT NOT NULL, + indicator_type TEXT NOT NULL, + value DOUBLE PRECISION, + extra JSONB, + created_at TIMESTAMP DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_mi_sym_type_ts ON market_indicators(symbol, indicator_type, ts DESC); + +-- 清算数据(由 liquidation_collector 写入) +CREATE TABLE IF NOT EXISTS liquidations ( + id BIGSERIAL PRIMARY KEY, + ts BIGINT NOT NULL, + symbol TEXT NOT NULL, + side TEXT NOT NULL, + qty DOUBLE PRECISION, + price DOUBLE PRECISION, + usd_value DOUBLE PRECISION, + created_at TIMESTAMP DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_liq_sym_ts ON liquidations(symbol, ts DESC); + +-- 用户表(由 auth.py 负责完整定义,此处仅兜底) CREATE TABLE IF NOT EXISTS users ( id BIGSERIAL PRIMARY KEY, email TEXT UNIQUE NOT NULL, @@ -282,6 +258,13 @@ CREATE TABLE IF NOT EXISTS invite_codes ( created_at TIMESTAMP DEFAULT NOW() ); +CREATE TABLE IF NOT EXISTS invite_usage ( + id BIGSERIAL PRIMARY KEY, + code TEXT NOT NULL, + used_by BIGINT REFERENCES users(id), + used_at TIMESTAMP DEFAULT NOW() +); + -- 模拟盘交易表 CREATE TABLE IF NOT EXISTS paper_trades ( id BIGSERIAL PRIMARY KEY, @@ -300,6 +283,7 @@ CREATE TABLE IF NOT EXISTS paper_trades ( status TEXT DEFAULT 'active', pnl_r DOUBLE PRECISION DEFAULT 0, atr_at_entry DOUBLE PRECISION DEFAULT 0, + risk_distance DOUBLE PRECISION, score_factors JSONB, created_at TIMESTAMP DEFAULT NOW() ); @@ -373,7 +357,6 @@ def ensure_partitions(): for m in set(months): year = int(m[:4]) month = int(m[4:]) - # 计算分区范围(UTC毫秒时间戳) start = datetime.datetime(year, month, 1, tzinfo=datetime.timezone.utc) if month == 12: end = datetime.datetime(year + 1, 1, 1, tzinfo=datetime.timezone.utc) @@ -402,13 +385,22 @@ def init_schema(): if stmt: try: cur.execute(stmt) - except Exception as e: + except Exception: conn.rollback() - # 忽略已存在错误 continue - cur.execute( - "ALTER TABLE paper_trades " - "ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'" - ) + # 补全字段(向前兼容旧部署) + migrations = [ + "ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS strategy VARCHAR(32) DEFAULT 'v51_baseline'", + "ALTER TABLE paper_trades ADD COLUMN IF NOT EXISTS risk_distance DOUBLE PRECISION", + "ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS strategy TEXT", + "ALTER TABLE signal_indicators ADD COLUMN IF NOT EXISTS factors JSONB", + "ALTER TABLE users ADD COLUMN IF NOT EXISTS discord_id TEXT", + "ALTER TABLE users ADD COLUMN IF NOT EXISTS banned BOOLEAN DEFAULT FALSE", + ] + for m in migrations: + try: + cur.execute(m) + except Exception: + conn.rollback() conn.commit() ensure_partitions() diff --git a/backend/market_data_collector.py b/backend/market_data_collector.py index ebee07c..14e7851 100644 --- a/backend/market_data_collector.py +++ b/backend/market_data_collector.py @@ -12,7 +12,7 @@ from psycopg2.extras import Json SYMBOLS = ["BTCUSDT", "ETHUSDT", "XRPUSDT", "SOLUSDT"] INTERVAL_SECONDS = 300 -PG_HOST = os.getenv("PG_HOST", "127.0.0.1") +PG_HOST = os.getenv("PG_HOST", "10.106.0.3") PG_PORT = int(os.getenv("PG_PORT", "5432")) PG_DB = os.getenv("PG_DB", "arb_engine") PG_USER = os.getenv("PG_USER", "arb")