feat: migrate auth system from SQLite to PostgreSQL

- auth.py: rewrite to use PG via db.py (was sqlite3)
- admin_cli.py: rewrite to use PG
- migrate_auth_sqlite_to_pg.py: one-time migration script
- SQLite arb.db no longer needed after migration
This commit is contained in:
root 2026-03-01 07:29:14 +00:00
parent 4f54e36d1a
commit 8b73500d22
3 changed files with 372 additions and 225 deletions

View File

@ -1,82 +1,90 @@
#!/usr/bin/env python3
"""Admin CLI for Arbitrage Engine"""
import sys, os, sqlite3, secrets, json
"""Admin CLI for Arbitrage Engine (PostgreSQL version)"""
import sys, os, secrets
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
sys.path.insert(0, os.path.dirname(__file__))
from db import get_sync_conn
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def gen_invite(count=1, max_uses=1):
conn = get_conn()
with get_sync_conn() as conn:
with conn.cursor() as cur:
codes = []
for _ in range(count):
code = secrets.token_urlsafe(6)[:8].upper()
conn.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)",
cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, 1, %s)",
(code, max_uses)
)
codes.append(code)
conn.commit()
conn.close()
for c in codes:
print(f" {c}")
print(f"\nGenerated {len(codes)} invite code(s)")
def list_invites():
conn = get_conn()
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall()
conn.close()
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC")
cols = [desc[0] for desc in cur.description]
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
if not rows:
print("No invite codes found")
return
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
print("-" * 60)
for r in rows:
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}")
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {str(r['created_at']):>20}")
def disable_invite(code):
conn = get_conn()
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,))
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = %s", (code,))
conn.commit()
conn.close()
print(f"Disabled invite code: {code}")
def list_users():
conn = get_conn()
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall()
conn.close()
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC")
cols = [desc[0] for desc in cur.description]
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
if not rows:
print("No users found")
return
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
print("-" * 72)
for r in rows:
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}")
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {str(r['created_at']):>20}")
def ban_user(user_id):
conn = get_conn()
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,))
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("UPDATE users SET banned = 1 WHERE id = %s", (user_id,))
conn.commit()
conn.close()
print(f"Banned user ID: {user_id}")
def unban_user(user_id):
conn = get_conn()
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,))
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("UPDATE users SET banned = 0 WHERE id = %s", (user_id,))
conn.commit()
conn.close()
print(f"Unbanned user ID: {user_id}")
def set_admin(user_id):
conn = get_conn()
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,))
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute("UPDATE users SET role = 'admin' WHERE id = %s", (user_id,))
conn.commit()
conn.close()
print(f"Set user {user_id} as admin")
def usage():
print("""Usage: python3 admin_cli.py <command> [args]

View File

@ -1,5 +1,4 @@
import os
import sqlite3
import hashlib
import secrets
import hmac
@ -7,12 +6,12 @@ import base64
import json
from datetime import datetime, timedelta
from typing import Optional
from functools import wraps
from fastapi import APIRouter, HTTPException, Header, Depends, Request
from pydantic import BaseModel, EmailStr
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db")
from db import get_sync_conn
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
ACCESS_TOKEN_HOURS = 24
REFRESH_TOKEN_DAYS = 7
@ -20,90 +19,98 @@ REFRESH_TOKEN_DAYS = 7
router = APIRouter(prefix="/api", tags=["auth"])
# ─── DB helpers ───────────────────────────────────────────────
# ─── PG Schema ───────────────────────────────────────────────
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def ensure_tables():
conn = get_conn()
cur = conn.cursor()
cur.execute("""
AUTH_SCHEMA = """
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGSERIAL PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
discord_id TEXT,
role TEXT NOT NULL DEFAULT 'user',
banned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL
)
""")
cur.execute("""
);
CREATE TABLE IF NOT EXISTS subscriptions (
user_id INTEGER PRIMARY KEY,
user_id BIGINT PRIMARY KEY REFERENCES users(id),
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS signal_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
rate REAL NOT NULL,
annualized REAL NOT NULL,
sent_at TEXT NOT NULL,
message TEXT NOT NULL
)
""")
cur.execute("""
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
id BIGSERIAL PRIMARY KEY,
code TEXT UNIQUE NOT NULL,
created_by INTEGER,
max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
expires_at TEXT,
created_at TEXT DEFAULT (datetime('now'))
)
""")
cur.execute("""
created_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS invite_usage (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
used_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
cur.execute("""
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL REFERENCES invite_codes(id),
user_id BIGINT NOT NULL REFERENCES users(id),
used_at TEXT DEFAULT (NOW()::TEXT)
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL REFERENCES users(id),
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT (datetime('now')),
FOREIGN KEY (user_id) REFERENCES users(id)
)
""")
# migrate: add role/banned columns if missing
created_at TEXT DEFAULT (NOW()::TEXT)
);
"""
def ensure_tables():
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in AUTH_SCHEMA.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'")
except:
pass
try:
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
except:
pass
cur.execute(stmt)
except Exception:
conn.rollback()
continue
conn.commit()
conn.close()
# ─── DB helper ───────────────────────────────────────────────
def _fetchone(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
row = cur.fetchone()
if not row:
return None
cols = [desc[0] for desc in cur.description]
return dict(zip(cols, row))
def _fetchall(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
def _execute(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
conn.commit()
try:
return cur.fetchone()
except Exception:
return None
# ─── Password utils ──────────────────────────────────────────
@ -144,13 +151,10 @@ def create_access_token(user_id: int, email: str, role: str) -> str:
def create_refresh_token(user_id: int) -> str:
token = secrets.token_urlsafe(48)
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
conn = get_conn()
conn.execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
_execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
(user_id, token, expires_at)
)
conn.commit()
conn.close()
return token
@ -173,7 +177,6 @@ def parse_token(token: str) -> Optional[dict]:
# ─── Auth dependency ─────────────────────────────────────────
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
"""FastAPI dependency: returns user dict or raises 401"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip()
@ -182,18 +185,15 @@ def get_current_user(authorization: Optional[str] = Header(default=None)) -> dic
raise HTTPException(status_code=401, detail="invalid or expired token")
if payload.get("type") != "access":
raise HTTPException(status_code=401, detail="invalid token type")
conn = get_conn()
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
conn.close()
user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
if not user:
raise HTTPException(status_code=401, detail="user not found")
if user["banned"]:
raise HTTPException(status_code=403, detail="account banned")
return dict(user)
return user
def require_admin(user: dict = Depends(get_current_user)) -> dict:
"""FastAPI dependency: requires admin role"""
if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="admin required")
return user
@ -230,60 +230,50 @@ class BanUserReq(BaseModel):
@router.post("/auth/register")
def register(body: RegisterReq):
ensure_tables()
conn = get_conn()
# verify invite code
invite = conn.execute(
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
).fetchone()
invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
if not invite:
conn.close()
raise HTTPException(status_code=400, detail="invalid invite code")
if invite["status"] != "active":
conn.close()
raise HTTPException(status_code=400, detail="invite code disabled")
if invite["used_count"] >= invite["max_uses"]:
conn.close()
raise HTTPException(status_code=400, detail="invite code exhausted")
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=400, detail="invite code expired")
try:
pwd_hash = hash_password(body.password)
cur = conn.execute(
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)",
try:
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
)
user_id = cur.lastrowid
conn.execute(
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)",
user_id = cur.fetchone()[0]
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
(user_id,),
)
# record invite usage
conn.execute(
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
cur.execute(
"INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
(invite["id"], user_id),
)
conn.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?",
cur.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
(invite["id"],),
)
# auto-exhaust if max reached
new_count = invite["used_count"] + 1
if new_count >= invite["max_uses"]:
conn.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?",
cur.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
(invite["id"],),
)
conn.commit()
except sqlite3.IntegrityError:
conn.close()
except Exception as e:
if "unique" in str(e).lower() or "duplicate" in str(e).lower():
raise HTTPException(status_code=400, detail="email already registered")
raise HTTPException(status_code=500, detail=str(e))
# issue tokens
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
conn.close()
user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"])
return {
@ -298,9 +288,7 @@ def register(body: RegisterReq):
@router.post("/auth/login")
def login(body: LoginReq):
ensure_tables()
conn = get_conn()
user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone()
conn.close()
user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials")
if user["banned"]:
@ -318,24 +306,17 @@ def login(body: LoginReq):
@router.post("/auth/refresh")
def refresh_token(body: RefreshReq):
conn = get_conn()
row = conn.execute(
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,)
).fetchone()
row = _fetchone(
"SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
)
if not row:
conn.close()
raise HTTPException(status_code=401, detail="invalid refresh token")
if row["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=401, detail="refresh token expired")
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone()
user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
if not user or user["banned"]:
conn.close()
raise HTTPException(status_code=403, detail="account unavailable")
# revoke old, issue new
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
conn.commit()
conn.close()
_execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
access = create_access_token(user["id"], user["email"], user["role"])
new_refresh = create_refresh_token(user["id"])
return {
@ -348,9 +329,7 @@ def refresh_token(body: RefreshReq):
@router.get("/auth/me")
def me(user: dict = Depends(get_current_user)):
conn = get_conn()
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
conn.close()
sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
return {
"id": user["id"], "email": user["email"], "role": user["role"],
"discord_id": user.get("discord_id"),
@ -363,53 +342,43 @@ def me(user: dict = Depends(get_current_user)):
@router.post("/admin/invite-codes")
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
conn = get_conn()
codes = []
with get_sync_conn() as conn:
with conn.cursor() as cur:
for _ in range(body.count):
code = secrets.token_urlsafe(6)[:8].upper()
conn.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)",
cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
(code, admin["id"], body.max_uses),
)
codes.append(code)
conn.commit()
conn.close()
return {"codes": codes}
@router.get("/admin/invite-codes")
def list_invite_codes(admin: dict = Depends(require_admin)):
conn = get_conn()
rows = conn.execute(
rows = _fetchall(
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
)
return {"items": rows}
@router.delete("/admin/invite-codes/{code_id}")
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
conn = get_conn()
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
conn.commit()
conn.close()
_execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
return {"ok": True}
@router.get("/admin/users")
def list_users(admin: dict = Depends(require_admin)):
conn = get_conn()
rows = conn.execute(
rows = _fetchall(
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
).fetchall()
conn.close()
return {"items": [dict(r) for r in rows]}
)
return {"items": rows}
@router.put("/admin/users/{user_id}/ban")
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
conn = get_conn()
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
conn.commit()
conn.close()
_execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
return {"ok": True, "user_id": user_id, "banned": body.banned}

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
migrate_auth_sqlite_to_pg.py 将SQLite中的auth相关表迁移到PG
运行前确保PG连接参数正确db.py中的配置
"""
import os, sys, sqlite3
sys.path.insert(0, os.path.dirname(__file__))
from db import get_sync_conn
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
def migrate():
if not os.path.exists(SQLITE_PATH):
print(f"SQLite DB not found: {SQLITE_PATH}")
return
sq = sqlite3.connect(SQLITE_PATH)
sq.row_factory = sqlite3.Row
with get_sync_conn() as pg:
cur = pg.cursor()
# 1. 建auth相关表
print("Creating auth tables in PG...")
auth_tables = """
CREATE TABLE IF NOT EXISTS subscriptions (
user_id BIGINT PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
used_at TEXT
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT
);
"""
for stmt in auth_tables.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
pg.rollback()
pg.commit()
# 2. 迁移usersPG已有users表但可能是空的旧结构
print("Migrating users...")
# 先加缺失列
for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]:
try:
cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
rows = sq.execute("SELECT * FROM users").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
discord_id = EXCLUDED.discord_id,
role = EXCLUDED.role,
banned = EXCLUDED.banned""",
(r["id"], r["email"], r["password_hash"],
r["discord_id"] if "discord_id" in r.keys() else None,
r["role"], r["banned"], r["created_at"])
)
except Exception as e:
print(f" User {r['email']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} users")
# 3. 迁移invite_codesPG已有但可能缺列
print("Migrating invite_codes...")
for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"),
("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"),
("expires_at", "TEXT")]:
try:
cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
try:
rows = sq.execute("SELECT * FROM invite_codes").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO NOTHING""",
(r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None,
r["max_uses"], r["used_count"], r["status"],
r["expires_at"] if "expires_at" in r.keys() else None,
r["created_at"] if "created_at" in r.keys() else None)
)
except Exception as e:
print(f" Invite {r['code']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} invite codes")
except Exception as e:
print(f" invite_codes table error: {e}")
# 4. 迁移subscriptions
print("Migrating subscriptions...")
try:
rows = sq.execute("SELECT * FROM subscriptions").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING",
(r["user_id"], r["tier"], r["expires_at"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} subscriptions")
except Exception as e:
print(f" subscriptions error: {e}")
# 5. 迁移refresh_tokens
print("Migrating refresh_tokens...")
try:
rows = sq.execute("SELECT * FROM refresh_tokens").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING",
(r["user_id"], r["token"], r["expires_at"], r["revoked"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} refresh tokens")
except Exception as e:
print(f" refresh_tokens error: {e}")
# 6. 重置序列
print("Resetting sequences...")
for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]:
try:
cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}")
pg.commit()
except:
pg.rollback()
sq.close()
print("\nAuth migration complete!")
if __name__ == "__main__":
migrate()