diff --git a/backend/admin_cli.py b/backend/admin_cli.py index 5da6d5c..fbaf420 100644 --- a/backend/admin_cli.py +++ b/backend/admin_cli.py @@ -1,82 +1,90 @@ #!/usr/bin/env python3 -"""Admin CLI for Arbitrage Engine""" -import sys, os, sqlite3, secrets, json +"""Admin CLI for Arbitrage Engine (PostgreSQL version)""" +import sys, os, secrets -DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") +sys.path.insert(0, os.path.dirname(__file__)) +from db import get_sync_conn -def get_conn(): - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - return conn def gen_invite(count=1, max_uses=1): - conn = get_conn() - codes = [] - for _ in range(count): - code = secrets.token_urlsafe(6)[:8].upper() - conn.execute( - "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)", - (code, max_uses) - ) - codes.append(code) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + codes = [] + for _ in range(count): + code = secrets.token_urlsafe(6)[:8].upper() + cur.execute( + "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, 1, %s)", + (code, max_uses) + ) + codes.append(code) + conn.commit() for c in codes: print(f" {c}") print(f"\nGenerated {len(codes)} invite code(s)") + def list_invites(): - conn = get_conn() - rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC") + cols = [desc[0] for desc in cur.description] + rows = [dict(zip(cols, row)) for row in cur.fetchall()] if not rows: print("No invite codes found") return print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}") print("-" * 60) for r in rows: - print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}") + print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {str(r['created_at']):>20}") + def disable_invite(code): - conn = get_conn() - conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,)) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = %s", (code,)) + conn.commit() print(f"Disabled invite code: {code}") + def list_users(): - conn = get_conn() - rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC") + cols = [desc[0] for desc in cur.description] + rows = [dict(zip(cols, row)) for row in cur.fetchall()] if not rows: print("No users found") return print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}") print("-" * 72) for r in rows: - print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}") + print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {str(r['created_at']):>20}") + def ban_user(user_id): - conn = get_conn() - conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,)) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("UPDATE users SET banned = 1 WHERE id = %s", (user_id,)) + conn.commit() print(f"Banned user ID: {user_id}") + def unban_user(user_id): - conn = get_conn() - conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,)) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("UPDATE users SET banned = 0 WHERE id = %s", (user_id,)) + conn.commit() print(f"Unbanned user ID: {user_id}") + def set_admin(user_id): - conn = get_conn() - conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,)) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute("UPDATE users SET role = 'admin' WHERE id = %s", (user_id,)) + conn.commit() print(f"Set user {user_id} as admin") + def usage(): print("""Usage: python3 admin_cli.py [args] diff --git a/backend/auth.py b/backend/auth.py index 651ea9e..5413a49 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -1,5 +1,4 @@ import os -import sqlite3 import hashlib import secrets import hmac @@ -7,12 +6,12 @@ import base64 import json from datetime import datetime, timedelta from typing import Optional -from functools import wraps from fastapi import APIRouter, HTTPException, Header, Depends, Request from pydantic import BaseModel, EmailStr -DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") +from db import get_sync_conn + JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026") ACCESS_TOKEN_HOURS = 24 REFRESH_TOKEN_DAYS = 7 @@ -20,90 +19,98 @@ REFRESH_TOKEN_DAYS = 7 router = APIRouter(prefix="/api", tags=["auth"]) -# ─── DB helpers ─────────────────────────────────────────────── +# ─── PG Schema ─────────────────────────────────────────────── -def get_conn(): - conn = sqlite3.connect(DB_PATH) - conn.row_factory = sqlite3.Row - return conn +AUTH_SCHEMA = """ +CREATE TABLE IF NOT EXISTS users ( + id BIGSERIAL PRIMARY KEY, + email TEXT UNIQUE NOT NULL, + password_hash TEXT NOT NULL, + discord_id TEXT, + role TEXT NOT NULL DEFAULT 'user', + banned INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS subscriptions ( + user_id BIGINT PRIMARY KEY REFERENCES users(id), + tier TEXT NOT NULL DEFAULT 'free', + expires_at TEXT +); + +CREATE TABLE IF NOT EXISTS invite_codes ( + id BIGSERIAL PRIMARY KEY, + code TEXT UNIQUE NOT NULL, + created_by INTEGER, + max_uses INTEGER NOT NULL DEFAULT 1, + used_count INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'active', + expires_at TEXT, + created_at TEXT DEFAULT (NOW()::TEXT) +); + +CREATE TABLE IF NOT EXISTS invite_usage ( + id BIGSERIAL PRIMARY KEY, + code_id BIGINT NOT NULL REFERENCES invite_codes(id), + user_id BIGINT NOT NULL REFERENCES users(id), + used_at TEXT DEFAULT (NOW()::TEXT) +); + +CREATE TABLE IF NOT EXISTS refresh_tokens ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id), + token TEXT UNIQUE NOT NULL, + expires_at TEXT NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_at TEXT DEFAULT (NOW()::TEXT) +); +""" def ensure_tables(): - conn = get_conn() - cur = conn.cursor() - cur.execute(""" - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - email TEXT UNIQUE NOT NULL, - password_hash TEXT NOT NULL, - discord_id TEXT, - role TEXT NOT NULL DEFAULT 'user', - banned INTEGER NOT NULL DEFAULT 0, - created_at TEXT NOT NULL - ) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS subscriptions ( - user_id INTEGER PRIMARY KEY, - tier TEXT NOT NULL DEFAULT 'free', - expires_at TEXT, - FOREIGN KEY(user_id) REFERENCES users(id) - ) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS signal_logs ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - symbol TEXT NOT NULL, - rate REAL NOT NULL, - annualized REAL NOT NULL, - sent_at TEXT NOT NULL, - message TEXT NOT NULL - ) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS invite_codes ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - code TEXT UNIQUE NOT NULL, - created_by INTEGER, - max_uses INTEGER NOT NULL DEFAULT 1, - used_count INTEGER NOT NULL DEFAULT 0, - status TEXT NOT NULL DEFAULT 'active', - expires_at TEXT, - created_at TEXT DEFAULT (datetime('now')) - ) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS invite_usage ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - code_id INTEGER NOT NULL, - user_id INTEGER NOT NULL, - used_at TEXT DEFAULT (datetime('now')), - FOREIGN KEY (code_id) REFERENCES invite_codes(id), - FOREIGN KEY (user_id) REFERENCES users(id) - ) - """) - cur.execute(""" - CREATE TABLE IF NOT EXISTS refresh_tokens ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - token TEXT UNIQUE NOT NULL, - expires_at TEXT NOT NULL, - revoked INTEGER NOT NULL DEFAULT 0, - created_at TEXT DEFAULT (datetime('now')), - FOREIGN KEY (user_id) REFERENCES users(id) - ) - """) - # migrate: add role/banned columns if missing - try: - cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'") - except: - pass - try: - cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0") - except: - pass - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + for stmt in AUTH_SCHEMA.split(";"): + stmt = stmt.strip() + if stmt: + try: + cur.execute(stmt) + except Exception: + conn.rollback() + continue + conn.commit() + + +# ─── DB helper ─────────────────────────────────────────────── + +def _fetchone(sql, params=None): + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, params or ()) + row = cur.fetchone() + if not row: + return None + cols = [desc[0] for desc in cur.description] + return dict(zip(cols, row)) + + +def _fetchall(sql, params=None): + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, params or ()) + cols = [desc[0] for desc in cur.description] + return [dict(zip(cols, row)) for row in cur.fetchall()] + + +def _execute(sql, params=None): + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute(sql, params or ()) + conn.commit() + try: + return cur.fetchone() + except Exception: + return None # ─── Password utils ────────────────────────────────────────── @@ -144,13 +151,10 @@ def create_access_token(user_id: int, email: str, role: str) -> str: def create_refresh_token(user_id: int) -> str: token = secrets.token_urlsafe(48) expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat() - conn = get_conn() - conn.execute( - "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)", + _execute( + "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)", (user_id, token, expires_at) ) - conn.commit() - conn.close() return token @@ -173,7 +177,6 @@ def parse_token(token: str) -> Optional[dict]: # ─── Auth dependency ───────────────────────────────────────── def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict: - """FastAPI dependency: returns user dict or raises 401""" if not authorization or not authorization.startswith("Bearer "): raise HTTPException(status_code=401, detail="missing token") token = authorization.split(" ", 1)[1].strip() @@ -182,18 +185,15 @@ def get_current_user(authorization: Optional[str] = Header(default=None)) -> dic raise HTTPException(status_code=401, detail="invalid or expired token") if payload.get("type") != "access": raise HTTPException(status_code=401, detail="invalid token type") - conn = get_conn() - user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone() - conn.close() + user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],)) if not user: raise HTTPException(status_code=401, detail="user not found") if user["banned"]: raise HTTPException(status_code=403, detail="account banned") - return dict(user) + return user def require_admin(user: dict = Depends(get_current_user)) -> dict: - """FastAPI dependency: requires admin role""" if user.get("role") != "admin": raise HTTPException(status_code=403, detail="admin required") return user @@ -230,60 +230,50 @@ class BanUserReq(BaseModel): @router.post("/auth/register") def register(body: RegisterReq): ensure_tables() - conn = get_conn() - - # verify invite code - invite = conn.execute( - "SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,) - ).fetchone() + invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,)) if not invite: - conn.close() raise HTTPException(status_code=400, detail="invalid invite code") if invite["status"] != "active": - conn.close() raise HTTPException(status_code=400, detail="invite code disabled") if invite["used_count"] >= invite["max_uses"]: - conn.close() raise HTTPException(status_code=400, detail="invite code exhausted") if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat(): - conn.close() raise HTTPException(status_code=400, detail="invite code expired") + pwd_hash = hash_password(body.password) try: - pwd_hash = hash_password(body.password) - cur = conn.execute( - "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)", - (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), - ) - user_id = cur.lastrowid - conn.execute( - "INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", - (user_id,), - ) - # record invite usage - conn.execute( - "INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)", - (invite["id"], user_id), - ) - conn.execute( - "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?", - (invite["id"],), - ) - # auto-exhaust if max reached - new_count = invite["used_count"] + 1 - if new_count >= invite["max_uses"]: - conn.execute( - "UPDATE invite_codes SET status = 'exhausted' WHERE id = ?", - (invite["id"],), - ) - conn.commit() - except sqlite3.IntegrityError: - conn.close() - raise HTTPException(status_code=400, detail="email already registered") + with get_sync_conn() as conn: + with conn.cursor() as cur: + cur.execute( + "INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id", + (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), + ) + user_id = cur.fetchone()[0] + cur.execute( + "INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING", + (user_id,), + ) + cur.execute( + "INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)", + (invite["id"], user_id), + ) + cur.execute( + "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s", + (invite["id"],), + ) + new_count = invite["used_count"] + 1 + if new_count >= invite["max_uses"]: + cur.execute( + "UPDATE invite_codes SET status = 'exhausted' WHERE id = %s", + (invite["id"],), + ) + conn.commit() + except Exception as e: + if "unique" in str(e).lower() or "duplicate" in str(e).lower(): + raise HTTPException(status_code=400, detail="email already registered") + raise HTTPException(status_code=500, detail=str(e)) - # issue tokens - user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone() - conn.close() + user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,)) access = create_access_token(user["id"], user["email"], user["role"]) refresh = create_refresh_token(user["id"]) return { @@ -298,9 +288,7 @@ def register(body: RegisterReq): @router.post("/auth/login") def login(body: LoginReq): ensure_tables() - conn = get_conn() - user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone() - conn.close() + user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),)) if not user or not verify_password(body.password, user["password_hash"]): raise HTTPException(status_code=401, detail="invalid credentials") if user["banned"]: @@ -318,24 +306,17 @@ def login(body: LoginReq): @router.post("/auth/refresh") def refresh_token(body: RefreshReq): - conn = get_conn() - row = conn.execute( - "SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,) - ).fetchone() + row = _fetchone( + "SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,) + ) if not row: - conn.close() raise HTTPException(status_code=401, detail="invalid refresh token") if row["expires_at"] < datetime.utcnow().isoformat(): - conn.close() raise HTTPException(status_code=401, detail="refresh token expired") - user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone() + user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],)) if not user or user["banned"]: - conn.close() raise HTTPException(status_code=403, detail="account unavailable") - # revoke old, issue new - conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],)) - conn.commit() - conn.close() + _execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],)) access = create_access_token(user["id"], user["email"], user["role"]) new_refresh = create_refresh_token(user["id"]) return { @@ -348,9 +329,7 @@ def refresh_token(body: RefreshReq): @router.get("/auth/me") def me(user: dict = Depends(get_current_user)): - conn = get_conn() - sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone() - conn.close() + sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],)) return { "id": user["id"], "email": user["email"], "role": user["role"], "discord_id": user.get("discord_id"), @@ -363,53 +342,43 @@ def me(user: dict = Depends(get_current_user)): @router.post("/admin/invite-codes") def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)): - conn = get_conn() codes = [] - for _ in range(body.count): - code = secrets.token_urlsafe(6)[:8].upper() - conn.execute( - "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)", - (code, admin["id"], body.max_uses), - ) - codes.append(code) - conn.commit() - conn.close() + with get_sync_conn() as conn: + with conn.cursor() as cur: + for _ in range(body.count): + code = secrets.token_urlsafe(6)[:8].upper() + cur.execute( + "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)", + (code, admin["id"], body.max_uses), + ) + codes.append(code) + conn.commit() return {"codes": codes} @router.get("/admin/invite-codes") def list_invite_codes(admin: dict = Depends(require_admin)): - conn = get_conn() - rows = conn.execute( + rows = _fetchall( "SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC" - ).fetchall() - conn.close() - return {"items": [dict(r) for r in rows]} + ) + return {"items": rows} @router.delete("/admin/invite-codes/{code_id}") def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)): - conn = get_conn() - conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,)) - conn.commit() - conn.close() + _execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,)) return {"ok": True} @router.get("/admin/users") def list_users(admin: dict = Depends(require_admin)): - conn = get_conn() - rows = conn.execute( + rows = _fetchall( "SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC" - ).fetchall() - conn.close() - return {"items": [dict(r) for r in rows]} + ) + return {"items": rows} @router.put("/admin/users/{user_id}/ban") def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)): - conn = get_conn() - conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id)) - conn.commit() - conn.close() + _execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id)) return {"ok": True, "user_id": user_id, "banned": body.banned} diff --git a/backend/migrate_auth_sqlite_to_pg.py b/backend/migrate_auth_sqlite_to_pg.py new file mode 100644 index 0000000..ae81da1 --- /dev/null +++ b/backend/migrate_auth_sqlite_to_pg.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +""" +migrate_auth_sqlite_to_pg.py — 将SQLite中的auth相关表迁移到PG +运行前确保PG连接参数正确(db.py中的配置) +""" +import os, sys, sqlite3 + +sys.path.insert(0, os.path.dirname(__file__)) +from db import get_sync_conn + +SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") + + +def migrate(): + if not os.path.exists(SQLITE_PATH): + print(f"SQLite DB not found: {SQLITE_PATH}") + return + + sq = sqlite3.connect(SQLITE_PATH) + sq.row_factory = sqlite3.Row + + with get_sync_conn() as pg: + cur = pg.cursor() + + # 1. 建auth相关表 + print("Creating auth tables in PG...") + auth_tables = """ + CREATE TABLE IF NOT EXISTS subscriptions ( + user_id BIGINT PRIMARY KEY, + tier TEXT NOT NULL DEFAULT 'free', + expires_at TEXT + ); + CREATE TABLE IF NOT EXISTS invite_usage ( + id BIGSERIAL PRIMARY KEY, + code_id BIGINT NOT NULL, + user_id BIGINT NOT NULL, + used_at TEXT + ); + CREATE TABLE IF NOT EXISTS refresh_tokens ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL, + token TEXT UNIQUE NOT NULL, + expires_at TEXT NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_at TEXT + ); + """ + for stmt in auth_tables.split(";"): + stmt = stmt.strip() + if stmt: + try: + cur.execute(stmt) + except Exception as e: + pg.rollback() + pg.commit() + + # 2. 迁移users(PG已有users表但可能是空的旧结构) + print("Migrating users...") + # 先加缺失列 + for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]: + try: + cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}") + pg.commit() + except: + pg.rollback() + + rows = sq.execute("SELECT * FROM users").fetchall() + for r in rows: + try: + cur.execute( + """INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at) + VALUES (%s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (id) DO UPDATE SET + password_hash = EXCLUDED.password_hash, + discord_id = EXCLUDED.discord_id, + role = EXCLUDED.role, + banned = EXCLUDED.banned""", + (r["id"], r["email"], r["password_hash"], + r["discord_id"] if "discord_id" in r.keys() else None, + r["role"], r["banned"], r["created_at"]) + ) + except Exception as e: + print(f" User {r['email']} error: {e}") + pg.rollback() + continue + pg.commit() + print(f" Migrated {len(rows)} users") + + # 3. 迁移invite_codes(PG已有但可能缺列) + print("Migrating invite_codes...") + for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"), + ("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"), + ("expires_at", "TEXT")]: + try: + cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}") + pg.commit() + except: + pg.rollback() + + try: + rows = sq.execute("SELECT * FROM invite_codes").fetchall() + for r in rows: + try: + cur.execute( + """INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s) + ON CONFLICT (id) DO NOTHING""", + (r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None, + r["max_uses"], r["used_count"], r["status"], + r["expires_at"] if "expires_at" in r.keys() else None, + r["created_at"] if "created_at" in r.keys() else None) + ) + except Exception as e: + print(f" Invite {r['code']} error: {e}") + pg.rollback() + continue + pg.commit() + print(f" Migrated {len(rows)} invite codes") + except Exception as e: + print(f" invite_codes table error: {e}") + + # 4. 迁移subscriptions + print("Migrating subscriptions...") + try: + rows = sq.execute("SELECT * FROM subscriptions").fetchall() + for r in rows: + try: + cur.execute( + "INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING", + (r["user_id"], r["tier"], r["expires_at"]) + ) + except Exception as e: + pg.rollback() + pg.commit() + print(f" Migrated {len(rows)} subscriptions") + except Exception as e: + print(f" subscriptions error: {e}") + + # 5. 迁移refresh_tokens + print("Migrating refresh_tokens...") + try: + rows = sq.execute("SELECT * FROM refresh_tokens").fetchall() + for r in rows: + try: + cur.execute( + "INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING", + (r["user_id"], r["token"], r["expires_at"], r["revoked"]) + ) + except Exception as e: + pg.rollback() + pg.commit() + print(f" Migrated {len(rows)} refresh tokens") + except Exception as e: + print(f" refresh_tokens error: {e}") + + # 6. 重置序列 + print("Resetting sequences...") + for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]: + try: + cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}") + pg.commit() + except: + pg.rollback() + + sq.close() + print("\nAuth migration complete!") + + +if __name__ == "__main__": + migrate()