diff --git a/backend/agg_trades_collector.py b/backend/agg_trades_collector.py index 1232f27..2ee83c3 100644 --- a/backend/agg_trades_collector.py +++ b/backend/agg_trades_collector.py @@ -22,7 +22,7 @@ import psycopg2 import psycopg2.extras import websockets -from db import get_sync_conn, get_sync_pool, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS +from db import get_sync_conn, get_sync_pool, get_cloud_sync_conn, ensure_partitions, PG_HOST, PG_PORT, PG_DB, PG_USER, PG_PASS, CLOUD_PG_ENABLED logging.basicConfig( level=logging.INFO, @@ -69,50 +69,65 @@ def update_meta(conn, symbol: str, last_agg_id: int, last_time_ms: int): def flush_buffer(symbol: str, trades: list) -> int: - """写入一批trades到PG,返回实际写入条数""" + """写入一批trades到PG(本地+Cloud SQL双写),返回实际写入条数""" if not trades: return 0 try: # 确保分区存在 ensure_partitions() + values = [] + last_agg_id = 0 + last_time_ms = 0 + + for t in trades: + agg_id = t["a"] + time_ms = t["T"] + values.append(( + agg_id, symbol, + float(t["p"]), float(t["q"]), + time_ms, + 1 if t["m"] else 0, + )) + if agg_id > last_agg_id: + last_agg_id = agg_id + last_time_ms = time_ms + + insert_sql = """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) + VALUES %s + ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""" + insert_template = "(%s, %s, %s, %s, %s, %s)" + + # 写本地PG + inserted = 0 with get_sync_conn() as conn: with conn.cursor() as cur: - # 批量插入(ON CONFLICT忽略重复) - values = [] - last_agg_id = 0 - last_time_ms = 0 - - for t in trades: - agg_id = t["a"] - time_ms = t["T"] - values.append(( - agg_id, symbol, - float(t["p"]), float(t["q"]), - time_ms, - 1 if t["m"] else 0, - )) - if agg_id > last_agg_id: - last_agg_id = agg_id - last_time_ms = time_ms - - # 批量INSERT psycopg2.extras.execute_values( - cur, - """INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) - VALUES %s - ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING""", - values, - template="(%s, %s, %s, %s, %s, %s)", - page_size=1000, + cur, insert_sql, values, + template=insert_template, page_size=1000, ) inserted = cur.rowcount - if last_agg_id > 0: update_meta(conn, symbol, last_agg_id, last_time_ms) - conn.commit() - return inserted + + # 双写Cloud SQL(失败不影响主流程) + if CLOUD_PG_ENABLED: + try: + with get_cloud_sync_conn() as cloud_conn: + if cloud_conn: + with cloud_conn.cursor() as cur: + psycopg2.extras.execute_values( + cur, insert_sql, values, + template=insert_template, page_size=1000, + ) + if last_agg_id > 0: + update_meta(cloud_conn, symbol, last_agg_id, last_time_ms) + cloud_conn.commit() + except Exception as e: + logger.warning(f"[{symbol}] Cloud SQL write failed (non-fatal): {e}") + + return inserted except Exception as e: logger.error(f"flush_buffer [{symbol}] error: {e}") return 0 diff --git a/backend/db.py b/backend/db.py index 326a01c..e9375d1 100644 --- a/backend/db.py +++ b/backend/db.py @@ -11,7 +11,7 @@ import psycopg2 import psycopg2.pool from contextlib import contextmanager -# PG连接参数 +# PG连接参数(本地) PG_HOST = os.getenv("PG_HOST", "127.0.0.1") PG_PORT = int(os.getenv("PG_PORT", 5432)) PG_DB = os.getenv("PG_DB", "arb_engine") @@ -20,6 +20,14 @@ PG_PASS = os.getenv("PG_PASS", "arb_engine_2026") PG_DSN = f"postgresql://{PG_USER}:{PG_PASS}@{PG_HOST}:{PG_PORT}/{PG_DB}" +# Cloud SQL连接参数(双写目标) +CLOUD_PG_HOST = os.getenv("CLOUD_PG_HOST", "10.106.0.3") +CLOUD_PG_PORT = int(os.getenv("CLOUD_PG_PORT", 5432)) +CLOUD_PG_DB = os.getenv("CLOUD_PG_DB", "arb_engine") +CLOUD_PG_USER = os.getenv("CLOUD_PG_USER", "arb") +CLOUD_PG_PASS = os.getenv("CLOUD_PG_PASS", "arb_engine_2026") +CLOUD_PG_ENABLED = os.getenv("CLOUD_PG_ENABLED", "true").lower() == "true" + # ─── 同步连接池(psycopg2)───────────────────────────────────── _sync_pool = None @@ -65,6 +73,51 @@ def sync_executemany(sql: str, params_list: list): conn.commit() +# ─── Cloud SQL 同步连接池(双写用)─────────────────────────────── + +_cloud_sync_pool = None + +def get_cloud_sync_pool(): + global _cloud_sync_pool + if not CLOUD_PG_ENABLED: + return None + if _cloud_sync_pool is None: + try: + _cloud_sync_pool = psycopg2.pool.ThreadedConnectionPool( + minconn=1, maxconn=5, + host=CLOUD_PG_HOST, port=CLOUD_PG_PORT, + dbname=CLOUD_PG_DB, user=CLOUD_PG_USER, password=CLOUD_PG_PASS, + ) + except Exception as e: + import logging + logging.getLogger("db").error(f"Cloud SQL pool init failed: {e}") + return None + return _cloud_sync_pool + + +@contextmanager +def get_cloud_sync_conn(): + """获取Cloud SQL同步连接(失败返回None,不影响主流程)""" + pool = get_cloud_sync_pool() + if pool is None: + yield None + return + conn = None + try: + conn = pool.getconn() + yield conn + except Exception as e: + import logging + logging.getLogger("db").error(f"Cloud SQL conn error: {e}") + yield None + finally: + if conn and pool: + try: + pool.putconn(conn) + except Exception: + pass + + # ─── 异步连接池(asyncpg)───────────────────────────────────── _async_pool: asyncpg.Pool | None = None