arbitrage-engine/archive/migrate_auth_sqlite_to_pg.py.archived

171 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
migrate_auth_sqlite_to_pg.py — 将SQLite中的auth相关表迁移到PG
运行前确保PG连接参数正确db.py中的配置
"""
import os, sys, sqlite3
sys.path.insert(0, os.path.dirname(__file__))
from db import get_sync_conn
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
def migrate():
if not os.path.exists(SQLITE_PATH):
print(f"SQLite DB not found: {SQLITE_PATH}")
return
sq = sqlite3.connect(SQLITE_PATH)
sq.row_factory = sqlite3.Row
with get_sync_conn() as pg:
cur = pg.cursor()
# 1. 建auth相关表
print("Creating auth tables in PG...")
auth_tables = """
CREATE TABLE IF NOT EXISTS subscriptions (
user_id BIGINT PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
used_at TEXT
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT
);
"""
for stmt in auth_tables.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
pg.rollback()
pg.commit()
# 2. 迁移usersPG已有users表但可能是空的旧结构
print("Migrating users...")
# 先加缺失列
for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]:
try:
cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
rows = sq.execute("SELECT * FROM users").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
discord_id = EXCLUDED.discord_id,
role = EXCLUDED.role,
banned = EXCLUDED.banned""",
(r["id"], r["email"], r["password_hash"],
r["discord_id"] if "discord_id" in r.keys() else None,
r["role"], r["banned"], r["created_at"])
)
except Exception as e:
print(f" User {r['email']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} users")
# 3. 迁移invite_codesPG已有但可能缺列
print("Migrating invite_codes...")
for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"),
("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"),
("expires_at", "TEXT")]:
try:
cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
try:
rows = sq.execute("SELECT * FROM invite_codes").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO NOTHING""",
(r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None,
r["max_uses"], r["used_count"], r["status"],
r["expires_at"] if "expires_at" in r.keys() else None,
r["created_at"] if "created_at" in r.keys() else None)
)
except Exception as e:
print(f" Invite {r['code']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} invite codes")
except Exception as e:
print(f" invite_codes table error: {e}")
# 4. 迁移subscriptions
print("Migrating subscriptions...")
try:
rows = sq.execute("SELECT * FROM subscriptions").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING",
(r["user_id"], r["tier"], r["expires_at"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} subscriptions")
except Exception as e:
print(f" subscriptions error: {e}")
# 5. 迁移refresh_tokens
print("Migrating refresh_tokens...")
try:
rows = sq.execute("SELECT * FROM refresh_tokens").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING",
(r["user_id"], r["token"], r["expires_at"], r["revoked"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} refresh tokens")
except Exception as e:
print(f" refresh_tokens error: {e}")
# 6. 重置序列
print("Resetting sequences...")
for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]:
try:
cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}")
pg.commit()
except:
pg.rollback()
sq.close()
print("\nAuth migration complete!")
if __name__ == "__main__":
migrate()