216 lines
6.9 KiB
Plaintext
216 lines
6.9 KiB
Plaintext
"""
|
||
migrate_sqlite_to_pg.py — 将SQLite数据迁移到PostgreSQL
|
||
|
||
用法:
|
||
python3 backend/migrate_sqlite_to_pg.py
|
||
|
||
迁移顺序:
|
||
1. 创建PG schema + 分区
|
||
2. 迁移 rate_snapshots
|
||
3. 迁移 agg_trades(按月表)
|
||
4. 迁移 agg_trades_meta
|
||
5. 校验计数
|
||
"""
|
||
|
||
import os
|
||
import sqlite3
|
||
import time
|
||
import logging
|
||
|
||
import psycopg2
|
||
import psycopg2.extras
|
||
|
||
from db import get_sync_conn, init_schema, ensure_partitions
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
|
||
logger = logging.getLogger("migrate")
|
||
|
||
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||
|
||
|
||
def migrate_rate_snapshots(sqlite_conn, pg_conn):
|
||
logger.info("=== 迁移 rate_snapshots ===")
|
||
rows = sqlite_conn.execute("SELECT ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price FROM rate_snapshots ORDER BY ts").fetchall()
|
||
logger.info(f" SQLite: {len(rows)} 条")
|
||
|
||
if not rows:
|
||
return
|
||
|
||
values = [(r["ts"], r["btc_rate"], r["eth_rate"], r["btc_price"], r["eth_price"],
|
||
r["btc_index_price"], r["eth_index_price"]) for r in rows]
|
||
|
||
with pg_conn.cursor() as cur:
|
||
psycopg2.extras.execute_values(
|
||
cur,
|
||
"INSERT INTO rate_snapshots (ts, btc_rate, eth_rate, btc_price, eth_price, btc_index_price, eth_index_price) VALUES %s ON CONFLICT DO NOTHING",
|
||
values,
|
||
template="(%s, %s, %s, %s, %s, %s, %s)",
|
||
page_size=5000,
|
||
)
|
||
pg_conn.commit()
|
||
logger.info(f" PG: {cur.rowcount} 条写入")
|
||
|
||
|
||
def migrate_agg_trades(sqlite_conn, pg_conn):
|
||
logger.info("=== 迁移 agg_trades ===")
|
||
|
||
# 确保PG分区存在
|
||
ensure_partitions()
|
||
|
||
# 找所有SQLite月表
|
||
tables = sqlite_conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%' ORDER BY name"
|
||
).fetchall()
|
||
|
||
total = 0
|
||
for t in tables:
|
||
tname = t["name"]
|
||
count = sqlite_conn.execute(f"SELECT COUNT(*) as c FROM {tname}").fetchone()["c"]
|
||
logger.info(f" {tname}: {count:,} 条")
|
||
|
||
if count == 0:
|
||
continue
|
||
|
||
# 确保对应PG分区存在
|
||
month = tname.replace("agg_trades_", "")
|
||
import datetime
|
||
year = int(month[:4])
|
||
mon = int(month[4:])
|
||
start = datetime.datetime(year, mon, 1)
|
||
if mon == 12:
|
||
end = datetime.datetime(year + 1, 1, 1)
|
||
else:
|
||
end = datetime.datetime(year, mon + 1, 1)
|
||
start_ms = int(start.timestamp() * 1000)
|
||
end_ms = int(end.timestamp() * 1000)
|
||
|
||
with pg_conn.cursor() as cur:
|
||
try:
|
||
cur.execute(f"""
|
||
CREATE TABLE IF NOT EXISTS {tname}
|
||
PARTITION OF agg_trades
|
||
FOR VALUES FROM ({start_ms}) TO ({end_ms})
|
||
""")
|
||
pg_conn.commit()
|
||
except Exception:
|
||
pg_conn.rollback()
|
||
|
||
# 分批读取+写入
|
||
offset = 0
|
||
batch_size = 10000
|
||
batch_total = 0
|
||
|
||
while True:
|
||
rows = sqlite_conn.execute(
|
||
f"SELECT agg_id, symbol, price, qty, time_ms, is_buyer_maker FROM {tname} "
|
||
f"ORDER BY agg_id LIMIT {batch_size} OFFSET {offset}"
|
||
).fetchall()
|
||
|
||
if not rows:
|
||
break
|
||
|
||
values = [(r["agg_id"], r["symbol"], r["price"], r["qty"], r["time_ms"], r["is_buyer_maker"]) for r in rows]
|
||
|
||
with pg_conn.cursor() as cur:
|
||
psycopg2.extras.execute_values(
|
||
cur,
|
||
"INSERT INTO agg_trades (agg_id, symbol, price, qty, time_ms, is_buyer_maker) VALUES %s "
|
||
"ON CONFLICT (time_ms, symbol, agg_id) DO NOTHING",
|
||
values,
|
||
template="(%s, %s, %s, %s, %s, %s)",
|
||
page_size=5000,
|
||
)
|
||
batch_total += cur.rowcount
|
||
pg_conn.commit()
|
||
|
||
offset += batch_size
|
||
|
||
if offset % 100000 == 0:
|
||
logger.info(f" {tname}: {offset:,}/{count:,} ({offset/count*100:.0f}%)")
|
||
|
||
total += batch_total
|
||
logger.info(f" {tname}: 完成,写入 {batch_total:,} 条")
|
||
|
||
logger.info(f" agg_trades 总计: {total:,} 条")
|
||
|
||
|
||
def migrate_meta(sqlite_conn, pg_conn):
|
||
logger.info("=== 迁移 agg_trades_meta ===")
|
||
try:
|
||
rows = sqlite_conn.execute("SELECT * FROM agg_trades_meta").fetchall()
|
||
except Exception:
|
||
rows = []
|
||
|
||
for r in rows:
|
||
with pg_conn.cursor() as cur:
|
||
cur.execute("""
|
||
INSERT INTO agg_trades_meta (symbol, last_agg_id, last_time_ms, earliest_agg_id, earliest_time_ms, updated_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s)
|
||
ON CONFLICT(symbol) DO UPDATE SET
|
||
last_agg_id = EXCLUDED.last_agg_id,
|
||
last_time_ms = EXCLUDED.last_time_ms,
|
||
earliest_agg_id = EXCLUDED.earliest_agg_id,
|
||
earliest_time_ms = EXCLUDED.earliest_time_ms
|
||
""", (
|
||
r["symbol"],
|
||
r["last_agg_id"],
|
||
r.get("last_time_ms"),
|
||
r.get("earliest_agg_id"),
|
||
r.get("earliest_time_ms"),
|
||
r.get("updated_at"),
|
||
))
|
||
pg_conn.commit()
|
||
logger.info(f" {len(rows)} 条 meta 迁移完成")
|
||
|
||
|
||
def verify(sqlite_conn, pg_conn):
|
||
logger.info("=== 校验 ===")
|
||
|
||
# rate_snapshots
|
||
sq_count = sqlite_conn.execute("SELECT COUNT(*) as c FROM rate_snapshots").fetchone()["c"]
|
||
with pg_conn.cursor() as cur:
|
||
cur.execute("SELECT COUNT(*) FROM rate_snapshots")
|
||
pg_count = cur.fetchone()[0]
|
||
logger.info(f" rate_snapshots: SQLite={sq_count:,}, PG={pg_count:,} {'✅' if pg_count >= sq_count else '⚠️'}")
|
||
|
||
# agg_trades
|
||
tables = sqlite_conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'agg_trades_2%'"
|
||
).fetchall()
|
||
sq_total = 0
|
||
for t in tables:
|
||
c = sqlite_conn.execute(f"SELECT COUNT(*) as c FROM {t['name']}").fetchone()["c"]
|
||
sq_total += c
|
||
|
||
with pg_conn.cursor() as cur:
|
||
cur.execute("SELECT COUNT(*) FROM agg_trades")
|
||
pg_total = cur.fetchone()[0]
|
||
logger.info(f" agg_trades: SQLite={sq_total:,}, PG={pg_total:,} {'✅' if pg_total >= sq_total else '⚠️'}")
|
||
|
||
|
||
def main():
|
||
logger.info("开始 SQLite → PostgreSQL 迁移")
|
||
start = time.time()
|
||
|
||
# 初始化PG
|
||
init_schema()
|
||
|
||
# 打开SQLite
|
||
sqlite_conn = sqlite3.connect(SQLITE_PATH)
|
||
sqlite_conn.row_factory = sqlite3.Row
|
||
|
||
# 获取PG连接
|
||
with get_sync_conn() as pg_conn:
|
||
migrate_rate_snapshots(sqlite_conn, pg_conn)
|
||
migrate_agg_trades(sqlite_conn, pg_conn)
|
||
migrate_meta(sqlite_conn, pg_conn)
|
||
verify(sqlite_conn, pg_conn)
|
||
|
||
sqlite_conn.close()
|
||
elapsed = time.time() - start
|
||
logger.info(f"=== 迁移完成,耗时 {elapsed:.0f}秒 ===")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|