#!/usr/bin/env python3 """ migrate_auth_sqlite_to_pg.py — 将SQLite中的auth相关表迁移到PG 运行前确保PG连接参数正确(db.py中的配置) """ import os, sys, sqlite3 sys.path.insert(0, os.path.dirname(__file__)) from db import get_sync_conn SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") def migrate(): if not os.path.exists(SQLITE_PATH): print(f"SQLite DB not found: {SQLITE_PATH}") return sq = sqlite3.connect(SQLITE_PATH) sq.row_factory = sqlite3.Row with get_sync_conn() as pg: cur = pg.cursor() # 1. 建auth相关表 print("Creating auth tables in PG...") auth_tables = """ CREATE TABLE IF NOT EXISTS subscriptions ( user_id BIGINT PRIMARY KEY, tier TEXT NOT NULL DEFAULT 'free', expires_at TEXT ); CREATE TABLE IF NOT EXISTS invite_usage ( id BIGSERIAL PRIMARY KEY, code_id BIGINT NOT NULL, user_id BIGINT NOT NULL, used_at TEXT ); CREATE TABLE IF NOT EXISTS refresh_tokens ( id BIGSERIAL PRIMARY KEY, user_id BIGINT NOT NULL, token TEXT UNIQUE NOT NULL, expires_at TEXT NOT NULL, revoked INTEGER NOT NULL DEFAULT 0, created_at TEXT ); """ for stmt in auth_tables.split(";"): stmt = stmt.strip() if stmt: try: cur.execute(stmt) except Exception as e: pg.rollback() pg.commit() # 2. 迁移users(PG已有users表但可能是空的旧结构) print("Migrating users...") # 先加缺失列 for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]: try: cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}") pg.commit() except: pg.rollback() rows = sq.execute("SELECT * FROM users").fetchall() for r in rows: try: cur.execute( """INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at) VALUES (%s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO UPDATE SET password_hash = EXCLUDED.password_hash, discord_id = EXCLUDED.discord_id, role = EXCLUDED.role, banned = EXCLUDED.banned""", (r["id"], r["email"], r["password_hash"], r["discord_id"] if "discord_id" in r.keys() else None, r["role"], r["banned"], r["created_at"]) ) except Exception as e: print(f" User {r['email']} error: {e}") pg.rollback() continue pg.commit() print(f" Migrated {len(rows)} users") # 3. 迁移invite_codes(PG已有但可能缺列) print("Migrating invite_codes...") for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"), ("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"), ("expires_at", "TEXT")]: try: cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}") pg.commit() except: pg.rollback() try: rows = sq.execute("SELECT * FROM invite_codes").fetchall() for r in rows: try: cur.execute( """INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (id) DO NOTHING""", (r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None, r["max_uses"], r["used_count"], r["status"], r["expires_at"] if "expires_at" in r.keys() else None, r["created_at"] if "created_at" in r.keys() else None) ) except Exception as e: print(f" Invite {r['code']} error: {e}") pg.rollback() continue pg.commit() print(f" Migrated {len(rows)} invite codes") except Exception as e: print(f" invite_codes table error: {e}") # 4. 迁移subscriptions print("Migrating subscriptions...") try: rows = sq.execute("SELECT * FROM subscriptions").fetchall() for r in rows: try: cur.execute( "INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING", (r["user_id"], r["tier"], r["expires_at"]) ) except Exception as e: pg.rollback() pg.commit() print(f" Migrated {len(rows)} subscriptions") except Exception as e: print(f" subscriptions error: {e}") # 5. 迁移refresh_tokens print("Migrating refresh_tokens...") try: rows = sq.execute("SELECT * FROM refresh_tokens").fetchall() for r in rows: try: cur.execute( "INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING", (r["user_id"], r["token"], r["expires_at"], r["revoked"]) ) except Exception as e: pg.rollback() pg.commit() print(f" Migrated {len(rows)} refresh tokens") except Exception as e: print(f" refresh_tokens error: {e}") # 6. 重置序列 print("Resetting sequences...") for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]: try: cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}") pg.commit() except: pg.rollback() sq.close() print("\nAuth migration complete!") if __name__ == "__main__": migrate()