- auth.py: rewrite to use PG via db.py (was sqlite3) - admin_cli.py: rewrite to use PG - migrate_auth_sqlite_to_pg.py: one-time migration script - SQLite arb.db no longer needed after migration
171 lines
6.4 KiB
Python
171 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
migrate_auth_sqlite_to_pg.py — 将SQLite中的auth相关表迁移到PG
|
||
运行前确保PG连接参数正确(db.py中的配置)
|
||
"""
|
||
import os, sys, sqlite3
|
||
|
||
sys.path.insert(0, os.path.dirname(__file__))
|
||
from db import get_sync_conn
|
||
|
||
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
|
||
|
||
|
||
def migrate():
|
||
if not os.path.exists(SQLITE_PATH):
|
||
print(f"SQLite DB not found: {SQLITE_PATH}")
|
||
return
|
||
|
||
sq = sqlite3.connect(SQLITE_PATH)
|
||
sq.row_factory = sqlite3.Row
|
||
|
||
with get_sync_conn() as pg:
|
||
cur = pg.cursor()
|
||
|
||
# 1. 建auth相关表
|
||
print("Creating auth tables in PG...")
|
||
auth_tables = """
|
||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||
user_id BIGINT PRIMARY KEY,
|
||
tier TEXT NOT NULL DEFAULT 'free',
|
||
expires_at TEXT
|
||
);
|
||
CREATE TABLE IF NOT EXISTS invite_usage (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
code_id BIGINT NOT NULL,
|
||
user_id BIGINT NOT NULL,
|
||
used_at TEXT
|
||
);
|
||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
user_id BIGINT NOT NULL,
|
||
token TEXT UNIQUE NOT NULL,
|
||
expires_at TEXT NOT NULL,
|
||
revoked INTEGER NOT NULL DEFAULT 0,
|
||
created_at TEXT
|
||
);
|
||
"""
|
||
for stmt in auth_tables.split(";"):
|
||
stmt = stmt.strip()
|
||
if stmt:
|
||
try:
|
||
cur.execute(stmt)
|
||
except Exception as e:
|
||
pg.rollback()
|
||
pg.commit()
|
||
|
||
# 2. 迁移users(PG已有users表但可能是空的旧结构)
|
||
print("Migrating users...")
|
||
# 先加缺失列
|
||
for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]:
|
||
try:
|
||
cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
|
||
pg.commit()
|
||
except:
|
||
pg.rollback()
|
||
|
||
rows = sq.execute("SELECT * FROM users").fetchall()
|
||
for r in rows:
|
||
try:
|
||
cur.execute(
|
||
"""INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||
ON CONFLICT (id) DO UPDATE SET
|
||
password_hash = EXCLUDED.password_hash,
|
||
discord_id = EXCLUDED.discord_id,
|
||
role = EXCLUDED.role,
|
||
banned = EXCLUDED.banned""",
|
||
(r["id"], r["email"], r["password_hash"],
|
||
r["discord_id"] if "discord_id" in r.keys() else None,
|
||
r["role"], r["banned"], r["created_at"])
|
||
)
|
||
except Exception as e:
|
||
print(f" User {r['email']} error: {e}")
|
||
pg.rollback()
|
||
continue
|
||
pg.commit()
|
||
print(f" Migrated {len(rows)} users")
|
||
|
||
# 3. 迁移invite_codes(PG已有但可能缺列)
|
||
print("Migrating invite_codes...")
|
||
for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"),
|
||
("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"),
|
||
("expires_at", "TEXT")]:
|
||
try:
|
||
cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}")
|
||
pg.commit()
|
||
except:
|
||
pg.rollback()
|
||
|
||
try:
|
||
rows = sq.execute("SELECT * FROM invite_codes").fetchall()
|
||
for r in rows:
|
||
try:
|
||
cur.execute(
|
||
"""INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
ON CONFLICT (id) DO NOTHING""",
|
||
(r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None,
|
||
r["max_uses"], r["used_count"], r["status"],
|
||
r["expires_at"] if "expires_at" in r.keys() else None,
|
||
r["created_at"] if "created_at" in r.keys() else None)
|
||
)
|
||
except Exception as e:
|
||
print(f" Invite {r['code']} error: {e}")
|
||
pg.rollback()
|
||
continue
|
||
pg.commit()
|
||
print(f" Migrated {len(rows)} invite codes")
|
||
except Exception as e:
|
||
print(f" invite_codes table error: {e}")
|
||
|
||
# 4. 迁移subscriptions
|
||
print("Migrating subscriptions...")
|
||
try:
|
||
rows = sq.execute("SELECT * FROM subscriptions").fetchall()
|
||
for r in rows:
|
||
try:
|
||
cur.execute(
|
||
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING",
|
||
(r["user_id"], r["tier"], r["expires_at"])
|
||
)
|
||
except Exception as e:
|
||
pg.rollback()
|
||
pg.commit()
|
||
print(f" Migrated {len(rows)} subscriptions")
|
||
except Exception as e:
|
||
print(f" subscriptions error: {e}")
|
||
|
||
# 5. 迁移refresh_tokens
|
||
print("Migrating refresh_tokens...")
|
||
try:
|
||
rows = sq.execute("SELECT * FROM refresh_tokens").fetchall()
|
||
for r in rows:
|
||
try:
|
||
cur.execute(
|
||
"INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING",
|
||
(r["user_id"], r["token"], r["expires_at"], r["revoked"])
|
||
)
|
||
except Exception as e:
|
||
pg.rollback()
|
||
pg.commit()
|
||
print(f" Migrated {len(rows)} refresh tokens")
|
||
except Exception as e:
|
||
print(f" refresh_tokens error: {e}")
|
||
|
||
# 6. 重置序列
|
||
print("Resetting sequences...")
|
||
for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]:
|
||
try:
|
||
cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}")
|
||
pg.commit()
|
||
except:
|
||
pg.rollback()
|
||
|
||
sq.close()
|
||
print("\nAuth migration complete!")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
migrate()
|