feat: migrate auth system from SQLite to PostgreSQL

- auth.py: rewrite to use PG via db.py (was sqlite3)
- admin_cli.py: rewrite to use PG
- migrate_auth_sqlite_to_pg.py: one-time migration script
- SQLite arb.db no longer needed after migration
This commit is contained in:
root 2026-03-01 07:29:14 +00:00
parent 4f54e36d1a
commit 8b73500d22
3 changed files with 372 additions and 225 deletions

View File

@ -1,82 +1,90 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""Admin CLI for Arbitrage Engine""" """Admin CLI for Arbitrage Engine (PostgreSQL version)"""
import sys, os, sqlite3, secrets, json import sys, os, secrets
DB_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db") sys.path.insert(0, os.path.dirname(__file__))
from db import get_sync_conn
def get_conn():
conn = sqlite3.connect(DB_PATH)
conn.row_factory = sqlite3.Row
return conn
def gen_invite(count=1, max_uses=1): def gen_invite(count=1, max_uses=1):
conn = get_conn() with get_sync_conn() as conn:
with conn.cursor() as cur:
codes = [] codes = []
for _ in range(count): for _ in range(count):
code = secrets.token_urlsafe(6)[:8].upper() code = secrets.token_urlsafe(6)[:8].upper()
conn.execute( cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, 1, ?)", "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, 1, %s)",
(code, max_uses) (code, max_uses)
) )
codes.append(code) codes.append(code)
conn.commit() conn.commit()
conn.close()
for c in codes: for c in codes:
print(f" {c}") print(f" {c}")
print(f"\nGenerated {len(codes)} invite code(s)") print(f"\nGenerated {len(codes)} invite code(s)")
def list_invites(): def list_invites():
conn = get_conn() with get_sync_conn() as conn:
rows = conn.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC").fetchall() with conn.cursor() as cur:
conn.close() cur.execute("SELECT id, code, max_uses, used_count, status, created_at FROM invite_codes ORDER BY id DESC")
cols = [desc[0] for desc in cur.description]
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
if not rows: if not rows:
print("No invite codes found") print("No invite codes found")
return return
print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}") print(f"{'ID':>4} {'CODE':>10} {'MAX':>4} {'USED':>5} {'STATUS':>10} {'CREATED':>20}")
print("-" * 60) print("-" * 60)
for r in rows: for r in rows:
print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {r['created_at']:>20}") print(f"{r['id']:>4} {r['code']:>10} {r['max_uses']:>4} {r['used_count']:>5} {r['status']:>10} {str(r['created_at']):>20}")
def disable_invite(code): def disable_invite(code):
conn = get_conn() with get_sync_conn() as conn:
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = ?", (code,)) with conn.cursor() as cur:
cur.execute("UPDATE invite_codes SET status = 'disabled' WHERE code = %s", (code,))
conn.commit() conn.commit()
conn.close()
print(f"Disabled invite code: {code}") print(f"Disabled invite code: {code}")
def list_users(): def list_users():
conn = get_conn() with get_sync_conn() as conn:
rows = conn.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC").fetchall() with conn.cursor() as cur:
conn.close() cur.execute("SELECT id, email, role, banned, created_at FROM users ORDER BY id DESC")
cols = [desc[0] for desc in cur.description]
rows = [dict(zip(cols, row)) for row in cur.fetchall()]
if not rows: if not rows:
print("No users found") print("No users found")
return return
print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}") print(f"{'ID':>4} {'EMAIL':>30} {'ROLE':>6} {'BANNED':>7} {'CREATED':>20}")
print("-" * 72) print("-" * 72)
for r in rows: for r in rows:
print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {r['created_at']:>20}") print(f"{r['id']:>4} {r['email']:>30} {r['role']:>6} {r['banned']:>7} {str(r['created_at']):>20}")
def ban_user(user_id): def ban_user(user_id):
conn = get_conn() with get_sync_conn() as conn:
conn.execute("UPDATE users SET banned = 1 WHERE id = ?", (user_id,)) with conn.cursor() as cur:
cur.execute("UPDATE users SET banned = 1 WHERE id = %s", (user_id,))
conn.commit() conn.commit()
conn.close()
print(f"Banned user ID: {user_id}") print(f"Banned user ID: {user_id}")
def unban_user(user_id): def unban_user(user_id):
conn = get_conn() with get_sync_conn() as conn:
conn.execute("UPDATE users SET banned = 0 WHERE id = ?", (user_id,)) with conn.cursor() as cur:
cur.execute("UPDATE users SET banned = 0 WHERE id = %s", (user_id,))
conn.commit() conn.commit()
conn.close()
print(f"Unbanned user ID: {user_id}") print(f"Unbanned user ID: {user_id}")
def set_admin(user_id): def set_admin(user_id):
conn = get_conn() with get_sync_conn() as conn:
conn.execute("UPDATE users SET role = 'admin' WHERE id = ?", (user_id,)) with conn.cursor() as cur:
cur.execute("UPDATE users SET role = 'admin' WHERE id = %s", (user_id,))
conn.commit() conn.commit()
conn.close()
print(f"Set user {user_id} as admin") print(f"Set user {user_id} as admin")
def usage(): def usage():
print("""Usage: python3 admin_cli.py <command> [args] print("""Usage: python3 admin_cli.py <command> [args]

View File

@ -1,5 +1,4 @@
import os import os
import sqlite3
import hashlib import hashlib
import secrets import secrets
import hmac import hmac
@ -7,12 +6,12 @@ import base64
import json import json
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional
from functools import wraps
from fastapi import APIRouter, HTTPException, Header, Depends, Request from fastapi import APIRouter, HTTPException, Header, Depends, Request
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
DB_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "arb.db") from db import get_sync_conn
JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026") JWT_SECRET = os.getenv("JWT_SECRET", "arb-engine-jwt-secret-v2-2026")
ACCESS_TOKEN_HOURS = 24 ACCESS_TOKEN_HOURS = 24
REFRESH_TOKEN_DAYS = 7 REFRESH_TOKEN_DAYS = 7
@ -20,90 +19,98 @@ REFRESH_TOKEN_DAYS = 7
router = APIRouter(prefix="/api", tags=["auth"]) router = APIRouter(prefix="/api", tags=["auth"])
# ─── DB helpers ─────────────────────────────────────────────── # ─── PG Schema ───────────────────────────────────────────────
def get_conn(): AUTH_SCHEMA = """
conn = sqlite3.connect(DB_PATH) CREATE TABLE IF NOT EXISTS users (
conn.row_factory = sqlite3.Row id BIGSERIAL PRIMARY KEY,
return conn
def ensure_tables():
conn = get_conn()
cur = conn.cursor()
cur.execute("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT UNIQUE NOT NULL, email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
discord_id TEXT, discord_id TEXT,
role TEXT NOT NULL DEFAULT 'user', role TEXT NOT NULL DEFAULT 'user',
banned INTEGER NOT NULL DEFAULT 0, banned INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL created_at TEXT NOT NULL
) );
""")
cur.execute(""" CREATE TABLE IF NOT EXISTS subscriptions (
CREATE TABLE IF NOT EXISTS subscriptions ( user_id BIGINT PRIMARY KEY REFERENCES users(id),
user_id INTEGER PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free', tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT, expires_at TEXT
FOREIGN KEY(user_id) REFERENCES users(id) );
)
""") CREATE TABLE IF NOT EXISTS invite_codes (
cur.execute(""" id BIGSERIAL PRIMARY KEY,
CREATE TABLE IF NOT EXISTS signal_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
symbol TEXT NOT NULL,
rate REAL NOT NULL,
annualized REAL NOT NULL,
sent_at TEXT NOT NULL,
message TEXT NOT NULL
)
""")
cur.execute("""
CREATE TABLE IF NOT EXISTS invite_codes (
id INTEGER PRIMARY KEY AUTOINCREMENT,
code TEXT UNIQUE NOT NULL, code TEXT UNIQUE NOT NULL,
created_by INTEGER, created_by INTEGER,
max_uses INTEGER NOT NULL DEFAULT 1, max_uses INTEGER NOT NULL DEFAULT 1,
used_count INTEGER NOT NULL DEFAULT 0, used_count INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active', status TEXT NOT NULL DEFAULT 'active',
expires_at TEXT, expires_at TEXT,
created_at TEXT DEFAULT (datetime('now')) created_at TEXT DEFAULT (NOW()::TEXT)
) );
""")
cur.execute(""" CREATE TABLE IF NOT EXISTS invite_usage (
CREATE TABLE IF NOT EXISTS invite_usage ( id BIGSERIAL PRIMARY KEY,
id INTEGER PRIMARY KEY AUTOINCREMENT, code_id BIGINT NOT NULL REFERENCES invite_codes(id),
code_id INTEGER NOT NULL, user_id BIGINT NOT NULL REFERENCES users(id),
user_id INTEGER NOT NULL, used_at TEXT DEFAULT (NOW()::TEXT)
used_at TEXT DEFAULT (datetime('now')), );
FOREIGN KEY (code_id) REFERENCES invite_codes(id),
FOREIGN KEY (user_id) REFERENCES users(id) CREATE TABLE IF NOT EXISTS refresh_tokens (
) id BIGSERIAL PRIMARY KEY,
""") user_id BIGINT NOT NULL REFERENCES users(id),
cur.execute("""
CREATE TABLE IF NOT EXISTS refresh_tokens (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
token TEXT UNIQUE NOT NULL, token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL, expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0, revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT DEFAULT (datetime('now')), created_at TEXT DEFAULT (NOW()::TEXT)
FOREIGN KEY (user_id) REFERENCES users(id) );
) """
""")
# migrate: add role/banned columns if missing
def ensure_tables():
with get_sync_conn() as conn:
with conn.cursor() as cur:
for stmt in AUTH_SCHEMA.split(";"):
stmt = stmt.strip()
if stmt:
try: try:
cur.execute("ALTER TABLE users ADD COLUMN role TEXT NOT NULL DEFAULT 'user'") cur.execute(stmt)
except: except Exception:
pass conn.rollback()
try: continue
cur.execute("ALTER TABLE users ADD COLUMN banned INTEGER NOT NULL DEFAULT 0")
except:
pass
conn.commit() conn.commit()
conn.close()
# ─── DB helper ───────────────────────────────────────────────
def _fetchone(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
row = cur.fetchone()
if not row:
return None
cols = [desc[0] for desc in cur.description]
return dict(zip(cols, row))
def _fetchall(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
cols = [desc[0] for desc in cur.description]
return [dict(zip(cols, row)) for row in cur.fetchall()]
def _execute(sql, params=None):
with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(sql, params or ())
conn.commit()
try:
return cur.fetchone()
except Exception:
return None
# ─── Password utils ────────────────────────────────────────── # ─── Password utils ──────────────────────────────────────────
@ -144,13 +151,10 @@ def create_access_token(user_id: int, email: str, role: str) -> str:
def create_refresh_token(user_id: int) -> str: def create_refresh_token(user_id: int) -> str:
token = secrets.token_urlsafe(48) token = secrets.token_urlsafe(48)
expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat() expires_at = (datetime.utcnow() + timedelta(days=REFRESH_TOKEN_DAYS)).isoformat()
conn = get_conn() _execute(
conn.execute( "INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (%s, %s, %s)",
"INSERT INTO refresh_tokens (user_id, token, expires_at) VALUES (?, ?, ?)",
(user_id, token, expires_at) (user_id, token, expires_at)
) )
conn.commit()
conn.close()
return token return token
@ -173,7 +177,6 @@ def parse_token(token: str) -> Optional[dict]:
# ─── Auth dependency ───────────────────────────────────────── # ─── Auth dependency ─────────────────────────────────────────
def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict: def get_current_user(authorization: Optional[str] = Header(default=None)) -> dict:
"""FastAPI dependency: returns user dict or raises 401"""
if not authorization or not authorization.startswith("Bearer "): if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="missing token") raise HTTPException(status_code=401, detail="missing token")
token = authorization.split(" ", 1)[1].strip() token = authorization.split(" ", 1)[1].strip()
@ -182,18 +185,15 @@ def get_current_user(authorization: Optional[str] = Header(default=None)) -> dic
raise HTTPException(status_code=401, detail="invalid or expired token") raise HTTPException(status_code=401, detail="invalid or expired token")
if payload.get("type") != "access": if payload.get("type") != "access":
raise HTTPException(status_code=401, detail="invalid token type") raise HTTPException(status_code=401, detail="invalid token type")
conn = get_conn() user = _fetchone("SELECT * FROM users WHERE id = %s", (payload["sub"],))
user = conn.execute("SELECT * FROM users WHERE id = ?", (payload["sub"],)).fetchone()
conn.close()
if not user: if not user:
raise HTTPException(status_code=401, detail="user not found") raise HTTPException(status_code=401, detail="user not found")
if user["banned"]: if user["banned"]:
raise HTTPException(status_code=403, detail="account banned") raise HTTPException(status_code=403, detail="account banned")
return dict(user) return user
def require_admin(user: dict = Depends(get_current_user)) -> dict: def require_admin(user: dict = Depends(get_current_user)) -> dict:
"""FastAPI dependency: requires admin role"""
if user.get("role") != "admin": if user.get("role") != "admin":
raise HTTPException(status_code=403, detail="admin required") raise HTTPException(status_code=403, detail="admin required")
return user return user
@ -230,60 +230,50 @@ class BanUserReq(BaseModel):
@router.post("/auth/register") @router.post("/auth/register")
def register(body: RegisterReq): def register(body: RegisterReq):
ensure_tables() ensure_tables()
conn = get_conn() invite = _fetchone("SELECT * FROM invite_codes WHERE code = %s", (body.invite_code,))
# verify invite code
invite = conn.execute(
"SELECT * FROM invite_codes WHERE code = ?", (body.invite_code,)
).fetchone()
if not invite: if not invite:
conn.close()
raise HTTPException(status_code=400, detail="invalid invite code") raise HTTPException(status_code=400, detail="invalid invite code")
if invite["status"] != "active": if invite["status"] != "active":
conn.close()
raise HTTPException(status_code=400, detail="invite code disabled") raise HTTPException(status_code=400, detail="invite code disabled")
if invite["used_count"] >= invite["max_uses"]: if invite["used_count"] >= invite["max_uses"]:
conn.close()
raise HTTPException(status_code=400, detail="invite code exhausted") raise HTTPException(status_code=400, detail="invite code exhausted")
if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat(): if invite["expires_at"] and invite["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=400, detail="invite code expired") raise HTTPException(status_code=400, detail="invite code expired")
try:
pwd_hash = hash_password(body.password) pwd_hash = hash_password(body.password)
cur = conn.execute( try:
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (?, ?, 'user', 0, ?)", with get_sync_conn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO users (email, password_hash, role, banned, created_at) VALUES (%s, %s, 'user', 0, %s) RETURNING id",
(body.email.lower(), pwd_hash, datetime.utcnow().isoformat()), (body.email.lower(), pwd_hash, datetime.utcnow().isoformat()),
) )
user_id = cur.lastrowid user_id = cur.fetchone()[0]
conn.execute( cur.execute(
"INSERT OR REPLACE INTO subscriptions (user_id, tier, expires_at) VALUES (?, 'free', NULL)", "INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, 'free', NULL) ON CONFLICT(user_id) DO NOTHING",
(user_id,), (user_id,),
) )
# record invite usage cur.execute(
conn.execute( "INSERT INTO invite_usage (code_id, user_id) VALUES (%s, %s)",
"INSERT INTO invite_usage (code_id, user_id) VALUES (?, ?)",
(invite["id"], user_id), (invite["id"], user_id),
) )
conn.execute( cur.execute(
"UPDATE invite_codes SET used_count = used_count + 1 WHERE id = ?", "UPDATE invite_codes SET used_count = used_count + 1 WHERE id = %s",
(invite["id"],), (invite["id"],),
) )
# auto-exhaust if max reached
new_count = invite["used_count"] + 1 new_count = invite["used_count"] + 1
if new_count >= invite["max_uses"]: if new_count >= invite["max_uses"]:
conn.execute( cur.execute(
"UPDATE invite_codes SET status = 'exhausted' WHERE id = ?", "UPDATE invite_codes SET status = 'exhausted' WHERE id = %s",
(invite["id"],), (invite["id"],),
) )
conn.commit() conn.commit()
except sqlite3.IntegrityError: except Exception as e:
conn.close() if "unique" in str(e).lower() or "duplicate" in str(e).lower():
raise HTTPException(status_code=400, detail="email already registered") raise HTTPException(status_code=400, detail="email already registered")
raise HTTPException(status_code=500, detail=str(e))
# issue tokens user = _fetchone("SELECT * FROM users WHERE id = %s", (user_id,))
user = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,)).fetchone()
conn.close()
access = create_access_token(user["id"], user["email"], user["role"]) access = create_access_token(user["id"], user["email"], user["role"])
refresh = create_refresh_token(user["id"]) refresh = create_refresh_token(user["id"])
return { return {
@ -298,9 +288,7 @@ def register(body: RegisterReq):
@router.post("/auth/login") @router.post("/auth/login")
def login(body: LoginReq): def login(body: LoginReq):
ensure_tables() ensure_tables()
conn = get_conn() user = _fetchone("SELECT * FROM users WHERE email = %s", (body.email.lower(),))
user = conn.execute("SELECT * FROM users WHERE email = ?", (body.email.lower(),)).fetchone()
conn.close()
if not user or not verify_password(body.password, user["password_hash"]): if not user or not verify_password(body.password, user["password_hash"]):
raise HTTPException(status_code=401, detail="invalid credentials") raise HTTPException(status_code=401, detail="invalid credentials")
if user["banned"]: if user["banned"]:
@ -318,24 +306,17 @@ def login(body: LoginReq):
@router.post("/auth/refresh") @router.post("/auth/refresh")
def refresh_token(body: RefreshReq): def refresh_token(body: RefreshReq):
conn = get_conn() row = _fetchone(
row = conn.execute( "SELECT * FROM refresh_tokens WHERE token = %s AND revoked = 0", (body.refresh_token,)
"SELECT * FROM refresh_tokens WHERE token = ? AND revoked = 0", (body.refresh_token,) )
).fetchone()
if not row: if not row:
conn.close()
raise HTTPException(status_code=401, detail="invalid refresh token") raise HTTPException(status_code=401, detail="invalid refresh token")
if row["expires_at"] < datetime.utcnow().isoformat(): if row["expires_at"] < datetime.utcnow().isoformat():
conn.close()
raise HTTPException(status_code=401, detail="refresh token expired") raise HTTPException(status_code=401, detail="refresh token expired")
user = conn.execute("SELECT * FROM users WHERE id = ?", (row["user_id"],)).fetchone() user = _fetchone("SELECT * FROM users WHERE id = %s", (row["user_id"],))
if not user or user["banned"]: if not user or user["banned"]:
conn.close()
raise HTTPException(status_code=403, detail="account unavailable") raise HTTPException(status_code=403, detail="account unavailable")
# revoke old, issue new _execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = %s", (row["id"],))
conn.execute("UPDATE refresh_tokens SET revoked = 1 WHERE id = ?", (row["id"],))
conn.commit()
conn.close()
access = create_access_token(user["id"], user["email"], user["role"]) access = create_access_token(user["id"], user["email"], user["role"])
new_refresh = create_refresh_token(user["id"]) new_refresh = create_refresh_token(user["id"])
return { return {
@ -348,9 +329,7 @@ def refresh_token(body: RefreshReq):
@router.get("/auth/me") @router.get("/auth/me")
def me(user: dict = Depends(get_current_user)): def me(user: dict = Depends(get_current_user)):
conn = get_conn() sub = _fetchone("SELECT tier, expires_at FROM subscriptions WHERE user_id = %s", (user["id"],))
sub = conn.execute("SELECT tier, expires_at FROM subscriptions WHERE user_id = ?", (user["id"],)).fetchone()
conn.close()
return { return {
"id": user["id"], "email": user["email"], "role": user["role"], "id": user["id"], "email": user["email"], "role": user["role"],
"discord_id": user.get("discord_id"), "discord_id": user.get("discord_id"),
@ -363,53 +342,43 @@ def me(user: dict = Depends(get_current_user)):
@router.post("/admin/invite-codes") @router.post("/admin/invite-codes")
def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)): def gen_invite_codes(body: GenInviteReq, admin: dict = Depends(require_admin)):
conn = get_conn()
codes = [] codes = []
with get_sync_conn() as conn:
with conn.cursor() as cur:
for _ in range(body.count): for _ in range(body.count):
code = secrets.token_urlsafe(6)[:8].upper() code = secrets.token_urlsafe(6)[:8].upper()
conn.execute( cur.execute(
"INSERT INTO invite_codes (code, created_by, max_uses) VALUES (?, ?, ?)", "INSERT INTO invite_codes (code, created_by, max_uses) VALUES (%s, %s, %s)",
(code, admin["id"], body.max_uses), (code, admin["id"], body.max_uses),
) )
codes.append(code) codes.append(code)
conn.commit() conn.commit()
conn.close()
return {"codes": codes} return {"codes": codes}
@router.get("/admin/invite-codes") @router.get("/admin/invite-codes")
def list_invite_codes(admin: dict = Depends(require_admin)): def list_invite_codes(admin: dict = Depends(require_admin)):
conn = get_conn() rows = _fetchall(
rows = conn.execute(
"SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC" "SELECT id, code, max_uses, used_count, status, expires_at, created_at FROM invite_codes ORDER BY id DESC"
).fetchall() )
conn.close() return {"items": rows}
return {"items": [dict(r) for r in rows]}
@router.delete("/admin/invite-codes/{code_id}") @router.delete("/admin/invite-codes/{code_id}")
def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)): def disable_invite_code(code_id: int, admin: dict = Depends(require_admin)):
conn = get_conn() _execute("UPDATE invite_codes SET status = 'disabled' WHERE id = %s", (code_id,))
conn.execute("UPDATE invite_codes SET status = 'disabled' WHERE id = ?", (code_id,))
conn.commit()
conn.close()
return {"ok": True} return {"ok": True}
@router.get("/admin/users") @router.get("/admin/users")
def list_users(admin: dict = Depends(require_admin)): def list_users(admin: dict = Depends(require_admin)):
conn = get_conn() rows = _fetchall(
rows = conn.execute(
"SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC" "SELECT id, email, role, banned, discord_id, created_at FROM users ORDER BY id DESC"
).fetchall() )
conn.close() return {"items": rows}
return {"items": [dict(r) for r in rows]}
@router.put("/admin/users/{user_id}/ban") @router.put("/admin/users/{user_id}/ban")
def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)): def ban_user(user_id: int, body: BanUserReq, admin: dict = Depends(require_admin)):
conn = get_conn() _execute("UPDATE users SET banned = %s WHERE id = %s", (1 if body.banned else 0, user_id))
conn.execute("UPDATE users SET banned = ? WHERE id = ?", (1 if body.banned else 0, user_id))
conn.commit()
conn.close()
return {"ok": True, "user_id": user_id, "banned": body.banned} return {"ok": True, "user_id": user_id, "banned": body.banned}

View File

@ -0,0 +1,170 @@
#!/usr/bin/env python3
"""
migrate_auth_sqlite_to_pg.py 将SQLite中的auth相关表迁移到PG
运行前确保PG连接参数正确db.py中的配置
"""
import os, sys, sqlite3
sys.path.insert(0, os.path.dirname(__file__))
from db import get_sync_conn
SQLITE_PATH = os.path.join(os.path.dirname(__file__), "..", "arb.db")
def migrate():
if not os.path.exists(SQLITE_PATH):
print(f"SQLite DB not found: {SQLITE_PATH}")
return
sq = sqlite3.connect(SQLITE_PATH)
sq.row_factory = sqlite3.Row
with get_sync_conn() as pg:
cur = pg.cursor()
# 1. 建auth相关表
print("Creating auth tables in PG...")
auth_tables = """
CREATE TABLE IF NOT EXISTS subscriptions (
user_id BIGINT PRIMARY KEY,
tier TEXT NOT NULL DEFAULT 'free',
expires_at TEXT
);
CREATE TABLE IF NOT EXISTS invite_usage (
id BIGSERIAL PRIMARY KEY,
code_id BIGINT NOT NULL,
user_id BIGINT NOT NULL,
used_at TEXT
);
CREATE TABLE IF NOT EXISTS refresh_tokens (
id BIGSERIAL PRIMARY KEY,
user_id BIGINT NOT NULL,
token TEXT UNIQUE NOT NULL,
expires_at TEXT NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at TEXT
);
"""
for stmt in auth_tables.split(";"):
stmt = stmt.strip()
if stmt:
try:
cur.execute(stmt)
except Exception as e:
pg.rollback()
pg.commit()
# 2. 迁移usersPG已有users表但可能是空的旧结构
print("Migrating users...")
# 先加缺失列
for col, defn in [("discord_id", "TEXT"), ("banned", "INTEGER DEFAULT 0")]:
try:
cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
rows = sq.execute("SELECT * FROM users").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO users (id, email, password_hash, discord_id, role, banned, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO UPDATE SET
password_hash = EXCLUDED.password_hash,
discord_id = EXCLUDED.discord_id,
role = EXCLUDED.role,
banned = EXCLUDED.banned""",
(r["id"], r["email"], r["password_hash"],
r["discord_id"] if "discord_id" in r.keys() else None,
r["role"], r["banned"], r["created_at"])
)
except Exception as e:
print(f" User {r['email']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} users")
# 3. 迁移invite_codesPG已有但可能缺列
print("Migrating invite_codes...")
for col, defn in [("created_by", "INTEGER"), ("max_uses", "INTEGER DEFAULT 1"),
("used_count", "INTEGER DEFAULT 0"), ("status", "TEXT DEFAULT 'active'"),
("expires_at", "TEXT")]:
try:
cur.execute(f"ALTER TABLE invite_codes ADD COLUMN {col} {defn}")
pg.commit()
except:
pg.rollback()
try:
rows = sq.execute("SELECT * FROM invite_codes").fetchall()
for r in rows:
try:
cur.execute(
"""INSERT INTO invite_codes (id, code, created_by, max_uses, used_count, status, expires_at, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
ON CONFLICT (id) DO NOTHING""",
(r["id"], r["code"], r["created_by"] if "created_by" in r.keys() else None,
r["max_uses"], r["used_count"], r["status"],
r["expires_at"] if "expires_at" in r.keys() else None,
r["created_at"] if "created_at" in r.keys() else None)
)
except Exception as e:
print(f" Invite {r['code']} error: {e}")
pg.rollback()
continue
pg.commit()
print(f" Migrated {len(rows)} invite codes")
except Exception as e:
print(f" invite_codes table error: {e}")
# 4. 迁移subscriptions
print("Migrating subscriptions...")
try:
rows = sq.execute("SELECT * FROM subscriptions").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO subscriptions (user_id, tier, expires_at) VALUES (%s, %s, %s) ON CONFLICT(user_id) DO NOTHING",
(r["user_id"], r["tier"], r["expires_at"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} subscriptions")
except Exception as e:
print(f" subscriptions error: {e}")
# 5. 迁移refresh_tokens
print("Migrating refresh_tokens...")
try:
rows = sq.execute("SELECT * FROM refresh_tokens").fetchall()
for r in rows:
try:
cur.execute(
"INSERT INTO refresh_tokens (user_id, token, expires_at, revoked) VALUES (%s, %s, %s, %s) ON CONFLICT(token) DO NOTHING",
(r["user_id"], r["token"], r["expires_at"], r["revoked"])
)
except Exception as e:
pg.rollback()
pg.commit()
print(f" Migrated {len(rows)} refresh tokens")
except Exception as e:
print(f" refresh_tokens error: {e}")
# 6. 重置序列
print("Resetting sequences...")
for table in ["users", "invite_codes", "invite_usage", "refresh_tokens"]:
try:
cur.execute(f"SELECT setval(pg_get_serial_sequence('{table}', 'id'), COALESCE(MAX(id), 1)) FROM {table}")
pg.commit()
except:
pg.rollback()
sq.close()
print("\nAuth migration complete!")
if __name__ == "__main__":
migrate()