arbitrage-engine/backend/migrate_sqlite_to_pg.py

216 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
migrate_sqlite_to_pg.py — 将SQLite数据迁移到PostgreSQL
用法:
python3 backend/migrate_sqlite_to_pg.py
迁移顺序:
1. 创建PG schema + 分区
2. 迁移 rate_snapshots
3. 迁移 agg_trades按月表
4. 迁移 agg_trades_meta
5. 校验计数
"""
import os
import sqlite3
import time
import logging
import psycopg2
import psycopg2.extras
from db import get_sync_conn, init_schema, ensure_partitions
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("migrate")
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
def migrate_rate_snapshots(sqlite_conn, pg_conn):
logger.info("=== 迁移 rate_snapshots ===")
rows = sqlite_conn.execute("SELECT ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price FROM rate_snapshots ORDER BY ts").fetchall()
logger.info(f" SQLite: {len(rows)}")
if not rows:
return
values = [(r["ts"], r["btc_rate"], r["eth_rate"], r["btc_price"], r["eth_price"],
r["btc_index_price"], r["eth_index_price"]) for r in rows]
with pg_conn.cursor() as cur:
psycopg2.extras.execute_values(
cur,
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES %s ON CONFLICT DO NOTHING",
values,
template="(%s, %s, %s, %s, %s, %s, %s)",
page_size=5000,
)
pg_conn.commit()
logger.info(f" PG: {cur.rowcount} 条写入")
def migrate_agg_trades(sqlite_conn, pg_conn):
logger.info("=== 迁移 agg_trades ===")
# 确保PG分区存在
ensure_partitions()
# 找所有SQLite月表
tables = sqlite_conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
).fetchall()
total = 0
for t in tables:
tname = t["name"]
count = sqlite_conn.execute(f"SELECT COUNT(*) as c FROM {tname}").fetchone()["c"]
logger.info(f" {tname}: {count:,}")
if count == 0:
continue
# 确保对应PG分区存在
month = tname.replace("agg_trades_", "")
import datetime
year = int(month[:4])
mon = int(month[4:])
start = datetime.datetime(year, mon, 1)
if mon == 12:
end = datetime.datetime(year + 1, 1, 1)
else:
end = datetime.datetime(year, mon + 1, 1)
start_ms = int(start.timestamp() * 1000)
end_ms = int(end.timestamp() * 1000)
with pg_conn.cursor() as cur:
try:
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {tname}
PARTITION OF agg_trades
FOR VALUES FROM ({start_ms}) TO ({end_ms})
""")
pg_conn.commit()
except Exception:
pg_conn.rollback()
# 分批读取+写入
offset = 0
batch_size = 10000
batch_total = 0
while True:
rows = sqlite_conn.execute(
f"SELECT agg_id, symbol, price, qty, time_ms, is_buyer_maker FROM {tname} "
f"ORDER BY agg_id LIMIT {batch_size} OFFSET {offset}"
).fetchall()
if not rows:
break
values = [(r["agg_id"], r["symbol"], r["price"], r["qty"], r["time_ms"], r["is_buyer_maker"]) for r in rows]
with pg_conn.cursor() as cur:
psycopg2.extras.execute_values(
cur,
"INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) VALUES %s "
"ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING",
values,
template="(%s, %s, %s, %s, %s, %s)",
page_size=5000,
)
batch_total += cur.rowcount
pg_conn.commit()
offset += batch_size
if offset % 100000 == 0:
logger.info(f" {tname}: {offset:,}/{count:,} ({offset/count*100:.0f}%)")
total += batch_total
logger.info(f" {tname}: 完成,写入 {batch_total:,}")
logger.info(f" agg_trades 总计: {total:,}")
def migrate_meta(sqlite_conn, pg_conn):
logger.info("=== 迁移 agg_trades_meta ===")
try:
rows = sqlite_conn.execute("SELECT * FROM agg_trades_meta").fetchall()
except Exception:
rows = []
for r in rows:
with pg_conn.cursor() as cur:
cur.execute("""
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms, updated_at)
VALUES (%s, %s, %s, %s, %s, %s)
ON CONFLICT(symbol) DO UPDATE SET
last_agg_id = EXCLUDED.last_agg_id,
last_time_ms = EXCLUDED.last_time_ms,
earliest_agg_id = EXCLUDED.earliest_agg_id,
earliest_time_ms = EXCLUDED.earliest_time_ms
""", (
r["symbol"],
r["last_agg_id"],
r.get("last_time_ms"),
r.get("earliest_agg_id"),
r.get("earliest_time_ms"),
r.get("updated_at"),
))
pg_conn.commit()
logger.info(f" {len(rows)} 条 meta 迁移完成")
def verify(sqlite_conn, pg_conn):
logger.info("=== 校验 ===")
# rate_snapshots
sq_count = sqlite_conn.execute("SELECT COUNT(*) as c FROM rate_snapshots").fetchone()["c"]
with pg_conn.cursor() as cur:
cur.execute("SELECT COUNT(*) FROM rate_snapshots")
pg_count = cur.fetchone()[0]
logger.info(f" rate_snapshots: SQLite={sq_count:,}, PG={pg_count:,} {'' if pg_count >= sq_count else '⚠️'}")
# agg_trades
tables = sqlite_conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'"
).fetchall()
sq_total = 0
for t in tables:
c = sqlite_conn.execute(f"SELECT COUNT(*) as c FROM {t['name']}").fetchone()["c"]
sq_total += c
with pg_conn.cursor() as cur:
cur.execute("SELECT COUNT(*) FROM agg_trades")
pg_total = cur.fetchone()[0]
logger.info(f" agg_trades: SQLite={sq_total:,}, PG={pg_total:,} {'' if pg_total >= sq_total else '⚠️'}")
def main():
logger.info("开始 SQLite → PostgreSQL 迁移")
start = time.time()
# 初始化PG
init_schema()
# 打开SQLite
sqlite_conn = sqlite3.connect(SQLITE_PATH)
sqlite_conn.row_factory = sqlite3.Row
# 获取PG连接
with get_sync_conn() as pg_conn:
migrate_rate_snapshots(sqlite_conn, pg_conn)
migrate_agg_trades(sqlite_conn, pg_conn)
migrate_meta(sqlite_conn, pg_conn)
verify(sqlite_conn, pg_conn)
sqlite_conn.close()
elapsed = time.time() - start
logger.info(f"=== 迁移完成,耗时 {elapsed:.0f}秒 ===")
if __name__ == "__main__":
main()